From 9356cb24a040f9f953afa031504165bf8b950b09 Mon Sep 17 00:00:00 2001 From: Will Brickner Date: Sat, 18 Nov 2023 20:29:00 -0600 Subject: [PATCH 1/2] Added quiet_softmax --- backend-comparison/benches/binary.rs | 52 +- backend-comparison/benches/custom_gelu.rs | 142 +- backend-comparison/benches/data.rs | 104 +- backend-comparison/benches/matmul.rs | 86 +- backend-comparison/benches/unary.rs | 50 +- backend-comparison/src/lib.rs | 108 +- burn-autodiff/src/backend.rs | 114 +- burn-autodiff/src/grads.rs | 140 +- burn-autodiff/src/graph/backward.rs | 39 +- burn-autodiff/src/graph/base.rs | 130 +- burn-autodiff/src/graph/node.rs | 44 +- burn-autodiff/src/graph/requirement.rs | 46 +- burn-autodiff/src/graph/traversal.rs | 62 +- burn-autodiff/src/ops/activation.rs | 94 +- burn-autodiff/src/ops/backward.rs | 118 +- burn-autodiff/src/ops/base.rs | 218 +- burn-autodiff/src/ops/bool_tensor.rs | 173 +- burn-autodiff/src/ops/int_tensor.rs | 624 +-- burn-autodiff/src/ops/maxmin.rs | 18 +- burn-autodiff/src/ops/module.rs | 1687 ++++--- burn-autodiff/src/ops/tensor.rs | 2782 ++++++----- burn-autodiff/src/tensor.rs | 150 +- burn-autodiff/src/tests/abs.rs | 40 +- burn-autodiff/src/tests/adaptive_avgpool1d.rs | 78 +- burn-autodiff/src/tests/adaptive_avgpool2d.rs | 110 +- burn-autodiff/src/tests/add.rs | 90 +- burn-autodiff/src/tests/aggregation.rs | 236 +- burn-autodiff/src/tests/avgpool1d.rs | 164 +- burn-autodiff/src/tests/avgpool2d.rs | 214 +- burn-autodiff/src/tests/backward.rs | 44 +- burn-autodiff/src/tests/broadcast.rs | 106 +- burn-autodiff/src/tests/cat.rs | 116 +- burn-autodiff/src/tests/complex.rs | 156 +- burn-autodiff/src/tests/conv1d.rs | 448 +- burn-autodiff/src/tests/conv2d.rs | 1458 +++--- burn-autodiff/src/tests/conv_transpose1d.rs | 476 +- burn-autodiff/src/tests/conv_transpose2d.rs | 1250 +++-- burn-autodiff/src/tests/cos.rs | 44 +- burn-autodiff/src/tests/cross_entropy.rs | 45 +- burn-autodiff/src/tests/div.rs | 172 +- burn-autodiff/src/tests/erf.rs | 40 +- burn-autodiff/src/tests/exp.rs | 40 +- burn-autodiff/src/tests/gather_scatter.rs | 108 +- burn-autodiff/src/tests/gelu.rs | 36 +- burn-autodiff/src/tests/gradients.rs | 31 +- burn-autodiff/src/tests/log.rs | 40 +- burn-autodiff/src/tests/log1p.rs | 42 +- burn-autodiff/src/tests/mask.rs | 101 +- burn-autodiff/src/tests/matmul.rs | 150 +- burn-autodiff/src/tests/maxmin.rs | 84 +- burn-autodiff/src/tests/maxpool1d.rs | 215 +- burn-autodiff/src/tests/maxpool2d.rs | 336 +- burn-autodiff/src/tests/mod.rs | 106 +- burn-autodiff/src/tests/mul.rs | 92 +- burn-autodiff/src/tests/multithread.rs | 122 +- burn-autodiff/src/tests/neg.rs | 32 +- burn-autodiff/src/tests/pow.rs | 40 +- burn-autodiff/src/tests/recip.rs | 27 +- burn-autodiff/src/tests/relu.rs | 34 +- burn-autodiff/src/tests/reshape.rs | 32 +- burn-autodiff/src/tests/select.rs | 100 +- burn-autodiff/src/tests/sin.rs | 42 +- burn-autodiff/src/tests/slice.rs | 148 +- burn-autodiff/src/tests/softmax.rs | 115 +- burn-autodiff/src/tests/sqrt.rs | 40 +- burn-autodiff/src/tests/sub.rs | 86 +- burn-autodiff/src/tests/tanh.rs | 40 +- burn-autodiff/src/tests/transpose.rs | 94 +- burn-autodiff/src/utils.rs | 22 +- burn-candle/src/backend.rs | 84 +- burn-candle/src/lib.rs | 234 +- burn-candle/src/ops/activation.rs | 16 +- burn-candle/src/ops/base.rs | 82 +- burn-candle/src/ops/bool_tensor.rs | 209 +- burn-candle/src/ops/candle_utils.rs | 28 +- burn-candle/src/ops/int_tensor.rs | 709 +-- burn-candle/src/ops/module.rs | 401 +- burn-candle/src/ops/tensor.rs | 881 ++-- burn-candle/src/tensor.rs | 56 +- burn-common/src/benchmark.rs | 204 +- burn-common/src/id.rs | 90 +- burn-common/src/rand.rs | 26 +- burn-common/src/reader.rs | 144 +- burn-common/src/stub.rs | 64 +- burn-compute/src/channel/base.rs | 20 +- burn-compute/src/channel/cell.rs | 59 +- burn-compute/src/channel/mpsc.rs | 228 +- burn-compute/src/channel/mutex.rs | 54 +- burn-compute/src/client.rs | 105 +- burn-compute/src/compute.rs | 116 +- burn-compute/src/id.rs | 76 +- burn-compute/src/memory_management/base.rs | 62 +- burn-compute/src/memory_management/simple.rs | 736 +-- burn-compute/src/server.rs | 70 +- burn-compute/src/storage/base.rs | 46 +- burn-compute/src/storage/bytes_cpu.rs | 160 +- burn-compute/src/tune/operation.rs | 32 +- burn-compute/src/tune/tune_benchmark.rs | 36 +- burn-compute/src/tune/tune_cache.rs | 44 +- burn-compute/src/tune/tuner.rs | 148 +- burn-compute/tests/dummy/compute.rs | 18 +- burn-compute/tests/dummy/kernel.rs | 22 +- burn-compute/tests/dummy/server.rs | 70 +- .../tests/dummy/tune/autotune_operations.rs | 34 +- burn-compute/tests/dummy/tune/kernels.rs | 130 +- .../tests/dummy/tune/operation_sets.rs | 260 +- burn-compute/tests/integration_test.rs | 182 +- burn-core/src/config.rs | 121 +- burn-core/src/data/dataloader/base.rs | 16 +- burn-core/src/data/dataloader/batch.rs | 407 +- burn-core/src/data/dataloader/batcher.rs | 26 +- burn-core/src/data/dataloader/builder.rs | 194 +- burn-core/src/data/dataloader/multithread.rs | 192 +- burn-core/src/data/dataloader/strategy.rs | 110 +- burn-core/src/data/mod.rs | 2 +- burn-core/src/grad_clipping/base.rs | 212 +- burn-core/src/lr_scheduler/base.rs | 18 +- burn-core/src/lr_scheduler/constant.rs | 40 +- burn-core/src/lr_scheduler/noam.rs | 116 +- burn-core/src/module/base.rs | 390 +- burn-core/src/module/param/base.rs | 44 +- burn-core/src/module/param/constant.rs | 300 +- burn-core/src/module/param/id.rs | 46 +- burn-core/src/module/param/primitive.rs | 233 +- burn-core/src/module/param/running.rs | 252 +- burn-core/src/module/param/tensor.rs | 160 +- burn-core/src/module/param/visitor.rs | 28 +- burn-core/src/nn/attention/mask.rs | 218 +- burn-core/src/nn/attention/mha.rs | 759 +-- burn-core/src/nn/cache/autoregressive.rs | 86 +- burn-core/src/nn/cache/base.rs | 24 +- burn-core/src/nn/conv/checks.rs | 10 +- burn-core/src/nn/conv/conv1d.rs | 240 +- burn-core/src/nn/conv/conv2d.rs | 237 +- burn-core/src/nn/conv/conv_transpose1d.rs | 250 +- burn-core/src/nn/conv/conv_transpose2d.rs | 253 +- burn-core/src/nn/dropout.rs | 88 +- burn-core/src/nn/embedding.rs | 128 +- burn-core/src/nn/gelu.rs | 26 +- burn-core/src/nn/initializer.rs | 626 ++- burn-core/src/nn/linear.rs | 246 +- burn-core/src/nn/loss/binary_cross_entropy.rs | 288 +- burn-core/src/nn/loss/cross_entropy.rs | 687 ++- burn-core/src/nn/loss/mse.rs | 104 +- burn-core/src/nn/loss/reduction.rs | 12 +- burn-core/src/nn/norm/batch.rs | 678 +-- burn-core/src/nn/norm/layer.rs | 202 +- burn-core/src/nn/padding.rs | 86 +- burn-core/src/nn/pool/adaptive_avg_pool1d.rs | 34 +- burn-core/src/nn/pool/adaptive_avg_pool2d.rs | 34 +- burn-core/src/nn/pool/avg_pool1d.rs | 84 +- burn-core/src/nn/pool/avg_pool2d.rs | 85 +- burn-core/src/nn/pool/max_pool1d.rs | 72 +- burn-core/src/nn/pool/max_pool2d.rs | 73 +- burn-core/src/nn/pos_encoding.rs | 326 +- burn-core/src/nn/relu.rs | 26 +- burn-core/src/nn/rnn/gate_controller.rs | 120 +- burn-core/src/nn/rnn/gru.rs | 424 +- burn-core/src/nn/rnn/lstm.rs | 585 ++- burn-core/src/nn/transformer/decoder.rs | 736 +-- burn-core/src/nn/transformer/encoder.rs | 600 +-- burn-core/src/nn/transformer/pwff.rs | 110 +- burn-core/src/nn/unfold.rs | 68 +- burn-core/src/optim/adagrad.rs | 446 +- burn-core/src/optim/adam.rs | 595 ++- burn-core/src/optim/adamw.rs | 637 ++- burn-core/src/optim/base.rs | 22 +- burn-core/src/optim/decay.rs | 78 +- burn-core/src/optim/grad_accum.rs | 179 +- burn-core/src/optim/grads.rs | 214 +- burn-core/src/optim/momentum.rs | 127 +- burn-core/src/optim/rmsprop.rs | 899 ++-- burn-core/src/optim/sgd.rs | 262 +- burn-core/src/optim/simple/adaptor.rs | 228 +- burn-core/src/optim/simple/base.rs | 38 +- burn-core/src/optim/simple/record/base.rs | 92 +- burn-core/src/optim/simple/record/v1.rs | 278 +- burn-core/src/optim/visitor.rs | 42 +- burn-core/src/record/base.rs | 12 +- burn-core/src/record/file.rs | 569 ++- burn-core/src/record/memory.rs | 128 +- burn-core/src/record/primitive.rs | 165 +- burn-core/src/record/recorder.rs | 376 +- burn-core/src/record/settings.rs | 22 +- burn-core/src/record/tensor.rs | 148 +- burn-core/tests/derive_config.rs | 88 +- burn-core/tests/derive_module.rs | 202 +- burn-core/tests/derive_record.rs | 4 +- burn-core/tests/record_resilience.rs | 580 ++- burn-dataset/examples/speech_commands.rs | 24 +- burn-dataset/src/audio/speech_commands.rs | 269 +- burn-dataset/src/dataset/base.rs | 82 +- burn-dataset/src/dataset/fake.rs | 44 +- burn-dataset/src/dataset/in_memory.rs | 270 +- burn-dataset/src/dataset/iterator.rs | 34 +- burn-dataset/src/dataset/sqlite.rs | 1387 +++--- burn-dataset/src/lib.rs | 16 +- .../src/source/huggingface/downloader.rs | 406 +- burn-dataset/src/source/huggingface/mnist.rs | 96 +- burn-dataset/src/transform/composed.rs | 36 +- burn-dataset/src/transform/mapper.rs | 74 +- burn-dataset/src/transform/partial.rs | 204 +- burn-dataset/src/transform/random.rs | 66 +- burn-dataset/src/transform/sampler.rs | 198 +- burn-derive/src/config/analyzer.rs | 110 +- burn-derive/src/config/analyzer_enum.rs | 268 +- burn-derive/src/config/analyzer_struct.rs | 450 +- burn-derive/src/config/base.rs | 34 +- burn-derive/src/lib.rs | 12 +- burn-derive/src/module/base.rs | 260 +- burn-derive/src/module/codegen.rs | 14 +- burn-derive/src/module/codegen_struct.rs | 234 +- burn-derive/src/module/display.rs | 10 +- burn-derive/src/module/record.rs | 4 +- burn-derive/src/module/record_struct.rs | 34 +- burn-derive/src/record/base.rs | 120 +- burn-derive/src/record/codegen.rs | 12 +- burn-derive/src/record/codegen_struct.rs | 98 +- burn-derive/src/shared/attribute.rs | 80 +- burn-derive/src/shared/field.rs | 158 +- burn-fusion/src/backend.rs | 182 +- burn-fusion/src/client/base.rs | 112 +- burn-fusion/src/client/mutex.rs | 267 +- burn-fusion/src/fusion.rs | 96 +- burn-fusion/src/graph/base.rs | 146 +- burn-fusion/src/graph/execution.rs | 106 +- burn-fusion/src/graph/ops.rs | 2442 +++++----- burn-fusion/src/handle.rs | 232 +- burn-fusion/src/ops/binary.rs | 106 +- burn-fusion/src/ops/boolean.rs | 731 ++- burn-fusion/src/ops/float.rs | 3265 ++++++------- burn-fusion/src/ops/int.rs | 2689 +++++------ burn-fusion/src/ops/module.rs | 1695 ++++--- burn-fusion/src/ops/unary.rs | 166 +- burn-fusion/src/server.rs | 289 +- burn-fusion/src/tensor.rs | 222 +- burn-import/build.rs | 18 +- burn-import/onnx-tests/build.rs | 194 +- burn-import/onnx-tests/tests/onnx_tests.rs | 1076 ++--- .../onnx-tests/tests/record_type_tests.rs | 100 +- burn-import/src/burn/codegen.rs | 94 +- burn-import/src/burn/graph.rs | 1156 ++--- burn-import/src/burn/imports.rs | 49 +- burn-import/src/burn/node/avg_pool2d.rs | 246 +- burn-import/src/burn/node/base.rs | 671 ++- burn-import/src/burn/node/batch_norm.rs | 306 +- burn-import/src/burn/node/binary.rs | 590 +-- burn-import/src/burn/node/clip.rs | 274 +- burn-import/src/burn/node/concat.rs | 159 +- burn-import/src/burn/node/constant.rs | 280 +- burn-import/src/burn/node/conv1d.rs | 338 +- burn-import/src/burn/node/conv2d.rs | 336 +- burn-import/src/burn/node/dropout.rs | 230 +- burn-import/src/burn/node/gather.rs | 160 +- burn-import/src/burn/node/global_avg_pool.rs | 334 +- burn-import/src/burn/node/linear.rs | 294 +- burn-import/src/burn/node/matmul.rs | 144 +- burn-import/src/burn/node/max_pool2d.rs | 254 +- burn-import/src/burn/node/reshape.rs | 126 +- burn-import/src/burn/node/test.rs | 6 +- burn-import/src/burn/node/unary.rs | 716 +-- burn-import/src/burn/scope.rs | 114 +- burn-import/src/burn/ty.rs | 202 +- burn-import/src/formatter.rs | 12 +- burn-import/src/logger.rs | 26 +- burn-import/src/main.rs | 22 +- burn-import/src/onnx/coalesce.rs | 262 +- burn-import/src/onnx/dim_inference.rs | 637 +-- burn-import/src/onnx/from_onnx.rs | 667 ++- burn-import/src/onnx/ir.rs | 1246 ++--- burn-import/src/onnx/node_remap.rs | 48 +- burn-import/src/onnx/op_configuration.rs | 920 ++-- burn-import/src/onnx/proto_conversion.rs | 396 +- burn-import/src/onnx/protos/mod.rs | 2 +- burn-import/src/onnx/to_burn.rs | 1262 +++-- burn-ndarray/build.rs | 8 +- burn-ndarray/src/backend.rs | 50 +- burn-ndarray/src/element.rs | 238 +- burn-ndarray/src/lib.rs | 22 +- burn-ndarray/src/ops/activations.rs | 22 +- burn-ndarray/src/ops/adaptive_avgpool.rs | 158 +- burn-ndarray/src/ops/avgpool.rs | 213 +- burn-ndarray/src/ops/base.rs | 792 ++-- burn-ndarray/src/ops/bool_tensor.rs | 224 +- burn-ndarray/src/ops/conv.rs | 445 +- burn-ndarray/src/ops/int_tensor.rs | 704 ++- burn-ndarray/src/ops/macros.rs | 40 +- burn-ndarray/src/ops/matmul.rs | 156 +- burn-ndarray/src/ops/maxpool.rs | 280 +- burn-ndarray/src/ops/module.rs | 172 +- burn-ndarray/src/ops/padding.rs | 46 +- burn-ndarray/src/ops/tensor.rs | 829 ++-- burn-ndarray/src/parallel.rs | 48 +- burn-ndarray/src/sharing.rs | 16 +- burn-ndarray/src/tensor.rs | 198 +- burn-no-std-tests/src/conv.rs | 58 +- burn-no-std-tests/src/mlp.rs | 84 +- burn-no-std-tests/src/model.rs | 78 +- burn-no-std-tests/tests/integration_test.rs | 28 +- burn-tch/src/backend.rs | 100 +- burn-tch/src/lib.rs | 12 +- burn-tch/src/ops/activation.rs | 34 +- burn-tch/src/ops/base.rs | 799 ++-- burn-tch/src/ops/bool_tensor.rs | 217 +- burn-tch/src/ops/int_tensor.rs | 742 ++- burn-tch/src/ops/module.rs | 560 +-- burn-tch/src/ops/tensor.rs | 867 ++-- burn-tch/src/tensor.rs | 428 +- burn-tensor-testgen/src/lib.rs | 30 +- burn-tensor/src/tensor/activation/base.rs | 62 +- burn-tensor/src/tensor/api/base.rs | 2456 +++++----- burn-tensor/src/tensor/api/bool.rs | 42 +- burn-tensor/src/tensor/api/check.rs | 1218 +++-- burn-tensor/src/tensor/api/float.rs | 616 +-- burn-tensor/src/tensor/api/int.rs | 140 +- burn-tensor/src/tensor/api/kind.rs | 32 +- burn-tensor/src/tensor/api/numeric.rs | 4084 ++++++++--------- burn-tensor/src/tensor/backend/base.rs | 255 +- burn-tensor/src/tensor/container.rs | 111 +- burn-tensor/src/tensor/data.rs | 739 ++- burn-tensor/src/tensor/element.rs | 146 +- burn-tensor/src/tensor/loss/mod.rs | 12 +- burn-tensor/src/tensor/module.rs | 246 +- burn-tensor/src/tensor/named/base.rs | 100 +- burn-tensor/src/tensor/named/dims.rs | 112 +- burn-tensor/src/tensor/named/matmul.rs | 74 +- burn-tensor/src/tensor/named/swap_dims.rs | 74 +- burn-tensor/src/tensor/ops/activation.rs | 190 +- burn-tensor/src/tensor/ops/bool_tensor.rs | 461 +- burn-tensor/src/tensor/ops/int_tensor.rs | 1677 ++++--- burn-tensor/src/tensor/ops/modules/base.rs | 741 ++- burn-tensor/src/tensor/ops/modules/conv.rs | 1316 +++--- burn-tensor/src/tensor/ops/modules/pool.rs | 246 +- burn-tensor/src/tensor/ops/modules/unfold.rs | 105 +- burn-tensor/src/tensor/ops/tensor.rs | 2123 +++++---- burn-tensor/src/tensor/shape.rs | 88 +- burn-tensor/src/tensor/stats/mod.rs | 38 +- burn-tensor/src/tests/activation/gelu.rs | 30 +- .../src/tests/activation/quiet_softmax.rs | 16 + burn-tensor/src/tests/activation/relu.rs | 20 +- burn-tensor/src/tests/activation/sigmoid.rs | 36 +- burn-tensor/src/tests/activation/silu.rs | 20 +- burn-tensor/src/tests/activation/softmax.rs | 20 +- .../src/tests/activation/tanh_activation.rs | 20 +- burn-tensor/src/tests/clone_invariance.rs | 1394 +++--- burn-tensor/src/tests/mod.rs | 142 +- .../src/tests/module/adaptive_avgpool1d.rs | 120 +- .../src/tests/module/adaptive_avgpool2d.rs | 180 +- burn-tensor/src/tests/module/avgpool1d.rs | 150 +- burn-tensor/src/tests/module/avgpool2d.rs | 200 +- burn-tensor/src/tests/module/conv1d.rs | 240 +- burn-tensor/src/tests/module/conv2d.rs | 304 +- .../src/tests/module/conv_transpose1d.rs | 262 +- .../src/tests/module/conv_transpose2d.rs | 632 +-- burn-tensor/src/tests/module/forward.rs | 30 +- burn-tensor/src/tests/module/maxpool1d.rs | 226 +- burn-tensor/src/tests/module/maxpool2d.rs | 594 +-- burn-tensor/src/tests/module/unfold4d.rs | 230 +- burn-tensor/src/tests/ops/abs.rs | 36 +- burn-tensor/src/tests/ops/add.rs | 160 +- burn-tensor/src/tests/ops/aggregation.rs | 172 +- burn-tensor/src/tests/ops/arange.rs | 30 +- burn-tensor/src/tests/ops/arange_step.rs | 86 +- burn-tensor/src/tests/ops/arg.rs | 100 +- burn-tensor/src/tests/ops/cast.rs | 80 +- burn-tensor/src/tests/ops/cat.rs | 164 +- burn-tensor/src/tests/ops/clamp.rs | 114 +- burn-tensor/src/tests/ops/cos.rs | 20 +- burn-tensor/src/tests/ops/create_like.rs | 71 +- burn-tensor/src/tests/ops/div.rs | 164 +- burn-tensor/src/tests/ops/erf.rs | 42 +- burn-tensor/src/tests/ops/exp.rs | 20 +- burn-tensor/src/tests/ops/flatten.rs | 98 +- burn-tensor/src/tests/ops/full.rs | 38 +- burn-tensor/src/tests/ops/gather_scatter.rs | 340 +- burn-tensor/src/tests/ops/init.rs | 96 +- burn-tensor/src/tests/ops/iter_dim.rs | 78 +- burn-tensor/src/tests/ops/log.rs | 26 +- burn-tensor/src/tests/ops/log1p.rs | 26 +- burn-tensor/src/tests/ops/map_comparison.rs | 610 +-- burn-tensor/src/tests/ops/mask.rs | 104 +- burn-tensor/src/tests/ops/matmul.rs | 207 +- burn-tensor/src/tests/ops/maxmin.rs | 68 +- burn-tensor/src/tests/ops/mul.rs | 164 +- burn-tensor/src/tests/ops/neg.rs | 20 +- burn-tensor/src/tests/ops/one_hot.rs | 46 +- burn-tensor/src/tests/ops/powf.rs | 68 +- burn-tensor/src/tests/ops/random.rs | 36 +- burn-tensor/src/tests/ops/recip.rs | 20 +- burn-tensor/src/tests/ops/repeat.rs | 30 +- burn-tensor/src/tests/ops/reshape.rs | 140 +- burn-tensor/src/tests/ops/select.rs | 248 +- burn-tensor/src/tests/ops/sin.rs | 20 +- burn-tensor/src/tests/ops/slice.rs | 218 +- burn-tensor/src/tests/ops/sqrt.rs | 22 +- burn-tensor/src/tests/ops/squeeze.rs | 66 +- burn-tensor/src/tests/ops/sub.rs | 160 +- burn-tensor/src/tests/ops/tanh.rs | 20 +- burn-tensor/src/tests/ops/transpose.rs | 184 +- burn-tensor/src/tests/stats/cov.rs | 116 +- burn-tensor/src/tests/stats/diagonal.rs | 24 +- burn-tensor/src/tests/stats/display.rs | 236 +- burn-tensor/src/tests/stats/var.rs | 74 +- burn-train/src/checkpoint/async_checkpoint.rs | 166 +- burn-train/src/checkpoint/base.rs | 50 +- burn-train/src/checkpoint/file.rs | 95 +- burn-train/src/checkpoint/strategy/base.rs | 34 +- .../src/checkpoint/strategy/composed.rs | 227 +- burn-train/src/checkpoint/strategy/lastn.rs | 72 +- burn-train/src/checkpoint/strategy/metric.rs | 224 +- burn-train/src/components.rs | 106 +- burn-train/src/learner/base.rs | 180 +- burn-train/src/learner/builder.rs | 570 ++- burn-train/src/learner/classification.rs | 24 +- burn-train/src/learner/early_stopping.rs | 352 +- burn-train/src/learner/epoch.rs | 426 +- burn-train/src/learner/log.rs | 59 +- burn-train/src/learner/regression.rs | 18 +- burn-train/src/learner/step/train.rs | 206 +- burn-train/src/learner/train_val.rs | 302 +- burn-train/src/logger/async_logger.rs | 115 +- burn-train/src/logger/base.rs | 36 +- burn-train/src/logger/file.rs | 48 +- burn-train/src/logger/in_memory.rs | 10 +- burn-train/src/logger/metric.rs | 268 +- burn-train/src/metric/acc.rs | 202 +- burn-train/src/metric/base.rs | 94 +- burn-train/src/metric/cpu_temp.rs | 58 +- burn-train/src/metric/cpu_use.rs | 78 +- burn-train/src/metric/cuda.rs | 154 +- burn-train/src/metric/learning_rate.rs | 49 +- burn-train/src/metric/loss.rs | 41 +- burn-train/src/metric/memory_use.rs | 94 +- burn-train/src/metric/processor/base.rs | 50 +- burn-train/src/metric/processor/full.rs | 144 +- burn-train/src/metric/processor/metrics.rs | 282 +- burn-train/src/metric/processor/minimal.rs | 80 +- burn-train/src/metric/processor/mod.rs | 66 +- burn-train/src/metric/state.rs | 130 +- burn-train/src/metric/store/aggregate.rs | 248 +- burn-train/src/metric/store/base.rs | 72 +- burn-train/src/metric/store/client.rs | 241 +- burn-train/src/metric/store/log.rs | 170 +- burn-train/src/renderer/base.rs | 94 +- burn-train/src/renderer/cli.rs | 24 +- burn-train/src/renderer/mod.rs | 12 +- burn-train/src/renderer/tui/base.rs | 64 +- burn-train/src/renderer/tui/controls.rs | 72 +- burn-train/src/renderer/tui/full_history.rs | 340 +- burn-train/src/renderer/tui/metric_numeric.rs | 361 +- burn-train/src/renderer/tui/metric_text.rs | 148 +- burn-train/src/renderer/tui/plot_utils.rs | 64 +- burn-train/src/renderer/tui/popup.rs | 200 +- burn-train/src/renderer/tui/progress.rs | 422 +- burn-train/src/renderer/tui/recent_history.rs | 356 +- burn-train/src/renderer/tui/renderer.rs | 258 +- burn-train/src/renderer/tui/status.rs | 126 +- burn-wgpu/benches/fused_elemwise.rs | 88 +- burn-wgpu/benches/matmul.rs | 168 +- burn-wgpu/benches/reduction.rs | 130 +- burn-wgpu/src/backend.rs | 66 +- burn-wgpu/src/compute/base.rs | 358 +- burn-wgpu/src/compute/kernel.rs | 113 +- burn-wgpu/src/compute/server.rs | 564 ++- burn-wgpu/src/compute/storage.rs | 170 +- burn-wgpu/src/compute/tune_key.rs | 24 +- burn-wgpu/src/device.rs | 56 +- burn-wgpu/src/element.rs | 64 +- burn-wgpu/src/fusion/base.rs | 204 +- burn-wgpu/src/fusion/codegen/body.rs | 22 +- burn-wgpu/src/fusion/codegen/function.rs | 22 +- burn-wgpu/src/fusion/codegen/operator.rs | 238 +- burn-wgpu/src/fusion/codegen/shader.rs | 252 +- burn-wgpu/src/fusion/codegen/variable.rs | 22 +- burn-wgpu/src/fusion/elemwise/ops.rs | 817 ++-- burn-wgpu/src/fusion/kernel.rs | 468 +- burn-wgpu/src/graphics.rs | 52 +- burn-wgpu/src/kernel/base.rs | 350 +- burn-wgpu/src/kernel/binary_elemwise.rs | 254 +- burn-wgpu/src/kernel/cast.rs | 99 +- burn-wgpu/src/kernel/cat.rs | 154 +- burn-wgpu/src/kernel/clamp.rs | 153 +- burn-wgpu/src/kernel/comparison/base.rs | 168 +- burn-wgpu/src/kernel/comparison/binary.rs | 269 +- burn-wgpu/src/kernel/comparison/elem.rs | 205 +- burn-wgpu/src/kernel/conv/conv2d.rs | 162 +- burn-wgpu/src/kernel/conv/conv_transpose2d.rs | 213 +- burn-wgpu/src/kernel/index/gather.rs | 133 +- burn-wgpu/src/kernel/index/scatter.rs | 251 +- burn-wgpu/src/kernel/index/select.rs | 313 +- burn-wgpu/src/kernel/index/slice.rs | 204 +- burn-wgpu/src/kernel/mask/base.rs | 34 +- burn-wgpu/src/kernel/mask/mask_fill.rs | 210 +- burn-wgpu/src/kernel/mask/mask_where.rs | 266 +- burn-wgpu/src/kernel/matmul/mem_coalescing.rs | 326 +- burn-wgpu/src/kernel/matmul/naive.rs | 298 +- burn-wgpu/src/kernel/matmul/tiling2d/base.rs | 120 +- .../src/kernel/matmul/tiling2d/padding.rs | 531 ++- .../src/kernel/matmul/tiling2d/unpadded.rs | 332 +- burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs | 242 +- .../src/kernel/matmul/tiling2d/vec4_lhs.rs | 244 +- burn-wgpu/src/kernel/matmul/tune/base.rs | 246 +- burn-wgpu/src/kernel/matmul/tune/key.rs | 172 +- burn-wgpu/src/kernel/matmul/utils.rs | 123 +- .../src/kernel/pool/adaptive_avg_pool2d.rs | 132 +- burn-wgpu/src/kernel/pool/avg_pool2d.rs | 269 +- burn-wgpu/src/kernel/pool/base.rs | 107 +- burn-wgpu/src/kernel/pool/max_pool2d.rs | 335 +- burn-wgpu/src/kernel/prng/base.rs | 136 +- burn-wgpu/src/kernel/prng/bernoulli.rs | 225 +- burn-wgpu/src/kernel/prng/normal.rs | 189 +- burn-wgpu/src/kernel/prng/uniform.rs | 265 +- burn-wgpu/src/kernel/reduce/base.rs | 30 +- burn-wgpu/src/kernel/reduce/reduction.rs | 296 +- .../kernel/reduce/reduction_shared_memory.rs | 248 +- burn-wgpu/src/kernel/reduce/tune/base.rs | 42 +- burn-wgpu/src/kernel/reduce/tune/key.rs | 62 +- burn-wgpu/src/kernel/reduce/tune/mean_dim.rs | 144 +- burn-wgpu/src/kernel/reduce/tune/sum_dim.rs | 144 +- burn-wgpu/src/kernel/source.rs | 98 +- burn-wgpu/src/kernel/unary.rs | 274 +- burn-wgpu/src/kernel/unary_scalar.rs | 272 +- burn-wgpu/src/lib.rs | 16 +- burn-wgpu/src/ops/activation_ops.rs | 28 +- burn-wgpu/src/ops/base.rs | 84 +- burn-wgpu/src/ops/bool_ops.rs | 227 +- burn-wgpu/src/ops/float_ops.rs | 882 ++-- burn-wgpu/src/ops/int_ops.rs | 625 ++- burn-wgpu/src/ops/module_ops.rs | 181 +- burn-wgpu/src/ops/numeric.rs | 214 +- burn-wgpu/src/tensor/base.rs | 232 +- burn/src/lib.rs | 2 +- .../examples/custom-renderer.rs | 2 +- examples/custom-renderer/src/lib.rs | 108 +- .../examples/custom-training-loop.rs | 2 +- examples/custom-training-loop/src/lib.rs | 244 +- .../examples/custom-wgpu-kernel.rs | 94 +- examples/custom-wgpu-kernel/src/backward.rs | 185 +- examples/custom-wgpu-kernel/src/forward.rs | 176 +- examples/custom-wgpu-kernel/src/lib.rs | 38 +- examples/guide/examples/guide.rs | 32 +- examples/guide/src/data.rs | 54 +- examples/guide/src/inference.rs | 34 +- examples/guide/src/model.rs | 121 +- examples/guide/src/training.rs | 154 +- examples/image-classification-web/build.rs | 76 +- .../src/model/normalizer.rs | 42 +- .../src/model/squeezenet.rs | 2 +- examples/image-classification-web/src/web.rs | 256 +- examples/mnist-inference-web/src/model.rs | 130 +- examples/mnist-inference-web/src/state.rs | 14 +- examples/mnist-inference-web/src/web.rs | 104 +- examples/mnist/examples/mnist.rs | 96 +- examples/mnist/src/data.rs | 54 +- examples/mnist/src/model.rs | 180 +- examples/mnist/src/training.rs | 130 +- .../named-tensor/examples/named-tensor.rs | 2 +- examples/named-tensor/src/lib.rs | 74 +- examples/onnx-inference/build.rs | 32 +- .../onnx-inference/src/bin/mnist_inference.rs | 74 +- examples/onnx-inference/src/model/mod.rs | 2 +- .../examples/ag-news-infer.rs | 92 +- .../examples/ag-news-train.rs | 114 +- .../examples/db-pedia-infer.rs | 92 +- .../examples/db-pedia-train.rs | 114 +- .../text-classification/src/data/batcher.rs | 108 +- .../text-classification/src/data/dataset.rs | 221 +- .../text-classification/src/data/tokenizer.rs | 74 +- examples/text-classification/src/inference.rs | 106 +- examples/text-classification/src/model.rs | 281 +- examples/text-classification/src/training.rs | 171 +- .../examples/text-generation.rs | 31 +- examples/text-generation/src/data/batcher.rs | 74 +- examples/text-generation/src/data/dataset.rs | 47 +- .../text-generation/src/data/tokenizer.rs | 144 +- examples/text-generation/src/model.rs | 157 +- examples/text-generation/src/training.rs | 144 +- xtask/src/main.rs | 34 +- xtask/src/publish.rs | 178 +- xtask/src/runchecks.rs | 578 +-- 580 files changed, 67666 insertions(+), 67986 deletions(-) create mode 100644 burn-tensor/src/tests/activation/quiet_softmax.rs diff --git a/backend-comparison/benches/binary.rs b/backend-comparison/benches/binary.rs index cb5b3264f5..43ae124d91 100644 --- a/backend-comparison/benches/binary.rs +++ b/backend-comparison/benches/binary.rs @@ -2,46 +2,46 @@ use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; pub struct BinaryBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for BinaryBenchmark { - type Args = (Tensor, Tensor); + type Args = (Tensor, Tensor); - fn name(&self) -> String { - "Binary Ops".into() - } + fn name(&self) -> String { + "Binary Ops".into() + } - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.num_repeats { - // Choice of add is arbitrary - B::add(lhs.clone().into_primitive(), rhs.clone().into_primitive()); - } + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.num_repeats { + // Choice of add is arbitrary + B::add(lhs.clone().into_primitive(), rhs.clone().into_primitive()); } + } - fn prepare(&self) -> Self::Args { - let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] fn bench(device: &B::Device) { - run_benchmark(BinaryBenchmark:: { - shape: [32, 512, 1024].into(), - num_repeats: 10, - device: device.clone(), - }) + run_benchmark(BinaryBenchmark:: { + shape: [32, 512, 1024].into(), + num_repeats: 10, + device: device.clone(), + }) } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/custom_gelu.rs b/backend-comparison/benches/custom_gelu.rs index 71db646a97..63c52eefc0 100644 --- a/backend-comparison/benches/custom_gelu.rs +++ b/backend-comparison/benches/custom_gelu.rs @@ -5,62 +5,62 @@ use derive_new::new; #[derive(Debug)] enum GeluKind { - Reference, - WithReferenceErf, - WithCustomErf, + Reference, + WithReferenceErf, + WithCustomErf, } /// Benchmark how well a backend executes a custom activation function with a lot of basic tensor /// operations. #[derive(new)] struct CustomGeluBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, - kind: GeluKind, + shape: Shape, + num_repeats: usize, + device: B::Device, + kind: GeluKind, } impl Benchmark for CustomGeluBenchmark { - type Args = Tensor; - - fn name(&self) -> String { - format!("Gelu {:?}", self.kind) + type Args = Tensor; + + fn name(&self) -> String { + format!("Gelu {:?}", self.kind) + } + + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + match self.kind { + GeluKind::Reference => burn::tensor::activation::gelu(args.clone()), + GeluKind::WithReferenceErf => gelu_custom(args.clone(), Tensor::erf), + GeluKind::WithCustomErf => gelu_custom(args.clone(), erf_custom), + }; } + } - fn execute(&self, args: Self::Args) { - for _ in 0..self.num_repeats { - match self.kind { - GeluKind::Reference => burn::tensor::activation::gelu(args.clone()), - GeluKind::WithReferenceErf => gelu_custom(args.clone(), Tensor::erf), - GeluKind::WithCustomErf => gelu_custom(args.clone(), erf_custom), - }; - } - } + fn prepare(&self) -> Self::Args { + Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } - fn prepare(&self) -> Self::Args { - Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } - - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } fn gelu_custom(x: Tensor, erf: Erf) -> Tensor where - B: Backend, - Erf: Fn(Tensor) -> Tensor, + B: Backend, + Erf: Fn(Tensor) -> Tensor, { - let x = x.clone() * (erf(x / SQRT_2) + 1); - x / 2 + let x = x.clone() * (erf(x / SQRT_2) + 1); + x / 2 } fn erf_custom(x: Tensor) -> Tensor { - let x1 = -erf_positive(-x.clone()); - let x2 = erf_positive(x.clone()); - let mask = x.greater_elem(0); + let x1 = -erf_positive(-x.clone()); + let x2 = erf_positive(x.clone()); + let mask = x.greater_elem(0); - x1.mask_where(mask, x2) + x1.mask_where(mask, x2) } /// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations @@ -68,47 +68,47 @@ fn erf_custom(x: Tensor) -> Tensor { /// > (maximum error: 1.5×10−7) /// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x). fn erf_positive(x: Tensor) -> Tensor { - let p = 0.3275911; - let a1 = 0.254829592; - let a2 = -0.284496736; - let a3 = 1.421413741; - let a4 = -1.453152027; - let a5 = 1.061405429; - - let x1 = x.clone().abs() * p + 1; - let t = x1.recip(); - let tmp = (((((t.clone() * a5) + a4) * t.clone()) + a3) * t.clone() + a2) * t.clone() + a1; - - -(tmp * t * (-x.clone() * x).exp()) + 1.0 + let p = 0.3275911; + let a1 = 0.254829592; + let a2 = -0.284496736; + let a3 = 1.421413741; + let a4 = -1.453152027; + let a5 = 1.061405429; + + let x1 = x.clone().abs() * p + 1; + let t = x1.recip(); + let tmp = (((((t.clone() * a5) + a4) * t.clone()) + a3) * t.clone() + a2) * t.clone() + a1; + + -(tmp * t * (-x.clone() * x).exp()) + 1.0 } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let shape: Shape = [32, 512, 2048].into(); - let num_repeats = 1; - - println!("Backend {}", B::name()); - run_benchmark(CustomGeluBenchmark::::new( - shape.clone(), - num_repeats, - device.clone(), - GeluKind::Reference, - )); - run_benchmark(CustomGeluBenchmark::::new( - shape.clone(), - num_repeats, - device.clone(), - GeluKind::WithReferenceErf, - )); - run_benchmark(CustomGeluBenchmark::::new( - shape, - num_repeats, - device.clone(), - GeluKind::WithCustomErf, - )); + const D: usize = 3; + let shape: Shape = [32, 512, 2048].into(); + let num_repeats = 1; + + println!("Backend {}", B::name()); + run_benchmark(CustomGeluBenchmark::::new( + shape.clone(), + num_repeats, + device.clone(), + GeluKind::Reference, + )); + run_benchmark(CustomGeluBenchmark::::new( + shape.clone(), + num_repeats, + device.clone(), + GeluKind::WithReferenceErf, + )); + run_benchmark(CustomGeluBenchmark::::new( + shape, + num_repeats, + device.clone(), + GeluKind::WithCustomErf, + )); } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/data.rs b/backend-comparison/benches/data.rs index e9379b3933..df1f439e1e 100644 --- a/backend-comparison/benches/data.rs +++ b/backend-comparison/benches/data.rs @@ -4,83 +4,83 @@ use derive_new::new; #[derive(new)] struct ToDataBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for ToDataBenchmark { - type Args = Tensor; + type Args = Tensor; - fn name(&self) -> String { - format!("to-data-{:?}-{}", self.shape.dims, self.num_repeats) - } + fn name(&self) -> String { + format!("to-data-{:?}-{}", self.shape.dims, self.num_repeats) + } - fn execute(&self, args: Self::Args) { - for _ in 0..self.num_repeats { - let _data = args.to_data(); - } + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + let _data = args.to_data(); } + } - fn prepare(&self) -> Self::Args { - Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } + fn prepare(&self) -> Self::Args { + Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[derive(new)] struct FromDataBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for FromDataBenchmark { - type Args = (Data, B::Device); - - fn name(&self) -> String { - format!("from-data-{:?}-{}", self.shape.dims, self.num_repeats) - } - - fn execute(&self, (data, device): Self::Args) { - for _ in 0..self.num_repeats { - let _data = Tensor::::from_data_device(data.clone(), &device); - } - } + type Args = (Data, B::Device); - fn prepare(&self) -> Self::Args { - ( - Data::random( - self.shape.clone(), - Distribution::Default, - &mut rand::thread_rng(), - ), - self.device.clone(), - ) - } + fn name(&self) -> String { + format!("from-data-{:?}-{}", self.shape.dims, self.num_repeats) + } - fn sync(&self) { - B::sync(&self.device) + fn execute(&self, (data, device): Self::Args) { + for _ in 0..self.num_repeats { + let _data = Tensor::::from_data_device(data.clone(), &device); } + } + + fn prepare(&self) -> Self::Args { + ( + Data::random( + self.shape.clone(), + Distribution::Default, + &mut rand::thread_rng(), + ), + self.device.clone(), + ) + } + + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let shape: Shape = [32, 512, 1024].into(); - let num_repeats = 10; + const D: usize = 3; + let shape: Shape = [32, 512, 1024].into(); + let num_repeats = 10; - let to_benchmark = ToDataBenchmark::::new(shape.clone(), num_repeats, device.clone()); - let from_benchmark = FromDataBenchmark::::new(shape, num_repeats, device.clone()); + let to_benchmark = ToDataBenchmark::::new(shape.clone(), num_repeats, device.clone()); + let from_benchmark = FromDataBenchmark::::new(shape, num_repeats, device.clone()); - println!("Backend {}", B::name()); - run_benchmark(to_benchmark); - run_benchmark(from_benchmark) + println!("Backend {}", B::name()); + run_benchmark(to_benchmark); + run_benchmark(from_benchmark) } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 7574e21970..5114300afa 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -4,62 +4,60 @@ use derive_new::new; #[derive(new)] struct MatmulBenchmark { - shape_lhs: Shape, - shape_rhs: Shape, - num_repeats: usize, - device: B::Device, + shape_lhs: Shape, + shape_rhs: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for MatmulBenchmark { - type Args = (Tensor, Tensor); - - fn name(&self) -> String { - format!( - "Matmul {:?} x {:?}", - self.shape_lhs.dims, self.shape_rhs.dims - ) - } - - fn num_samples(&self) -> usize { - 10 - } - - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.num_repeats { - lhs.clone().matmul(rhs.clone()); - } + type Args = (Tensor, Tensor); + + fn name(&self) -> String { + format!( + "Matmul {:?} x {:?}", + self.shape_lhs.dims, self.shape_rhs.dims + ) + } + + fn num_samples(&self) -> usize { + 10 + } + + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.num_repeats { + lhs.clone().matmul(rhs.clone()); } + } - fn prepare(&self) -> Self::Args { - let lhs = - Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); - let rhs = - Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + let lhs = Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let num_repeats = 3; - let batch_size = 3; - let m = 1024; - let k = 2048; - let n = 1024; - let shape_lhs = [batch_size, m, k].into(); - let shape_rhs = [batch_size, k, n].into(); - - let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, num_repeats, device.clone()); - println!("Backend {}", B::name()); - run_benchmark(benchmark); + const D: usize = 3; + let num_repeats = 3; + let batch_size = 3; + let m = 1024; + let k = 2048; + let n = 1024; + let shape_lhs = [batch_size, m, k].into(); + let shape_rhs = [batch_size, k, n].into(); + + let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, num_repeats, device.clone()); + println!("Backend {}", B::name()); + run_benchmark(benchmark); } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/unary.rs b/backend-comparison/benches/unary.rs index 4befcdd2ad..b836f0a845 100644 --- a/backend-comparison/benches/unary.rs +++ b/backend-comparison/benches/unary.rs @@ -4,46 +4,46 @@ use derive_new::new; #[derive(new)] struct UnaryBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for UnaryBenchmark { - type Args = Tensor; + type Args = Tensor; - fn name(&self) -> String { - "Unary Ops".into() - } + fn name(&self) -> String { + "Unary Ops".into() + } - fn execute(&self, args: Self::Args) { - for _ in 0..self.num_repeats { - // Choice of tanh is arbitrary - B::tanh(args.clone().into_primitive()); - } + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + // Choice of tanh is arbitrary + B::tanh(args.clone().into_primitive()); } + } - fn prepare(&self) -> Self::Args { - Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } + fn prepare(&self) -> Self::Args { + Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let shape: Shape = [32, 512, 1024].into(); - let num_repeats = 10; + const D: usize = 3; + let shape: Shape = [32, 512, 1024].into(); + let num_repeats = 10; - let benchmark = UnaryBenchmark::::new(shape, num_repeats, device.clone()); + let benchmark = UnaryBenchmark::::new(shape, num_repeats, device.clone()); - println!("Backend {}", B::name()); - run_benchmark(benchmark) + println!("Backend {}", B::name()); + run_benchmark(benchmark) } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 065b50f418..c82e5e351f 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -1,70 +1,70 @@ #[macro_export] macro_rules! bench_on_backend { - () => { - #[cfg(feature = "wgpu-fusion")] - { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Fusion; + () => { + #[cfg(feature = "wgpu-fusion")] + { + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Fusion; - bench::>>(&WgpuDevice::default()); - } + bench::>>(&WgpuDevice::default()); + } - #[cfg(feature = "wgpu")] - { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + #[cfg(feature = "wgpu")] + { + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - bench::>(&WgpuDevice::default()); - } + bench::>(&WgpuDevice::default()); + } - #[cfg(feature = "tch-gpu")] - { - use burn::backend::{libtorch::LibTorchDevice, LibTorch}; + #[cfg(feature = "tch-gpu")] + { + use burn::backend::{libtorch::LibTorchDevice, LibTorch}; - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; - bench::(&device); - } + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + bench::(&device); + } - #[cfg(feature = "tch-cpu")] - { - use burn::backend::{libtorch::LibTorchDevice, LibTorch}; + #[cfg(feature = "tch-cpu")] + { + use burn::backend::{libtorch::LibTorchDevice, LibTorch}; - let device = LibTorchDevice::Cpu; - bench::(&device); - } + let device = LibTorchDevice::Cpu; + bench::(&device); + } - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - { - use burn::backend::ndarray::NdArrayDevice; - use burn::backend::NdArray; + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + { + use burn::backend::ndarray::NdArrayDevice; + use burn::backend::NdArray; - let device = NdArrayDevice::Cpu; - bench::(&device); - } + let device = NdArrayDevice::Cpu; + bench::(&device); + } - #[cfg(feature = "candle-cpu")] - { - use burn::backend::candle::CandleDevice; - use burn::backend::Candle; + #[cfg(feature = "candle-cpu")] + { + use burn::backend::candle::CandleDevice; + use burn::backend::Candle; - let device = CandleDevice::Cpu; - bench::(&device); - } + let device = CandleDevice::Cpu; + bench::(&device); + } - #[cfg(feature = "candle-cuda")] - { - use burn::backend::candle::CandleDevice; - use burn::backend::Candle; + #[cfg(feature = "candle-cuda")] + { + use burn::backend::candle::CandleDevice; + use burn::backend::Candle; - let device = CandleDevice::Cuda(0); - bench::(&device); - } - }; + let device = CandleDevice::Cuda(0); + bench::(&device); + } + }; } diff --git a/burn-autodiff/src/backend.rs b/burn-autodiff/src/backend.rs index e0039ae290..302fa2d7d2 100644 --- a/burn-autodiff/src/backend.rs +++ b/burn-autodiff/src/backend.rs @@ -8,75 +8,75 @@ use core::marker::PhantomData; /// backpropagation. #[derive(Clone, Copy, Debug, Default)] pub struct Autodiff { - _b: PhantomData, + _b: PhantomData, } impl Backend for Autodiff { - type Device = B::Device; + type Device = B::Device; - type FullPrecisionElem = B::FullPrecisionElem; - type FullPrecisionBackend = Autodiff; + type FullPrecisionElem = B::FullPrecisionElem; + type FullPrecisionBackend = Autodiff; - type TensorPrimitive = AutodiffTensor; - type FloatElem = B::FloatElem; + type TensorPrimitive = AutodiffTensor; + type FloatElem = B::FloatElem; - type IntTensorPrimitive = B::IntTensorPrimitive; - type IntElem = B::IntElem; + type IntTensorPrimitive = B::IntTensorPrimitive; + type IntElem = B::IntElem; - type BoolTensorPrimitive = B::BoolTensorPrimitive; + type BoolTensorPrimitive = B::BoolTensorPrimitive; - fn ad_enabled() -> bool { - true - } + fn ad_enabled() -> bool { + true + } - fn name() -> String { - format!("autodiff<{}>", B::name()) - } + fn name() -> String { + format!("autodiff<{}>", B::name()) + } - fn seed(seed: u64) { - B::seed(seed) - } + fn seed(seed: u64) { + B::seed(seed) + } - fn sync(device: &B::Device) { - B::sync(device); - } + fn sync(device: &B::Device) { + B::sync(device); + } } impl AutodiffBackend for Autodiff { - type InnerBackend = B; - type Gradients = Gradients; - - fn backward(tensor: AutodiffTensor) -> Gradients { - backward(tensor) - } - - fn grad( - tensor: &AutodiffTensor, - grads: &Gradients, - ) -> Option> { - grads.get(tensor) - } - - fn grad_remove( - tensor: &AutodiffTensor, - grads: &mut Gradients, - ) -> Option> { - grads.remove(tensor) - } - fn inner(tensor: AutodiffTensor) -> B::TensorPrimitive { - tensor.primitive - } - - fn from_inner(tensor: B::TensorPrimitive) -> AutodiffTensor { - AutodiffTensor::new(tensor) - } - - fn grad_replace( - tensor: &AutodiffTensor, - grads: &mut Self::Gradients, - grad: B::TensorPrimitive, - ) { - grads.remove(tensor); - grads.register::(tensor.node.clone(), grad); - } + type InnerBackend = B; + type Gradients = Gradients; + + fn backward(tensor: AutodiffTensor) -> Gradients { + backward(tensor) + } + + fn grad( + tensor: &AutodiffTensor, + grads: &Gradients, + ) -> Option> { + grads.get(tensor) + } + + fn grad_remove( + tensor: &AutodiffTensor, + grads: &mut Gradients, + ) -> Option> { + grads.remove(tensor) + } + fn inner(tensor: AutodiffTensor) -> B::TensorPrimitive { + tensor.primitive + } + + fn from_inner(tensor: B::TensorPrimitive) -> AutodiffTensor { + AutodiffTensor::new(tensor) + } + + fn grad_replace( + tensor: &AutodiffTensor, + grads: &mut Self::Gradients, + grad: B::TensorPrimitive, + ) { + grads.remove(tensor); + grads.register::(tensor.node.clone(), grad); + } } diff --git a/burn-autodiff/src/grads.rs b/burn-autodiff/src/grads.rs index 9e2c5b3e8c..1e090e6ea4 100644 --- a/burn-autodiff/src/grads.rs +++ b/burn-autodiff/src/grads.rs @@ -1,8 +1,8 @@ use burn_tensor::{backend::Backend, container::TensorContainer, Tensor}; use crate::{ - graph::{NodeRef, Requirement}, - tensor::AutodiffTensor, + graph::{NodeRef, Requirement}, + tensor::AutodiffTensor, }; /// Gradient identifier. @@ -10,81 +10,85 @@ pub type GradID = u64; /// Gradients container used during the backward pass. pub struct Gradients { - container: TensorContainer, + container: TensorContainer, } type TensorPrimitive = ::TensorPrimitive; impl Gradients { - /// Creates a new gradients container. - pub fn new( - root_node: NodeRef, - root_tensor: TensorPrimitive, - ) -> Self { - let mut gradients = Self { - container: TensorContainer::new(), - }; - gradients.register::( - root_node, - B::ones(B::shape(&root_tensor), &B::device(&root_tensor)), - ); - gradients - } + /// Creates a new gradients container. + pub fn new( + root_node: NodeRef, + root_tensor: TensorPrimitive, + ) -> Self { + let mut gradients = Self { + container: TensorContainer::new(), + }; + gradients.register::( + root_node, + B::ones(B::shape(&root_tensor), &B::device(&root_tensor)), + ); + gradients + } - /// Consumes the gradients for a given tensor. - /// - /// Each tensor should be consumed exactly 1 time if its gradients are only required during the - /// backward pass, otherwise, it may be consume multiple times. - pub fn consume(&mut self, node: &NodeRef) -> TensorPrimitive { - match node.requirement { - Requirement::Grad => self - .container - .get::(&node.id.value) - .map(|tensor| tensor.into_primitive()) - .expect("Can't consume the gradients before they are registered at least once."), - Requirement::GradInBackward => self - .container - .remove::(&node.id.value) - .map(|tensor| tensor.into_primitive()) - .expect("Can't consume the gradients before they are registered at least once."), - Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"), - } + /// Consumes the gradients for a given tensor. + /// + /// Each tensor should be consumed exactly 1 time if its gradients are only required during the + /// backward pass, otherwise, it may be consume multiple times. + pub fn consume(&mut self, node: &NodeRef) -> TensorPrimitive { + match node.requirement { + Requirement::Grad => self + .container + .get::(&node.id.value) + .map(|tensor| tensor.into_primitive()) + .expect("Can't consume the gradients before they are registered at least once."), + Requirement::GradInBackward => self + .container + .remove::(&node.id.value) + .map(|tensor| tensor.into_primitive()) + .expect("Can't consume the gradients before they are registered at least once."), + Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"), } + } - /// Removes a grad tensor from the container. - pub fn remove( - &mut self, - tensor: &AutodiffTensor, - ) -> Option> { - self.container - .remove::(&tensor.node.id.value) - .map(|tensor| tensor.into_primitive()) - } + /// Removes a grad tensor from the container. + pub fn remove( + &mut self, + tensor: &AutodiffTensor, + ) -> Option> { + self + .container + .remove::(&tensor.node.id.value) + .map(|tensor| tensor.into_primitive()) + } - /// Gets a grad tensor from the container. - pub fn get( - &self, - tensor: &AutodiffTensor, - ) -> Option> { - self.container - .get::(&tensor.node.id.value) - .map(|tensor| tensor.into_primitive()) - } + /// Gets a grad tensor from the container. + pub fn get( + &self, + tensor: &AutodiffTensor, + ) -> Option> { + self + .container + .get::(&tensor.node.id.value) + .map(|tensor| tensor.into_primitive()) + } - /// Register a grad tensor in the container. - /// - /// If the tensor already exists, add both tensors together before saving the result. - pub fn register( - &mut self, - node: NodeRef, - value: TensorPrimitive, - ) { - if let Some(tensor_old) = self.container.remove::(&node.id.value) { - self.container - .register(node.id.value, Tensor::from_primitive(value).add(tensor_old)); - } else { - self.container - .register::(node.id.value, Tensor::from_primitive(value)); - } + /// Register a grad tensor in the container. + /// + /// If the tensor already exists, add both tensors together before saving the result. + pub fn register( + &mut self, + node: NodeRef, + value: TensorPrimitive, + ) { + if let Some(tensor_old) = self.container.remove::(&node.id.value) { + self + .container + .register(node.id.value, Tensor::from_primitive(value).add(tensor_old)); + } else { + self + .container + .register::(node.id.value, Tensor::from_primitive(value)); } + } } diff --git a/burn-autodiff/src/graph/backward.rs b/burn-autodiff/src/graph/backward.rs index ea1edf7cc0..ea1517c81c 100644 --- a/burn-autodiff/src/graph/backward.rs +++ b/burn-autodiff/src/graph/backward.rs @@ -5,33 +5,34 @@ use crate::{grads::Gradients, tensor::AutodiffTensor}; use super::{traversal::BreadthFirstSearch, Graph, NodeRef, StepBoxed}; pub fn backward(root: AutodiffTensor) -> Gradients { - let grads = Gradients::new::(root.node.clone(), root.primitive); - let tape = build_tape(root.node, root.graph); + let grads = Gradients::new::(root.node.clone(), root.primitive); + let tape = build_tape(root.node, root.graph); - execute_steps(tape, grads) + execute_steps(tape, grads) } fn build_tape(root: NodeRef, graph: Graph) -> Vec> { - let mut tape = (0..root.order) - .map(|_| Vec::with_capacity(1)) - .collect::>(); + let mut tape = (0..root.order) + .map(|_| Vec::with_capacity(1)) + .collect::>(); - BreadthFirstSearch.traverse(root, graph, |node, step| { - if node.order == 0 { - return; - } + BreadthFirstSearch.traverse(root, graph, |node, step| { + if node.order == 0 { + return; + } - if let Some(steps) = tape.get_mut(node.order - 1) { - steps.push(step) - }; - }); + if let Some(steps) = tape.get_mut(node.order - 1) { + steps.push(step) + }; + }); - tape + tape } fn execute_steps(tape: Vec>, mut grads: Gradients) -> Gradients { - tape.into_iter() - .rev() - .for_each(|steps| steps.into_iter().for_each(|step| step.step(&mut grads))); - grads + tape + .into_iter() + .rev() + .for_each(|steps| steps.into_iter().for_each(|step| step.step(&mut grads))); + grads } diff --git a/burn-autodiff/src/graph/base.rs b/burn-autodiff/src/graph/base.rs index 1e8e989f32..5c3830a57e 100644 --- a/burn-autodiff/src/graph/base.rs +++ b/burn-autodiff/src/graph/base.rs @@ -7,10 +7,10 @@ use super::{NodeID, NodeRef}; /// Backward step for reverse mode autodiff. pub trait Step: Send + Sync + std::fmt::Debug { - /// Executes the step and consumes it. - fn step(self: Box, grads: &mut Gradients); - /// The node associated to the step. - fn node(&self) -> NodeRef; + /// Executes the step and consumes it. + fn step(self: Box, grads: &mut Gradients); + /// The node associated to the step. + fn node(&self) -> NodeRef; } pub type StepBoxed = Box; @@ -21,76 +21,76 @@ pub type NodeSteps = HashMap; /// The graph contains the [node steps](Step), which can be access by [node id](NodeID). #[derive(Default, Clone, Debug)] pub struct Graph { - steps: Arc>, + steps: Arc>, } impl Graph { - /// Create a new graph. - pub fn new() -> Self { - Self::default() - } + /// Create a new graph. + pub fn new() -> Self { + Self::default() + } - /// Get all the steps for the graph. - /// - /// # Notes - /// - /// This is a owned method, so the current graph will be freed. However, the steps can - /// be shared with other graphs, therefore they are going to be cleared. - /// - /// This is useful, since the graph is supposed to be consumed only once for backprop, and - /// keeping all the tensors alive for multiple backward call is a heavy waste of resources. - pub fn steps(self) -> NodeSteps { - let mut map_drain = HashMap::new(); - self.execute_mut(|map| { - std::mem::swap(&mut *map, &mut map_drain); - }); - map_drain - } - - /// Register a new step into the graph. - pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self { - self.execute_mut(|map| { - map.insert(id.clone(), ops); - }) - } + /// Get all the steps for the graph. + /// + /// # Notes + /// + /// This is a owned method, so the current graph will be freed. However, the steps can + /// be shared with other graphs, therefore they are going to be cleared. + /// + /// This is useful, since the graph is supposed to be consumed only once for backprop, and + /// keeping all the tensors alive for multiple backward call is a heavy waste of resources. + pub fn steps(self) -> NodeSteps { + let mut map_drain = HashMap::new(); + self.execute_mut(|map| { + std::mem::swap(&mut *map, &mut map_drain); + }); + map_drain + } - /// Merge two graphs. - pub fn merge(self, other: Self) -> Self { - if Arc::ptr_eq(&self.steps, &other.steps) { - return self; - } + /// Register a new step into the graph. + pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self { + self.execute_mut(|map| { + map.insert(id.clone(), ops); + }) + } - self.merge_different(other) + /// Merge two graphs. + pub fn merge(self, other: Self) -> Self { + if Arc::ptr_eq(&self.steps, &other.steps) { + return self; } - fn execute_mut(mut self, func: F) -> Self { - match Arc::get_mut(&mut self.steps) { - Some(mutex) => { - let map = mutex.get_mut(); - func(map); - } - None => { - // Only lock when there are multiple references to the graph. - let mut map = self.steps.lock(); - func(&mut map); - } - }; + self.merge_different(other) + } - self - } + fn execute_mut(mut self, func: F) -> Self { + match Arc::get_mut(&mut self.steps) { + Some(mutex) => { + let map = mutex.get_mut(); + func(map); + } + None => { + // Only lock when there are multiple references to the graph. + let mut map = self.steps.lock(); + func(&mut map); + } + }; - fn merge_different(self, other: Self) -> Self { - let mut map2 = other.steps(); + self + } - self.execute_mut(|map1| { - if map1.len() > map2.len() { - map1.extend(map2); - } else { - let mut map_drain = HashMap::new(); - std::mem::swap(map1, &mut map_drain); - map2.extend(map_drain); - std::mem::swap(map1, &mut map2); - } - }) - } + fn merge_different(self, other: Self) -> Self { + let mut map2 = other.steps(); + + self.execute_mut(|map1| { + if map1.len() > map2.len() { + map1.extend(map2); + } else { + let mut map_drain = HashMap::new(); + std::mem::swap(map1, &mut map_drain); + map2.extend(map_drain); + std::mem::swap(map1, &mut map2); + } + }) + } } diff --git a/burn-autodiff/src/graph/node.rs b/burn-autodiff/src/graph/node.rs index 38665408ec..7c448742b3 100644 --- a/burn-autodiff/src/graph/node.rs +++ b/burn-autodiff/src/graph/node.rs @@ -6,43 +6,43 @@ use super::Requirement; /// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning. #[derive(new, Debug)] pub struct Node { - pub parents: Vec, - pub order: usize, - pub id: NodeID, - pub requirement: Requirement, + pub parents: Vec, + pub order: usize, + pub id: NodeID, + pub requirement: Requirement, } pub type NodeRef = Arc; impl Node { - /// Returns the [node](Node) only if gradients are required. - pub fn clone_if_require_grad(self: &Arc) -> Option { - match self.requirement.is_none() { - true => None, - false => Some(self.clone()), - } + /// Returns the [node](Node) only if gradients are required. + pub fn clone_if_require_grad(self: &Arc) -> Option { + match self.requirement.is_none() { + true => None, + false => Some(self.clone()), } + } } /// Unique identifier generated for each [node](Node). #[derive(Clone, Hash, PartialEq, Eq, Debug)] pub struct NodeID { - pub value: u64, + pub value: u64, } impl NodeID { - /// Create a unique [node id](NodeID). - pub fn new() -> Self { - static COUNTER: AtomicU64 = AtomicU64::new(0); - let value = COUNTER.fetch_add(1, Ordering::Relaxed); - if value == u64::MAX { - panic!("NodeID overflowed"); - } - Self { value } + /// Create a unique [node id](NodeID). + pub fn new() -> Self { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let value = COUNTER.fetch_add(1, Ordering::Relaxed); + if value == u64::MAX { + panic!("NodeID overflowed"); } + Self { value } + } } impl Default for NodeID { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } diff --git a/burn-autodiff/src/graph/requirement.rs b/burn-autodiff/src/graph/requirement.rs index c2825ff131..9d405b9562 100644 --- a/burn-autodiff/src/graph/requirement.rs +++ b/burn-autodiff/src/graph/requirement.rs @@ -3,32 +3,32 @@ use super::NodeRef; /// Requirement for each tensor in the graph. #[derive(Debug, Clone, Copy)] pub enum Requirement { - /// Operations that require gradients. - Grad, - /// Operations that require gradients only for backprop. - GradInBackward, - /// Operations that don't need gradients, therefore not to be included in the graph. - None, + /// Operations that require gradients. + Grad, + /// Operations that require gradients only for backprop. + GradInBackward, + /// Operations that don't need gradients, therefore not to be included in the graph. + None, } impl Requirement { - /// Returns true if gradients are not required. - pub fn is_none(&self) -> bool { - matches!(self, Self::None) - } - /// Returns the right requirement from a list of nodes. - pub fn from_nodes(nodes: &[NodeRef]) -> Self { - nodes - .iter() - .map(|node| node.requirement) - .reduce(|acc, requirement| requirement.infer(&acc)) - .unwrap_or(Requirement::None) - } + /// Returns true if gradients are not required. + pub fn is_none(&self) -> bool { + matches!(self, Self::None) + } + /// Returns the right requirement from a list of nodes. + pub fn from_nodes(nodes: &[NodeRef]) -> Self { + nodes + .iter() + .map(|node| node.requirement) + .reduce(|acc, requirement| requirement.infer(&acc)) + .unwrap_or(Requirement::None) + } - fn infer(&self, other: &Self) -> Self { - match self.is_none() && other.is_none() { - true => Self::None, - false => Self::GradInBackward, - } + fn infer(&self, other: &Self) -> Self { + match self.is_none() && other.is_none() { + true => Self::None, + false => Self::GradInBackward, } + } } diff --git a/burn-autodiff/src/graph/traversal.rs b/burn-autodiff/src/graph/traversal.rs index eefebe2b78..de9e10ed3b 100644 --- a/burn-autodiff/src/graph/traversal.rs +++ b/burn-autodiff/src/graph/traversal.rs @@ -6,45 +6,45 @@ use super::{Graph, NodeRef, StepBoxed}; pub struct BreadthFirstSearch; impl BreadthFirstSearch { - /// Traverse the graph of backward steps from a root node. - pub fn traverse( - &self, - root: NodeRef, - graph: Graph, - mut callback: F, - ) { - let mut visited = HashSet::with_capacity(root.order); - let mut parents = Vec::with_capacity(root.order); - let mut steps = graph.steps(); - let root_step = steps + /// Traverse the graph of backward steps from a root node. + pub fn traverse( + &self, + root: NodeRef, + graph: Graph, + mut callback: F, + ) { + let mut visited = HashSet::with_capacity(root.order); + let mut parents = Vec::with_capacity(root.order); + let mut steps = graph.steps(); + let root_step = steps .remove(&root.id) .expect("Root node should have a step registered, did you forget to call `Tensor::register_grad` on the tensor where you need gradients?"); - visited.insert(root.id.clone()); - parents.append(&mut root.parents.clone()); - callback(root, root_step); + visited.insert(root.id.clone()); + parents.append(&mut root.parents.clone()); + callback(root, root_step); - while let Some(id) = parents.pop() { - let step = match steps.remove(&id) { - Some(step) => step, - None => continue, - }; + while let Some(id) = parents.pop() { + let step = match steps.remove(&id) { + Some(step) => step, + None => continue, + }; - let node = step.node(); + let node = step.node(); - if visited.contains(&node.id) { - continue; - } + if visited.contains(&node.id) { + continue; + } - visited.insert(node.id.clone()); + visited.insert(node.id.clone()); - for id in node.parents.iter() { - if !visited.contains(id) { - parents.push(id.clone()); - } - } - - callback(node, step); + for id in node.parents.iter() { + if !visited.contains(id) { + parents.push(id.clone()); } + } + + callback(node, step); } + } } diff --git a/burn-autodiff/src/ops/activation.rs b/burn-autodiff/src/ops/activation.rs index 0a34499f28..965f13a64e 100644 --- a/burn-autodiff/src/ops/activation.rs +++ b/burn-autodiff/src/ops/activation.rs @@ -1,57 +1,57 @@ use crate::{ - grads::Gradients, - ops::{unary, Backward, Ops, OpsKind}, - Autodiff, + grads::Gradients, + ops::{unary, Backward, Ops, OpsKind}, + Autodiff, }; use burn_tensor::{ - backend::Backend, - ops::{ActivationOps, FloatTensor}, + backend::Backend, + ops::{ActivationOps, FloatTensor}, }; impl ActivationOps> for Autodiff { - fn gelu(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Gelu; - - impl Backward for Gelu { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let input = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - B::gelu_backward(input, grad) - }); - } - } - - match Gelu::.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::gelu(tensor.primitive.clone()); - prep.finish(tensor.primitive, output) - } - OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)), - } + fn gelu(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Gelu; + + impl Backward for Gelu { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let input = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + B::gelu_backward(input, grad) + }); + } + } + + match Gelu::.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::gelu(tensor.primitive.clone()); + prep.finish(tensor.primitive, output) + } + OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)), + } + } + + fn relu(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Relu; + + impl Backward for Relu { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::relu_backward(ops.state, grad) + }); + } } + let output = B::relu(tensor.primitive); - fn relu(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Relu; - - impl Backward for Relu { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::relu_backward(ops.state, grad) - }); - } - } - let output = B::relu(tensor.primitive); - - match Relu.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(output.clone(), output), - OpsKind::UnTracked(prep) => prep.finish(output), - } + match Relu.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(output.clone(), output), + OpsKind::UnTracked(prep) => prep.finish(output), } + } } diff --git a/burn-autodiff/src/ops/backward.rs b/burn-autodiff/src/ops/backward.rs index 53e251c4f7..dd2a71954f 100644 --- a/burn-autodiff/src/ops/backward.rs +++ b/burn-autodiff/src/ops/backward.rs @@ -1,8 +1,8 @@ use super::{Ops, OpsPrep}; use crate::{ - grads::Gradients, - graph::{Graph, NodeRef, Requirement}, - utils::duplicate, + grads::Gradients, + graph::{Graph, NodeRef, Requirement}, + utils::duplicate, }; use burn_tensor::backend::Backend; @@ -15,88 +15,84 @@ use burn_tensor::backend::Backend; /// they should be declared with the associated type 'State'. pub trait Backward: Send + Sync + std::fmt::Debug where - Self: Sized + 'static, - B: Backend, + Self: Sized + 'static, + B: Backend, { - /// Associated type to compute the backward pass. - type State: Clone + Send + Sync + std::fmt::Debug + 'static; + /// Associated type to compute the backward pass. + type State: Clone + Send + Sync + std::fmt::Debug + 'static; - /// The backward pass. - fn backward(self, ops: Ops, grads: &mut Gradients); + /// The backward pass. + fn backward(self, ops: Ops, grads: &mut Gradients); - /// Prepare the backward ops. - fn prepare( - self, - nodes: [NodeRef; N], - graphs: [Graph; N], - ) -> OpsPrep { - let requirement = Requirement::from_nodes(&nodes); - OpsPrep::new(nodes, graphs, requirement, self) - } + /// Prepare the backward ops. + fn prepare(self, nodes: [NodeRef; N], graphs: [Graph; N]) -> OpsPrep { + let requirement = Requirement::from_nodes(&nodes); + OpsPrep::new(nodes, graphs, requirement, self) + } } /// Execute a binary operation during the backward step. pub fn binary( - parents: [Option; 2], - node: NodeRef, - grads: &mut Gradients, - func_lhs: FLhs, - func_rhs: FRhs, + parents: [Option; 2], + node: NodeRef, + grads: &mut Gradients, + func_lhs: FLhs, + func_rhs: FRhs, ) where - B: Backend, - FLhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, - FRhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, + B: Backend, + FLhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, + FRhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, { - let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::(&node))); - let [node_lhs, node_rhs] = parents; + let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::(&node))); + let [node_lhs, node_rhs] = parents; - if let Some(node) = node_lhs { - let grad = func_lhs(grad_4lhs.unwrap()); - grads.register::(node, grad) - } + if let Some(node) = node_lhs { + let grad = func_lhs(grad_4lhs.unwrap()); + grads.register::(node, grad) + } - if let Some(node) = node_rhs { - let grad = func_rhs(grad_4rhs.unwrap()); - grads.register::(node, grad) - } + if let Some(node) = node_rhs { + let grad = func_rhs(grad_4rhs.unwrap()); + grads.register::(node, grad) + } } /// Execute a unary operation during the backward step. pub fn unary( - parents: [Option; 1], - node: NodeRef, - grads: &mut Gradients, - func: F, + parents: [Option; 1], + node: NodeRef, + grads: &mut Gradients, + func: F, ) where - B: Backend, - F: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, + B: Backend, + F: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, { - let [parent_node] = parents; - let grad = grads.consume::(&node); + let [parent_node] = parents; + let grad = grads.consume::(&node); - if let Some(node) = parent_node { - let grad = func(grad); - grads.register::(node, grad) - } + if let Some(node) = parent_node { + let grad = func(grad); + grads.register::(node, grad) + } } /// Execute a unary operation during the backward step where the input backend /// is different from the output backend. pub fn unary_different_backend( - parents: [Option; 1], - node: NodeRef, - grads: &mut Gradients, - func: F, + parents: [Option; 1], + node: NodeRef, + grads: &mut Gradients, + func: F, ) where - BIn: Backend, - BOut: Backend, - F: FnOnce(BOut::TensorPrimitive) -> BIn::TensorPrimitive, + BIn: Backend, + BOut: Backend, + F: FnOnce(BOut::TensorPrimitive) -> BIn::TensorPrimitive, { - let [parent_node] = parents; - let grad = grads.consume::(&node); + let [parent_node] = parents; + let grad = grads.consume::(&node); - if let Some(node) = parent_node { - let grad = func(grad); - grads.register::(node, grad) - } + if let Some(node) = parent_node { + let grad = func(grad); + grads.register::(node, grad) + } } diff --git a/burn-autodiff/src/ops/base.rs b/burn-autodiff/src/ops/base.rs index d212525943..2081d49473 100644 --- a/burn-autodiff/src/ops/base.rs +++ b/burn-autodiff/src/ops/base.rs @@ -1,10 +1,10 @@ use super::Backward; use crate::{ - grads::Gradients, - graph::{ - NodeRef, Requirement, {Graph, Step}, - }, - tensor::AutodiffTensor, + grads::Gradients, + graph::{ + NodeRef, Requirement, {Graph, Step}, + }, + tensor::AutodiffTensor, }; use burn_tensor::{backend::Backend, Shape}; use std::marker::PhantomData; @@ -15,13 +15,13 @@ use std::marker::PhantomData; /// Each mode has its own set of functions to minimize cloning for unused backward states. #[derive(new)] pub struct OpsPrep { - nodes: [NodeRef; N], - graphs: [Graph; N], - requirement: Requirement, - backward: Backward, - phantom_backend: PhantomData, - phantom_state: PhantomData, - marker: PhantomData, + nodes: [NodeRef; N], + graphs: [Graph; N], + requirement: Requirement, + backward: Backward, + phantom_backend: PhantomData, + phantom_state: PhantomData, + marker: PhantomData, } /// Init operation tag. @@ -33,130 +33,130 @@ pub struct UnTracked; impl OpsPrep where - B: Backend, - BO: Backward, + B: Backend, + BO: Backward, { - /// Prepare a stateless operation. - pub fn stateless(self, output: ::TensorPrimitive) -> AutodiffTensor { - match self.stateful() { - OpsKind::Tracked(prep) => prep.finish((), output), - OpsKind::UnTracked(prep) => prep.finish(output), - } + /// Prepare a stateless operation. + pub fn stateless(self, output: ::TensorPrimitive) -> AutodiffTensor { + match self.stateful() { + OpsKind::Tracked(prep) => prep.finish((), output), + OpsKind::UnTracked(prep) => prep.finish(output), } + } } impl OpsPrep where - B: Backend, - S: Clone + Send + Sync + std::fmt::Debug + 'static, - BO: Backward, + B: Backend, + S: Clone + Send + Sync + std::fmt::Debug + 'static, + BO: Backward, { - /// Prepare an operation that requires a state during the backward pass. - pub fn stateful(self) -> OpsKind { - match self.requirement.is_none() { - false => OpsKind::Tracked(OpsPrep::new( - self.nodes, - self.graphs, - self.requirement, - self.backward, - )), - true => OpsKind::UnTracked(OpsPrep::new( - self.nodes, - self.graphs, - self.requirement, - self.backward, - )), - } + /// Prepare an operation that requires a state during the backward pass. + pub fn stateful(self) -> OpsKind { + match self.requirement.is_none() { + false => OpsKind::Tracked(OpsPrep::new( + self.nodes, + self.graphs, + self.requirement, + self.backward, + )), + true => OpsKind::UnTracked(OpsPrep::new( + self.nodes, + self.graphs, + self.requirement, + self.backward, + )), } + } } impl OpsPrep where - B: Backend, - S: Clone + Send + Sync + std::fmt::Debug + 'static, - BO: Backward, + B: Backend, + S: Clone + Send + Sync + std::fmt::Debug + 'static, + BO: Backward, { - /// Finish the preparation of an untracked operation and returns the output tensor. - pub fn finish(self, output: ::TensorPrimitive) -> AutodiffTensor { - AutodiffTensor::from_parents( - output, - &self.nodes, - self.graphs.into_iter(), - self.requirement, - ) - } + /// Finish the preparation of an untracked operation and returns the output tensor. + pub fn finish(self, output: ::TensorPrimitive) -> AutodiffTensor { + AutodiffTensor::from_parents( + output, + &self.nodes, + self.graphs.into_iter(), + self.requirement, + ) + } } impl OpsPrep where - B: Backend, - S: Clone + Send + Sync + std::fmt::Debug + 'static, - BO: Backward, + B: Backend, + S: Clone + Send + Sync + std::fmt::Debug + 'static, + BO: Backward, { - /// Finish the preparation of a tracked operation and returns the output tensor. - pub fn finish( - self, - state: S, - output: ::TensorPrimitive, - ) -> AutodiffTensor { - let output = AutodiffTensor::from_parents( - output, - &self.nodes, - self.graphs.into_iter(), - self.requirement, - ); - let parents = self.nodes.map(|node| node.clone_if_require_grad()); - let ops = Ops::new(parents, output.node.clone(), state); + /// Finish the preparation of a tracked operation and returns the output tensor. + pub fn finish( + self, + state: S, + output: ::TensorPrimitive, + ) -> AutodiffTensor { + let output = AutodiffTensor::from_parents( + output, + &self.nodes, + self.graphs.into_iter(), + self.requirement, + ); + let parents = self.nodes.map(|node| node.clone_if_require_grad()); + let ops = Ops::new(parents, output.node.clone(), state); - output.register_step(OpsStep::new(ops, self.backward)) - } + output.register_step(OpsStep::new(ops, self.backward)) + } } /// Enum used before finishing tracked and untracked operations. pub enum OpsKind { - /// Tracked operation preparation. - Tracked(OpsPrep), - /// Untracked operation preparation. - UnTracked(OpsPrep), + /// Tracked operation preparation. + Tracked(OpsPrep), + /// Untracked operation preparation. + UnTracked(OpsPrep), } /// Operation containing its parent nodes, its own node and the backward step state. #[derive(new, Debug)] pub struct Ops { - /// Parents nodes. - pub parents: [Option; N], - /// The node. - pub node: NodeRef, - /// The state. - pub state: S, + /// Parents nodes. + pub parents: [Option; N], + /// The node. + pub node: NodeRef, + /// The state. + pub state: S, } /// Operation implementing backward [step](Step) with type erasing. #[derive(new, Debug)] struct OpsStep where - B: Backend, - T: Backward, - SB: Clone + Send + Sync + std::fmt::Debug + 'static, + B: Backend, + T: Backward, + SB: Clone + Send + Sync + std::fmt::Debug + 'static, { - ops: Ops, - backward: T, - phantom: PhantomData, + ops: Ops, + backward: T, + phantom: PhantomData, } impl Step for OpsStep where - B: Backend, - T: Backward, - SB: Clone + Send + Sync + std::fmt::Debug + 'static, + B: Backend, + T: Backward, + SB: Clone + Send + Sync + std::fmt::Debug + 'static, { - fn step(self: Box, grads: &mut Gradients) { - self.backward.backward(self.ops, grads); - } + fn step(self: Box, grads: &mut Gradients) { + self.backward.backward(self.ops, grads); + } - fn node(&self) -> NodeRef { - self.ops.node.clone() - } + fn node(&self) -> NodeRef { + self.ops.node.clone() + } } /// Make sure the grad tensor has the given shape. @@ -164,22 +164,22 @@ where /// If broadcasting happened during the forward pass, the gradients will be sum along the /// broadcasted dimension. pub fn broadcast_shape( - mut grad: B::TensorPrimitive, - shape: &Shape, + mut grad: B::TensorPrimitive, + shape: &Shape, ) -> B::TensorPrimitive { - let shape_grad = B::shape(&grad); + let shape_grad = B::shape(&grad); - for i in 0..D { - if shape_grad.dims[i] != shape.dims[i] { - if shape.dims[i] != 1 { - panic!( - "Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}", - shape.dims, shape_grad.dims, "Expected the shape of the next grad to be 1." - ); - } - grad = B::sum_dim(grad, i); - } + for i in 0..D { + if shape_grad.dims[i] != shape.dims[i] { + if shape.dims[i] != 1 { + panic!( + "Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}", + shape.dims, shape_grad.dims, "Expected the shape of the next grad to be 1." + ); + } + grad = B::sum_dim(grad, i); } + } - grad + grad } diff --git a/burn-autodiff/src/ops/bool_tensor.rs b/burn-autodiff/src/ops/bool_tensor.rs index 59b342f840..bc5685d642 100644 --- a/burn-autodiff/src/ops/bool_tensor.rs +++ b/burn-autodiff/src/ops/bool_tensor.rs @@ -1,95 +1,92 @@ use crate::{tensor::AutodiffTensor, Autodiff}; use burn_tensor::{ - backend::Backend, - ops::{BoolTensor, BoolTensorOps, IntTensor}, - Data, Device, Reader, Shape, + backend::Backend, + ops::{BoolTensor, BoolTensorOps, IntTensor}, + Data, Device, Reader, Shape, }; impl BoolTensorOps for Autodiff { - fn bool_from_data(data: Data, device: &Device) -> BoolTensor { - B::bool_from_data(data, device) - } - - fn bool_shape(tensor: &BoolTensor) -> Shape { - B::bool_shape(tensor) - } - - fn bool_to_data(tensor: &BoolTensor) -> Reader> { - B::bool_to_data(tensor) - } - - fn bool_into_data(tensor: BoolTensor) -> Reader> { - B::bool_into_data(tensor) - } - - fn bool_into_int(tensor: BoolTensor) -> IntTensor { - B::bool_into_int(tensor) - } - - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - B::bool_to_device(tensor, device) - } - - fn bool_device(tensor: &BoolTensor) -> Device { - B::bool_device(tensor) - } - - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - B::bool_reshape(tensor, shape) - } - - fn bool_slice( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - ) -> BoolTensor { - B::bool_slice(tensor, ranges) - } - - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - B::bool_empty(shape, device) - } - - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - value: BoolTensor, - ) -> BoolTensor { - B::bool_slice_assign(tensor, ranges, value) - } - - fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { - B::bool_cat(tensors, dim) - } - - fn bool_equal( - lhs: BoolTensor, - rhs: BoolTensor, - ) -> BoolTensor { - B::bool_equal(lhs, rhs) - } - - fn bool_not(tensor: BoolTensor) -> BoolTensor { - B::bool_not(tensor) - } - - fn bool_into_float( - tensor: BoolTensor, - ) -> as Backend>::TensorPrimitive { - AutodiffTensor::new(B::bool_into_float(tensor)) - } - - fn bool_swap_dims( - tensor: as Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::BoolTensorPrimitive { - B::bool_swap_dims(tensor, dim1, dim2) - } + fn bool_from_data(data: Data, device: &Device) -> BoolTensor { + B::bool_from_data(data, device) + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + B::bool_shape(tensor) + } + + fn bool_to_data(tensor: &BoolTensor) -> Reader> { + B::bool_to_data(tensor) + } + + fn bool_into_data(tensor: BoolTensor) -> Reader> { + B::bool_into_data(tensor) + } + + fn bool_into_int(tensor: BoolTensor) -> IntTensor { + B::bool_into_int(tensor) + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + B::bool_to_device(tensor, device) + } + + fn bool_device(tensor: &BoolTensor) -> Device { + B::bool_device(tensor) + } + + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + B::bool_reshape(tensor, shape) + } + + fn bool_slice( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + ) -> BoolTensor { + B::bool_slice(tensor, ranges) + } + + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + B::bool_empty(shape, device) + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + value: BoolTensor, + ) -> BoolTensor { + B::bool_slice_assign(tensor, ranges, value) + } + + fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { + B::bool_cat(tensors, dim) + } + + fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + B::bool_equal(lhs, rhs) + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + B::bool_not(tensor) + } + + fn bool_into_float( + tensor: BoolTensor, + ) -> as Backend>::TensorPrimitive { + AutodiffTensor::new(B::bool_into_float(tensor)) + } + + fn bool_swap_dims( + tensor: as Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::BoolTensorPrimitive { + B::bool_swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-autodiff/src/ops/int_tensor.rs b/burn-autodiff/src/ops/int_tensor.rs index a7af0420e1..48d6ab460c 100644 --- a/burn-autodiff/src/ops/int_tensor.rs +++ b/burn-autodiff/src/ops/int_tensor.rs @@ -1,319 +1,319 @@ use crate::{tensor::AutodiffTensor, Autodiff}; use burn_tensor::{ - backend::Backend, - ops::{BoolTensor, IntTensor, IntTensorOps}, - Data, Device, Reader, Shape, + backend::Backend, + ops::{BoolTensor, IntTensor, IntTensorOps}, + Data, Device, Reader, Shape, }; impl IntTensorOps> for Autodiff { - fn int_from_data( - data: Data, - device: &Device, - ) -> IntTensor { - B::int_from_data(data, device) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - B::int_shape(tensor) - } - - fn int_to_data(tensor: &IntTensor) -> Reader> { - B::int_to_data(tensor) - } - - fn int_into_data(tensor: IntTensor) -> Reader> { - B::int_into_data(tensor) - } - - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - B::int_to_device(tensor, device) - } - - fn int_device(tensor: &IntTensor) -> Device { - B::int_device(tensor) - } - - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - B::int_reshape(tensor, shape) - } - - fn int_slice( - tensor: IntTensor, - ranges: [std::ops::Range; D2], - ) -> IntTensor { - B::int_slice(tensor, ranges) - } - - fn int_empty( - shape: Shape, - device: & as Backend>::Device, - ) -> IntTensor { - B::int_empty(shape, device) - } - - fn int_slice_assign( - tensor: IntTensor, - ranges: [std::ops::Range; D2], - value: IntTensor, - ) -> IntTensor { - B::int_slice_assign(tensor, ranges, value) - } - - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - B::int_cat(tensors, dim) - } - - fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - B::int_equal(lhs, rhs) - } - - fn int_equal_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { - B::int_equal_elem(lhs, rhs) - } - - fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_add(lhs, rhs) - } - - fn int_add_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_add_scalar(lhs, rhs) - } - - fn int_clamp_min(tensor: IntTensor, min: B::IntElem) -> IntTensor { - B::int_clamp_min(tensor, min) - } - - fn int_clamp_max(tensor: IntTensor, max: B::IntElem) -> IntTensor { - B::int_clamp_max(tensor, max) - } - - fn int_clamp( - tensor: IntTensor, - min: B::IntElem, - max: B::IntElem, - ) -> IntTensor { - B::int_clamp(tensor, min, max) - } - - fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_sub(lhs, rhs) - } - - fn int_sub_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_sub_scalar(lhs, rhs) - } - - fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_mul(lhs, rhs) - } - - fn int_mul_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_mul_scalar(lhs, rhs) - } - - fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_div(lhs, rhs) - } - - fn int_div_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_div_scalar(lhs, rhs) - } - - fn int_neg(tensor: IntTensor) -> IntTensor { - B::int_neg(tensor) - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - B::int_zeros(shape, device) - } - - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - B::int_ones(shape, device) - } - - fn int_full( - shape: Shape, - fill_value: B::IntElem, - device: &Device, - ) -> IntTensor { - B::int_full(shape, fill_value, device) - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - B::int_sum(tensor) - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_sum_dim(tensor, dim) - } - - fn int_mean(tensor: IntTensor) -> IntTensor { - B::int_mean(tensor) - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_mean_dim(tensor, dim) - } - - fn int_repeat( - tensor: IntTensor, - dim: usize, - times: usize, - ) -> IntTensor { - B::int_repeat(tensor, dim, times) - } - - fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - B::int_greater(lhs, rhs) - } - - fn int_greater_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { - B::int_greater_elem(lhs, rhs) - } - - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - B::int_greater_equal(lhs, rhs) - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: B::IntElem, - ) -> BoolTensor { - B::int_greater_equal_elem(lhs, rhs) - } - - fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - B::int_lower(lhs, rhs) - } - - fn int_lower_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { - B::int_lower_elem(lhs, rhs) - } - - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - B::int_lower_equal(lhs, rhs) - } - - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: B::IntElem, - ) -> BoolTensor { - B::int_lower_equal_elem(lhs, rhs) - } - - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - B::int_gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - B::int_scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - B::int_select(tensor, dim, indices) - } - - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - B::int_select_assign(tensor, dim, indices, value) - } - - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - value: IntTensor, - ) -> as Backend>::IntTensorPrimitive { - B::int_mask_where(tensor, mask, value) - } - - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: B::IntElem, - ) -> as Backend>::IntTensorPrimitive { - B::int_mask_fill(tensor, mask, value) - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_argmax(tensor, dim) - } - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_argmin(tensor, dim) - } - fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { - B::int_max(tensor) - } - fn int_max_dim( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> B::IntTensorPrimitive { - B::int_max_dim(tensor, dim) - } - fn int_max_dim_with_indices( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { - B::int_max_dim_with_indices(tensor, dim) - } - fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { - B::int_min(tensor) - } - fn int_min_dim( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> B::IntTensorPrimitive { - B::int_min_dim(tensor, dim) - } - fn int_min_dim_with_indices( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { - B::int_min_dim_with_indices(tensor, dim) - } - fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive { - B::int_abs(tensor) - } - fn int_into_float( - tensor: as Backend>::IntTensorPrimitive, - ) -> as Backend>::TensorPrimitive { - AutodiffTensor::new(B::int_into_float(tensor)) - } - - fn int_swap_dims( - tensor: as Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::IntTensorPrimitive { - B::int_swap_dims(tensor, dim1, dim2) - } + fn int_from_data( + data: Data, + device: &Device, + ) -> IntTensor { + B::int_from_data(data, device) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + B::int_shape(tensor) + } + + fn int_to_data(tensor: &IntTensor) -> Reader> { + B::int_to_data(tensor) + } + + fn int_into_data(tensor: IntTensor) -> Reader> { + B::int_into_data(tensor) + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + B::int_to_device(tensor, device) + } + + fn int_device(tensor: &IntTensor) -> Device { + B::int_device(tensor) + } + + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + B::int_reshape(tensor, shape) + } + + fn int_slice( + tensor: IntTensor, + ranges: [std::ops::Range; D2], + ) -> IntTensor { + B::int_slice(tensor, ranges) + } + + fn int_empty( + shape: Shape, + device: & as Backend>::Device, + ) -> IntTensor { + B::int_empty(shape, device) + } + + fn int_slice_assign( + tensor: IntTensor, + ranges: [std::ops::Range; D2], + value: IntTensor, + ) -> IntTensor { + B::int_slice_assign(tensor, ranges, value) + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + B::int_cat(tensors, dim) + } + + fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + B::int_equal(lhs, rhs) + } + + fn int_equal_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { + B::int_equal_elem(lhs, rhs) + } + + fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_add(lhs, rhs) + } + + fn int_add_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_add_scalar(lhs, rhs) + } + + fn int_clamp_min(tensor: IntTensor, min: B::IntElem) -> IntTensor { + B::int_clamp_min(tensor, min) + } + + fn int_clamp_max(tensor: IntTensor, max: B::IntElem) -> IntTensor { + B::int_clamp_max(tensor, max) + } + + fn int_clamp( + tensor: IntTensor, + min: B::IntElem, + max: B::IntElem, + ) -> IntTensor { + B::int_clamp(tensor, min, max) + } + + fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_sub(lhs, rhs) + } + + fn int_sub_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_sub_scalar(lhs, rhs) + } + + fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_mul(lhs, rhs) + } + + fn int_mul_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_mul_scalar(lhs, rhs) + } + + fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_div(lhs, rhs) + } + + fn int_div_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_div_scalar(lhs, rhs) + } + + fn int_neg(tensor: IntTensor) -> IntTensor { + B::int_neg(tensor) + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + B::int_zeros(shape, device) + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + B::int_ones(shape, device) + } + + fn int_full( + shape: Shape, + fill_value: B::IntElem, + device: &Device, + ) -> IntTensor { + B::int_full(shape, fill_value, device) + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + B::int_sum(tensor) + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_sum_dim(tensor, dim) + } + + fn int_mean(tensor: IntTensor) -> IntTensor { + B::int_mean(tensor) + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_mean_dim(tensor, dim) + } + + fn int_repeat( + tensor: IntTensor, + dim: usize, + times: usize, + ) -> IntTensor { + B::int_repeat(tensor, dim, times) + } + + fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + B::int_greater(lhs, rhs) + } + + fn int_greater_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { + B::int_greater_elem(lhs, rhs) + } + + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + B::int_greater_equal(lhs, rhs) + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: B::IntElem, + ) -> BoolTensor { + B::int_greater_equal_elem(lhs, rhs) + } + + fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + B::int_lower(lhs, rhs) + } + + fn int_lower_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { + B::int_lower_elem(lhs, rhs) + } + + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + B::int_lower_equal(lhs, rhs) + } + + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: B::IntElem, + ) -> BoolTensor { + B::int_lower_equal_elem(lhs, rhs) + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + B::int_gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + B::int_scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + B::int_select(tensor, dim, indices) + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + B::int_select_assign(tensor, dim, indices, value) + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + value: IntTensor, + ) -> as Backend>::IntTensorPrimitive { + B::int_mask_where(tensor, mask, value) + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: B::IntElem, + ) -> as Backend>::IntTensorPrimitive { + B::int_mask_fill(tensor, mask, value) + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_argmax(tensor, dim) + } + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_argmin(tensor, dim) + } + fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { + B::int_max(tensor) + } + fn int_max_dim( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> B::IntTensorPrimitive { + B::int_max_dim(tensor, dim) + } + fn int_max_dim_with_indices( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { + B::int_max_dim_with_indices(tensor, dim) + } + fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { + B::int_min(tensor) + } + fn int_min_dim( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> B::IntTensorPrimitive { + B::int_min_dim(tensor, dim) + } + fn int_min_dim_with_indices( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { + B::int_min_dim_with_indices(tensor, dim) + } + fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive { + B::int_abs(tensor) + } + fn int_into_float( + tensor: as Backend>::IntTensorPrimitive, + ) -> as Backend>::TensorPrimitive { + AutodiffTensor::new(B::int_into_float(tensor)) + } + + fn int_swap_dims( + tensor: as Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::IntTensorPrimitive { + B::int_swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-autodiff/src/ops/maxmin.rs b/burn-autodiff/src/ops/maxmin.rs index 3371c03eb2..4e788e3681 100644 --- a/burn-autodiff/src/ops/maxmin.rs +++ b/burn-autodiff/src/ops/maxmin.rs @@ -6,15 +6,15 @@ use burn_tensor::{backend::Backend, Shape}; pub(crate) struct MaxMinDim; impl Backward for MaxMinDim { - type State = (B::IntTensorPrimitive, Shape); + type State = (B::IntTensorPrimitive, Shape); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let (indices, shape) = ops.state; - let device = B::device(&grad); - let zeros = B::zeros(shape, &device); + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let (indices, shape) = ops.state; + let device = B::device(&grad); + let zeros = B::zeros(shape, &device); - B::scatter(D - 1, zeros, indices, grad) - }); - } + B::scatter(D - 1, zeros, indices, grad) + }); + } } diff --git a/burn-autodiff/src/ops/module.rs b/burn-autodiff/src/ops/module.rs index 9da1440f3f..7a4d3bd56e 100644 --- a/burn-autodiff/src/ops/module.rs +++ b/burn-autodiff/src/ops/module.rs @@ -9,948 +9,899 @@ use burn_tensor::ops::*; use super::OpsKind; impl ModuleOps> for Autodiff { - fn embedding(weights: AutodiffTensor, indices: IntTensor) -> AutodiffTensor { - #[derive(Debug)] - struct Embedding; + fn embedding(weights: AutodiffTensor, indices: IntTensor) -> AutodiffTensor { + #[derive(Debug)] + struct Embedding; - impl Backward for Embedding { - type State = (B::TensorPrimitive<2>, IntTensor); + impl Backward for Embedding { + type State = (B::TensorPrimitive<2>, IntTensor); - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (weights, indices) = ops.state; + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (weights, indices) = ops.state; - unary::(ops.parents, ops.node, grads, |grad| { - B::embedding_backward(weights, grad, indices) - }); - } + unary::(ops.parents, ops.node, grads, |grad| { + B::embedding_backward(weights, grad, indices) + }); + } + } + + match Embedding + .prepare([weights.node], [weights.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + (weights.primitive.clone(), indices.clone()), + B::embedding(weights.primitive, indices), + ), + OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)), + } + } + + fn embedding_backward( + _weights: AutodiffTensor, + _output: AutodiffTensor, + _indices: IntTensor, + ) -> AutodiffTensor { + panic!("Can't differentiate embedding backward."); + } + + fn conv2d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct Conv2DWithBias; + #[derive(Debug)] + struct Conv2DNoBias; + + impl Backward for Conv2DWithBias { + type State = ( + B::TensorPrimitive<4>, + B::TensorPrimitive<4>, + B::TensorPrimitive<1>, + ConvOptions<2>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, bias, options) = ops.state; + let backward = B::conv2d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) + } + } + } + + impl Backward for Conv2DNoBias { + type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, ConvOptions<2>); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv2d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + } + } - match Embedding - .prepare([weights.node], [weights.graph]) - .stateful() + match bias { + Some(bias) => { + match Conv2DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() { - OpsKind::Tracked(prep) => prep.finish( - (weights.primitive.clone(), indices.clone()), - B::embedding(weights.primitive, indices), + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), ), - OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)), + B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv2d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), } + } + None => { + match Conv2DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv2d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::conv2d(x.primitive, weight.primitive, None, options)) + } + } + } } + } + + fn conv_transpose2d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct ConvTranspose2DWithBias; + #[derive(Debug)] + struct ConvTranspose2DNoBias; + + impl Backward for ConvTranspose2DWithBias { + type State = ( + B::TensorPrimitive<4>, + B::TensorPrimitive<4>, + B::TensorPrimitive<1>, + ConvTransposeOptions<2>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); - fn embedding_backward( - _weights: AutodiffTensor, - _output: AutodiffTensor, - _indices: IntTensor, - ) -> AutodiffTensor { - panic!("Can't differentiate embedding backward."); + let (x, weight, bias, options) = ops.state; + let backward = B::conv_transpose2d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) + } + } } - fn conv2d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct Conv2DWithBias; - #[derive(Debug)] - struct Conv2DNoBias; - - impl Backward for Conv2DWithBias { - type State = ( - B::TensorPrimitive<4>, - B::TensorPrimitive<4>, - B::TensorPrimitive<1>, - ConvOptions<2>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, bias, options) = ops.state; - let backward = B::conv2d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) - } - } - } - - impl Backward for Conv2DNoBias { - type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, ConvOptions<2>); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv2d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - } - } - - match bias { - Some(bias) => { - match Conv2DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), - ), - B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv2d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), - } - } - None => { - match Conv2DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv2d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::conv2d(x.primitive, weight.primitive, None, options)) - } - } - } + impl Backward for ConvTranspose2DNoBias { + type State = ( + B::TensorPrimitive<4>, + B::TensorPrimitive<4>, + ConvTransposeOptions<2>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv_transpose2d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + } } - fn conv_transpose2d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct ConvTranspose2DWithBias; - #[derive(Debug)] - struct ConvTranspose2DNoBias; - - impl Backward for ConvTranspose2DWithBias { - type State = ( - B::TensorPrimitive<4>, - B::TensorPrimitive<4>, - B::TensorPrimitive<1>, - ConvTransposeOptions<2>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, bias, options) = ops.state; - let backward = B::conv_transpose2d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) - } - } - } - - impl Backward for ConvTranspose2DNoBias { - type State = ( - B::TensorPrimitive<4>, - B::TensorPrimitive<4>, - ConvTransposeOptions<2>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv_transpose2d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - } - } - - match bias { - Some(bias) => { - match ConvTranspose2DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), - ), - B::conv_transpose2d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - ), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), - } - } - None => { - match ConvTranspose2DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv_transpose2d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( - x.primitive, - weight.primitive, - None, - options, - )), - } - } + match bias { + Some(bias) => { + match ConvTranspose2DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), + ), + B::conv_transpose2d(x.primitive, weight.primitive, Some(bias.primitive), options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), + } + } + None => { + match ConvTranspose2DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv_transpose2d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( + x.primitive, + weight.primitive, + None, + options, + )), } + } } + } + + fn conv1d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct Conv1DWithBias; + #[derive(Debug)] + struct Conv1DNoBias; + + impl Backward for Conv1DWithBias { + type State = ( + B::TensorPrimitive<3>, + B::TensorPrimitive<3>, + B::TensorPrimitive<1>, + ConvOptions<1>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); - fn conv1d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct Conv1DWithBias; - #[derive(Debug)] - struct Conv1DNoBias; - - impl Backward for Conv1DWithBias { - type State = ( - B::TensorPrimitive<3>, - B::TensorPrimitive<3>, - B::TensorPrimitive<1>, - ConvOptions<1>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, bias, options) = ops.state; - let backward = B::conv1d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) - } - } - } - - impl Backward for Conv1DNoBias { - type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, ConvOptions<1>); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv1d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - } - } - match bias { - Some(bias) => { - match Conv1DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), - ), - B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv1d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), - } - } - None => { - match Conv1DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv1d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::conv1d(x.primitive, weight.primitive, None, options)) - } - } - } + let (x, weight, bias, options) = ops.state; + let backward = B::conv1d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) } + } } - fn conv_transpose1d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct ConvTranspose1DWithBias; - #[derive(Debug)] - struct ConvTranspose1DNoBias; - - impl Backward for ConvTranspose1DWithBias { - type State = ( - B::TensorPrimitive<3>, - B::TensorPrimitive<3>, - B::TensorPrimitive<1>, - ConvTransposeOptions<1>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, bias, options) = ops.state; - let backward = B::conv_transpose1d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) - } - } - } - - impl Backward for ConvTranspose1DNoBias { - type State = ( - B::TensorPrimitive<3>, - B::TensorPrimitive<3>, - ConvTransposeOptions<1>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv_transpose1d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - } - } - - match bias { - Some(bias) => { - match ConvTranspose1DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), - ), - B::conv_transpose1d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - ), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), - } - } - None => { - match ConvTranspose1DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv_transpose1d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( - x.primitive, - weight.primitive, - None, - options, - )), - } - } + impl Backward for Conv1DNoBias { + type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, ConvOptions<1>); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv1d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + } + } + match bias { + Some(bias) => { + match Conv1DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), + ), + B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv1d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), + } + } + None => { + match Conv1DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv1d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::conv1d(x.primitive, weight.primitive, None, options)) + } } + } } + } + + fn conv_transpose1d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct ConvTranspose1DWithBias; + #[derive(Debug)] + struct ConvTranspose1DNoBias; + + impl Backward for ConvTranspose1DWithBias { + type State = ( + B::TensorPrimitive<3>, + B::TensorPrimitive<3>, + B::TensorPrimitive<1>, + ConvTransposeOptions<1>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); - // TODO: Support a custom unfold4d operation by overriding the default implementation. - // - // We don't override it now because the fold operation isn't available for the backward pass. - // This implies that when autodiff is enabled, custom unfold operations defined by backends - // won't be used. Instead, the conv2d operation with custom weights matrix will be used. - // Therefore, the conv2d backward pass will be used for the unfold4d backward pass. - // - // fn unfold4d( - // x: AutodiffTensor, - // kernel_size: [usize; 2], - // options: UnfoldOptions, - // ) -> AutodiffTensor { - // todo!() - // } - - fn avg_pool1d( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> AutodiffTensor { - #[derive(Debug)] - struct AvgPool1D; - - impl Backward for AvgPool1D { - type State = (B::TensorPrimitive<3>, usize, usize, usize, bool); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, kernel_size, stride, padding, count_include_pad) = ops.state; - - if let Some(node) = node_parent { - let grad = B::avg_pool1d_backward( - x, - grad, - kernel_size, - stride, - padding, - count_include_pad, - ); - grads.register::(node, grad); - } - } - } - - match AvgPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::avg_pool1d( - x.primitive.clone(), - kernel_size, - stride, - padding, - count_include_pad, - ); - prep.finish( - (x.primitive, kernel_size, stride, padding, count_include_pad), - output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )), + let (x, weight, bias, options) = ops.state; + let backward = B::conv_transpose1d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) } + } } - fn avg_pool2d( - x: AutodiffTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> AutodiffTensor { - #[derive(Debug)] - struct AvgPool2D; - - impl Backward for AvgPool2D { - type State = ( - B::TensorPrimitive<4>, - [usize; 2], - [usize; 2], - [usize; 2], - bool, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, kernel_size, stride, padding, count_include_pad) = ops.state; - - if let Some(node) = node_parent { - let grad = B::avg_pool2d_backward( - x, - grad, - kernel_size, - stride, - padding, - count_include_pad, - ); - grads.register::(node, grad); - } - } - } - - match AvgPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::avg_pool2d( - x.primitive.clone(), - kernel_size, - stride, - padding, - count_include_pad, - ); - prep.finish( - (x.primitive, kernel_size, stride, padding, count_include_pad), - output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )), + impl Backward for ConvTranspose1DNoBias { + type State = ( + B::TensorPrimitive<3>, + B::TensorPrimitive<3>, + ConvTransposeOptions<1>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv_transpose1d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) } + } } - fn avg_pool2d_backward( - _x: AutodiffTensor, - _grad: AutodiffTensor, - _kernel_size: [usize; 2], - _stride: [usize; 2], - _padding: [usize; 2], - _count_include_pad: bool, - ) -> AutodiffTensor { - panic!("Can't differentiate avg pool 2d backward."); + match bias { + Some(bias) => { + match ConvTranspose1DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), + ), + B::conv_transpose1d(x.primitive, weight.primitive, Some(bias.primitive), options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), + } + } + None => { + match ConvTranspose1DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv_transpose1d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( + x.primitive, + weight.primitive, + None, + options, + )), + } + } } + } + + // TODO: Support a custom unfold4d operation by overriding the default implementation. + // + // We don't override it now because the fold operation isn't available for the backward pass. + // This implies that when autodiff is enabled, custom unfold operations defined by backends + // won't be used. Instead, the conv2d operation with custom weights matrix will be used. + // Therefore, the conv2d backward pass will be used for the unfold4d backward pass. + // + // fn unfold4d( + // x: AutodiffTensor, + // kernel_size: [usize; 2], + // options: UnfoldOptions, + // ) -> AutodiffTensor { + // todo!() + // } + + fn avg_pool1d( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> AutodiffTensor { + #[derive(Debug)] + struct AvgPool1D; + + impl Backward for AvgPool1D { + type State = (B::TensorPrimitive<3>, usize, usize, usize, bool); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, kernel_size, stride, padding, count_include_pad) = ops.state; - fn max_pool1d( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> AutodiffTensor { - match MaxPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::max_pool1d_with_indices( - x.primitive.clone(), - kernel_size, - stride, - padding, - dilation, - ); - prep.finish( - ( - x.primitive, - output.indices, - kernel_size, - stride, - padding, - dilation, - ), - output.output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )), + if let Some(node) = node_parent { + let grad = + B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad); + grads.register::(node, grad); } + } } - fn max_pool1d_with_indices( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices> { - match MaxPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::max_pool1d_with_indices( - x.primitive.clone(), - kernel_size, - stride, - padding, - dilation, - ); - - let output_tensor = prep.finish( - ( - x.primitive, - output.indices.clone(), - kernel_size, - stride, - padding, - dilation, - ), - output.output, - ); - - MaxPool1dWithIndices::new(output_tensor, output.indices) - } - OpsKind::UnTracked(prep) => { - let output = - B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - let output_tensor = prep.finish(output.output); - - MaxPool1dWithIndices::new(output_tensor, output.indices) - } + match AvgPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::avg_pool1d( + x.primitive.clone(), + kernel_size, + stride, + padding, + count_include_pad, + ); + prep.finish( + (x.primitive, kernel_size, stride, padding, count_include_pad), + output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )), + } + } + + fn avg_pool2d( + x: AutodiffTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> AutodiffTensor { + #[derive(Debug)] + struct AvgPool2D; + + impl Backward for AvgPool2D { + type State = ( + B::TensorPrimitive<4>, + [usize; 2], + [usize; 2], + [usize; 2], + bool, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, kernel_size, stride, padding, count_include_pad) = ops.state; + + if let Some(node) = node_parent { + let grad = + B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad); + grads.register::(node, grad); } + } } - fn max_pool1d_with_indices_backward( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: AutodiffTensor, - indices: IntTensor, - ) -> MaxPool1dBackward> { - let output = B::max_pool1d_with_indices_backward( + match AvgPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::avg_pool2d( + x.primitive.clone(), + kernel_size, + stride, + padding, + count_include_pad, + ); + prep.finish( + (x.primitive, kernel_size, stride, padding, count_include_pad), + output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )), + } + } + + fn avg_pool2d_backward( + _x: AutodiffTensor, + _grad: AutodiffTensor, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _count_include_pad: bool, + ) -> AutodiffTensor { + panic!("Can't differentiate avg pool 2d backward."); + } + + fn max_pool1d( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> AutodiffTensor { + match MaxPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = + B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); + prep.finish( + ( x.primitive, + output.indices, kernel_size, stride, padding, dilation, - output_grad.primitive, - indices, - ); - MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad)) + ), + output.output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )), } + } + + fn max_pool1d_with_indices( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices> { + match MaxPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = + B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); + + let output_tensor = prep.finish( + ( + x.primitive, + output.indices.clone(), + kernel_size, + stride, + padding, + dilation, + ), + output.output, + ); - fn max_pool2d( - x: AutodiffTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> AutodiffTensor { - match MaxPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::max_pool2d_with_indices( - x.primitive.clone(), - kernel_size, - stride, - padding, - dilation, - ); - prep.finish( - ( - x.primitive, - output.indices, - kernel_size, - stride, - padding, - dilation, - ), - output.output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )), - } - } + MaxPool1dWithIndices::new(output_tensor, output.indices) + } + OpsKind::UnTracked(prep) => { + let output = + B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output_tensor = prep.finish(output.output); - fn max_pool2d_with_indices( - x: AutodiffTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - match MaxPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::max_pool2d_with_indices( - x.primitive.clone(), - kernel_size, - stride, - padding, - dilation, - ); - - let output_tensor = prep.finish( - ( - x.primitive, - output.indices.clone(), - kernel_size, - stride, - padding, - dilation, - ), - output.output, - ); - - MaxPool2dWithIndices::new(output_tensor, output.indices) - } - OpsKind::UnTracked(prep) => { - let output = - B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - let output_tensor = prep.finish(output.output); - - MaxPool2dWithIndices::new(output_tensor, output.indices) - } - } + MaxPool1dWithIndices::new(output_tensor, output.indices) + } } - - fn max_pool2d_with_indices_backward( - _x: AutodiffTensor, - _kernel_size: [usize; 2], - _stride: [usize; 2], - _padding: [usize; 2], - _dilation: [usize; 2], - _output_grad: AutodiffTensor, - _indices: IntTensor, - ) -> MaxPool2dBackward> { - panic!("Can't differentiate max pool2d with indices backward."); + } + + fn max_pool1d_with_indices_backward( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: AutodiffTensor, + indices: IntTensor, + ) -> MaxPool1dBackward> { + let output = B::max_pool1d_with_indices_backward( + x.primitive, + kernel_size, + stride, + padding, + dilation, + output_grad.primitive, + indices, + ); + MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad)) + } + + fn max_pool2d( + x: AutodiffTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> AutodiffTensor { + match MaxPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = + B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); + prep.finish( + ( + x.primitive, + output.indices, + kernel_size, + stride, + padding, + dilation, + ), + output.output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )), } - fn adaptive_avg_pool1d(x: AutodiffTensor, output_size: usize) -> AutodiffTensor { - #[derive(Debug)] - struct AdaptiveAvgPool1D; - - impl Backward for AdaptiveAvgPool1D { - type State = B::TensorPrimitive<3>; + } + + fn max_pool2d_with_indices( + x: AutodiffTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + match MaxPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = + B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); + + let output_tensor = prep.finish( + ( + x.primitive, + output.indices.clone(), + kernel_size, + stride, + padding, + dilation, + ), + output.output, + ); - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); + MaxPool2dWithIndices::new(output_tensor, output.indices) + } + OpsKind::UnTracked(prep) => { + let output = + B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output_tensor = prep.finish(output.output); - if let Some(node) = node_parent { - let grad = B::adaptive_avg_pool1d_backward(ops.state, grad); - grads.register::(node, grad); - } - } - } + MaxPool2dWithIndices::new(output_tensor, output.indices) + } + } + } + + fn max_pool2d_with_indices_backward( + _x: AutodiffTensor, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _dilation: [usize; 2], + _output_grad: AutodiffTensor, + _indices: IntTensor, + ) -> MaxPool2dBackward> { + panic!("Can't differentiate max pool2d with indices backward."); + } + fn adaptive_avg_pool1d(x: AutodiffTensor, output_size: usize) -> AutodiffTensor { + #[derive(Debug)] + struct AdaptiveAvgPool1D; + + impl Backward for AdaptiveAvgPool1D { + type State = B::TensorPrimitive<3>; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); - match AdaptiveAvgPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - x.primitive.clone(), - B::adaptive_avg_pool1d(x.primitive, output_size), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size)) - } + if let Some(node) = node_parent { + let grad = B::adaptive_avg_pool1d_backward(ops.state, grad); + grads.register::(node, grad); } + } } - fn adaptive_avg_pool2d( - x: AutodiffTensor, - output_size: [usize; 2], - ) -> AutodiffTensor { - #[derive(Debug)] - struct AdaptiveAvgPool2D; + match AdaptiveAvgPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + x.primitive.clone(), + B::adaptive_avg_pool1d(x.primitive, output_size), + ), + OpsKind::UnTracked(prep) => prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size)), + } + } - impl Backward for AdaptiveAvgPool2D { - type State = B::TensorPrimitive<4>; + fn adaptive_avg_pool2d(x: AutodiffTensor, output_size: [usize; 2]) -> AutodiffTensor { + #[derive(Debug)] + struct AdaptiveAvgPool2D; - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); + impl Backward for AdaptiveAvgPool2D { + type State = B::TensorPrimitive<4>; - if let Some(node) = node_parent { - let grad = B::adaptive_avg_pool2d_backward(ops.state, grad); - grads.register::(node, grad); - } - } - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); - match AdaptiveAvgPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - x.primitive.clone(), - B::adaptive_avg_pool2d(x.primitive, output_size), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size)) - } + if let Some(node) = node_parent { + let grad = B::adaptive_avg_pool2d_backward(ops.state, grad); + grads.register::(node, grad); } + } } - fn adaptive_avg_pool2d_backward( - _x: AutodiffTensor, - _grad: AutodiffTensor, - ) -> as Backend>::TensorPrimitive<4> { - panic!("Can't differentiate adaptive avg pool2d backward."); + match AdaptiveAvgPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + x.primitive.clone(), + B::adaptive_avg_pool2d(x.primitive, output_size), + ), + OpsKind::UnTracked(prep) => prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size)), } + } + + fn adaptive_avg_pool2d_backward( + _x: AutodiffTensor, + _grad: AutodiffTensor, + ) -> as Backend>::TensorPrimitive<4> { + panic!("Can't differentiate adaptive avg pool2d backward."); + } } #[derive(Debug)] struct MaxPool1D; impl Backward for MaxPool1D { - type State = ( - B::TensorPrimitive<3>, - IntTensor, - usize, - usize, - usize, - usize, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, indices, kernel_size, stride, padding, dilation) = ops.state; - - if let Some(node) = node_parent { - let grad = B::max_pool1d_with_indices_backward( - x, - kernel_size, - stride, - padding, - dilation, - grad, - indices, - ); - - grads.register::(node, grad.x_grad); - } + type State = ( + B::TensorPrimitive<3>, + IntTensor, + usize, + usize, + usize, + usize, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, indices, kernel_size, stride, padding, dilation) = ops.state; + + if let Some(node) = node_parent { + let grad = B::max_pool1d_with_indices_backward( + x, + kernel_size, + stride, + padding, + dilation, + grad, + indices, + ); + + grads.register::(node, grad.x_grad); } + } } #[derive(Debug)] struct MaxPool2D; impl Backward for MaxPool2D { - type State = ( - B::TensorPrimitive<4>, - IntTensor, - [usize; 2], - [usize; 2], - [usize; 2], - [usize; 2], - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, indices, kernel_size, stride, padding, dilation) = ops.state; - - if let Some(node) = node_parent { - let grad = B::max_pool2d_with_indices_backward( - x, - kernel_size, - stride, - padding, - dilation, - grad, - indices, - ); - - grads.register::(node, grad.x_grad); - } + type State = ( + B::TensorPrimitive<4>, + IntTensor, + [usize; 2], + [usize; 2], + [usize; 2], + [usize; 2], + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, indices, kernel_size, stride, padding, dilation) = ops.state; + + if let Some(node) = node_parent { + let grad = B::max_pool2d_with_indices_backward( + x, + kernel_size, + stride, + padding, + dilation, + grad, + indices, + ); + + grads.register::(node, grad.x_grad); } + } } diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index e4cd4755e7..d5ae02ecba 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -1,1554 +1,1534 @@ use std::marker::PhantomData; use crate::{ - grads::Gradients, - graph::{NodeRef, Requirement, Step}, - ops::{binary, broadcast_shape, unary, unary_different_backend, Backward, Ops, OpsKind}, - tensor::AutodiffTensor, - utils::duplicate, - Autodiff, + grads::Gradients, + graph::{NodeRef, Requirement, Step}, + ops::{binary, broadcast_shape, unary, unary_different_backend, Backward, Ops, OpsKind}, + tensor::AutodiffTensor, + utils::duplicate, + Autodiff, }; use burn_tensor::{ - backend::Backend, - ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, - Data, Device, ElementConversion, Reader, Shape, Tensor, + backend::Backend, + ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, + Data, Device, ElementConversion, Reader, Shape, Tensor, }; use super::maxmin::MaxMinDim; impl TensorOps for Autodiff { - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor { - AutodiffTensor::new(B::from_data(data, device)) - } - - fn random( - shape: Shape, - distribution: burn_tensor::Distribution>, - device: &Device, - ) -> FloatTensor { - AutodiffTensor::new(B::random(shape, distribution, device)) - } - - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::zeros(shape), device) + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor { + AutodiffTensor::new(B::from_data(data, device)) + } + + fn random( + shape: Shape, + distribution: burn_tensor::Distribution>, + device: &Device, + ) -> FloatTensor { + AutodiffTensor::new(B::random(shape, distribution, device)) + } + + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::zeros(shape), device) + } + + fn ones(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::ones(shape), device) + } + + fn shape(tensor: &FloatTensor) -> Shape { + B::shape(&tensor.primitive) + } + + fn to_data(tensor: &FloatTensor) -> Reader, D>> { + B::to_data(&tensor.primitive) + } + + fn into_data(tensor: FloatTensor) -> Reader, D>> { + B::into_data(tensor.primitive) + } + + fn device(tensor: &FloatTensor) -> Device { + B::device(&tensor.primitive) + } + + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor { + #[derive(Debug)] + struct ToDevice; + + impl Backward for ToDevice { + type State = B::Device; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::to_device(grad, &ops.state) + }); + } + } + + match ToDevice.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let device_old = B::device(&tensor.primitive); + prep.finish(device_old, B::to_device(tensor.primitive, device)) + } + OpsKind::UnTracked(prep) => prep.finish(B::to_device(tensor.primitive, device)), + } + } + + fn arange(range: std::ops::Range, device: &Device) -> IntTensor { + B::arange(range, device) + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + AutodiffTensor::new(B::empty(shape, device)) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Add; + + impl Backward for Add { + type State = (Shape, Shape); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape_lhs, shape_rhs) = ops.state; + + binary::( + ops.parents, + ops.node, + grads, + |grad| broadcast_shape::(grad, &shape_lhs), + |grad| broadcast_shape::(grad, &shape_rhs), + ); + } + } + + match Add + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(preps) => preps.finish( + (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), + B::add(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(preps) => preps.finish(B::add(lhs.primitive, rhs.primitive)), + } + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct AddScalar; + + impl Backward for AddScalar { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| grad); + } + } + + AddScalar + .prepare([lhs.node], [lhs.graph]) + .stateless(B::add_scalar(lhs.primitive, rhs)) + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Sub; + + impl Backward for Sub { + type State = (Shape, Shape); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape_lhs, shape_rhs) = ops.state; + + binary::( + ops.parents, + ops.node, + grads, + |grad| broadcast_shape::(grad, &shape_lhs), + |grad| broadcast_shape::(B::neg(grad), &shape_rhs), + ); + } + } + + match Sub + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(preps) => preps.finish( + (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), + B::sub(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(preps) => preps.finish(B::sub(lhs.primitive, rhs.primitive)), + } + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct SubScalar; + + impl Backward for SubScalar { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| grad); + } + } + + SubScalar + .prepare([lhs.node], [lhs.graph]) + .stateless(B::sub_scalar(lhs.primitive, rhs)) + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Mul; + + impl Backward for Mul { + type State = ( + Option>, + Option>, + BinaryOpsBroadcast, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (lhs, rhs, broadcast) = ops.state; + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let grad = B::mul(grad, rhs.unwrap()); + broadcast.backward_lhs::(grad) + }, + |grad| { + let grad = B::mul(grad, lhs.unwrap()); + broadcast.backward_rhs::(grad) + }, + ); + } + } + + let lhs_tracked = lhs.is_tracked(); + let rhs_tracked = rhs.is_tracked(); + let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); + + match Mul + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + rhs_tracked.then(|| lhs.primitive.clone()), + lhs_tracked.then(|| rhs.primitive.clone()), + broadcast, + ), + B::mul(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::mul(lhs.primitive, rhs.primitive)), + } + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct MulScalar; + + impl Backward for MulScalar { + type State = FloatElem; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::mul_scalar(grad, ops.state) + }); + } + } + + match MulScalar.prepare([lhs.node], [lhs.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(rhs, B::mul_scalar(lhs.primitive, rhs)), + OpsKind::UnTracked(prep) => prep.finish(B::mul_scalar(lhs.primitive, rhs)), + } + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Div; + + impl Backward for Div { + type State = ( + Option>, + Option>, + BinaryOpsBroadcast, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (lhs, rhs, broadcast) = ops.state; + let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, rhs); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let rhs = rhs_4lhs.unwrap(); + let value = B::powf(rhs, -1.0); + let grad = B::mul(grad, value); + + broadcast.backward_lhs::(grad) + }, + |grad| { + let rhs = rhs_4rhs.unwrap(); + let lhs = lhs.unwrap(); + let value = B::div(B::neg(lhs), B::powf(rhs, 2.0)); + let grad = B::mul(grad, value); + + broadcast.backward_rhs::(grad) + }, + ); + } + } + + let lhs_tracked = lhs.is_tracked(); + let rhs_tracked = rhs.is_tracked(); + let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); + + match Div + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + rhs_tracked.then(|| lhs.primitive.clone()), + (lhs_tracked || rhs_tracked).then(|| rhs.primitive.clone()), + broadcast, + ), + B::div(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::div(lhs.primitive, rhs.primitive)), + } + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct DivScalar; + + impl Backward for DivScalar { + type State = FloatElem; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let tmp = 1.0 / ops.state.elem::(); + B::mul_scalar(grad, tmp.elem()) + }); + } + } + + match DivScalar.prepare([lhs.node], [lhs.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(rhs, B::div_scalar(lhs.primitive, rhs)), + OpsKind::UnTracked(prep) => prep.finish(B::div_scalar(lhs.primitive, rhs)), + } + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Matmul; + + impl Backward for Matmul { + type State = ( + Option>, + Option>, + BinaryOpsBroadcast, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (lhs, rhs, broadcast) = ops.state; + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let rhs = B::transpose(rhs.unwrap()); + let grad = B::matmul(grad, rhs); + + broadcast.backward_lhs::(grad) + }, + |grad| { + let lhs = B::transpose(lhs.unwrap()); + let grad = B::matmul(lhs, grad); + + broadcast.backward_rhs::(grad) + }, + ); + } + } + + let lhs_tracked = lhs.is_tracked(); + let rhs_tracked = rhs.is_tracked(); + let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); + + match Matmul + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + rhs_tracked.then(|| lhs.primitive.clone()), + lhs_tracked.then(|| rhs.primitive.clone()), + broadcast, + ), + B::matmul(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::matmul(lhs.primitive, rhs.primitive)), + } + } + + fn neg(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Neg; + + impl Backward for Neg { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| B::neg(grad)); + } + } + + Neg + .prepare([tensor.node], [tensor.graph]) + .stateless(B::neg(tensor.primitive)) + } + + fn recip(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Recip; + + impl Backward for Recip { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let tensor = ops.state; + unary::(ops.parents, ops.node, grads, |grad| { + let tmp = B::powf(tensor, -2.0); + let value = B::neg(tmp); + + B::mul(grad, value) + }); + } } - fn ones(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::ones(shape), device) + match Recip.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)), } + } - fn shape(tensor: &FloatTensor) -> Shape { - B::shape(&tensor.primitive) - } + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + #[derive(Debug)] + struct SwapDim; - fn to_data(tensor: &FloatTensor) -> Reader, D>> { - B::to_data(&tensor.primitive) - } + impl Backward for SwapDim { + type State = (usize, usize); - fn into_data(tensor: FloatTensor) -> Reader, D>> { - B::into_data(tensor.primitive) - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim1, dim2) = ops.state; - fn device(tensor: &FloatTensor) -> Device { - B::device(&tensor.primitive) + unary::(ops.parents, ops.node, grads, |grad| { + B::swap_dims(grad, dim2, dim1) + }); + } } - fn to_device( - tensor: FloatTensor, - device: &Device, - ) -> FloatTensor { - #[derive(Debug)] - struct ToDevice; - - impl Backward for ToDevice { - type State = B::Device; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::to_device(grad, &ops.state) - }); - } - } - - match ToDevice.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let device_old = B::device(&tensor.primitive); - prep.finish(device_old, B::to_device(tensor.primitive, device)) - } - OpsKind::UnTracked(prep) => prep.finish(B::to_device(tensor.primitive, device)), - } - } + let output = B::swap_dims(tensor.primitive, dim1, dim2); - fn arange(range: std::ops::Range, device: &Device) -> IntTensor { - B::arange(range, device) + match SwapDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish((dim1, dim2), output), + OpsKind::UnTracked(prep) => prep.finish(output), } + } - fn empty(shape: Shape, device: &Device) -> FloatTensor { - AutodiffTensor::new(B::empty(shape, device)) - } + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + #[derive(Debug)] + struct ReshapeDim; - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Add; + impl Backward for ReshapeDim { + type State = (Shape, Shape); - impl Backward for Add { - type State = (Shape, Shape); + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape_original, shape) = ops.state; - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape_lhs, shape_rhs) = ops.state; + unary::(ops.parents, ops.node, grads, |grad| { + let shape_grad = B::shape(&grad); + let mut grad = grad; - binary::( - ops.parents, - ops.node, - grads, - |grad| broadcast_shape::(grad, &shape_lhs), - |grad| broadcast_shape::(grad, &shape_rhs), - ); - } - } - - match Add - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(preps) => preps.finish( - (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), - B::add(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(preps) => preps.finish(B::add(lhs.primitive, rhs.primitive)), - } - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct AddScalar; - - impl Backward for AddScalar { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| grad); + for i in 0..D2 { + if shape.dims[i] == 1 && shape_grad.dims[i] != 1 { + grad = B::sum_dim(grad, i); } - } + } - AddScalar - .prepare([lhs.node], [lhs.graph]) - .stateless(B::add_scalar(lhs.primitive, rhs)) + B::reshape(grad, shape_original) + }); + } } - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Sub; - - impl Backward for Sub { - type State = (Shape, Shape); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape_lhs, shape_rhs) = ops.state; - - binary::( - ops.parents, - ops.node, - grads, - |grad| broadcast_shape::(grad, &shape_lhs), - |grad| broadcast_shape::(B::neg(grad), &shape_rhs), - ); - } - } - - match Sub - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(preps) => preps.finish( - (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), - B::sub(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(preps) => preps.finish(B::sub(lhs.primitive, rhs.primitive)), - } - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct SubScalar; - - impl Backward for SubScalar { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| grad); - } - } - - SubScalar - .prepare([lhs.node], [lhs.graph]) - .stateless(B::sub_scalar(lhs.primitive, rhs)) - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Mul; - - impl Backward for Mul { - type State = ( - Option>, - Option>, - BinaryOpsBroadcast, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (lhs, rhs, broadcast) = ops.state; - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let grad = B::mul(grad, rhs.unwrap()); - broadcast.backward_lhs::(grad) - }, - |grad| { - let grad = B::mul(grad, lhs.unwrap()); - broadcast.backward_rhs::(grad) - }, - ); - } - } - - let lhs_tracked = lhs.is_tracked(); - let rhs_tracked = rhs.is_tracked(); - let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); - - match Mul - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - rhs_tracked.then(|| lhs.primitive.clone()), - lhs_tracked.then(|| rhs.primitive.clone()), - broadcast, - ), - B::mul(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::mul(lhs.primitive, rhs.primitive)), - } - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct MulScalar; - - impl Backward for MulScalar { - type State = FloatElem; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::mul_scalar(grad, ops.state) - }); - } - } - - match MulScalar.prepare([lhs.node], [lhs.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(rhs, B::mul_scalar(lhs.primitive, rhs)), - OpsKind::UnTracked(prep) => prep.finish(B::mul_scalar(lhs.primitive, rhs)), - } - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Div; - - impl Backward for Div { - type State = ( - Option>, - Option>, - BinaryOpsBroadcast, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (lhs, rhs, broadcast) = ops.state; - let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, rhs); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let rhs = rhs_4lhs.unwrap(); - let value = B::powf(rhs, -1.0); - let grad = B::mul(grad, value); - - broadcast.backward_lhs::(grad) - }, - |grad| { - let rhs = rhs_4rhs.unwrap(); - let lhs = lhs.unwrap(); - let value = B::div(B::neg(lhs), B::powf(rhs, 2.0)); - let grad = B::mul(grad, value); - - broadcast.backward_rhs::(grad) - }, - ); - } - } - - let lhs_tracked = lhs.is_tracked(); - let rhs_tracked = rhs.is_tracked(); - let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); - - match Div - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - rhs_tracked.then(|| lhs.primitive.clone()), - (lhs_tracked || rhs_tracked).then(|| rhs.primitive.clone()), - broadcast, - ), - B::div(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::div(lhs.primitive, rhs.primitive)), - } - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct DivScalar; - - impl Backward for DivScalar { - type State = FloatElem; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let tmp = 1.0 / ops.state.elem::(); - B::mul_scalar(grad, tmp.elem()) - }); - } - } - - match DivScalar.prepare([lhs.node], [lhs.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(rhs, B::div_scalar(lhs.primitive, rhs)), - OpsKind::UnTracked(prep) => prep.finish(B::div_scalar(lhs.primitive, rhs)), - } - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Matmul; - - impl Backward for Matmul { - type State = ( - Option>, - Option>, - BinaryOpsBroadcast, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (lhs, rhs, broadcast) = ops.state; - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let rhs = B::transpose(rhs.unwrap()); - let grad = B::matmul(grad, rhs); - - broadcast.backward_lhs::(grad) - }, - |grad| { - let lhs = B::transpose(lhs.unwrap()); - let grad = B::matmul(lhs, grad); - - broadcast.backward_rhs::(grad) - }, - ); - } - } - - let lhs_tracked = lhs.is_tracked(); - let rhs_tracked = rhs.is_tracked(); - let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); - - match Matmul - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - rhs_tracked.then(|| lhs.primitive.clone()), - lhs_tracked.then(|| rhs.primitive.clone()), - broadcast, - ), - B::matmul(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::matmul(lhs.primitive, rhs.primitive)), - } - } - - fn neg(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Neg; - - impl Backward for Neg { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| B::neg(grad)); - } - } - - Neg.prepare([tensor.node], [tensor.graph]) - .stateless(B::neg(tensor.primitive)) + match ReshapeDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (B::shape(&tensor.primitive), shape.clone()), + B::reshape(tensor.primitive, shape), + ), + OpsKind::UnTracked(prep) => prep.finish(B::reshape(tensor.primitive, shape)), } + } - fn recip(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Recip; + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Gather; - impl Backward for Recip { - type State = B::TensorPrimitive; + impl Backward for Gather { + type State = (usize, IntTensor, Shape, B::Device); - fn backward(self, ops: Ops, grads: &mut Gradients) { - let tensor = ops.state; - unary::(ops.parents, ops.node, grads, |grad| { - let tmp = B::powf(tensor, -2.0); - let value = B::neg(tmp); - - B::mul(grad, value) - }); - } - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape, device) = ops.state; - match Recip.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)), - } + unary::(ops.parents, ops.node, grads, |grad| { + let zeros = B::zeros(shape, &device); + B::scatter(dim, zeros, indices, grad) + }); + } + } + + match Gather.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::device(&tensor.primitive), + ), + B::gather(dim, tensor.primitive, indices), + ), + OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indices)), + } + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Scatter; + + impl Backward for Scatter { + type State = (usize, IntTensor, Shape, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; + let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_lhs, &device); + B::scatter(dim, grad, indices_4lhs.unwrap(), zeros) + }, + |grad| { + let zeros = B::zeros(shape_rhs, &device); + B::scatter(dim, zeros, indices_4rhs.unwrap(), grad) + }, + ); + } + } + + match Scatter + .prepare([tensor.node, value.node], [tensor.graph, value.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::shape(&value.primitive), + B::device(&value.primitive), + ), + B::scatter(dim, tensor.primitive, indices, value.primitive), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::scatter(dim, tensor.primitive, indices, value.primitive)) + } + } + } + + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct IndexSelectDim; + + impl Backward for IndexSelectDim { + type State = (usize, IntTensor, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape, device) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let zeros = B::zeros(shape, &device); + B::select_assign(zeros, dim, indices, grad) + }); + } + } + + match IndexSelectDim + .prepare([tensor.node], [tensor.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::device(&tensor.primitive), + ), + B::select(tensor.primitive, dim, indices), + ), + OpsKind::UnTracked(prep) => prep.finish(B::select(tensor.primitive, dim, indices)), + } + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct IndexSelectDimAssign; + + impl Backward for IndexSelectDimAssign { + type State = (usize, IntTensor, Shape, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; + let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_lhs, &device); + B::select_assign(grad, dim, indices_4lhs.unwrap(), zeros) + }, + |grad| { + let zeros = B::zeros(shape_rhs, &device); + B::select_assign(zeros, dim, indices_4rhs.unwrap(), grad) + }, + ); + } + } + + match IndexSelectDimAssign:: + .prepare([tensor.node, value.node], [tensor.graph, value.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::shape(&value.primitive), + B::device(&value.primitive), + ), + B::select_assign(tensor.primitive, dim, indices, value.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::select_assign( + tensor.primitive, + dim, + indices, + value.primitive, + )), + } + } + + fn slice( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + ) -> FloatTensor { + #[derive(Debug)] + struct Index; + + impl Backward for Index { + type State = ([std::ops::Range; D2], Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (ranges, shape, device) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let zeros = B::zeros(shape, &device); + B::slice_assign(zeros, ranges, grad) + }); + } + } + + match Index.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + ( + ranges.clone(), + B::shape(&tensor.primitive), + B::device(&tensor.primitive), + ), + B::slice(tensor.primitive, ranges), + ), + OpsKind::UnTracked(prep) => prep.finish(B::slice(tensor.primitive, ranges)), + } + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + value: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct IndexAssign; + + impl Backward for IndexAssign { + type State = ([std::ops::Range; D2], Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (ranges, shape_rhs, device) = ops.state; + let [ranges_4lhs, ranges_4rhs] = duplicate(&ops.parents, Some(ranges)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_rhs, &device); + B::slice_assign(grad, ranges_4lhs.unwrap(), zeros) + }, + |grad| B::slice(grad, ranges_4rhs.unwrap()), + ); + } + } + + match IndexAssign + .prepare([tensor.node, value.node], [tensor.graph, value.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + ranges.clone(), + B::shape(&value.primitive), + B::device(&value.primitive), + ), + B::slice_assign(tensor.primitive, ranges, value.primitive), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::slice_assign(tensor.primitive, ranges, value.primitive)) + } + } + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + source: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct MaskWhere; + + impl Backward for MaskWhere { + type State = (BoolTensor, Shape, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (mask, shape_lhs, shape_rhs, device) = ops.state; + let [mask_4lhs, mask_4rhs] = duplicate(&ops.parents, Some(mask)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_lhs.clone(), &device); + let grad = B::mask_where(grad, mask_4lhs.unwrap(), zeros); + + broadcast_shape::(grad, &shape_lhs) + }, + |grad| { + let zeros = B::zeros(shape_rhs.clone(), &device); + let grad = B::mask_where(zeros, mask_4rhs.unwrap(), grad); + + broadcast_shape::(grad, &shape_rhs) + }, + ); + } + } + + match MaskWhere + .prepare([tensor.node, source.node], [tensor.graph, source.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + mask.clone(), + B::shape(&tensor.primitive), + B::shape(&source.primitive), + B::device(&source.primitive), + ), + B::mask_where(tensor.primitive, mask, source.primitive), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::mask_where(tensor.primitive, mask, source.primitive)) + } + } + } + + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct MaskFill; + + impl Backward for MaskFill { + type State = BoolTensor; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::mask_fill(grad, ops.state, 0.elem()) + }); + } + } + + match MaskFill.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(mask.clone(), B::mask_fill(tensor.primitive, mask, value)) + } + OpsKind::UnTracked(prep) => prep.finish(B::mask_fill(tensor.primitive, mask, value)), + } + } + + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::equal(lhs.primitive, rhs.primitive) + } + + fn equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { + B::equal_elem(lhs.primitive, rhs) + } + + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::greater(lhs.primitive, rhs.primitive) + } + + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::greater_elem(lhs.primitive, rhs) + } + + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::greater_equal(lhs.primitive, rhs.primitive) + } + + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::greater_equal_elem(lhs.primitive, rhs) + } + + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::lower(lhs.primitive, rhs.primitive) + } + + fn lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { + B::lower_elem(lhs.primitive, rhs) + } + + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::lower_equal(lhs.primitive, rhs.primitive) + } + + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::lower_equal_elem(lhs.primitive, rhs) + } + + fn detach(tensor: FloatTensor) -> FloatTensor { + // When we detach a tensor, we remove it from the graph, but we still want to keep the + // `require_grad` setting. + let is_require_grad = Self::is_require_grad(&tensor); + let tensor = AutodiffTensor::new(tensor.primitive); + + match is_require_grad { + true => tensor.require_grad(), + false => tensor, + } + } + + fn set_require_grad( + tensor: FloatTensor, + require_grad: bool, + ) -> FloatTensor { + if require_grad { + return tensor.require_grad(); + } + + AutodiffTensor::new(tensor.primitive) + } + + fn is_require_grad(tensor: &FloatTensor) -> bool { + matches!(tensor.node.requirement, Requirement::Grad) + } + + fn mean(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Mean; + + impl Backward for Mean { + type State = Shape; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let shape = ops.state; + let val = 1_f64 / shape.num_elements() as f64; + let ones = B::ones(shape, &B::device(&grad)); + let val = B::mul_scalar(ones, val.elem()); + + let grad: Tensor = Tensor::from_primitive(grad); + let val: Tensor = Tensor::from_primitive(val); + + val.mul(grad.unsqueeze()).into_primitive() + }); + } } - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - #[derive(Debug)] - struct SwapDim; - - impl Backward for SwapDim { - type State = (usize, usize); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim1, dim2) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - B::swap_dims(grad, dim2, dim1) - }); - } - } - - let output = B::swap_dims(tensor.primitive, dim1, dim2); - - match SwapDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish((dim1, dim2), output), - OpsKind::UnTracked(prep) => prep.finish(output), - } + match Mean.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(B::shape(&tensor.primitive), B::mean(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::mean(tensor.primitive)), } + } - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - #[derive(Debug)] - struct ReshapeDim; - - impl Backward for ReshapeDim { - type State = (Shape, Shape); + fn sum(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Sum; - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape_original, shape) = ops.state; + impl Backward for Sum { + type State = Shape; - unary::(ops.parents, ops.node, grads, |grad| { - let shape_grad = B::shape(&grad); - let mut grad = grad; + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let val = B::ones(ops.state, &B::device(&grad)); - for i in 0..D2 { - if shape.dims[i] == 1 && shape_grad.dims[i] != 1 { - grad = B::sum_dim(grad, i); - } - } + let grad: Tensor = Tensor::from_primitive(grad); + let val: Tensor = Tensor::from_primitive(val); - B::reshape(grad, shape_original) - }); - } - } - - match ReshapeDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (B::shape(&tensor.primitive), shape.clone()), - B::reshape(tensor.primitive, shape), - ), - OpsKind::UnTracked(prep) => prep.finish(B::reshape(tensor.primitive, shape)), - } - } - - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Gather; - - impl Backward for Gather { - type State = (usize, IntTensor, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape, device) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - let zeros = B::zeros(shape, &device); - B::scatter(dim, zeros, indices, grad) - }); - } - } - - match Gather.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::device(&tensor.primitive), - ), - B::gather(dim, tensor.primitive, indices), - ), - OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indices)), - } - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Scatter; - - impl Backward for Scatter { - type State = (usize, IntTensor, Shape, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; - let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_lhs, &device); - B::scatter(dim, grad, indices_4lhs.unwrap(), zeros) - }, - |grad| { - let zeros = B::zeros(shape_rhs, &device); - B::scatter(dim, zeros, indices_4rhs.unwrap(), grad) - }, - ); - } - } - - match Scatter - .prepare([tensor.node, value.node], [tensor.graph, value.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::shape(&value.primitive), - B::device(&value.primitive), - ), - B::scatter(dim, tensor.primitive, indices, value.primitive), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::scatter(dim, tensor.primitive, indices, value.primitive)) - } - } + val.mul(grad.unsqueeze()).into_primitive() + }); + } } - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct IndexSelectDim; - - impl Backward for IndexSelectDim { - type State = (usize, IntTensor, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape, device) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - let zeros = B::zeros(shape, &device); - B::select_assign(zeros, dim, indices, grad) - }); - } - } - - match IndexSelectDim - .prepare([tensor.node], [tensor.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::device(&tensor.primitive), - ), - B::select(tensor.primitive, dim, indices), - ), - OpsKind::UnTracked(prep) => prep.finish(B::select(tensor.primitive, dim, indices)), - } - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct IndexSelectDimAssign; - - impl Backward for IndexSelectDimAssign { - type State = (usize, IntTensor, Shape, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; - let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_lhs, &device); - B::select_assign(grad, dim, indices_4lhs.unwrap(), zeros) - }, - |grad| { - let zeros = B::zeros(shape_rhs, &device); - B::select_assign(zeros, dim, indices_4rhs.unwrap(), grad) - }, - ); - } - } - - match IndexSelectDimAssign:: - .prepare([tensor.node, value.node], [tensor.graph, value.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::shape(&value.primitive), - B::device(&value.primitive), - ), - B::select_assign(tensor.primitive, dim, indices, value.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::select_assign( - tensor.primitive, - dim, - indices, - value.primitive, - )), - } - } - - fn slice( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - ) -> FloatTensor { - #[derive(Debug)] - struct Index; - - impl Backward for Index { - type State = ([std::ops::Range; D2], Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (ranges, shape, device) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - let zeros = B::zeros(shape, &device); - B::slice_assign(zeros, ranges, grad) - }); - } - } - - match Index.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - ( - ranges.clone(), - B::shape(&tensor.primitive), - B::device(&tensor.primitive), - ), - B::slice(tensor.primitive, ranges), - ), - OpsKind::UnTracked(prep) => prep.finish(B::slice(tensor.primitive, ranges)), - } - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - value: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct IndexAssign; - - impl Backward for IndexAssign { - type State = ([std::ops::Range; D2], Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (ranges, shape_rhs, device) = ops.state; - let [ranges_4lhs, ranges_4rhs] = duplicate(&ops.parents, Some(ranges)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_rhs, &device); - B::slice_assign(grad, ranges_4lhs.unwrap(), zeros) - }, - |grad| B::slice(grad, ranges_4rhs.unwrap()), - ); - } - } - - match IndexAssign - .prepare([tensor.node, value.node], [tensor.graph, value.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - ranges.clone(), - B::shape(&value.primitive), - B::device(&value.primitive), - ), - B::slice_assign(tensor.primitive, ranges, value.primitive), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::slice_assign(tensor.primitive, ranges, value.primitive)) - } - } - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - source: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct MaskWhere; - - impl Backward for MaskWhere { - type State = (BoolTensor, Shape, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (mask, shape_lhs, shape_rhs, device) = ops.state; - let [mask_4lhs, mask_4rhs] = duplicate(&ops.parents, Some(mask)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_lhs.clone(), &device); - let grad = B::mask_where(grad, mask_4lhs.unwrap(), zeros); - - broadcast_shape::(grad, &shape_lhs) - }, - |grad| { - let zeros = B::zeros(shape_rhs.clone(), &device); - let grad = B::mask_where(zeros, mask_4rhs.unwrap(), grad); - - broadcast_shape::(grad, &shape_rhs) - }, - ); - } - } - - match MaskWhere - .prepare([tensor.node, source.node], [tensor.graph, source.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - mask.clone(), - B::shape(&tensor.primitive), - B::shape(&source.primitive), - B::device(&source.primitive), - ), - B::mask_where(tensor.primitive, mask, source.primitive), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::mask_where(tensor.primitive, mask, source.primitive)) - } - } + match Sum.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(B::shape(&tensor.primitive), B::sum(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::sum(tensor.primitive)), } + } - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct MaskFill; + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[derive(Debug)] + struct MeamDim; - impl Backward for MaskFill { - type State = BoolTensor; + impl Backward for MeamDim { + type State = (Shape, usize); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::mask_fill(grad, ops.state, 0.elem()) - }); - } - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape, dim) = ops.state; - match MaskFill.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(mask.clone(), B::mask_fill(tensor.primitive, mask, value)) - } - OpsKind::UnTracked(prep) => prep.finish(B::mask_fill(tensor.primitive, mask, value)), - } - } + unary::(ops.parents, ops.node, grads, |grad| { + let val = 1_f64 / shape.dims[dim] as f64; + let ones = B::ones(shape, &B::device(&grad)); + let val = B::mul_scalar(ones, B::FloatElem::from_elem(val)); - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::equal(lhs.primitive, rhs.primitive) + let grad = B::sum_dim(grad, dim); + B::mul(val, grad) + }); + } } - fn equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::equal_elem(lhs.primitive, rhs) + match MeamDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (B::shape(&tensor.primitive), dim), + B::mean_dim(tensor.primitive, dim), + ), + OpsKind::UnTracked(prep) => prep.finish(B::mean_dim(tensor.primitive, dim)), } + } - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::greater(lhs.primitive, rhs.primitive) - } + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[derive(Debug)] + struct SumDim; - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::greater_elem(lhs.primitive, rhs) - } + impl Backward for SumDim { + type State = (Shape, usize); - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::greater_equal(lhs.primitive, rhs.primitive) - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape, dim) = ops.state; - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::greater_equal_elem(lhs.primitive, rhs) - } + unary::(ops.parents, ops.node, grads, |grad| { + let ones = B::ones(shape, &B::device(&grad)); + let grad = B::sum_dim(grad, dim); - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::lower(lhs.primitive, rhs.primitive) + B::mul(ones, grad) + }); + } } - fn lower_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::lower_elem(lhs.primitive, rhs) + match SumDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (B::shape(&tensor.primitive), dim), + B::sum_dim(tensor.primitive, dim), + ), + OpsKind::UnTracked(prep) => prep.finish(B::sum_dim(tensor.primitive, dim)), } + } - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::lower_equal(lhs.primitive, rhs.primitive) + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + #[derive(Debug)] + struct ToFullPrecision { + phantom: PhantomData, } - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::lower_equal_elem(lhs.primitive, rhs) + impl Backward for ToFullPrecision { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary_different_backend::( + ops.parents, + ops.node, + grads, + |grad| B::from_full_precision(grad), + ); + } } - fn detach(tensor: FloatTensor) -> FloatTensor { - // When we detach a tensor, we remove it from the graph, but we still want to keep the - // `require_grad` setting. - let is_require_grad = Self::is_require_grad(&tensor); - let tensor = AutodiffTensor::new(tensor.primitive); + let ops = ToFullPrecision:: { + phantom: PhantomData, + }; + ops + .prepare([tensor.node.clone()], [tensor.graph.clone()]) + .stateless(B::to_full_precision(&tensor.primitive)) + } - match is_require_grad { - true => tensor.require_grad(), - false => tensor, - } + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + #[derive(Debug)] + struct FromFullPrecision { + phantom: PhantomData, } - fn set_require_grad( - tensor: FloatTensor, - require_grad: bool, - ) -> FloatTensor { - if require_grad { - return tensor.require_grad(); - } + impl Backward for FromFullPrecision { + type State = (); - AutodiffTensor::new(tensor.primitive) + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary_different_backend::( + ops.parents, + ops.node, + grads, + |grad| B::to_full_precision(&grad), + ); + } } - fn is_require_grad(tensor: &FloatTensor) -> bool { - matches!(tensor.node.requirement, Requirement::Grad) - } + let ops = FromFullPrecision:: { + phantom: PhantomData, + }; - fn mean(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Mean; + ops + .prepare([tensor.node.clone()], [tensor.graph]) + .stateless(B::from_full_precision(tensor.primitive)) + } - impl Backward for Mean { - type State = Shape; + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + B::argmax(tensor.primitive, dim) + } - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let shape = ops.state; - let val = 1_f64 / shape.num_elements() as f64; - let ones = B::ones(shape, &B::device(&grad)); - let val = B::mul_scalar(ones, val.elem()); + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + B::argmin(tensor.primitive, dim) + } - let grad: Tensor = Tensor::from_primitive(grad); - let val: Tensor = Tensor::from_primitive(val); + fn exp(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Exp; - val.mul(grad.unsqueeze()).into_primitive() - }); - } - } + impl Backward for Exp { + type State = B::TensorPrimitive; - match Mean.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(B::shape(&tensor.primitive), B::mean(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::mean(tensor.primitive)), - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); + } } - fn sum(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Sum; - - impl Backward for Sum { - type State = Shape; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let val = B::ones(ops.state, &B::device(&grad)); - - let grad: Tensor = Tensor::from_primitive(grad); - let val: Tensor = Tensor::from_primitive(val); + let output = B::exp(tensor.primitive); - val.mul(grad.unsqueeze()).into_primitive() - }); - } - } - - match Sum.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(B::shape(&tensor.primitive), B::sum(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::sum(tensor.primitive)), - } + match Exp.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(output.clone(), output), + OpsKind::UnTracked(prep) => prep.finish(output), } + } - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[derive(Debug)] - struct MeamDim; - - impl Backward for MeamDim { - type State = (Shape, usize); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape, dim) = ops.state; + fn log(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Log; - unary::(ops.parents, ops.node, grads, |grad| { - let val = 1_f64 / shape.dims[dim] as f64; - let ones = B::ones(shape, &B::device(&grad)); - let val = B::mul_scalar(ones, B::FloatElem::from_elem(val)); + impl Backward for Log { + type State = B::TensorPrimitive; - let grad = B::sum_dim(grad, dim); - B::mul(val, grad) - }); - } - } - - match MeamDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (B::shape(&tensor.primitive), dim), - B::mean_dim(tensor.primitive, dim), - ), - OpsKind::UnTracked(prep) => prep.finish(B::mean_dim(tensor.primitive, dim)), - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::powf(ops.state, -1.0); + B::mul(grad, value) + }); + } } - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[derive(Debug)] - struct SumDim; - - impl Backward for SumDim { - type State = (Shape, usize); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape, dim) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - let ones = B::ones(shape, &B::device(&grad)); - let grad = B::sum_dim(grad, dim); + match Log.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::log(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::log(tensor.primitive)), + } + } - B::mul(ones, grad) - }); - } - } - - match SumDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (B::shape(&tensor.primitive), dim), - B::sum_dim(tensor.primitive, dim), - ), - OpsKind::UnTracked(prep) => prep.finish(B::sum_dim(tensor.primitive, dim)), - } - } - - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - #[derive(Debug)] - struct ToFullPrecision { - phantom: PhantomData, - } - - impl Backward for ToFullPrecision { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary_different_backend::( - ops.parents, - ops.node, - grads, - |grad| B::from_full_precision(grad), - ); - } - } - - let ops = ToFullPrecision:: { - phantom: PhantomData, - }; - ops.prepare([tensor.node.clone()], [tensor.graph.clone()]) - .stateless(B::to_full_precision(&tensor.primitive)) - } - - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - #[derive(Debug)] - struct FromFullPrecision { - phantom: PhantomData, - } - - impl Backward for FromFullPrecision { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary_different_backend::( - ops.parents, - ops.node, - grads, - |grad| B::to_full_precision(&grad), - ); - } - } + fn log1p(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Log1P; - let ops = FromFullPrecision:: { - phantom: PhantomData, - }; + impl Backward for Log1P { + type State = B::TensorPrimitive; - ops.prepare([tensor.node.clone()], [tensor.graph]) - .stateless(B::from_full_precision(tensor.primitive)) - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::add_scalar(ops.state, 1.elem()); + let value = B::powf(value, -1.0); - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - B::argmax(tensor.primitive, dim) + B::mul(grad, value) + }); + } } - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - B::argmin(tensor.primitive, dim) + match Log1P.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::log1p(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::log1p(tensor.primitive)), } + } - fn exp(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Exp; + fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { + #[derive(Debug)] + struct PowF; - impl Backward for Exp { - type State = B::TensorPrimitive; + impl Backward for PowF { + type State = (B::TensorPrimitive, f32); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); - } - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (tensor, value) = ops.state; - let output = B::exp(tensor.primitive); + unary::(ops.parents, ops.node, grads, |grad| { + let tmp = B::powf(tensor, value - 1.0); + let value = B::mul_scalar(tmp, value.elem()); - match Exp.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(output.clone(), output), - OpsKind::UnTracked(prep) => prep.finish(output), - } + B::mul(grad, value) + }); + } } - fn log(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Log; - - impl Backward for Log { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::powf(ops.state, -1.0); - B::mul(grad, value) - }); - } - } - - match Log.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(tensor.primitive.clone(), B::log(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::log(tensor.primitive)), - } + match PowF.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (tensor.primitive.clone(), value), + B::powf(tensor.primitive, value), + ), + OpsKind::UnTracked(prep) => prep.finish(B::powf(tensor.primitive, value)), } + } - fn log1p(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Log1P; - - impl Backward for Log1P { - type State = B::TensorPrimitive; + fn sqrt(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Sqrt; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::add_scalar(ops.state, 1.elem()); - let value = B::powf(value, -1.0); + impl Backward for Sqrt { + type State = B::TensorPrimitive; - B::mul(grad, value) - }); - } - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let input = ops.state; + let value = B::div_scalar(B::powf(input, -0.5), 2.elem()); - match Log1P.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(tensor.primitive.clone(), B::log1p(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::log1p(tensor.primitive)), - } + B::mul(grad, value) + }); + } } - fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { - #[derive(Debug)] - struct PowF; - - impl Backward for PowF { - type State = (B::TensorPrimitive, f32); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (tensor, value) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - let tmp = B::powf(tensor, value - 1.0); - let value = B::mul_scalar(tmp, value.elem()); - - B::mul(grad, value) - }); - } - } - - match PowF.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (tensor.primitive.clone(), value), - B::powf(tensor.primitive, value), - ), - OpsKind::UnTracked(prep) => prep.finish(B::powf(tensor.primitive, value)), - } + match Sqrt.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::sqrt(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::sqrt(tensor.primitive)), } + } - fn sqrt(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Sqrt; - - impl Backward for Sqrt { - type State = B::TensorPrimitive; + fn abs(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Abs; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let input = ops.state; - let value = B::div_scalar(B::powf(input, -0.5), 2.elem()); + impl Backward for Abs { + type State = B::TensorPrimitive; - B::mul(grad, value) - }); - } - } - - match Sqrt.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(tensor.primitive.clone(), B::sqrt(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::sqrt(tensor.primitive)), - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); + } } - fn abs(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Abs; - - impl Backward for Abs { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); - } - } - - match Abs.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::abs(tensor.primitive.clone()); - let state = B::div(tensor.primitive, output.clone()); - prep.finish(state, output) - } - OpsKind::UnTracked(prep) => prep.finish(B::abs(tensor.primitive)), - } + match Abs.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::abs(tensor.primitive.clone()); + let state = B::div(tensor.primitive, output.clone()); + prep.finish(state, output) + } + OpsKind::UnTracked(prep) => prep.finish(B::abs(tensor.primitive)), } + } - fn cos(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Cos; + fn cos(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Cos; - impl Backward for Cos { - type State = B::TensorPrimitive; + impl Backward for Cos { + type State = B::TensorPrimitive; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let input = ops.state; - let value = B::neg(B::sin(input)); - - B::mul(grad, value) - }); - } - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let input = ops.state; + let value = B::neg(B::sin(input)); - match Cos.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(tensor.primitive.clone(), B::cos(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::cos(tensor.primitive)), - } + B::mul(grad, value) + }); + } } - fn sin(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Sin; - - impl Backward for Sin { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::cos(ops.state); - B::mul(grad, value) - }); - } - } - - match Sin.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(tensor.primitive.clone(), B::sin(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::sin(tensor.primitive)), - } + match Cos.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::cos(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::cos(tensor.primitive)), } + } - fn tanh(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Tanh; + fn sin(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Sin; - impl Backward for Tanh { - type State = B::TensorPrimitive; + impl Backward for Sin { + type State = B::TensorPrimitive; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::add_scalar(B::neg(B::powf(ops.state, 2.0)), 1.elem()); - B::mul(grad, value) - }); - } - } - - match Tanh.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::tanh(tensor.primitive); - prep.finish(output.clone(), output) - } - OpsKind::UnTracked(prep) => prep.finish(B::tanh(tensor.primitive)), - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::cos(ops.state); + B::mul(grad, value) + }); + } } - fn erf(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Erf; + match Sin.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::sin(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::sin(tensor.primitive)), + } + } - impl Backward for Erf { - type State = B::TensorPrimitive; + fn tanh(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Tanh; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let exponent = B::neg(B::powf(ops.state, 2.0)); - let numerator = B::mul_scalar(B::exp(exponent), 2.0.elem()); - let denominator = std::f64::consts::PI.sqrt().elem(); - let value = B::div_scalar(numerator, denominator); + impl Backward for Tanh { + type State = B::TensorPrimitive; - B::mul(grad, value) - }); - } - } - - match Erf.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(tensor.primitive.clone(), B::erf(tensor.primitive)) - } - OpsKind::UnTracked(prep) => prep.finish(B::erf(tensor.primitive)), - } - } - - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - #[derive(new, Debug)] - struct CatStep { - nodes: Vec>, - // The dimension of each tensor along the dim dimension. - // This indicates the number of dimension concatenated for each tensor. - dim_sizes: Vec, - output: NodeRef, - phantom: PhantomData, - dim: usize, - } - - impl Step for CatStep { - fn step(self: Box, grads: &mut Gradients) { - let grad = grads.consume::(&self.output); - let ranges: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect(); - let ranges: [std::ops::Range; D] = ranges.try_into().unwrap(); - - let mut current_index = 0; - - self.nodes - .into_iter() - .zip(self.dim_sizes) - .filter_map(|(node, dim_size)| node.map(|node| (node, dim_size))) - .for_each(|(node, dim_size)| { - let mut ranges = ranges.clone(); - ranges[self.dim] = current_index..dim_size + current_index; - current_index += dim_size; - grads.register::(node, B::slice(grad.clone(), ranges)); - }); - } - - fn node(&self) -> NodeRef { - self.output.clone() - } - } - - let mut nodes = Vec::with_capacity(tensors.len()); - let mut graphs = Vec::with_capacity(tensors.len()); - let mut primitives = Vec::with_capacity(tensors.len()); - let mut dim_sizes = Vec::with_capacity(tensors.len()); - - tensors.into_iter().for_each(|tensor| { - dim_sizes.push(B::shape(&tensor.primitive).dims[dim]); - nodes.push(tensor.node); - primitives.push(tensor.primitive); - graphs.push(tensor.graph); + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::add_scalar(B::neg(B::powf(ops.state, 2.0)), 1.elem()); + B::mul(grad, value) }); + } + } - let requirement = Requirement::from_nodes(&nodes); - - let output = B::cat(primitives, dim); - if requirement.is_none() { - return AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); - } - - let output = AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); - let nodes = nodes - .into_iter() - .map(|node| node.clone_if_require_grad()) - .collect::>(); - - let ops = CatStep::::new(nodes, dim_sizes, output.node.clone(), dim); - output.register_step(ops) + match Tanh.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::tanh(tensor.primitive); + prep.finish(output.clone(), output) + } + OpsKind::UnTracked(prep) => prep.finish(B::tanh(tensor.primitive)), } + } - fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); - prep.finish((index, shape), tensor) - } - OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)), - } - } - fn max_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish((index.clone(), shape), tensor); - - (tensor, index) - } - OpsKind::UnTracked(prep) => { - let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish(tensor); + fn erf(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Erf; - (tensor, index) - } - } - } - fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); - prep.finish((index, shape), tensor) - } - OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)), - } - } - fn min_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish((index.clone(), shape), tensor); - - (tensor, index) - } - OpsKind::UnTracked(prep) => { - let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish(tensor); + impl Backward for Erf { + type State = B::TensorPrimitive; - (tensor, index) - } - } - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let exponent = B::neg(B::powf(ops.state, 2.0)); + let numerator = B::mul_scalar(B::exp(exponent), 2.0.elem()); + let denominator = std::f64::consts::PI.sqrt().elem(); + let value = B::div_scalar(numerator, denominator); - fn into_int( - tensor: FloatTensor, - ) -> as Backend>::IntTensorPrimitive { - B::into_int(tensor.primitive) - } + B::mul(grad, value) + }); + } + } + + match Erf.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::erf(tensor.primitive)), + OpsKind::UnTracked(prep) => prep.finish(B::erf(tensor.primitive)), + } + } + + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + #[derive(new, Debug)] + struct CatStep { + nodes: Vec>, + // The dimension of each tensor along the dim dimension. + // This indicates the number of dimension concatenated for each tensor. + dim_sizes: Vec, + output: NodeRef, + phantom: PhantomData, + dim: usize, + } + + impl Step for CatStep { + fn step(self: Box, grads: &mut Gradients) { + let grad = grads.consume::(&self.output); + let ranges: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect(); + let ranges: [std::ops::Range; D] = ranges.try_into().unwrap(); + + let mut current_index = 0; + + self + .nodes + .into_iter() + .zip(self.dim_sizes) + .filter_map(|(node, dim_size)| node.map(|node| (node, dim_size))) + .for_each(|(node, dim_size)| { + let mut ranges = ranges.clone(); + ranges[self.dim] = current_index..dim_size + current_index; + current_index += dim_size; + grads.register::(node, B::slice(grad.clone(), ranges)); + }); + } + + fn node(&self) -> NodeRef { + self.output.clone() + } + } + + let mut nodes = Vec::with_capacity(tensors.len()); + let mut graphs = Vec::with_capacity(tensors.len()); + let mut primitives = Vec::with_capacity(tensors.len()); + let mut dim_sizes = Vec::with_capacity(tensors.len()); + + tensors.into_iter().for_each(|tensor| { + dim_sizes.push(B::shape(&tensor.primitive).dims[dim]); + nodes.push(tensor.node); + primitives.push(tensor.primitive); + graphs.push(tensor.graph); + }); + + let requirement = Requirement::from_nodes(&nodes); + + let output = B::cat(primitives, dim); + if requirement.is_none() { + return AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); + } + + let output = AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); + let nodes = nodes + .into_iter() + .map(|node| node.clone_if_require_grad()) + .collect::>(); + + let ops = CatStep::::new(nodes, dim_sizes, output.node.clone(), dim); + output.register_step(ops) + } + + fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); + prep.finish((index, shape), tensor) + } + OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)), + } + } + fn max_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish((index.clone(), shape), tensor); + + (tensor, index) + } + OpsKind::UnTracked(prep) => { + let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish(tensor); + + (tensor, index) + } + } + } + fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); + prep.finish((index, shape), tensor) + } + OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)), + } + } + fn min_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish((index.clone(), shape), tensor); + + (tensor, index) + } + OpsKind::UnTracked(prep) => { + let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish(tensor); + + (tensor, index) + } + } + } + + fn into_int( + tensor: FloatTensor, + ) -> as Backend>::IntTensorPrimitive { + B::into_int(tensor.primitive) + } } #[derive(Debug, Clone)] enum BinaryOpsBroadcast { - Broadcasted(Shape, Shape), - None, + Broadcasted(Shape, Shape), + None, } impl BinaryOpsBroadcast { - fn new(lhs: &B::TensorPrimitive, rhs: &B::TensorPrimitive) -> Self { - let shape_lhs = B::shape(lhs); - let shape_rhs = B::shape(rhs); - - for i in 0..D { - if shape_rhs.dims[i] != shape_lhs.dims[i] { - return Self::Broadcasted(shape_lhs, shape_rhs); - } - } + fn new(lhs: &B::TensorPrimitive, rhs: &B::TensorPrimitive) -> Self { + let shape_lhs = B::shape(lhs); + let shape_rhs = B::shape(rhs); - Self::None + for i in 0..D { + if shape_rhs.dims[i] != shape_lhs.dims[i] { + return Self::Broadcasted(shape_lhs, shape_rhs); + } } - fn backward_lhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { - match self { - BinaryOpsBroadcast::Broadcasted(lhs, _rhs) => broadcast_shape::(grad, lhs), - BinaryOpsBroadcast::None => grad, - } + Self::None + } + + fn backward_lhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { + match self { + BinaryOpsBroadcast::Broadcasted(lhs, _rhs) => broadcast_shape::(grad, lhs), + BinaryOpsBroadcast::None => grad, } + } - fn backward_rhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { - match self { - BinaryOpsBroadcast::Broadcasted(_lhs, rhs) => broadcast_shape::(grad, rhs), - BinaryOpsBroadcast::None => grad, - } + fn backward_rhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { + match self { + BinaryOpsBroadcast::Broadcasted(_lhs, rhs) => broadcast_shape::(grad, rhs), + BinaryOpsBroadcast::None => grad, } + } } diff --git a/burn-autodiff/src/tensor.rs b/burn-autodiff/src/tensor.rs index 84c6c80b73..ab5465a5fd 100644 --- a/burn-autodiff/src/tensor.rs +++ b/burn-autodiff/src/tensor.rs @@ -1,106 +1,106 @@ use burn_tensor::backend::Backend; use crate::{ - grads::Gradients, - graph::{ - Node, NodeID, NodeRef, Requirement, {Graph, Step}, - }, + grads::Gradients, + graph::{ + Node, NodeID, NodeRef, Requirement, {Graph, Step}, + }, }; #[derive(Debug, Clone)] pub struct AutodiffTensor { - pub primitive: B::TensorPrimitive, - pub node: NodeRef, - pub graph: Graph, + pub primitive: B::TensorPrimitive, + pub node: NodeRef, + pub graph: Graph, } #[derive(new, Debug)] struct RootStep { - node: NodeRef, + node: NodeRef, } impl Step for RootStep { - fn step(self: Box, _grads: &mut Gradients) { - // Nothing to do - } + fn step(self: Box, _grads: &mut Gradients) { + // Nothing to do + } - fn node(&self) -> NodeRef { - self.node.clone() - } + fn node(&self) -> NodeRef { + self.node.clone() + } } impl AutodiffTensor { - /// Create a new leaf tensor. - pub fn new(primitive: B::TensorPrimitive) -> Self { - let id = NodeID::new(); - let node = Node::new(vec![], 0, id, Requirement::None); + /// Create a new leaf tensor. + pub fn new(primitive: B::TensorPrimitive) -> Self { + let id = NodeID::new(); + let node = Node::new(vec![], 0, id, Requirement::None); - Self { - primitive, - node: node.into(), - graph: Graph::new(), - } + Self { + primitive, + node: node.into(), + graph: Graph::new(), } + } - pub fn is_tracked(&self) -> bool { - !self.node.requirement.is_none() - } + pub fn is_tracked(&self) -> bool { + !self.node.requirement.is_none() + } - /// Mark the tensor as requirering gradients. - /// - /// # Panics - /// - /// It panics if the tensor is non a leaf. - pub fn require_grad(mut self) -> Self { - match self.node.requirement { - Requirement::Grad => self, - Requirement::GradInBackward => { - panic!("Can't convert a non leaf tensor into a tracked tensor") - } - Requirement::None => { - self.node = Node::new(vec![], 0, self.node.id.clone(), Requirement::Grad).into(); - let ops = RootStep::new(self.node.clone()); + /// Mark the tensor as requirering gradients. + /// + /// # Panics + /// + /// It panics if the tensor is non a leaf. + pub fn require_grad(mut self) -> Self { + match self.node.requirement { + Requirement::Grad => self, + Requirement::GradInBackward => { + panic!("Can't convert a non leaf tensor into a tracked tensor") + } + Requirement::None => { + self.node = Node::new(vec![], 0, self.node.id.clone(), Requirement::Grad).into(); + let ops = RootStep::new(self.node.clone()); - self.register_step(ops) - } - } + self.register_step(ops) + } } + } - /// Create a tensor from parent infos. - pub fn from_parents>( - output: B::TensorPrimitive, - parent_nodes: &[NodeRef], - parent_graphs: I, - requirement: Requirement, - ) -> Self { - let graph = parent_graphs - .reduce(|acc, graph| acc.merge(graph)) - .unwrap_or_else(Graph::new); + /// Create a tensor from parent infos. + pub fn from_parents>( + output: B::TensorPrimitive, + parent_nodes: &[NodeRef], + parent_graphs: I, + requirement: Requirement, + ) -> Self { + let graph = parent_graphs + .reduce(|acc, graph| acc.merge(graph)) + .unwrap_or_else(Graph::new); - let order = parent_nodes - .iter() - .map(|node| node.order) - .reduce(usize::max) - .unwrap_or(0) - + 1; + let order = parent_nodes + .iter() + .map(|node| node.order) + .reduce(usize::max) + .unwrap_or(0) + + 1; - let node = Node::new( - parent_nodes.iter().map(|node| node.id.clone()).collect(), - order, - NodeID::new(), - requirement, - ); + let node = Node::new( + parent_nodes.iter().map(|node| node.id.clone()).collect(), + order, + NodeID::new(), + requirement, + ); - Self { - primitive: output, - node: node.into(), - graph, - } + Self { + primitive: output, + node: node.into(), + graph, } + } - /// Register a step into a graph for that tensor. - pub fn register_step(mut self, ops: O) -> Self { - self.graph = self.graph.register(&self.node.id, Box::new(ops)); - self - } + /// Register a step into a graph for that tensor. + pub fn register_step(mut self, ops: O) -> Self { + self.graph = self.graph.register(&self.node.id, Box::new(ops)); + self + } } diff --git a/burn-autodiff/src/tests/abs.rs b/burn-autodiff/src/tests/abs.rs index 02c40d135b..275761a7e1 100644 --- a/burn-autodiff/src/tests/abs.rs +++ b/burn-autodiff/src/tests/abs.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_abs)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_abs() { - let data_1 = Data::::from([[0.0, -1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, -10.0]]); + #[test] + fn should_diff_abs() { + let data_1 = Data::::from([[0.0, -1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, -10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[71.0, 107.0], [71.0, 107.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[84.0, 42.0], [90.0, 54.0]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[71.0, 107.0], [71.0, 107.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[84.0, 42.0], [90.0, 54.0]]), 3); + } } diff --git a/burn-autodiff/src/tests/adaptive_avgpool1d.rs b/burn-autodiff/src/tests/adaptive_avgpool1d.rs index 60caee893b..aaec2a2c1b 100644 --- a/burn-autodiff/src/tests/adaptive_avgpool1d.rs +++ b/burn-autodiff/src/tests/adaptive_avgpool1d.rs @@ -1,48 +1,48 @@ #[burn_tensor_testgen::testgen(ad_adaptive_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool1d_simple() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 5, - output_size: 3, - }; + #[test] + fn test_avg_pool1d_simple() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 5, + output_size: 3, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], - [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], + [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], + ]])); + } - struct AdaptiveAvgPool1dTestCase { - batch_size: usize, - channels: usize, - length: usize, - output_size: usize, - } + struct AdaptiveAvgPool1dTestCase { + batch_size: usize, + channels: usize, + length: usize, + output_size: usize, + } - impl AdaptiveAvgPool1dTestCase { - fn assert_output(self, x_grad: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = adaptive_avg_pool1d(x.clone(), self.output_size); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AdaptiveAvgPool1dTestCase { + fn assert_output(self, x_grad: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = adaptive_avg_pool1d(x.clone(), self.output_size); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); - } + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/adaptive_avgpool2d.rs b/burn-autodiff/src/tests/adaptive_avgpool2d.rs index 4e09a63891..a77974fe2f 100644 --- a/burn-autodiff/src/tests/adaptive_avgpool2d.rs +++ b/burn-autodiff/src/tests/adaptive_avgpool2d.rs @@ -1,64 +1,64 @@ #[burn_tensor_testgen::testgen(ad_adaptive_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool2d_simple() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 5, - width: 3, - output_size_1: 3, - output_size_2: 2, - }; + #[test] + fn test_avg_pool2d_simple() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 5, + width: 3, + output_size_1: 3, + output_size_2: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [0.2500, 0.5000, 0.2500], - [0.4167, 0.8333, 0.4167], - [0.1667, 0.3333, 0.1667], - [0.4167, 0.8333, 0.4167], - [0.2500, 0.5000, 0.2500], - ], - [ - [0.2500, 0.5000, 0.2500], - [0.4167, 0.8333, 0.4167], - [0.1667, 0.3333, 0.1667], - [0.4167, 0.8333, 0.4167], - [0.2500, 0.5000, 0.2500], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [0.2500, 0.5000, 0.2500], + [0.4167, 0.8333, 0.4167], + [0.1667, 0.3333, 0.1667], + [0.4167, 0.8333, 0.4167], + [0.2500, 0.5000, 0.2500], + ], + [ + [0.2500, 0.5000, 0.2500], + [0.4167, 0.8333, 0.4167], + [0.1667, 0.3333, 0.1667], + [0.4167, 0.8333, 0.4167], + [0.2500, 0.5000, 0.2500], + ], + ]])); + } - struct AdaptiveAvgPool2dTestCase { - batch_size: usize, - channels: usize, - height: usize, - width: usize, - output_size_1: usize, - output_size_2: usize, - } + struct AdaptiveAvgPool2dTestCase { + batch_size: usize, + channels: usize, + height: usize, + width: usize, + output_size_1: usize, + output_size_2: usize, + } - impl AdaptiveAvgPool2dTestCase { - fn assert_output(self, x_grad: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AdaptiveAvgPool2dTestCase { + fn assert_output(self, x_grad: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); - } + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/add.rs b/burn-autodiff/src/tests/add.rs index 884ced38af..b21b6e5fd1 100644 --- a/burn-autodiff/src/tests/add.rs +++ b/burn-autodiff/src/tests/add.rs @@ -1,62 +1,62 @@ #[burn_tensor_testgen::testgen(ad_add)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_diff_add() { - let tensor_1 = TestAutodiffTensor::from_floats([2.0, 5.0]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0]).require_grad(); + #[test] + fn should_diff_add() { + let tensor_1 = TestAutodiffTensor::from_floats([2.0, 5.0]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0]).require_grad(); - let tensor_3 = tensor_1.clone() + tensor_2.clone(); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone() + tensor_2.clone(); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); - assert_eq!(grad_2.to_data(), Data::from([1.0, 1.0])); - assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0])); - } + assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.to_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0])); + } - #[test] - fn should_diff_add_scalar() { - let data = Data::from([2.0, 10.0]); + #[test] + fn should_diff_add_scalar() { + let data = Data::from([2.0, 10.0]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().add_scalar(5.0); - let grads = tensor_out.backward(); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().add_scalar(5.0); + let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); - assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0])); - } + assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0])); + } - #[test] - fn test_add_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + #[test] + fn test_add_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - let tensor_4 = tensor_1.clone().add(tensor_2.clone()); - let tensor_5 = tensor_4 - .add(tensor_3) - .add_scalar(5.0) - .add(tensor_1.clone()) - .add(tensor_2.clone()); - let tensor_6 = tensor_1.clone().add(tensor_5); + let tensor_4 = tensor_1.clone().add(tensor_2.clone()); + let tensor_5 = tensor_4 + .add(tensor_3) + .add_scalar(5.0) + .add(tensor_1.clone()) + .add(tensor_2.clone()); + let tensor_6 = tensor_1.clone().add(tensor_5); - let grads = tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[3.0, 3.0], [3.0, 3.0]])); - assert_eq!(grad_2.to_data(), Data::from([[2.0, 2.0], [2.0, 2.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[3.0, 3.0], [3.0, 3.0]])); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 2.0], [2.0, 2.0]])); + } } diff --git a/burn-autodiff/src/tests/aggregation.rs b/burn-autodiff/src/tests/aggregation.rs index a546a01469..d57b182051 100644 --- a/burn-autodiff/src/tests/aggregation.rs +++ b/burn-autodiff/src/tests/aggregation.rs @@ -1,121 +1,121 @@ #[burn_tensor_testgen::testgen(ad_aggregation)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_mean() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[3.5, 9.5], [3.5, 9.5]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-0.75, -0.75], [3.0, 3.0]]), 5); - } - - #[test] - fn should_diff_sum_1() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[14.0, 38.0], [14.0, 38.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-3.0, -3.0], [12.0, 12.0]]), 5); - } - - #[test] - fn should_diff_sum_2() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.clone().sum_dim(1); - let tensor_5 = tensor_4.mul(tensor_3); - - let grads = tensor_5.sum().backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[494.0, 722.0], [2990.0, 4370.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[690.0, 690.0], [958.0, 958.0]]), 3); - } - - #[test] - fn should_diff_mean_dim() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[4.0, 36.0], [3.0, -17.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[9.0, 9.0], [35.5, 35.5]]), 5); - } - - #[test] - fn should_diff_sum_dim() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[8.0, 72.0], [6.0, -34.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[18.0, 18.0], [71.0, 71.0]]), 5); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_mean() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[3.5, 9.5], [3.5, 9.5]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-0.75, -0.75], [3.0, 3.0]]), 5); + } + + #[test] + fn should_diff_sum_1() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[14.0, 38.0], [14.0, 38.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-3.0, -3.0], [12.0, 12.0]]), 5); + } + + #[test] + fn should_diff_sum_2() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.clone().sum_dim(1); + let tensor_5 = tensor_4.mul(tensor_3); + + let grads = tensor_5.sum().backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[494.0, 722.0], [2990.0, 4370.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[690.0, 690.0], [958.0, 958.0]]), 3); + } + + #[test] + fn should_diff_mean_dim() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[4.0, 36.0], [3.0, -17.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[9.0, 9.0], [35.5, 35.5]]), 5); + } + + #[test] + fn should_diff_sum_dim() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[8.0, 72.0], [6.0, -34.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[18.0, 18.0], [71.0, 71.0]]), 5); + } } diff --git a/burn-autodiff/src/tests/avgpool1d.rs b/burn-autodiff/src/tests/avgpool1d.rs index a0224cf11f..feb9891175 100644 --- a/burn-autodiff/src/tests/avgpool1d.rs +++ b/burn-autodiff/src/tests/avgpool1d.rs @@ -1,95 +1,95 @@ #[burn_tensor_testgen::testgen(ad_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool1d_simple() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 1, - kernel_size: 3, - padding: 0, - stride: 1, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_simple() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 1, + kernel_size: 3, + padding: 0, + stride: 1, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - 0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333, - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + 0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333, + ]]])); + } - #[test] - fn test_avg_pool1d_complex() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_complex() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[ - [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], - [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], + [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], + ]])); + } - #[test] - fn test_avg_pool1d_complex_dont_count_pad() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool1d_complex_dont_count_pad() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], - [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], + [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], + ]])); + } - struct AvgPool1dTestCase { - batch_size: usize, - channels: usize, - kernel_size: usize, - padding: usize, - stride: usize, - length: usize, - count_include_pad: bool, - } + struct AvgPool1dTestCase { + batch_size: usize, + channels: usize, + kernel_size: usize, + padding: usize, + stride: usize, + length: usize, + count_include_pad: bool, + } - impl AvgPool1dTestCase { - fn assert_output(self, x_grad: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = avg_pool1d( - x.clone(), - self.kernel_size, - self.stride, - self.padding, - self.count_include_pad, - ); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AvgPool1dTestCase { + fn assert_output(self, x_grad: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = avg_pool1d( + x.clone(), + self.kernel_size, + self.stride, + self.padding, + self.count_include_pad, + ); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); - } + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/avgpool2d.rs b/burn-autodiff/src/tests/avgpool2d.rs index 5ad2aa50a3..aba6936a52 100644 --- a/burn-autodiff/src/tests/avgpool2d.rs +++ b/burn-autodiff/src/tests/avgpool2d.rs @@ -1,120 +1,120 @@ #[burn_tensor_testgen::testgen(ad_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool2d_simple() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - height: 6, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_simple() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + height: 6, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], - [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], - [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], - [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], - [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], - [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], + [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], + [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], + [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], + [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], + [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], + ]]])); + } - #[test] - fn test_avg_pool2d_complex() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_complex() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], - [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], - [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], - [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], + [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], + [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], + [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], + ]]])); + } - #[test] - fn test_avg_pool2d_complex_dont_include_pad() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool2d_complex_dont_include_pad() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], - [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], - [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], - [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], + [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], + [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], + [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], + ]]])); + } - struct AvgPool2dTestCase { - batch_size: usize, - channels: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - height: usize, - width: usize, - count_include_pad: bool, - } + struct AvgPool2dTestCase { + batch_size: usize, + channels: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + height: usize, + width: usize, + count_include_pad: bool, + } - impl AvgPool2dTestCase { - fn assert_output(self, x_grad: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = avg_pool2d( - x.clone(), - [self.kernel_size_1, self.kernel_size_2], - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - self.count_include_pad, - ); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AvgPool2dTestCase { + fn assert_output(self, x_grad: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = avg_pool2d( + x.clone(), + [self.kernel_size_1, self.kernel_size_2], + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + self.count_include_pad, + ); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); - } + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/backward.rs b/burn-autodiff/src/tests/backward.rs index ca25e71da9..e75bc30da6 100644 --- a/burn-autodiff/src/tests/backward.rs +++ b/burn-autodiff/src/tests/backward.rs @@ -1,29 +1,27 @@ #[burn_tensor_testgen::testgen(module_backward)] mod tests { - use super::*; - use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; + use super::*; + use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; - #[test] - fn test_embedding_backward() { - let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = Data::from([[0, 1], [1, 1]]); - let x = Data::from([ - [[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]], - [[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]], - ]); - let weights = Tensor::::from_data(weights).require_grad(); - let indices = Tensor::::from_data(indices); - let x = Tensor::::from_data(x).require_grad(); + #[test] + fn test_embedding_backward() { + let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = Data::from([[0, 1], [1, 1]]); + let x = Data::from([ + [[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]], + [[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]], + ]); + let weights = Tensor::::from_data(weights).require_grad(); + let indices = Tensor::::from_data(indices); + let x = Tensor::::from_data(x).require_grad(); - let output = embedding(weights.clone(), indices); - let output = output.matmul(x); - let grads = output.backward(); + let output = embedding(weights.clone(), indices); + let output = output.matmul(x); + let grads = output.backward(); - let grad = weights.grad(&grads).unwrap(); - let expected = Data::<::FloatElem, 2>::from([ - [3., 9., 7.], - [21., 35., 27.], - ]); - assert_eq!(grad.to_data(), expected); - } + let grad = weights.grad(&grads).unwrap(); + let expected = + Data::<::FloatElem, 2>::from([[3., 9., 7.], [21., 35., 27.]]); + assert_eq!(grad.to_data(), expected); + } } diff --git a/burn-autodiff/src/tests/broadcast.rs b/burn-autodiff/src/tests/broadcast.rs index 324c538d9c..428df71176 100644 --- a/burn-autodiff/src/tests/broadcast.rs +++ b/burn-autodiff/src/tests/broadcast.rs @@ -1,56 +1,56 @@ #[burn_tensor_testgen::testgen(ad_broadcast)] mod tests { - use super::*; - use burn_tensor::{Data, Distribution, Int, Shape, Tensor}; - - #[test] - fn mul_broadcast() { - test_ops_broadcast_backward(|x, y| x * y); - } - - #[test] - fn div_broadcast() { - test_ops_broadcast_backward(|x, y| x / y); - } - - #[test] - fn sub_broadcast() { - test_ops_broadcast_backward(|x, y| x - y); - } - - #[test] - fn add_broadcast() { - test_ops_broadcast_backward(|x, y| x + y); - } - - #[test] - fn matmul_broadcast() { - test_ops_broadcast_backward(|x, y| x.matmul(y)); - } - - #[test] - fn mask_where_broadcast() { - test_ops_broadcast_backward(|x, y| x.mask_where(y.clone().equal_elem(4), y)); - } - - fn test_ops_broadcast_backward(func: F) - where - F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>, - { - let w = TestAutodiffTensor::zeros([16, 5, 5]).require_grad(); - let x = TestAutodiffTensor::zeros([4, 5, 5]).require_grad(); - - // Slice isn't a broadcastable operation, so it will fail when the previous backward pass - // of an operation that support broadcast doesn't support it during the backward pass. - let y = func(w.clone().slice([0..1]), x.clone()); - - // Will panic if broadcast isn't supported! - let grads = y.backward(); - - let w_grad = w.grad(&grads).unwrap(); - let x_grad = x.grad(&grads).unwrap(); - - assert_eq!(w_grad.shape(), w.shape()); - assert_eq!(x_grad.shape(), x.shape()); - } + use super::*; + use burn_tensor::{Data, Distribution, Int, Shape, Tensor}; + + #[test] + fn mul_broadcast() { + test_ops_broadcast_backward(|x, y| x * y); + } + + #[test] + fn div_broadcast() { + test_ops_broadcast_backward(|x, y| x / y); + } + + #[test] + fn sub_broadcast() { + test_ops_broadcast_backward(|x, y| x - y); + } + + #[test] + fn add_broadcast() { + test_ops_broadcast_backward(|x, y| x + y); + } + + #[test] + fn matmul_broadcast() { + test_ops_broadcast_backward(|x, y| x.matmul(y)); + } + + #[test] + fn mask_where_broadcast() { + test_ops_broadcast_backward(|x, y| x.mask_where(y.clone().equal_elem(4), y)); + } + + fn test_ops_broadcast_backward(func: F) + where + F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>, + { + let w = TestAutodiffTensor::zeros([16, 5, 5]).require_grad(); + let x = TestAutodiffTensor::zeros([4, 5, 5]).require_grad(); + + // Slice isn't a broadcastable operation, so it will fail when the previous backward pass + // of an operation that support broadcast doesn't support it during the backward pass. + let y = func(w.clone().slice([0..1]), x.clone()); + + // Will panic if broadcast isn't supported! + let grads = y.backward(); + + let w_grad = w.grad(&grads).unwrap(); + let x_grad = x.grad(&grads).unwrap(); + + assert_eq!(w_grad.shape(), w.shape()); + assert_eq!(x_grad.shape(), x.shape()); + } } diff --git a/burn-autodiff/src/tests/cat.rs b/burn-autodiff/src/tests/cat.rs index 3a27c42135..3668fcbbd5 100644 --- a/burn-autodiff/src/tests/cat.rs +++ b/burn-autodiff/src/tests/cat.rs @@ -1,76 +1,76 @@ #[burn_tensor_testgen::testgen(ad_cat)] mod tests { - use super::*; - use burn_tensor::{Data, Float}; + use super::*; + use burn_tensor::{Data, Float}; - #[test] - fn should_diff_cat() { - let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0]]).require_grad(); + #[test] + fn should_diff_cat() { + let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0]]).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - let mut tensor_1_list = Vec::new(); - let mut tensor_2_list = Vec::new(); + let mut tensor_1_list = Vec::new(); + let mut tensor_2_list = Vec::new(); - for i in 0..2 { - tensor_1_list.push(tensor_1.clone().slice([i..i + 1])); - tensor_2_list.push(tensor_2.clone().slice([i..i + 1])); - } + for i in 0..2 { + tensor_1_list.push(tensor_1.clone().slice([i..i + 1])); + tensor_2_list.push(tensor_2.clone().slice([i..i + 1])); + } - let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0); - let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0); + let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0); + let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0); - let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone()); - let grads = tensor_3_cat.backward(); + let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone()); + let grads = tensor_3_cat.backward(); - let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]); - let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]); + let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]); + let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]); - let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]); - let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]); + let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]); + let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]); - grad_1 - .clone() - .slice([0..1]) - .to_data() - .assert_approx_eq(&grad_1_slice_1.to_data(), 3); - grad_1 - .slice([1..2]) - .to_data() - .assert_approx_eq(&grad_1_slice_2.to_data(), 3); + grad_1 + .clone() + .slice([0..1]) + .to_data() + .assert_approx_eq(&grad_1_slice_1.to_data(), 3); + grad_1 + .slice([1..2]) + .to_data() + .assert_approx_eq(&grad_1_slice_2.to_data(), 3); - grad_2 - .clone() - .slice([0..1]) - .to_data() - .assert_approx_eq(&grad_2_slice_1.to_data(), 3); - grad_2 - .slice([1..2]) - .to_data() - .assert_approx_eq(&grad_2_slice_2.to_data(), 3); - } + grad_2 + .clone() + .slice([0..1]) + .to_data() + .assert_approx_eq(&grad_2_slice_1.to_data(), 3); + grad_2 + .slice([1..2]) + .to_data() + .assert_approx_eq(&grad_2_slice_2.to_data(), 3); + } - #[test] - fn should_diff_cat_more_than_1_dim() { - let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); - let tensor_2 = - TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]]).require_grad(); + #[test] + fn should_diff_cat_more_than_1_dim() { + let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); + let tensor_2 = + TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]]).require_grad(); - // Concat a tensor [2, 2] with another tensor [3, 2] along dim 0. - // The resulting tensor should be [5, 2] - let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0); - assert_eq!(tensor_3.dims(), [5, 2]); - let grads = tensor_3.backward(); + // Concat a tensor [2, 2] with another tensor [3, 2] along dim 0. + // The resulting tensor should be [5, 2] + let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0); + assert_eq!(tensor_3.dims(), [5, 2]); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(tensor_1.dims(), grad_1.dims()); - assert_eq!(tensor_2.dims(), grad_2.dims()); - } + assert_eq!(tensor_1.dims(), grad_1.dims()); + assert_eq!(tensor_2.dims(), grad_2.dims()); + } } diff --git a/burn-autodiff/src/tests/complex.rs b/burn-autodiff/src/tests/complex.rs index aa15d53213..40f7db9fa5 100644 --- a/burn-autodiff/src/tests/complex.rs +++ b/burn-autodiff/src/tests/complex.rs @@ -1,81 +1,81 @@ #[burn_tensor_testgen::testgen(ad_complex)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_full_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.matmul(tensor_1.clone()); - let tensor_5 = tensor_4.mul(tensor_2.clone()); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[593., 463.0], [487.0, 539.0]]) - ); - assert_eq!( - grad_2.to_data(), - Data::from([[734.0, 294.0], [1414.0, 242.0]]) - ); - } - - #[test] - fn should_diff_full_complex_2() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.matmul(tensor_1.clone()); - let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone()); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[166.0, 110.0], [212.0, 156.0]]) - ); - assert_eq!(grad_2.to_data(), Data::from([[113.0, 141.0], [33.0, 41.0]])); - } - - #[test] - fn should_diff_full_complex_3() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.matmul(tensor_1.clone()); - let tensor_5 = tensor_4.clone().sub(tensor_2.clone()); - let tensor_6 = tensor_5.add(tensor_4); - - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[332.0, 220.0], [424.0, 312.0]]) - ); - assert_eq!(grad_2.to_data(), Data::from([[223.0, 279.0], [63.0, 79.0]])); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_full_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.matmul(tensor_1.clone()); + let tensor_5 = tensor_4.mul(tensor_2.clone()); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[593., 463.0], [487.0, 539.0]]) + ); + assert_eq!( + grad_2.to_data(), + Data::from([[734.0, 294.0], [1414.0, 242.0]]) + ); + } + + #[test] + fn should_diff_full_complex_2() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.matmul(tensor_1.clone()); + let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone()); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[166.0, 110.0], [212.0, 156.0]]) + ); + assert_eq!(grad_2.to_data(), Data::from([[113.0, 141.0], [33.0, 41.0]])); + } + + #[test] + fn should_diff_full_complex_3() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.matmul(tensor_1.clone()); + let tensor_5 = tensor_4.clone().sub(tensor_2.clone()); + let tensor_6 = tensor_5.add(tensor_4); + + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[332.0, 220.0], [424.0, 312.0]]) + ); + assert_eq!(grad_2.to_data(), Data::from([[223.0, 279.0], [63.0, 79.0]])); + } } diff --git a/burn-autodiff/src/tests/conv1d.rs b/burn-autodiff/src/tests/conv1d.rs index 3ff44aa0d0..55a2d473bb 100644 --- a/burn-autodiff/src/tests/conv1d.rs +++ b/burn-autodiff/src/tests/conv1d.rs @@ -1,240 +1,240 @@ #[burn_tensor_testgen::testgen(ad_conv1d)] mod tests { - use super::*; - use burn_tensor::{module::conv1d, ops::ConvOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv1d, ops::ConvOptions, Data, Shape}; - #[test] - fn test_conv1d_basic() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[14., 24., 24., 18.], [26., 42., 42., 30.]], - [[14., 24., 24., 18.], [26., 42., 42., 30.]], - ]), - weight: TestTensor::from_floats([ - [[30., 44., 36.], [54., 76., 60.]], - [[30., 44., 36.], [54., 76., 60.]], - ]), - bias: TestTensor::from_floats([8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_basic() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[14., 24., 24., 18.], [26., 42., 42., 30.]], + [[14., 24., 24., 18.], [26., 42., 42., 30.]], + ]), + weight: TestTensor::from_floats([ + [[30., 44., 36.], [54., 76., 60.]], + [[30., 44., 36.], [54., 76., 60.]], + ]), + bias: TestTensor::from_floats([8., 8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_different_channels() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 3, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[39., 63., 63., 45.], [57., 90., 90., 63.]], - [[39., 63., 63., 45.], [57., 90., 90., 63.]], - ]), - weight: TestTensor::from_floats([ - [[30., 44., 36.], [54., 76., 60.]], - [[30., 44., 36.], [54., 76., 60.]], - [[30., 44., 36.], [54., 76., 60.]], - ]), - bias: TestTensor::from_floats([8., 8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_different_channels() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 3, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[39., 63., 63., 45.], [57., 90., 90., 63.]], + [[39., 63., 63., 45.], [57., 90., 90., 63.]], + ]), + weight: TestTensor::from_floats([ + [[30., 44., 36.], [54., 76., 60.]], + [[30., 44., 36.], [54., 76., 60.]], + [[30., 44., 36.], [54., 76., 60.]], + ]), + bias: TestTensor::from_floats([8., 8., 8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_with_padding() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 2, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[24., 24., 24., 24.], [42., 42., 42., 42.]], - [[24., 24., 24., 24.], [42., 42., 42., 42.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [76., 76., 76.]], - [[44., 44., 44.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([12., 12.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_with_padding() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 2, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[24., 24., 24., 24.], [42., 42., 42., 42.]], + [[24., 24., 24., 24.], [42., 42., 42., 42.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [76., 76., 76.]], + [[44., 44., 44.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([12., 12.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_with_stride() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 2, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[8., 16., 8., 10.], [14., 28., 14., 16.]], - [[8., 16., 8., 10.], [14., 28., 14., 16.]], - ]), - weight: TestTensor::from_floats([ - [[10., 20., 24.], [18., 36., 40.]], - [[10., 20., 24.], [18., 36., 40.]], - ]), - bias: TestTensor::from_floats([4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_with_stride() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[8., 16., 8., 10.], [14., 28., 14., 16.]], + [[8., 16., 8., 10.], [14., 28., 14., 16.]], + ]), + weight: TestTensor::from_floats([ + [[10., 20., 24.], [18., 36., 40.]], + [[10., 20., 24.], [18., 36., 40.]], + ]), + bias: TestTensor::from_floats([4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_dilation() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 2, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[6., 8., 8., 10.], [12., 14., 14., 16.]], - [[6., 8., 8., 10.], [12., 14., 14., 16.]], - ]), - weight: TestTensor::from_floats([ - [[8., 22., 14.], [16., 38., 22.]], - [[8., 22., 14.], [16., 38., 22.]], - ]), - bias: TestTensor::from_floats([4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_dilation() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 2, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[6., 8., 8., 10.], [12., 14., 14., 16.]], + [[6., 8., 8., 10.], [12., 14., 14., 16.]], + ]), + weight: TestTensor::from_floats([ + [[8., 22., 14.], [16., 38., 22.]], + [[8., 22., 14.], [16., 38., 22.]], + ]), + bias: TestTensor::from_floats([4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_groups() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 2, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[1., 3., 3., 3.], [7., 12., 12., 9.]], - [[1., 3., 3., 3.], [7., 12., 12., 9.]], - ]), - weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]]), - bias: TestTensor::from_floats([8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_groups() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 2, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[1., 3., 3., 3.], [7., 12., 12., 9.]], + [[1., 3., 3., 3.], [7., 12., 12., 9.]], + ]), + weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]]), + bias: TestTensor::from_floats([8., 8.]), + }; + test.assert_grads(grads); + } - struct Conv1dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size: usize, - padding: usize, - stride: usize, - dilation: usize, - groups: usize, - length: usize, - } + struct Conv1dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size: usize, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + length: usize, + } - struct Grads { - x: TestTensor<3>, - weight: TestTensor<3>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<3>, + weight: TestTensor<3>, + bias: TestTensor<1>, + } - impl Conv1dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size, - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); + impl Conv1dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size, + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); - let output = conv1d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), - ); - let grads = output.backward(); + let output = conv1d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/conv2d.rs b/burn-autodiff/src/tests/conv2d.rs index 7e6d2a801a..9afe7a86f4 100644 --- a/burn-autodiff/src/tests/conv2d.rs +++ b/burn-autodiff/src/tests/conv2d.rs @@ -1,750 +1,750 @@ #[burn_tensor_testgen::testgen(ad_conv2d)] mod tests { - use super::*; - use burn_tensor::{module::conv2d, ops::ConvOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv2d, ops::ConvOptions, Data, Shape}; - #[test] - fn test_conv2d_basic() { - let test = Conv2dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [88., 138., 138., 96.], - [150., 234., 234., 162.], - [150., 234., 234., 162.], - [112., 174., 174., 120.], - ], - [ - [160., 246., 246., 168.], - [258., 396., 396., 270.], - [258., 396., 396., 270.], - [184., 282., 282., 192.], - ], - ], - [ - [ - [88., 138., 138., 96.], - [150., 234., 234., 162.], - [150., 234., 234., 162.], - [112., 174., 174., 120.], - ], - [ - [160., 246., 246., 168.], - [258., 396., 396., 270.], - [258., 396., 396., 270.], - [184., 282., 282., 192.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - ]), - bias: TestTensor::from_floats([32., 32.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_basic() { + let test = Conv2dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [88., 138., 138., 96.], + [150., 234., 234., 162.], + [150., 234., 234., 162.], + [112., 174., 174., 120.], + ], + [ + [160., 246., 246., 168.], + [258., 396., 396., 270.], + [258., 396., 396., 270.], + [184., 282., 282., 192.], + ], + ], + [ + [ + [88., 138., 138., 96.], + [150., 234., 234., 162.], + [150., 234., 234., 162.], + [112., 174., 174., 120.], + ], + [ + [160., 246., 246., 168.], + [258., 396., 396., 270.], + [258., 396., 396., 270.], + [184., 282., 282., 192.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + ]), + bias: TestTensor::from_floats([32., 32.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_channels() { - let test = Conv2dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 3, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [240., 369., 369., 252.], - [387., 594., 594., 405.], - [387., 594., 594., 405.], - [276., 423., 423., 288.], - ], - [ - [348., 531., 531., 360.], - [549., 837., 837., 567.], - [549., 837., 837., 567.], - [384., 585., 585., 396.], - ], - ], - [ - [ - [240., 369., 369., 252.], - [387., 594., 594., 405.], - [387., 594., 594., 405.], - [276., 423., 423., 288.], - ], - [ - [348., 531., 531., 360.], - [549., 837., 837., 567.], - [549., 837., 837., 567.], - [384., 585., 585., 396.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - ]), - bias: TestTensor::from_floats([32., 32., 32.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_channels() { + let test = Conv2dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [240., 369., 369., 252.], + [387., 594., 594., 405.], + [387., 594., 594., 405.], + [276., 423., 423., 288.], + ], + [ + [348., 531., 531., 360.], + [549., 837., 837., 567.], + [549., 837., 837., 567.], + [384., 585., 585., 396.], + ], + ], + [ + [ + [240., 369., 369., 252.], + [387., 594., 594., 405.], + [387., 594., 594., 405.], + [276., 423., 423., 288.], + ], + [ + [348., 531., 531., 360.], + [549., 837., 837., 567.], + [549., 837., 837., 567.], + [384., 585., 585., 396.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + ]), + bias: TestTensor::from_floats([32., 32., 32.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_kernel_size() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [116., 180., 192., 132.], - [198., 306., 324., 222.], - [198., 306., 324., 222.], - [148., 228., 240., 164.], - ], - [ - [212., 324., 336., 228.], - [342., 522., 540., 366.], - [342., 522., 540., 366.], - [244., 372., 384., 260.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [ - [27., 45., 54., 39.], - [52., 84., 96., 68.], - [51., 81., 90., 63.], - ], - [ - [123., 189., 198., 135.], - [180., 276., 288., 196.], - [147., 225., 234., 159.], - ], - ], - [ - [ - [27., 45., 54., 39.], - [52., 84., 96., 68.], - [51., 81., 90., 63.], - ], - [ - [123., 189., 198., 135.], - [180., 276., 288., 196.], - [147., 225., 234., 159.], - ], - ], - ]), - bias: TestTensor::from_floats([12., 12.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_kernel_size() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [116., 180., 192., 132.], + [198., 306., 324., 222.], + [198., 306., 324., 222.], + [148., 228., 240., 164.], + ], + [ + [212., 324., 336., 228.], + [342., 522., 540., 366.], + [342., 522., 540., 366.], + [244., 372., 384., 260.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [ + [27., 45., 54., 39.], + [52., 84., 96., 68.], + [51., 81., 90., 63.], + ], + [ + [123., 189., 198., 135.], + [180., 276., 288., 196.], + [147., 225., 234., 159.], + ], + ], + [ + [ + [27., 45., 54., 39.], + [52., 84., 96., 68.], + [51., 81., 90., 63.], + ], + [ + [123., 189., 198., 135.], + [180., 276., 288., 196.], + [147., 225., 234., 159.], + ], + ], + ]), + bias: TestTensor::from_floats([12., 12.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_padding() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [138., 138., 138., 138.], - [234., 234., 234., 234.], - [234., 234., 234., 234.], - [174., 174., 174., 174.], - ], - [ - [246., 246., 246., 246.], - [396., 396., 396., 396.], - [396., 396., 396., 396.], - [282., 282., 282., 282.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], - [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], - ], - [ - [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], - [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], - ], - ]), - bias: TestTensor::from_floats([24., 24.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_padding() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [138., 138., 138., 138.], + [234., 234., 234., 234.], + [234., 234., 234., 234.], + [174., 174., 174., 174.], + ], + [ + [246., 246., 246., 246.], + [396., 396., 396., 396.], + [396., 396., 396., 396.], + [282., 282., 282., 282.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], + [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], + ], + [ + [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], + [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], + ], + ]), + bias: TestTensor::from_floats([24., 24.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_width() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 5, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [88., 138., 138., 138., 96.], - [150., 234., 234., 234., 162.], - [150., 234., 234., 234., 162.], - [112., 174., 174., 174., 120.], - ], - [ - [160., 246., 246., 246., 168.], - [258., 396., 396., 396., 270.], - [258., 396., 396., 396., 270.], - [184., 282., 282., 282., 192.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], - [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], - ], - [ - [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], - [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], - ], - ]), - bias: TestTensor::from_floats([20., 20.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_width() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 5, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [88., 138., 138., 138., 96.], + [150., 234., 234., 234., 162.], + [150., 234., 234., 234., 162.], + [112., 174., 174., 174., 120.], + ], + [ + [160., 246., 246., 246., 168.], + [258., 396., 396., 396., 270.], + [258., 396., 396., 396., 270.], + [184., 282., 282., 282., 192.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], + [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], + ], + [ + [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], + [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], + ], + ]), + bias: TestTensor::from_floats([20., 20.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_stride_2() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 2, - stride_2: 2, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 6, - width: 6, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [26., 52., 26., 52., 26., 28.], - [52., 104., 52., 104., 52., 56.], - [26., 52., 26., 52., 26., 28.], - [52., 104., 52., 104., 52., 56.], - [26., 52., 26., 52., 26., 28.], - [32., 64., 32., 64., 32., 34.], - ], - [ - [44., 88., 44., 88., 44., 46.], - [88., 176., 88., 176., 88., 92.], - [44., 88., 44., 88., 44., 46.], - [88., 176., 88., 176., 88., 92.], - [44., 88., 44., 88., 44., 46.], - [50., 100., 50., 100., 50., 52.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], - [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], - ], - [ - [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], - [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], - ], - ]), - bias: TestTensor::from_floats([9., 9.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_stride_2() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 2, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 6, + width: 6, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [26., 52., 26., 52., 26., 28.], + [52., 104., 52., 104., 52., 56.], + [26., 52., 26., 52., 26., 28.], + [52., 104., 52., 104., 52., 56.], + [26., 52., 26., 52., 26., 28.], + [32., 64., 32., 64., 32., 34.], + ], + [ + [44., 88., 44., 88., 44., 46.], + [88., 176., 88., 176., 88., 92.], + [44., 88., 44., 88., 44., 46.], + [88., 176., 88., 176., 88., 92.], + [44., 88., 44., 88., 44., 46.], + [50., 100., 50., 100., 50., 52.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], + [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], + ], + [ + [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], + [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], + ], + ]), + bias: TestTensor::from_floats([9., 9.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_stride() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 3, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 8, - width: 8, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [50., 78., 78., 78., 78., 78., 78., 54.], - [62., 96., 96., 96., 96., 96., 96., 66.], - [38., 60., 60., 60., 60., 60., 60., 42.], - [50., 78., 78., 78., 78., 78., 78., 54.], - [62., 96., 96., 96., 96., 96., 96., 66.], - [38., 60., 60., 60., 60., 60., 60., 42.], - [50., 78., 78., 78., 78., 78., 78., 54.], - [62., 96., 96., 96., 96., 96., 96., 66.], - ], - [ - [86., 132., 132., 132., 132., 132., 132., 90.], - [98., 150., 150., 150., 150., 150., 150., 102.], - [74., 114., 114., 114., 114., 114., 114., 78.], - [86., 132., 132., 132., 132., 132., 132., 90.], - [98., 150., 150., 150., 150., 150., 150., 102.], - [74., 114., 114., 114., 114., 114., 114., 78.], - [86., 132., 132., 132., 132., 132., 132., 90.], - [98., 150., 150., 150., 150., 150., 150., 102.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], - [ - [1330., 1528., 1344.], - [1911., 2196., 1932.], - [2079., 2388., 2100.], - ], - ], - [ - [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], - [ - [1330., 1528., 1344.], - [1911., 2196., 1932.], - [2079., 2388., 2100.], - ], - ], - ]), - bias: TestTensor::from_floats([24., 24.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_stride() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 3, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 8, + width: 8, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [50., 78., 78., 78., 78., 78., 78., 54.], + [62., 96., 96., 96., 96., 96., 96., 66.], + [38., 60., 60., 60., 60., 60., 60., 42.], + [50., 78., 78., 78., 78., 78., 78., 54.], + [62., 96., 96., 96., 96., 96., 96., 66.], + [38., 60., 60., 60., 60., 60., 60., 42.], + [50., 78., 78., 78., 78., 78., 78., 54.], + [62., 96., 96., 96., 96., 96., 96., 66.], + ], + [ + [86., 132., 132., 132., 132., 132., 132., 90.], + [98., 150., 150., 150., 150., 150., 150., 102.], + [74., 114., 114., 114., 114., 114., 114., 78.], + [86., 132., 132., 132., 132., 132., 132., 90.], + [98., 150., 150., 150., 150., 150., 150., 102.], + [74., 114., 114., 114., 114., 114., 114., 78.], + [86., 132., 132., 132., 132., 132., 132., 90.], + [98., 150., 150., 150., 150., 150., 150., 102.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], + [ + [1330., 1528., 1344.], + [1911., 2196., 1932.], + [2079., 2388., 2100.], + ], + ], + [ + [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], + [ + [1330., 1528., 1344.], + [1911., 2196., 1932.], + [2079., 2388., 2100.], + ], + ], + ]), + bias: TestTensor::from_floats([24., 24.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_dilation_2() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 2, - dilation_2: 2, - groups: 1, - height: 6, - width: 6, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [18., 38., 38., 42., 42., 22.], - [42., 88., 88., 96., 96., 50.], - [42., 88., 88., 96., 96., 50.], - [54., 112., 112., 120., 120., 62.], - [54., 112., 112., 120., 120., 62.], - [30., 62., 62., 66., 66., 34.], - ], - [ - [36., 74., 74., 78., 78., 40.], - [78., 160., 160., 168., 168., 86.], - [78., 160., 160., 168., 168., 86.], - [90., 184., 184., 192., 192., 98.], - [90., 184., 184., 192., 192., 98.], - [48., 98., 98., 102., 102., 52.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], - [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], - ], - [ - [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], - [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], - ], - ]), - bias: TestTensor::from_floats([16., 16.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_dilation_2() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 2, + dilation_2: 2, + groups: 1, + height: 6, + width: 6, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [18., 38., 38., 42., 42., 22.], + [42., 88., 88., 96., 96., 50.], + [42., 88., 88., 96., 96., 50.], + [54., 112., 112., 120., 120., 62.], + [54., 112., 112., 120., 120., 62.], + [30., 62., 62., 66., 66., 34.], + ], + [ + [36., 74., 74., 78., 78., 40.], + [78., 160., 160., 168., 168., 86.], + [78., 160., 160., 168., 168., 86.], + [90., 184., 184., 192., 192., 98.], + [90., 184., 184., 192., 192., 98.], + [48., 98., 98., 102., 102., 52.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], + [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], + ], + [ + [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], + [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], + ], + ]), + bias: TestTensor::from_floats([16., 16.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_dilation() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 2, - dilation_2: 3, - groups: 1, - height: 6, - width: 6, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [18., 0., 20., 20., 0., 22.], - [42., 0., 46., 46., 0., 50.], - [42., 0., 46., 46., 0., 50.], - [54., 0., 58., 58., 0., 62.], - [54., 0., 58., 58., 0., 62.], - [30., 0., 32., 32., 0., 34.], - ], - [ - [36., 0., 38., 38., 0., 40.], - [78., 0., 82., 82., 0., 86.], - [78., 0., 82., 82., 0., 86.], - [90., 0., 94., 94., 0., 98.], - [90., 0., 94., 94., 0., 98.], - [48., 0., 50., 50., 0., 52.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], - [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], - ], - [ - [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], - [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], - ], - ]), - bias: TestTensor::from_floats([8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_dilation() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 2, + dilation_2: 3, + groups: 1, + height: 6, + width: 6, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [18., 0., 20., 20., 0., 22.], + [42., 0., 46., 46., 0., 50.], + [42., 0., 46., 46., 0., 50.], + [54., 0., 58., 58., 0., 62.], + [54., 0., 58., 58., 0., 62.], + [30., 0., 32., 32., 0., 34.], + ], + [ + [36., 0., 38., 38., 0., 40.], + [78., 0., 82., 82., 0., 86.], + [78., 0., 82., 82., 0., 86.], + [90., 0., 94., 94., 0., 98.], + [90., 0., 94., 94., 0., 98.], + [48., 0., 50., 50., 0., 52.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], + [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], + ], + [ + [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], + [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], + ], + ]), + bias: TestTensor::from_floats([8., 8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_groups() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 5, - width: 5, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [0., 1., 3., 3., 2.], - [3., 8., 15., 12., 7.], - [9., 21., 36., 27., 15.], - [9., 20., 33., 24., 13.], - [6., 13., 21., 15., 8.], - ], - [ - [9., 19., 30., 21., 11.], - [21., 44., 69., 48., 25.], - [36., 75., 117., 81., 42.], - [27., 56., 87., 60., 31.], - [15., 31., 48., 33., 17.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]], - [[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]], - ]), - bias: TestTensor::from_floats([9., 9.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_groups() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 5, + width: 5, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [0., 1., 3., 3., 2.], + [3., 8., 15., 12., 7.], + [9., 21., 36., 27., 15.], + [9., 20., 33., 24., 13.], + [6., 13., 21., 15., 8.], + ], + [ + [9., 19., 30., 21., 11.], + [21., 44., 69., 48., 25.], + [36., 75., 117., 81., 42.], + [27., 56., 87., 60., 31.], + [15., 31., 48., 33., 17.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]], + [[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]], + ]), + bias: TestTensor::from_floats([9., 9.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_groups_different_channels() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 3, - channels_out: 6, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 3, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [9., 20., 24., 13.], - [24., 52., 60., 32.], - [36., 76., 84., 44.], - [21., 44., 48., 25.], - ], - [ - [45., 92., 96., 49.], - [96., 196., 204., 104.], - [108., 220., 228., 116.], - [57., 116., 120., 61.], - ], - [ - [81., 164., 168., 85.], - [168., 340., 348., 176.], - [180., 364., 372., 188.], - [93., 188., 192., 97.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], - [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], - [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], - [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], - [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], - [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], - ]), - bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_groups_different_channels() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 6, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 3, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [9., 20., 24., 13.], + [24., 52., 60., 32.], + [36., 76., 84., 44.], + [21., 44., 48., 25.], + ], + [ + [45., 92., 96., 49.], + [96., 196., 204., 104.], + [108., 220., 228., 116.], + [57., 116., 120., 61.], + ], + [ + [81., 164., 168., 85.], + [168., 340., 348., 176.], + [180., 364., 372., 188.], + [93., 188., 192., 97.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], + [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], + [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], + [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], + [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], + [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], + ]), + bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_complex() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 3, - kernel_size_1: 2, - kernel_size_2: 3, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - dilation_1: 2, - dilation_2: 3, - groups: 1, - height: 4, - width: 5, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [36., 39., 0., 39., 42.], - [81., 87., 0., 87., 93.], - [81., 87., 0., 87., 93.], - [45., 48., 0., 48., 51.], - ], - [ - [54., 57., 0., 57., 60.], - [117., 123., 0., 123., 129.], - [117., 123., 0., 123., 129.], - [63., 66., 0., 66., 69.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[15., 42., 27.], [30., 72., 42.]], - [[75., 162., 87.], [90., 192., 102.]], - ], - [ - [[15., 42., 27.], [30., 72., 42.]], - [[75., 162., 87.], [90., 192., 102.]], - ], - [ - [[15., 42., 27.], [30., 72., 42.]], - [[75., 162., 87.], [90., 192., 102.]], - ], - ]), - bias: TestTensor::from_floats([8., 8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_complex() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 2, + kernel_size_2: 3, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + dilation_1: 2, + dilation_2: 3, + groups: 1, + height: 4, + width: 5, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [36., 39., 0., 39., 42.], + [81., 87., 0., 87., 93.], + [81., 87., 0., 87., 93.], + [45., 48., 0., 48., 51.], + ], + [ + [54., 57., 0., 57., 60.], + [117., 123., 0., 123., 129.], + [117., 123., 0., 123., 129.], + [63., 66., 0., 66., 69.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[15., 42., 27.], [30., 72., 42.]], + [[75., 162., 87.], [90., 192., 102.]], + ], + [ + [[15., 42., 27.], [30., 72., 42.]], + [[75., 162., 87.], [90., 192., 102.]], + ], + [ + [[15., 42., 27.], [30., 72., 42.]], + [[75., 162., 87.], [90., 192., 102.]], + ], + ]), + bias: TestTensor::from_floats([8., 8., 8.]), + }; + test.assert_grads(grads); + } - struct Conv2dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - dilation_1: usize, - dilation_2: usize, - groups: usize, - height: usize, - width: usize, - } + struct Conv2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + groups: usize, + height: usize, + width: usize, + } - struct Grads { - x: TestTensor<4>, - weight: TestTensor<4>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<4>, + weight: TestTensor<4>, + bias: TestTensor<1>, + } - impl Conv2dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size_1, - self.kernel_size_2, - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = conv2d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvOptions::new( - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - [self.dilation_1, self.dilation_2], - self.groups, - ), - ); - let grads = output.backward(); + impl Conv2dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = conv2d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], + self.groups, + ), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); - } + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/conv_transpose1d.rs b/burn-autodiff/src/tests/conv_transpose1d.rs index 4f05a447d8..948e51ad15 100644 --- a/burn-autodiff/src/tests/conv_transpose1d.rs +++ b/burn-autodiff/src/tests/conv_transpose1d.rs @@ -1,253 +1,253 @@ #[burn_tensor_testgen::testgen(ad_conv_transpose1d)] mod tests { - use super::*; - use burn_tensor::{module::conv_transpose1d, ops::ConvTransposeOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv_transpose1d, ops::ConvTransposeOptions, Data, Shape}; - #[test] - fn test_conv_transpose1d_basic() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], - [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], - ]), - weight: TestTensor::from_floats([ - [[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]], - [[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]], - ]), - bias: TestTensor::from_floats([12., 12.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_basic() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], + [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], + ]), + weight: TestTensor::from_floats([ + [[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]], + [[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]], + ]), + bias: TestTensor::from_floats([12., 12.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_padding() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 2, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[7., 12., 8., 3.], [19., 36., 32., 15.]], - [[7., 12., 8., 3.], [19., 36., 32., 15.]], - ]), - weight: TestTensor::from_floats([ - [[26., 22., 18.], [26., 22., 18.]], - [[42., 38., 34.], [42., 38., 34.]], - ]), - bias: TestTensor::from_floats([4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_padding() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 2, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[7., 12., 8., 3.], [19., 36., 32., 15.]], + [[7., 12., 8., 3.], [19., 36., 32., 15.]], + ]), + weight: TestTensor::from_floats([ + [[26., 22., 18.], [26., 22., 18.]], + [[42., 38., 34.], [42., 38., 34.]], + ]), + bias: TestTensor::from_floats([4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_stride() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 0, - stride: 2, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [44., 44., 44.]], - [[76., 76., 76.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([18., 18.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_stride() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 0, + stride: 2, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [44., 44., 44.]], + [[76., 76., 76.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([18., 18.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_stride_padding_out() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 1, - stride: 2, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [44., 44., 44.]], - [[76., 76., 76.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([20., 20.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_stride_padding_out() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 1, + stride: 2, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [44., 44., 44.]], + [[76., 76., 76.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([20., 20.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_dilation() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 0, - stride: 1, - dilation: 2, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [44., 44., 44.]], - [[76., 76., 76.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([16., 16.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_dilation() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 0, + stride: 1, + dilation: 2, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [44., 44., 44.]], + [[76., 76., 76.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([16., 16.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_complex() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 4], - kernel_size: 3, - padding: 1, - padding_out: 1, - stride: 2, - dilation: 2, - groups: 2, - size: 8, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], - [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], - ], - [ - [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], - [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], - ], - ]), - weight: TestTensor::from_floats([ - [[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]], - [[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]], - ]), - bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_complex() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 4], + kernel_size: 3, + padding: 1, + padding_out: 1, + stride: 2, + dilation: 2, + groups: 2, + size: 8, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], + [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], + ], + [ + [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], + [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], + ], + ]), + weight: TestTensor::from_floats([ + [[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]], + [[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]], + ]), + bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0]), + }; + test.assert_grads(grads); + } - struct ConvTranspose1dTestCase { - batch_size: usize, - channels: [usize; 2], - kernel_size: usize, - padding: usize, - padding_out: usize, - stride: usize, - dilation: usize, - groups: usize, - size: usize, - } + struct ConvTranspose1dTestCase { + batch_size: usize, + channels: [usize; 2], + kernel_size: usize, + padding: usize, + padding_out: usize, + stride: usize, + dilation: usize, + groups: usize, + size: usize, + } - struct Grads { - x: TestTensor<3>, - weight: TestTensor<3>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<3>, + weight: TestTensor<3>, + bias: TestTensor<1>, + } - impl ConvTranspose1dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]); - let shape_weight = Shape::new([ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size, - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels[1]) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = conv_transpose1d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvTransposeOptions::new( - [self.stride], - [self.padding], - [self.padding_out], - [self.dilation], - self.groups, - ), - ); - let grads = output.backward(); + impl ConvTranspose1dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]); + let shape_weight = Shape::new([ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size, + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels[1]) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = conv_transpose1d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvTransposeOptions::new( + [self.stride], + [self.padding], + [self.padding_out], + [self.dilation], + self.groups, + ), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); - } + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/conv_transpose2d.rs b/burn-autodiff/src/tests/conv_transpose2d.rs index 138afbf5fe..bedc9d617d 100644 --- a/burn-autodiff/src/tests/conv_transpose2d.rs +++ b/burn-autodiff/src/tests/conv_transpose2d.rs @@ -1,647 +1,643 @@ #[burn_tensor_testgen::testgen(ad_conv_transpose2d)] mod tests { - use super::*; - use burn_tensor::{module::conv_transpose2d, ops::ConvTransposeOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv_transpose2d, ops::ConvTransposeOptions, Data, Shape}; - #[test] - fn test_conv_transpose2d_basic() { - let test = ConvTranspose2dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - ], - [ - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - ], - ], - [ - [ - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - ], - [ - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], - [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], - ], - [ - [ - [1264., 1264., 1264.], - [1264., 1264., 1264.], - [1264., 1264., 1264.], - ], - [ - [1264., 1264., 1264.], - [1264., 1264., 1264.], - [1264., 1264., 1264.], - ], - ], - ]), - bias: TestTensor::from_floats([72., 72.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_basic() { + let test = ConvTranspose2dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + ], + [ + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + ], + ], + [ + [ + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + ], + [ + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], + [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], + ], + [ + [ + [1264., 1264., 1264.], + [1264., 1264., 1264.], + [1264., 1264., 1264.], + ], + [ + [1264., 1264., 1264.], + [1264., 1264., 1264.], + [1264., 1264., 1264.], + ], + ], + ]), + bias: TestTensor::from_floats([72., 72.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_padding() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [1, 2], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [13., 24., 20., 9.], - [15., 27., 21., 9.], - [15., 27., 21., 9.], - [7., 12., 8., 3.], - ]]]), - weight: TestTensor::from_floats([[[ - [63., 57., 51.], - [68., 60., 52.], - [39., 33., 27.], - ]]]), - bias: TestTensor::from_floats([8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_padding() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [1, 2], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [13., 24., 20., 9.], + [15., 27., 21., 9.], + [15., 27., 21., 9.], + [7., 12., 8., 3.], + ]]]), + weight: TestTensor::from_floats([[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]]), + bias: TestTensor::from_floats([8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_stride() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [2, 3], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ]]]), - weight: TestTensor::from_floats([[[ - [120., 120., 120.], - [120., 120., 120.], - [120., 120., 120.], - ]]]), - bias: TestTensor::from_floats([108.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_stride() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [2, 3], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ]]]), + weight: TestTensor::from_floats([[[ + [120., 120., 120.], + [120., 120., 120.], + [120., 120., 120.], + ]]]), + bias: TestTensor::from_floats([108.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_stride_padding_out() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [1, 2], - stride: [2, 3], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ]]]), - weight: TestTensor::from_floats([[[ - [120., 120., 120.], - [120., 120., 120.], - [120., 120., 120.], - ]]]), - bias: TestTensor::from_floats([140.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_stride_padding_out() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [1, 2], + stride: [2, 3], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ]]]), + weight: TestTensor::from_floats([[[ + [120., 120., 120.], + [120., 120., 120.], + [120., 120., 120.], + ]]]), + bias: TestTensor::from_floats([140.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_dilation() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [2, 3], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ]]]), - weight: TestTensor::from_floats([[[ - [120., 120., 120.], - [120., 120., 120.], - [120., 120., 120.], - ]]]), - bias: TestTensor::from_floats([80.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_dilation() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [2, 3], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ]]]), + weight: TestTensor::from_floats([[[ + [120., 120., 120.], + [120., 120., 120.], + [120., 120., 120.], + ]]]), + bias: TestTensor::from_floats([80.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_channels() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [2, 3], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [351., 351., 351., 351.], - [351., 351., 351., 351.], - [351., 351., 351., 351.], - [351., 351., 351., 351.], - ], - [ - [1080., 1080., 1080., 1080.], - [1080., 1080., 1080., 1080.], - [1080., 1080., 1080., 1080.], - [1080., 1080., 1080., 1080.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], - [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], - [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], - ], - [ - [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], - [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], - [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], - ], - ]), - bias: TestTensor::from_floats([36., 36., 36.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_channels() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [2, 3], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [351., 351., 351., 351.], + [351., 351., 351., 351.], + [351., 351., 351., 351.], + [351., 351., 351., 351.], + ], + [ + [1080., 1080., 1080., 1080.], + [1080., 1080., 1080., 1080.], + [1080., 1080., 1080., 1080.], + [1080., 1080., 1080., 1080.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], + [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], + [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], + ], + [ + [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], + [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], + [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], + ], + ]), + bias: TestTensor::from_floats([36., 36., 36.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_kernel_size() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 5], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [6, 6], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - ]]]), - weight: TestTensor::from_floats([[[ - [630., 630., 630., 630., 630.], - [630., 630., 630., 630., 630.], - [630., 630., 630., 630., 630.], - ]]]), - bias: TestTensor::from_floats([80.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_kernel_size() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 5], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [6, 6], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + ]]]), + weight: TestTensor::from_floats([[[ + [630., 630., 630., 630., 630.], + [630., 630., 630., 630., 630.], + [630., 630., 630., 630., 630.], + ]]]), + bias: TestTensor::from_floats([80.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_groups() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [2, 2], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 2, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ], - [ - [117., 117., 117., 117.], - [117., 117., 117., 117.], - [117., 117., 117., 117.], - [117., 117., 117., 117.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]], - [[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]], - ]), - bias: TestTensor::from_floats([36., 36.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_groups() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [2, 2], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 2, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ], + [ + [117., 117., 117., 117.], + [117., 117., 117., 117.], + [117., 117., 117., 117.], + [117., 117., 117., 117.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]], + [[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]], + ]), + bias: TestTensor::from_floats([36., 36.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_complex_no_groups() { - let test = ConvTranspose2dTestCase { - batch_size: 2, - channels: [2, 3], - kernel_size: [3, 5], - padding: [1, 2], - padding_out: [1, 2], - stride: [2, 3], - dilation: [2, 3], - groups: 1, - size: [6, 8], - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [600., 735., 735., 735., 735., 735., 735., 735.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - ], - [ - [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - ], - ], - [ - [ - [600., 735., 735., 735., 735., 735., 735., 735.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - ], - [ - [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [ - [5320., 6040., 6040., 6040., 6040.], - [6048., 6864., 6864., 6864., 6864.], - [6048., 6864., 6864., 6864., 6864.], - ], - [ - [5320., 6040., 6040., 6040., 6040.], - [6048., 6864., 6864., 6864., 6864.], - [6048., 6864., 6864., 6864., 6864.], - ], - [ - [5320., 6040., 6040., 6040., 6040.], - [6048., 6864., 6864., 6864., 6864.], - [6048., 6864., 6864., 6864., 6864.], - ], - ], - [ - [ - [8680., 9880., 9880., 9880., 9880.], - [10080., 11472., 11472., 11472., 11472.], - [10080., 11472., 11472., 11472., 11472.], - ], - [ - [8680., 9880., 9880., 9880., 9880.], - [10080., 11472., 11472., 11472., 11472.], - [10080., 11472., 11472., 11472., 11472.], - ], - [ - [8680., 9880., 9880., 9880., 9880.], - [10080., 11472., 11472., 11472., 11472.], - [10080., 11472., 11472., 11472., 11472.], - ], - ], - ]), - bias: TestTensor::from_floats([896., 896., 896.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_complex_no_groups() { + let test = ConvTranspose2dTestCase { + batch_size: 2, + channels: [2, 3], + kernel_size: [3, 5], + padding: [1, 2], + padding_out: [1, 2], + stride: [2, 3], + dilation: [2, 3], + groups: 1, + size: [6, 8], + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [600., 735., 735., 735., 735., 735., 735., 735.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + ], + [ + [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + ], + ], + [ + [ + [600., 735., 735., 735., 735., 735., 735., 735.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + ], + [ + [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [ + [5320., 6040., 6040., 6040., 6040.], + [6048., 6864., 6864., 6864., 6864.], + [6048., 6864., 6864., 6864., 6864.], + ], + [ + [5320., 6040., 6040., 6040., 6040.], + [6048., 6864., 6864., 6864., 6864.], + [6048., 6864., 6864., 6864., 6864.], + ], + [ + [5320., 6040., 6040., 6040., 6040.], + [6048., 6864., 6864., 6864., 6864.], + [6048., 6864., 6864., 6864., 6864.], + ], + ], + [ + [ + [8680., 9880., 9880., 9880., 9880.], + [10080., 11472., 11472., 11472., 11472.], + [10080., 11472., 11472., 11472., 11472.], + ], + [ + [8680., 9880., 9880., 9880., 9880.], + [10080., 11472., 11472., 11472., 11472.], + [10080., 11472., 11472., 11472., 11472.], + ], + [ + [8680., 9880., 9880., 9880., 9880.], + [10080., 11472., 11472., 11472., 11472.], + [10080., 11472., 11472., 11472., 11472.], + ], + ], + ]), + bias: TestTensor::from_floats([896., 896., 896.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_complex_no_groups_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [4, 2], - kernel_size: [2, 3], - padding: [1, 2], - padding_out: [1, 2], - stride: [2, 3], - dilation: [1, 2], - groups: 1, - size: [10, 10], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [30., 42., 42., 42., 42., 42., 42., 42., 42., 42.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - ], - [ - [78., 114., 114., 114., 114., 114., 114., 114., 114., 114.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - ], - [ - [126., 186., 186., 186., 186., 186., 186., 186., 186., 186.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - ], - [ - [174., 258., 258., 258., 258., 258., 258., 258., 258., 258.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[4455., 4905., 4905.], [4500., 4950., 4950.]], - [[4455., 4905., 4905.], [4500., 4950., 4950.]], - ], - [ - [[12555., 13905., 13905.], [13500., 14950., 14950.]], - [[12555., 13905., 13905.], [13500., 14950., 14950.]], - ], - [ - [[20655., 22905., 22905.], [22500., 24950., 24950.]], - [[20655., 22905., 22905.], [22500., 24950., 24950.]], - ], - [ - [[28755., 31905., 31905.], [31500., 34950., 34950.]], - [[28755., 31905., 31905.], [31500., 34950., 34950.]], - ], - ]), - bias: TestTensor::from_floats([570., 570.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_complex_no_groups_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [4, 2], + kernel_size: [2, 3], + padding: [1, 2], + padding_out: [1, 2], + stride: [2, 3], + dilation: [1, 2], + groups: 1, + size: [10, 10], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [30., 42., 42., 42., 42., 42., 42., 42., 42., 42.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + ], + [ + [78., 114., 114., 114., 114., 114., 114., 114., 114., 114.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + ], + [ + [126., 186., 186., 186., 186., 186., 186., 186., 186., 186.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + ], + [ + [174., 258., 258., 258., 258., 258., 258., 258., 258., 258.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[4455., 4905., 4905.], [4500., 4950., 4950.]], + [[4455., 4905., 4905.], [4500., 4950., 4950.]], + ], + [ + [[12555., 13905., 13905.], [13500., 14950., 14950.]], + [[12555., 13905., 13905.], [13500., 14950., 14950.]], + ], + [ + [[20655., 22905., 22905.], [22500., 24950., 24950.]], + [[20655., 22905., 22905.], [22500., 24950., 24950.]], + ], + [ + [[28755., 31905., 31905.], [31500., 34950., 34950.]], + [[28755., 31905., 31905.], [31500., 34950., 34950.]], + ], + ]), + bias: TestTensor::from_floats([570., 570.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_complex_groups() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [4, 2], - kernel_size: [2, 3], - padding: [1, 2], - padding_out: [1, 2], - stride: [2, 3], - dilation: [1, 2], - groups: 2, - size: [10, 10], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [9., 12., 12., 12., 12., 12., 12., 12., 12., 12.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - ], - [ - [21., 30., 30., 30., 30., 30., 30., 30., 30., 30.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - ], - [ - [33., 48., 48., 48., 48., 48., 48., 48., 48., 48.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - ], - [ - [45., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[4455., 4905., 4905.], [4500., 4950., 4950.]]], - [[[12555., 13905., 13905.], [13500., 14950., 14950.]]], - [[[20655., 22905., 22905.], [22500., 24950., 24950.]]], - [[[28755., 31905., 31905.], [31500., 34950., 34950.]]], - ]), - bias: TestTensor::from_floats([570., 570.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_complex_groups() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [4, 2], + kernel_size: [2, 3], + padding: [1, 2], + padding_out: [1, 2], + stride: [2, 3], + dilation: [1, 2], + groups: 2, + size: [10, 10], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [9., 12., 12., 12., 12., 12., 12., 12., 12., 12.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + ], + [ + [21., 30., 30., 30., 30., 30., 30., 30., 30., 30.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + ], + [ + [33., 48., 48., 48., 48., 48., 48., 48., 48., 48.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + ], + [ + [45., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[4455., 4905., 4905.], [4500., 4950., 4950.]]], + [[[12555., 13905., 13905.], [13500., 14950., 14950.]]], + [[[20655., 22905., 22905.], [22500., 24950., 24950.]]], + [[[28755., 31905., 31905.], [31500., 34950., 34950.]]], + ]), + bias: TestTensor::from_floats([570., 570.]), + }; + test.assert_grads(grads); + } - struct ConvTranspose2dTestCase { - batch_size: usize, - channels: [usize; 2], - kernel_size: [usize; 2], - padding: [usize; 2], - padding_out: [usize; 2], - stride: [usize; 2], - dilation: [usize; 2], - groups: usize, - size: [usize; 2], - } + struct ConvTranspose2dTestCase { + batch_size: usize, + channels: [usize; 2], + kernel_size: [usize; 2], + padding: [usize; 2], + padding_out: [usize; 2], + stride: [usize; 2], + dilation: [usize; 2], + groups: usize, + size: [usize; 2], + } - struct Grads { - x: TestTensor<4>, - weight: TestTensor<4>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<4>, + weight: TestTensor<4>, + bias: TestTensor<1>, + } - impl ConvTranspose2dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([ - self.batch_size, - self.channels[0], - self.size[0], - self.size[1], - ]); - let shape_weight = Shape::new([ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size[0], - self.kernel_size[1], - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels[1]) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = conv_transpose2d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvTransposeOptions::new( - self.stride, - self.padding, - self.padding_out, - self.dilation, - self.groups, - ), - ); - let grads = output.backward(); + impl ConvTranspose2dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([ + self.batch_size, + self.channels[0], + self.size[0], + self.size[1], + ]); + let shape_weight = Shape::new([ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size[0], + self.kernel_size[1], + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels[1]) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = conv_transpose2d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvTransposeOptions::new( + self.stride, + self.padding, + self.padding_out, + self.dilation, + self.groups, + ), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); - } + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); } + } } diff --git a/burn-autodiff/src/tests/cos.rs b/burn-autodiff/src/tests/cos.rs index af42104e8e..a89e08139a 100644 --- a/burn-autodiff/src/tests/cos.rs +++ b/burn-autodiff/src/tests/cos.rs @@ -1,30 +1,30 @@ #[burn_tensor_testgen::testgen(ad_cos)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_cos() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_cos() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1.to_data().assert_approx_eq_diff( - &Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]), - 2.0e-3, - ); - grad_2.to_data().assert_approx_eq_diff( - &Data::from([[9.222064, -39.123375], [-28.721354, 49.748356]]), - 2.0e-3, - ); - } + grad_1.to_data().assert_approx_eq_diff( + &Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]), + 2.0e-3, + ); + grad_2.to_data().assert_approx_eq_diff( + &Data::from([[9.222064, -39.123375], [-28.721354, 49.748356]]), + 2.0e-3, + ); + } } diff --git a/burn-autodiff/src/tests/cross_entropy.rs b/burn-autodiff/src/tests/cross_entropy.rs index f898f6b2f6..c22a478f4a 100644 --- a/burn-autodiff/src/tests/cross_entropy.rs +++ b/burn-autodiff/src/tests/cross_entropy.rs @@ -1,31 +1,30 @@ #[burn_tensor_testgen::testgen(ad_cross_entropy_loss)] mod tests { - use super::*; - use burn_tensor::{loss, Data, Tensor}; + use super::*; + use burn_tensor::{loss, Data, Tensor}; - #[test] - fn test_cross_entropy_loss_grad() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - let data_targets = Data::from([[0.8, 0.2], [0.9, 0.1]]); + #[test] + fn test_cross_entropy_loss_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + let data_targets = Data::from([[0.8, 0.2], [0.9, 0.1]]); - let tensor_1 = Tensor::::from_data(data_1).require_grad(); - let tensor_2 = Tensor::::from_data(data_2).require_grad(); - let tensor_targets = - Tensor::::from_data(data_targets).require_grad(); + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); + let tensor_targets = Tensor::::from_data(data_targets).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets); - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[0.2655, 0.2655], [0.4496, 0.4496]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-1.3486, 1.3486], [-2.0637, 2.0637]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[0.2655, 0.2655], [0.4496, 0.4496]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-1.3486, 1.3486], [-2.0637, 2.0637]]), 3); + } } diff --git a/burn-autodiff/src/tests/div.rs b/burn-autodiff/src/tests/div.rs index 7ab4470921..47402e9d6d 100644 --- a/burn-autodiff/src/tests/div.rs +++ b/burn-autodiff/src/tests/div.rs @@ -1,89 +1,89 @@ #[burn_tensor_testgen::testgen(ad_div)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_div() { - let data_1 = Data::from([1.0, 7.0]); - let data_2 = Data::from([4.0, 7.0]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().div(tensor_2.clone()); - let grads = tensor_3.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([0.25, 0.1429]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([-0.0625, -0.1429]), 3); - } - - #[test] - fn should_diff_div_scalar() { - let data = Data::from([1.0, 7.0]); - - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().div_scalar(4.0); - - let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); - - assert_eq!(grad.to_data(), Data::from([0.25, 0.25])); - } - - #[test] - fn test_div_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().div(tensor_2.clone()); - let tensor_5 = tensor_4.div(tensor_3); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[0.1250, 0.0714], [0.25, 0.1667]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-0.0312, -0.0714], [-1.6250, 0.1667]]), 3); - } - - #[test] - fn test_div_complex_2() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.div(tensor_2.clone()); - - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[2.00, 2.9286], [1.3667, 2.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[0.0833, 0.0959], [-0.0556, -0.0671]]), 3); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_div() { + let data_1 = Data::from([1.0, 7.0]); + let data_2 = Data::from([4.0, 7.0]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().div(tensor_2.clone()); + let grads = tensor_3.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([0.25, 0.1429]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([-0.0625, -0.1429]), 3); + } + + #[test] + fn should_diff_div_scalar() { + let data = Data::from([1.0, 7.0]); + + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().div_scalar(4.0); + + let grads = tensor_out.backward(); + let grad = tensor.grad(&grads).unwrap(); + + assert_eq!(grad.to_data(), Data::from([0.25, 0.25])); + } + + #[test] + fn test_div_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().div(tensor_2.clone()); + let tensor_5 = tensor_4.div(tensor_3); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[0.1250, 0.0714], [0.25, 0.1667]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-0.0312, -0.0714], [-1.6250, 0.1667]]), 3); + } + + #[test] + fn test_div_complex_2() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.div(tensor_2.clone()); + + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[2.00, 2.9286], [1.3667, 2.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[0.0833, 0.0959], [-0.0556, -0.0671]]), 3); + } } diff --git a/burn-autodiff/src/tests/erf.rs b/burn-autodiff/src/tests/erf.rs index bd80c347ad..5db398fe67 100644 --- a/burn-autodiff/src/tests/erf.rs +++ b/burn-autodiff/src/tests/erf.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_erf)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_erf() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_erf() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3); + } } diff --git a/burn-autodiff/src/tests/exp.rs b/burn-autodiff/src/tests/exp.rs index bba159bb60..bfd9293bee 100644 --- a/burn-autodiff/src/tests/exp.rs +++ b/burn-autodiff/src/tests/exp.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_exp)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_exp() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + #[test] + fn should_diff_exp() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[54.5991, 27.4746], [54.5991, 27.4746]]), 3); - grad_2.to_data().assert_approx_eq( - &Data::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]), - 3, - ); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[54.5991, 27.4746], [54.5991, 27.4746]]), 3); + grad_2.to_data().assert_approx_eq( + &Data::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]), + 3, + ); + } } diff --git a/burn-autodiff/src/tests/gather_scatter.rs b/burn-autodiff/src/tests/gather_scatter.rs index 3557f11c8a..e1384c2218 100644 --- a/burn-autodiff/src/tests/gather_scatter.rs +++ b/burn-autodiff/src/tests/gather_scatter.rs @@ -1,58 +1,56 @@ #[burn_tensor_testgen::testgen(ad_gather_scatter)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_gather_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) - .require_grad(); - let indices = Tensor::::from_data(Data::from([ - [2, 1, 0, 1, 2], - [1, 0, 2, 1, 0], - ])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().gather(1, indices); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[94., 150., 187.], [242., 305., 304.]]) - ); - } - - #[test] - fn test_scatter_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) - .require_grad(); - let values = TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - .require_grad(); - let indices = - Tensor::::from_data(Data::from([[2, 1, 0], [2, 0, 1]])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().scatter(1, indices, values.clone()); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = values.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[127., 181., 235.], [226., 316., 406.]]) - ); - assert_eq!( - grad_2.into_data(), - Data::from([[19., 19., 19.], [64., 64., 64.]]) - ); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_gather_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); + let indices = Tensor::::from_data(Data::from([ + [2, 1, 0, 1, 2], + [1, 0, 2, 1, 0], + ])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().gather(1, indices); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[94., 150., 187.], [242., 305., 304.]]) + ); + } + + #[test] + fn test_scatter_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); + let values = + TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad(); + let indices = + Tensor::::from_data(Data::from([[2, 1, 0], [2, 0, 1]])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().scatter(1, indices, values.clone()); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = values.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[127., 181., 235.], [226., 316., 406.]]) + ); + assert_eq!( + grad_2.into_data(), + Data::from([[19., 19., 19.], [64., 64., 64.]]) + ); + } } diff --git a/burn-autodiff/src/tests/gelu.rs b/burn-autodiff/src/tests/gelu.rs index fec6eb3aa0..c39ff5ddbb 100644 --- a/burn-autodiff/src/tests/gelu.rs +++ b/burn-autodiff/src/tests/gelu.rs @@ -1,25 +1,25 @@ #[burn_tensor_testgen::testgen(ad_gelu)] mod tests { - use super::*; - use burn_tensor::{activation, Data}; + use super::*; + use burn_tensor::{activation, Data}; - #[test] - fn should_diff_gelu() { - let tensor_1 = TestAutodiffTensor::from_floats([[0.0, 1.0], [-3.0, 4.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]]).require_grad(); + #[test] + fn should_diff_gelu() { + let tensor_1 = TestAutodiffTensor::from_floats([[0.0, 1.0], [-3.0, 4.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]]).require_grad(); - let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone())); - let x = tensor_1.clone().matmul(x); - let grads = x.backward(); + let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone())); + let x = tensor_1.clone().matmul(x); + let grads = x.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[1.4629, 1.4629], [48.2286, 153.4629]]), 2); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-15.0000, -1.9895], [17.0000, 17.0000]]), 2); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[1.4629, 1.4629], [48.2286, 153.4629]]), 2); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-15.0000, -1.9895], [17.0000, 17.0000]]), 2); + } } diff --git a/burn-autodiff/src/tests/gradients.rs b/burn-autodiff/src/tests/gradients.rs index a1f6eda3c0..844de4f41c 100644 --- a/burn-autodiff/src/tests/gradients.rs +++ b/burn-autodiff/src/tests/gradients.rs @@ -1,25 +1,24 @@ #[burn_tensor_testgen::testgen(gradients)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Distribution}; + use super::*; + use burn_tensor::{activation, Data, Distribution}; - #[test] - fn should_update_tensor_when_grad_replace() { - let tensor_1 = TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); - let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default); + #[test] + fn should_update_tensor_when_grad_replace() { + let tensor_1 = TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); + let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default); - let x = tensor_1.clone().matmul(activation::gelu(tensor_2)); - let mut grads = x.backward(); + let x = tensor_1.clone().matmul(activation::gelu(tensor_2)); + let mut grads = x.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_1_updated = - TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); - tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner()); + let grad_1_updated = TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); + tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner()); - let grad_1_new = tensor_1.grad(&grads).unwrap(); + let grad_1_new = tensor_1.grad(&grads).unwrap(); - assert_ne!(grad_1_new.to_data(), grad_1.into_data()); - assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data()); - } + assert_ne!(grad_1_new.to_data(), grad_1.into_data()); + assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data()); + } } diff --git a/burn-autodiff/src/tests/log.rs b/burn-autodiff/src/tests/log.rs index 9c7766da97..3752aebea5 100644 --- a/burn-autodiff/src/tests/log.rs +++ b/burn-autodiff/src/tests/log.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_log)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_log() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_log() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[60.2652, 72.3130], [60.2652, 72.3130]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[22.8614, 24.5043], [24.5729, 26.8507]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[60.2652, 72.3130], [60.2652, 72.3130]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[22.8614, 24.5043], [24.5729, 26.8507]]), 3); + } } diff --git a/burn-autodiff/src/tests/log1p.rs b/burn-autodiff/src/tests/log1p.rs index d94f5aa176..627fd00aeb 100644 --- a/burn-autodiff/src/tests/log1p.rs +++ b/burn-autodiff/src/tests/log1p.rs @@ -1,29 +1,29 @@ #[burn_tensor_testgen::testgen(ad_log1p)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_log1p() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_log1p() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[64.80622, 75.49362], [64.80622, 75.49362]]), 3); - grad_2.to_data().assert_approx_eq( - &Data::from([[22.922085, 24.475657], [24.727802, 26.864166]]), - 3, - ); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[64.80622, 75.49362], [64.80622, 75.49362]]), 3); + grad_2.to_data().assert_approx_eq( + &Data::from([[22.922085, 24.475657], [24.727802, 26.864166]]), + 3, + ); + } } diff --git a/burn-autodiff/src/tests/mask.rs b/burn-autodiff/src/tests/mask.rs index d400149cfb..fe7b25fd15 100644 --- a/burn-autodiff/src/tests/mask.rs +++ b/burn-autodiff/src/tests/mask.rs @@ -1,54 +1,53 @@ #[burn_tensor_testgen::testgen(ad_mask)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Tensor}; - - #[test] - fn should_diff_mask_fill() { - let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); - let mask = Data::::from([[true, false], [false, true]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let mask = Tensor::::from_bool(mask); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.mask_fill(mask, 2.0); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[7.0, 3.0], [4.0, 2.0]])); - assert_eq!(grad_2.to_data(), Data::from([[2.0, 1.0], [3.0, 7.0]])); - } - - #[test] - fn should_diff_mask_where() { - let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]]).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]]).require_grad(); - let mask = - Tensor::::from_data([[true, false], [false, true]]); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.clone().matmul(tensor_3.clone()); - let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone()); - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - let grad_3 = tensor_3.grad(&grads).unwrap(); - - grad_1 - .into_data() - .assert_approx_eq(&Data::from([[121.8, 55.0], [110.8, 50.0]]), 3); - grad_2 - .into_data() - .assert_approx_eq(&Data::from([[27.4, 33.4], [95.0, 115.0]]), 3); - grad_3 - .into_data() - .assert_approx_eq(&Data::from([[15., 18.], [23., 29.]]), 3); - } + use super::*; + use burn_tensor::{Bool, Data, Tensor}; + + #[test] + fn should_diff_mask_fill() { + let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); + let mask = Data::::from([[true, false], [false, true]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let mask = Tensor::::from_bool(mask); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.mask_fill(mask, 2.0); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[7.0, 3.0], [4.0, 2.0]])); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 1.0], [3.0, 7.0]])); + } + + #[test] + fn should_diff_mask_where() { + let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]]).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]]).require_grad(); + let mask = Tensor::::from_data([[true, false], [false, true]]); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.clone().matmul(tensor_3.clone()); + let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone()); + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_3 = tensor_3.grad(&grads).unwrap(); + + grad_1 + .into_data() + .assert_approx_eq(&Data::from([[121.8, 55.0], [110.8, 50.0]]), 3); + grad_2 + .into_data() + .assert_approx_eq(&Data::from([[27.4, 33.4], [95.0, 115.0]]), 3); + grad_3 + .into_data() + .assert_approx_eq(&Data::from([[15., 18.], [23., 29.]]), 3); + } } diff --git a/burn-autodiff/src/tests/matmul.rs b/burn-autodiff/src/tests/matmul.rs index 7f5de915e5..31c7be72e0 100644 --- a/burn-autodiff/src/tests/matmul.rs +++ b/burn-autodiff/src/tests/matmul.rs @@ -1,78 +1,78 @@ #[burn_tensor_testgen::testgen(ad_matmul)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_matmul() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let grads = tensor_3.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); - assert_eq!( - tensor_3.clone().into_data(), - Data::from([[18.0, 28.0], [14.0, 23.0]]) - ); - } - - #[test] - fn test_matmul_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); - assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); - } - - #[test] - fn test_matmul_complex_2() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3.clone()); - let tensor_6 = tensor_1.clone().matmul(tensor_5); - - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[800.0, 792.0], [360.0, 592.0]]) - ); - assert_eq!( - grad_2.to_data(), - Data::from([[264., 264.0], [344.0, 344.0]]) - ); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_matmul() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let grads = tensor_3.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); + assert_eq!( + tensor_3.clone().into_data(), + Data::from([[18.0, 28.0], [14.0, 23.0]]) + ); + } + + #[test] + fn test_matmul_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); + assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); + } + + #[test] + fn test_matmul_complex_2() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3.clone()); + let tensor_6 = tensor_1.clone().matmul(tensor_5); + + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[800.0, 792.0], [360.0, 592.0]]) + ); + assert_eq!( + grad_2.to_data(), + Data::from([[264., 264.0], [344.0, 344.0]]) + ); + } } diff --git a/burn-autodiff/src/tests/maxmin.rs b/burn-autodiff/src/tests/maxmin.rs index 3edab63f6e..2f10eac2fc 100644 --- a/burn-autodiff/src/tests/maxmin.rs +++ b/burn-autodiff/src/tests/maxmin.rs @@ -1,45 +1,45 @@ #[burn_tensor_testgen::testgen(ad_maxmin)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_max_dim() { - let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[50.0, 34.0], [40.0, -10.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[8.0, 10.0], [56.0, 15.0]]), 5); - } - - #[test] - fn should_diff_min_dim() { - let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[-42.0, 38.0], [-34.0, -24.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_max_dim() { + let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[50.0, 34.0], [40.0, -10.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[8.0, 10.0], [56.0, 15.0]]), 5); + } + + #[test] + fn should_diff_min_dim() { + let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[-42.0, 38.0], [-34.0, -24.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5); + } } diff --git a/burn-autodiff/src/tests/maxpool1d.rs b/burn-autodiff/src/tests/maxpool1d.rs index 2ccceafd22..05d8c09573 100644 --- a/burn-autodiff/src/tests/maxpool1d.rs +++ b/burn-autodiff/src/tests/maxpool1d.rs @@ -1,111 +1,110 @@ #[burn_tensor_testgen::testgen(ad_max_pool1d)] mod tests { - use super::*; - use burn_tensor::{module::max_pool1d, Data}; - - #[test] - fn test_max_pool1d_simple() { - let kernel_size = 4; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = - TestAutodiffTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool1d_with_dilation() { - let kernel_size = 4; - let padding = 0; - let stride = 1; - let dilation = 2; - - let x = TestAutodiffTensor::from_floats([[[ - 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, - 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, - 0.4610, 0.5365, 0.6880, - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - 0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0., - 0., 0., 1., - ]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool1d_complex() { - let kernel_size = 4; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = TestAutodiffTensor::from_floats([[[ - 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, - 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, - 0.4610, 0.5365, 0.6880, - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - 0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., - 1., 1., 1., - ]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool1d_complex_with_padding() { - let kernel_size = 4; - let padding = 2; - let stride = 1; - let dilation = 1; - - let x = TestAutodiffTensor::from_floats([[[ - 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, - 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, - 0.4610, 0.5365, 0.6880, - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - 1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., - 1., 1., 3., - ]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } + use super::*; + use burn_tensor::{module::max_pool1d, Data}; + + #[test] + fn test_max_pool1d_simple() { + let kernel_size = 4; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = TestAutodiffTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_with_dilation() { + let kernel_size = 4; + let padding = 0; + let stride = 1; + let dilation = 2; + + let x = TestAutodiffTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + 0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0., 0., + 0., 1., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_complex() { + let kernel_size = 4; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = TestAutodiffTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + 0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., 1., + 1., 1., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_complex_with_padding() { + let kernel_size = 4; + let padding = 2; + let stride = 1; + let dilation = 1; + + let x = TestAutodiffTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + 1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., 1., + 1., 3., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } } diff --git a/burn-autodiff/src/tests/maxpool2d.rs b/burn-autodiff/src/tests/maxpool2d.rs index 49d66212ae..a73e11ff0f 100644 --- a/burn-autodiff/src/tests/maxpool2d.rs +++ b/burn-autodiff/src/tests/maxpool2d.rs @@ -1,171 +1,171 @@ #[burn_tensor_testgen::testgen(ad_max_pool2d)] mod tests { - use super::*; - use burn_tensor::{module::max_pool2d, Data}; - - #[test] - fn test_max_pool2d_simple_1() { - let kernel_size_1 = 3; - let kernel_size_2 = 3; - let padding_1 = 0; - let padding_2 = 0; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; - - let x = TestAutodiffTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 2.0], - [0.0, 2.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool2d_simple_2() { - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; - - let x = TestAutodiffTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [1., 3., 0., 2.], - [3., 0., 0., 4.], - [1., 4., 0., 1.], - [2., 0., 3., 1.], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool2d_with_dilation() { - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 2; - let dilation_2 = 2; - - let x = TestAutodiffTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [0., 0., 0., 0.], - [1., 1., 1., 2.], - [0., 4., 4., 0.], - [0., 1., 2., 0.], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool2d_complex() { - let kernel_size_1 = 4; - let kernel_size_2 = 2; - let padding_1 = 2; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 2; - let dilation_1 = 1; - let dilation_2 = 1; - - let x = TestAutodiffTensor::from_floats([[[ - [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], - [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], - [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], - [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], - [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [0., 0., 0., 3., 0.], - [4., 0., 2., 1., 0.], - [0., 0., 0., 0., 0.], - [2., 4., 0., 0., 0.], - [0., 0., 0., 0., 2.], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } + use super::*; + use burn_tensor::{module::max_pool2d, Data}; + + #[test] + fn test_max_pool2d_simple_1() { + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 0; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; + + let x = TestAutodiffTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 2.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_simple_2() { + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; + + let x = TestAutodiffTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [1., 3., 0., 2.], + [3., 0., 0., 4.], + [1., 4., 0., 1.], + [2., 0., 3., 1.], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_with_dilation() { + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 2; + let dilation_2 = 2; + + let x = TestAutodiffTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [0., 0., 0., 0.], + [1., 1., 1., 2.], + [0., 4., 4., 0.], + [0., 1., 2., 0.], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_complex() { + let kernel_size_1 = 4; + let kernel_size_2 = 2; + let padding_1 = 2; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 2; + let dilation_1 = 1; + let dilation_2 = 1; + + let x = TestAutodiffTensor::from_floats([[[ + [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], + [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], + [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], + [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], + [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [0., 0., 0., 3., 0.], + [4., 0., 2., 1., 0.], + [0., 0., 0., 0., 0.], + [2., 4., 0., 0., 0.], + [0., 0., 0., 0., 2.], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } } diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index 23e84426d0..15328cc5b8 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -48,61 +48,61 @@ mod transpose; #[macro_export] macro_rules! testgen_all { - () => { - type TestAutodiffBackend = burn_autodiff::Autodiff; - type TestAutodiffTensor = burn_tensor::Tensor; + () => { + type TestAutodiffBackend = burn_autodiff::Autodiff; + type TestAutodiffTensor = burn_tensor::Tensor; - // Behavior - burn_autodiff::testgen_ad_broadcast!(); - burn_autodiff::testgen_gradients!(); + // Behavior + burn_autodiff::testgen_ad_broadcast!(); + burn_autodiff::testgen_gradients!(); - // Activation - burn_autodiff::testgen_ad_relu!(); - burn_autodiff::testgen_ad_gelu!(); + // Activation + burn_autodiff::testgen_ad_relu!(); + burn_autodiff::testgen_ad_gelu!(); - // Modules - burn_autodiff::testgen_ad_conv1d!(); - burn_autodiff::testgen_ad_conv2d!(); - burn_autodiff::testgen_ad_conv_transpose1d!(); - burn_autodiff::testgen_ad_conv_transpose2d!(); - burn_autodiff::testgen_ad_max_pool1d!(); - burn_autodiff::testgen_ad_max_pool2d!(); - burn_autodiff::testgen_ad_avg_pool1d!(); - burn_autodiff::testgen_ad_avg_pool2d!(); - burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); - burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); - burn_autodiff::testgen_module_backward!(); + // Modules + burn_autodiff::testgen_ad_conv1d!(); + burn_autodiff::testgen_ad_conv2d!(); + burn_autodiff::testgen_ad_conv_transpose1d!(); + burn_autodiff::testgen_ad_conv_transpose2d!(); + burn_autodiff::testgen_ad_max_pool1d!(); + burn_autodiff::testgen_ad_max_pool2d!(); + burn_autodiff::testgen_ad_avg_pool1d!(); + burn_autodiff::testgen_ad_avg_pool2d!(); + burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); + burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); + burn_autodiff::testgen_module_backward!(); - // Tensor - burn_autodiff::testgen_ad_complex!(); - burn_autodiff::testgen_ad_multithread!(); - burn_autodiff::testgen_ad_add!(); - burn_autodiff::testgen_ad_aggregation!(); - burn_autodiff::testgen_ad_maxmin!(); - burn_autodiff::testgen_ad_cat!(); - burn_autodiff::testgen_ad_cos!(); - burn_autodiff::testgen_ad_cross_entropy_loss!(); - burn_autodiff::testgen_ad_div!(); - burn_autodiff::testgen_ad_erf!(); - burn_autodiff::testgen_ad_exp!(); - burn_autodiff::testgen_ad_slice!(); - burn_autodiff::testgen_ad_gather_scatter!(); - burn_autodiff::testgen_ad_select!(); - burn_autodiff::testgen_ad_log!(); - burn_autodiff::testgen_ad_log1p!(); - burn_autodiff::testgen_ad_mask!(); - burn_autodiff::testgen_ad_matmul!(); - burn_autodiff::testgen_ad_mul!(); - burn_autodiff::testgen_ad_neg!(); - burn_autodiff::testgen_ad_powf!(); - burn_autodiff::testgen_ad_recip!(); - burn_autodiff::testgen_ad_reshape!(); - burn_autodiff::testgen_ad_sin!(); - burn_autodiff::testgen_ad_softmax!(); - burn_autodiff::testgen_ad_sqrt!(); - burn_autodiff::testgen_ad_abs!(); - burn_autodiff::testgen_ad_sub!(); - burn_autodiff::testgen_ad_tanh!(); - burn_autodiff::testgen_ad_transpose!(); - }; + // Tensor + burn_autodiff::testgen_ad_complex!(); + burn_autodiff::testgen_ad_multithread!(); + burn_autodiff::testgen_ad_add!(); + burn_autodiff::testgen_ad_aggregation!(); + burn_autodiff::testgen_ad_maxmin!(); + burn_autodiff::testgen_ad_cat!(); + burn_autodiff::testgen_ad_cos!(); + burn_autodiff::testgen_ad_cross_entropy_loss!(); + burn_autodiff::testgen_ad_div!(); + burn_autodiff::testgen_ad_erf!(); + burn_autodiff::testgen_ad_exp!(); + burn_autodiff::testgen_ad_slice!(); + burn_autodiff::testgen_ad_gather_scatter!(); + burn_autodiff::testgen_ad_select!(); + burn_autodiff::testgen_ad_log!(); + burn_autodiff::testgen_ad_log1p!(); + burn_autodiff::testgen_ad_mask!(); + burn_autodiff::testgen_ad_matmul!(); + burn_autodiff::testgen_ad_mul!(); + burn_autodiff::testgen_ad_neg!(); + burn_autodiff::testgen_ad_powf!(); + burn_autodiff::testgen_ad_recip!(); + burn_autodiff::testgen_ad_reshape!(); + burn_autodiff::testgen_ad_sin!(); + burn_autodiff::testgen_ad_softmax!(); + burn_autodiff::testgen_ad_sqrt!(); + burn_autodiff::testgen_ad_abs!(); + burn_autodiff::testgen_ad_sub!(); + burn_autodiff::testgen_ad_tanh!(); + burn_autodiff::testgen_ad_transpose!(); + }; } diff --git a/burn-autodiff/src/tests/mul.rs b/burn-autodiff/src/tests/mul.rs index 85eec40498..a214abd24b 100644 --- a/burn-autodiff/src/tests/mul.rs +++ b/burn-autodiff/src/tests/mul.rs @@ -1,64 +1,64 @@ #[burn_tensor_testgen::testgen(ad_mul)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_mul() { - let data_1 = Data::from([1.0, 7.0]); - let data_2 = Data::from([4.0, 7.0]); + #[test] + fn should_diff_mul() { + let data_1 = Data::from([1.0, 7.0]); + let data_2 = Data::from([4.0, 7.0]); - let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); - let tensor_3 = tensor_1.clone().mul(tensor_2.clone()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().mul(tensor_2.clone()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), data_2); - assert_eq!(grad_2.to_data(), data_1); - assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0])); - } + assert_eq!(grad_1.to_data(), data_2); + assert_eq!(grad_2.to_data(), data_1); + assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0])); + } - #[test] - fn should_diff_mul_scalar() { - let data = Data::from([2.0, 5.0]); + #[test] + fn should_diff_mul_scalar() { + let data = Data::from([2.0, 5.0]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().mul_scalar(4.0); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().mul_scalar(4.0); - let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grads = tensor_out.backward(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0])); - assert_eq!(grad.to_data(), Data::from([4.0, 4.0])); - } + assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0])); + assert_eq!(grad.to_data(), Data::from([4.0, 4.0])); + } - #[test] - fn test_mul_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + #[test] + fn test_mul_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - let tensor_4 = tensor_1.clone().mul(tensor_2.clone()); - let tensor_5 = tensor_4.mul(tensor_3); - let tensor_6 = tensor_1.clone().mul(tensor_5); + let tensor_4 = tensor_1.clone().mul(tensor_2.clone()); + let tensor_5 = tensor_4.mul(tensor_3); + let tensor_6 = tensor_1.clone().mul(tensor_5); - let grads = tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!( - grad_1.to_data(), - Data::from([[16.0, 196.0], [104.0, -36.0]]) - ); - assert_eq!(grad_2.to_data(), Data::from([[2.0, 98.0], [338.0, 18.0]])); - } + assert_eq!( + grad_1.to_data(), + Data::from([[16.0, 196.0], [104.0, -36.0]]) + ); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 98.0], [338.0, 18.0]])); + } } diff --git a/burn-autodiff/src/tests/multithread.rs b/burn-autodiff/src/tests/multithread.rs index 041572da6e..3b30b52a8a 100644 --- a/burn-autodiff/src/tests/multithread.rs +++ b/burn-autodiff/src/tests/multithread.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(ad_multithread)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_behave_the_same_with_multithread() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + #[test] + fn should_behave_the_same_with_multithread() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let with_move = || { - let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); + let with_move = || { + let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3); - // Task 1 - let tensor_1_cloned = tensor_1.clone(); - let tensor_2_cloned = tensor_2.clone(); - let tensor_5_cloned = tensor_5.clone(); + // Task 1 + let tensor_1_cloned = tensor_1.clone(); + let tensor_2_cloned = tensor_2.clone(); + let tensor_5_cloned = tensor_5.clone(); - let first_call = move || { - let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned); - tensor_6_1.matmul(tensor_1_cloned) - }; + let first_call = move || { + let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned); + tensor_6_1.matmul(tensor_1_cloned) + }; - // Task 2 - let tensor_1_cloned = tensor_1.clone(); - let tensor_2_cloned = tensor_2.clone(); - let tensor_5_cloned = tensor_5; + // Task 2 + let tensor_1_cloned = tensor_1.clone(); + let tensor_2_cloned = tensor_2.clone(); + let tensor_5_cloned = tensor_5; - let second_call = move || { - let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned); - tensor_6_2.matmul(tensor_2_cloned) - }; + let second_call = move || { + let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned); + tensor_6_2.matmul(tensor_2_cloned) + }; - let tensor_7_1_handle = std::thread::spawn(first_call); - let tensor_7_2_handle = std::thread::spawn(second_call); + let tensor_7_1_handle = std::thread::spawn(first_call); + let tensor_7_2_handle = std::thread::spawn(second_call); - let tensor_7_1 = tensor_7_1_handle.join().unwrap(); - let tensor_7_2 = tensor_7_2_handle.join().unwrap(); - let tensor_8 = tensor_7_1.matmul(tensor_7_2); + let tensor_7_1 = tensor_7_1_handle.join().unwrap(); + let tensor_7_2 = tensor_7_2_handle.join().unwrap(); + let tensor_8 = tensor_7_1.matmul(tensor_7_2); - let grads = tensor_8.backward(); + let grads = tensor_8.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - (grad_1, grad_2) - }; - let without_move = || { - let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); + (grad_1, grad_2) + }; + let without_move = || { + let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3); - // Task 1 - let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone()); - let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone()); + // Task 1 + let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone()); + let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone()); - // Task 2 - let tensor_6_2 = tensor_5.matmul(tensor_1.clone()); - let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone()); + // Task 2 + let tensor_6_2 = tensor_5.matmul(tensor_1.clone()); + let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone()); - let tensor_8 = tensor_7_1.matmul(tensor_7_2); + let tensor_8 = tensor_7_1.matmul(tensor_7_2); - let grads = tensor_8.backward(); + let grads = tensor_8.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - (grad_1, grad_2) - }; + (grad_1, grad_2) + }; - let (grad_1, grad_2) = without_move(); - let (grad_1_moved, grad_2_moved) = with_move(); + let (grad_1, grad_2) = without_move(); + let (grad_1_moved, grad_2_moved) = with_move(); - assert_eq!(grad_1.to_data(), grad_1_moved.to_data()); - assert_eq!(grad_2.to_data(), grad_2_moved.to_data()); - } + assert_eq!(grad_1.to_data(), grad_1_moved.to_data()); + assert_eq!(grad_2.to_data(), grad_2_moved.to_data()); + } } diff --git a/burn-autodiff/src/tests/neg.rs b/burn-autodiff/src/tests/neg.rs index 83657ea1f5..51974cb025 100644 --- a/burn-autodiff/src/tests/neg.rs +++ b/burn-autodiff/src/tests/neg.rs @@ -1,24 +1,24 @@ #[burn_tensor_testgen::testgen(ad_neg)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_neg() { - let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); + #[test] + fn should_diff_neg() { + let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg()); - let tensor_4 = tensor_3.neg(); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg()); + let tensor_4 = tensor_3.neg(); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); + } } diff --git a/burn-autodiff/src/tests/pow.rs b/burn-autodiff/src/tests/pow.rs index 7321951ddc..aadbd1d88d 100644 --- a/burn-autodiff/src/tests/pow.rs +++ b/burn-autodiff/src/tests/pow.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_powf)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_powf() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_powf() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf(0.4)); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf(0.4)); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3); + } } diff --git a/burn-autodiff/src/tests/recip.rs b/burn-autodiff/src/tests/recip.rs index c77579e273..dc6911008e 100644 --- a/burn-autodiff/src/tests/recip.rs +++ b/burn-autodiff/src/tests/recip.rs @@ -1,20 +1,21 @@ #[burn_tensor_testgen::testgen(ad_recip)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_recip() { - let data = Data::from([2.0, 5.0, 0.4]); + #[test] + fn should_diff_recip() { + let data = Data::from([2.0, 5.0, 0.4]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().recip(); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().recip(); - let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grads = tensor_out.backward(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5])); - grad.to_data() - .assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3); - } + assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5])); + grad + .to_data() + .assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3); + } } diff --git a/burn-autodiff/src/tests/relu.rs b/burn-autodiff/src/tests/relu.rs index 57cfd51baa..13a6de9b24 100644 --- a/burn-autodiff/src/tests/relu.rs +++ b/burn-autodiff/src/tests/relu.rs @@ -1,25 +1,25 @@ #[burn_tensor_testgen::testgen(ad_relu)] mod tests { - use super::*; - use burn_tensor::{activation, Data}; + use super::*; + use burn_tensor::{activation, Data}; - #[test] - fn should_diff_relu() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + #[test] + fn should_diff_relu() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = activation::relu(tensor_3); - let tensor_5 = tensor_4.matmul(tensor_2.clone()); - let grads = tensor_5.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::relu(tensor_3); + let tensor_5 = tensor_4.matmul(tensor_2.clone()); + let grads = tensor_5.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[-47.0, 9.0], [-35.0, 15.0]])); - assert_eq!(grad_2.to_data(), Data::from([[15.0, 13.0], [-2.0, 39.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[-47.0, 9.0], [-35.0, 15.0]])); + assert_eq!(grad_2.to_data(), Data::from([[15.0, 13.0], [-2.0, 39.0]])); + } } diff --git a/burn-autodiff/src/tests/reshape.rs b/burn-autodiff/src/tests/reshape.rs index 057241aba5..7d3bc2cdd1 100644 --- a/burn-autodiff/src/tests/reshape.rs +++ b/burn-autodiff/src/tests/reshape.rs @@ -1,24 +1,24 @@ #[burn_tensor_testgen::testgen(ad_reshape)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_reshape() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([4.0, 7.0, 2.0, 3.0]); + #[test] + fn should_diff_reshape() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([4.0, 7.0, 2.0, 3.0]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_2.clone().reshape([2, 2]); - let tensor_4 = tensor_1.clone().matmul(tensor_3); - let grads = tensor_4.backward(); + let tensor_3 = tensor_2.clone().reshape([2, 2]); + let tensor_4 = tensor_1.clone().matmul(tensor_3); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.to_data(), Data::from([3.0, 3.0, 10.0, 10.0])); - } + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([3.0, 3.0, 10.0, 10.0])); + } } diff --git a/burn-autodiff/src/tests/select.rs b/burn-autodiff/src/tests/select.rs index 21c49f5242..9c20fa5068 100644 --- a/burn-autodiff/src/tests/select.rs +++ b/burn-autodiff/src/tests/select.rs @@ -1,54 +1,52 @@ #[burn_tensor_testgen::testgen(ad_select)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_select_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) - .require_grad(); - let indices = Tensor::::from_data(Data::from([1, 0])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().select(0, indices); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[109., 148., 187.], [37., 58., 79.]]) - ); - } - - #[test] - fn test_select_assign_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) - .require_grad(); - let values = TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - .require_grad(); - let indices = Tensor::::from_data(Data::from([1, 0])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().select_assign(0, indices, values.clone()); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = values.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[127., 199., 271.], [172., 244., 316.]]) - ); - assert_eq!( - grad_2.into_data(), - Data::from([[64., 64., 64.], [19., 19., 19.]]) - ); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_select_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); + let indices = Tensor::::from_data(Data::from([1, 0])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().select(0, indices); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[109., 148., 187.], [37., 58., 79.]]) + ); + } + + #[test] + fn test_select_assign_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); + let values = + TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad(); + let indices = Tensor::::from_data(Data::from([1, 0])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().select_assign(0, indices, values.clone()); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = values.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[127., 199., 271.], [172., 244., 316.]]) + ); + assert_eq!( + grad_2.into_data(), + Data::from([[64., 64., 64.], [19., 19., 19.]]) + ); + } } diff --git a/burn-autodiff/src/tests/sin.rs b/burn-autodiff/src/tests/sin.rs index 8462893d9a..2e7b544928 100644 --- a/burn-autodiff/src/tests/sin.rs +++ b/burn-autodiff/src/tests/sin.rs @@ -1,29 +1,29 @@ #[burn_tensor_testgen::testgen(ad_sin)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_sin() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_sin() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq_diff(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 2.6e-3); - grad_2.to_data().assert_approx_eq_diff( - &Data::from([[38.668987, 44.194775], [-59.97261, -80.46094]]), - 2.6e-3, - ); - } + grad_1 + .to_data() + .assert_approx_eq_diff(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 2.6e-3); + grad_2.to_data().assert_approx_eq_diff( + &Data::from([[38.668987, 44.194775], [-59.97261, -80.46094]]), + 2.6e-3, + ); + } } diff --git a/burn-autodiff/src/tests/slice.rs b/burn-autodiff/src/tests/slice.rs index 6b8b46d70b..d6bb7c1505 100644 --- a/burn-autodiff/src/tests/slice.rs +++ b/burn-autodiff/src/tests/slice.rs @@ -1,77 +1,77 @@ #[burn_tensor_testgen::testgen(ad_slice)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_matmul_with_slice() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_2.clone().slice([0..2, 0..2]); - let tensor_4 = tensor_1.clone().matmul(tensor_3); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!( - grad_2.to_data(), - Data::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]) - ); - } - - #[test] - fn should_diff_matmul_with_slice_assign() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_assigned: Data = Data::from([[9.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_assigned = TestAutodiffTensor::from_data(data_assigned).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned); - let tensor_5 = tensor_4.matmul(tensor_1.clone()); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[58.0, 38.0], [118.0, 82.0]])); - assert_eq!(grad_2.to_data(), Data::from([[16.0, 15.0], [24.0, 50.0]])); - } - - #[test] - fn should_diff_matmul_with_slice_assign_complex() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[9.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_2.clone().slice([0..1, 0..1]); - let tensor_6 = tensor_5.mul(tensor_3.clone()); - let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6); - let tensor_8 = tensor_7.matmul(tensor_1.clone()); - - let grads = tensor_8.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - let grad_3 = tensor_3.grad(&grads).unwrap(); - - assert_eq!(grad_3.to_data(), Data::from([[32.0]])); - assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]])); - assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]])); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_matmul_with_slice() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_2.clone().slice([0..2, 0..2]); + let tensor_4 = tensor_1.clone().matmul(tensor_3); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!( + grad_2.to_data(), + Data::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]) + ); + } + + #[test] + fn should_diff_matmul_with_slice_assign() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_assigned: Data = Data::from([[9.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_assigned = TestAutodiffTensor::from_data(data_assigned).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned); + let tensor_5 = tensor_4.matmul(tensor_1.clone()); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[58.0, 38.0], [118.0, 82.0]])); + assert_eq!(grad_2.to_data(), Data::from([[16.0, 15.0], [24.0, 50.0]])); + } + + #[test] + fn should_diff_matmul_with_slice_assign_complex() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[9.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_2.clone().slice([0..1, 0..1]); + let tensor_6 = tensor_5.mul(tensor_3.clone()); + let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6); + let tensor_8 = tensor_7.matmul(tensor_1.clone()); + + let grads = tensor_8.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_3 = tensor_3.grad(&grads).unwrap(); + + assert_eq!(grad_3.to_data(), Data::from([[32.0]])); + assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]])); + assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]])); + } } diff --git a/burn-autodiff/src/tests/softmax.rs b/burn-autodiff/src/tests/softmax.rs index e19651384a..c825ae3192 100644 --- a/burn-autodiff/src/tests/softmax.rs +++ b/burn-autodiff/src/tests/softmax.rs @@ -1,49 +1,72 @@ #[burn_tensor_testgen::testgen(ad_softmax)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; - - #[test] - fn test_softmax_grad() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = Tensor::::from_data(data_1).require_grad(); - let tensor_2 = Tensor::::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); - - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); - } - - #[test] - fn test_log_softmax_grad() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = Tensor::::from_data(data_1).require_grad(); - let tensor_2 = Tensor::::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone()); - - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3); - } + use super::*; + use burn_tensor::{activation, Data, Tensor}; + + #[test] + fn test_softmax_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); + + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); + } + + #[test] + fn test_log_softmax_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone()); + + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3); + } + + #[test] + fn test_quiet_softmax_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); + + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); + } } diff --git a/burn-autodiff/src/tests/sqrt.rs b/burn-autodiff/src/tests/sqrt.rs index c9d075fedd..94b0d6c860 100644 --- a/burn-autodiff/src/tests/sqrt.rs +++ b/burn-autodiff/src/tests/sqrt.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_sqrt)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_sqrt() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_sqrt() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[82.1126, 99.0832], [82.1126, 99.0832]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[30.3093, 33.1204], [34.5819, 38.7694]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[82.1126, 99.0832], [82.1126, 99.0832]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[30.3093, 33.1204], [34.5819, 38.7694]]), 3); + } } diff --git a/burn-autodiff/src/tests/sub.rs b/burn-autodiff/src/tests/sub.rs index 50beae42f3..b89850f506 100644 --- a/burn-autodiff/src/tests/sub.rs +++ b/burn-autodiff/src/tests/sub.rs @@ -1,60 +1,60 @@ #[burn_tensor_testgen::testgen(ad_sub)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_sub() { - let data_1 = Data::from([2.0, 5.0]); - let data_2 = Data::from([4.0, 1.0]); + #[test] + fn should_diff_sub() { + let data_1 = Data::from([2.0, 5.0]); + let data_2 = Data::from([4.0, 1.0]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().sub(tensor_2.clone()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().sub(tensor_2.clone()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); - assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0])); - assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); - } + assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0])); + assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); + } - #[test] - fn should_diff_sub_scalar() { - let data = Data::from([2.0, 10.0]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().sub_scalar(5.0); - let grads = tensor_out.backward(); + #[test] + fn should_diff_sub_scalar() { + let data = Data::from([2.0, 10.0]); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().sub_scalar(5.0); + let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); - assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); - } + assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); + } - #[test] - fn test_sub_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + #[test] + fn test_sub_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - let tensor_4 = tensor_1.clone().sub(tensor_2.clone()); - let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0); - let tensor_6 = tensor_1.clone().sub(tensor_5); + let tensor_4 = tensor_1.clone().sub(tensor_2.clone()); + let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0); + let tensor_6 = tensor_1.clone().sub(tensor_5); - let grads = tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); - assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); + assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); + } } diff --git a/burn-autodiff/src/tests/tanh.rs b/burn-autodiff/src/tests/tanh.rs index db1b884baf..3dc8700451 100644 --- a/burn-autodiff/src/tests/tanh.rs +++ b/burn-autodiff/src/tests/tanh.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_tanh)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_tanh() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_tanh() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[8.00092, 8.000153], [8.000003, 7.999995]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[8.00092, 8.000153], [8.000003, 7.999995]]), 3); + } } diff --git a/burn-autodiff/src/tests/transpose.rs b/burn-autodiff/src/tests/transpose.rs index bead7b4671..aaf4ecd952 100644 --- a/burn-autodiff/src/tests/transpose.rs +++ b/burn-autodiff/src/tests/transpose.rs @@ -1,50 +1,50 @@ #[burn_tensor_testgen::testgen(ad_transpose)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_transpose() { - let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose()); - let tensor_4 = tensor_3.transpose(); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]])); - assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]])); - } - - #[test] - fn should_diff_swap_dims() { - let tensor_1 = - TestAutodiffTensor::from_floats([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]) - .require_grad(); - let tensor_2 = - TestAutodiffTensor::from_floats([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]) - .require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2)); - let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2)); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]) - ); - assert_eq!( - grad_2.to_data(), - Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]) - ); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_transpose() { + let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose()); + let tensor_4 = tensor_3.transpose(); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]])); + } + + #[test] + fn should_diff_swap_dims() { + let tensor_1 = + TestAutodiffTensor::from_floats([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]) + .require_grad(); + let tensor_2 = + TestAutodiffTensor::from_floats([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]) + .require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2)); + let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2)); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]) + ); + assert_eq!( + grad_2.to_data(), + Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]) + ); + } } diff --git a/burn-autodiff/src/utils.rs b/burn-autodiff/src/utils.rs index 617c101e2a..56480f4805 100644 --- a/burn-autodiff/src/utils.rs +++ b/burn-autodiff/src/utils.rs @@ -9,16 +9,16 @@ use crate::graph::NodeRef; /// /// If the object is a tensor and if one reference exists, it can be updated inplace. pub fn duplicate( - nodes: &[Option; N], - obj: Option, + nodes: &[Option; N], + obj: Option, ) -> [Option; N] { - nodes - .iter() - .map(|node| match node { - Some(_) => obj.clone(), - None => None, - }) - .collect::>() - .try_into() - .unwrap() + nodes + .iter() + .map(|node| match node { + Some(_) => obj.clone(), + None => None, + }) + .collect::>() + .try_into() + .unwrap() } diff --git a/burn-candle/src/backend.rs b/burn-candle/src/backend.rs index c2bec24177..9aab0cc306 100644 --- a/burn-candle/src/backend.rs +++ b/burn-candle/src/backend.rs @@ -4,8 +4,8 @@ use burn_tensor::backend::Backend; use candle_core::DeviceLocation; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + CandleTensor, }; /// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations. @@ -15,11 +15,11 @@ use crate::{ #[derive(Clone, Copy, Default, Debug)] pub struct Candle where - F: FloatCandleElement, - I: IntCandleElement, + F: FloatCandleElement, + I: IntCandleElement, { - _float: PhantomData, - _int: PhantomData, + _float: PhantomData, + _int: PhantomData, } /// The device type for the candle backend. @@ -28,62 +28,62 @@ where /// /// Note that you need to provide the device index when using Cuda. pub enum CandleDevice { - /// CPU device. - Cpu, + /// CPU device. + Cpu, - /// Cuda device with the given index. The index is the index of the Cuda device in the list of - /// all Cuda devices found on the system. - Cuda(usize), + /// Cuda device with the given index. The index is the index of the Cuda device in the list of + /// all Cuda devices found on the system. + Cuda(usize), } impl From for candle_core::Device { - fn from(device: CandleDevice) -> Self { - match device { - CandleDevice::Cpu => candle_core::Device::Cpu, - CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(), - } + fn from(device: CandleDevice) -> Self { + match device { + CandleDevice::Cpu => candle_core::Device::Cpu, + CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(), } + } } impl From for CandleDevice { - fn from(device: candle_core::Device) -> Self { - match device.location() { - DeviceLocation::Cpu => CandleDevice::Cpu, - DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id), - } + fn from(device: candle_core::Device) -> Self { + match device.location() { + DeviceLocation::Cpu => CandleDevice::Cpu, + DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id), } + } } impl Default for CandleDevice { - fn default() -> Self { - Self::Cpu - } + fn default() -> Self { + Self::Cpu + } } impl Backend for Candle { - type Device = CandleDevice; + type Device = CandleDevice; - type FullPrecisionBackend = Candle; - type FullPrecisionElem = f32; + type FullPrecisionBackend = Candle; + type FullPrecisionElem = f32; - type TensorPrimitive = CandleTensor; - type FloatElem = F; + type TensorPrimitive = CandleTensor; + type FloatElem = F; - type IntTensorPrimitive = CandleTensor; - type IntElem = I; + type IntTensorPrimitive = CandleTensor; + type IntElem = I; - type BoolTensorPrimitive = CandleTensor; + type BoolTensorPrimitive = CandleTensor; - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn name() -> String { - "candle".to_string() - } + fn name() -> String { + "candle".to_string() + } - fn seed(seed: u64) { - // TODO submit an issue at Candle - panic!("Manual seed not supported by Candle. ") - } + fn seed(seed: u64) { + // TODO submit an issue at Candle + panic!("Manual seed not supported by Candle. ") + } } diff --git a/burn-candle/src/lib.rs b/burn-candle/src/lib.rs index a0ee7f0490..34d7b992be 100644 --- a/burn-candle/src/lib.rs +++ b/burn-candle/src/lib.rs @@ -15,132 +15,132 @@ pub use tensor::*; #[cfg(test)] mod tests { - extern crate alloc; - use super::*; + extern crate alloc; + use super::*; - pub type TestBackend = Candle; - pub type ReferenceBackend = burn_tch::LibTorch; + pub type TestBackend = Candle; + pub type ReferenceBackend = burn_tch::LibTorch; - pub type TestTensor = burn_tensor::Tensor; - pub type ReferenceTensor = burn_tensor::Tensor; - pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensor = burn_tensor::Tensor; + pub type ReferenceTensor = burn_tensor::Tensor; + pub type TestTensorInt = burn_tensor::Tensor; - type TestAutodiffBackend = burn_autodiff::Autodiff; - type TestAutodiffTensor = burn_tensor::Tensor; + type TestAutodiffBackend = burn_autodiff::Autodiff; + type TestAutodiffTensor = burn_tensor::Tensor; - // test activation - burn_tensor::testgen_gelu!(); - burn_tensor::testgen_relu!(); - burn_tensor::testgen_softmax!(); - burn_tensor::testgen_sigmoid!(); - burn_tensor::testgen_silu!(); + // test activation + burn_tensor::testgen_gelu!(); + burn_tensor::testgen_relu!(); + burn_tensor::testgen_softmax!(); + burn_tensor::testgen_sigmoid!(); + burn_tensor::testgen_silu!(); - // test module - burn_tensor::testgen_module_forward!(); - burn_tensor::testgen_module_conv1d!(); - // burn_tensor::testgen_module_conv2d!(); - // burn_tensor::testgen_module_conv_transpose1d!(); - // burn_tensor::testgen_module_conv_transpose2d!(); - // burn_tensor::testgen_module_max_pool1d!(); - // burn_tensor::testgen_module_max_pool2d!(); - // burn_tensor::testgen_module_avg_pool1d!(); - // burn_tensor::testgen_module_avg_pool2d!(); - // burn_tensor::testgen_module_adaptive_avg_pool1d!(); - // burn_tensor::testgen_module_adaptive_avg_pool2d!(); + // test module + burn_tensor::testgen_module_forward!(); + burn_tensor::testgen_module_conv1d!(); + // burn_tensor::testgen_module_conv2d!(); + // burn_tensor::testgen_module_conv_transpose1d!(); + // burn_tensor::testgen_module_conv_transpose2d!(); + // burn_tensor::testgen_module_max_pool1d!(); + // burn_tensor::testgen_module_max_pool2d!(); + // burn_tensor::testgen_module_avg_pool1d!(); + // burn_tensor::testgen_module_avg_pool2d!(); + // burn_tensor::testgen_module_adaptive_avg_pool1d!(); + // burn_tensor::testgen_module_adaptive_avg_pool2d!(); - // test ops - burn_tensor::testgen_add!(); - // burn_tensor::testgen_aggregation!(); - burn_tensor::testgen_arange!(); - burn_tensor::testgen_arange_step!(); - burn_tensor::testgen_arg!(); - burn_tensor::testgen_cast!(); - burn_tensor::testgen_cat!(); - burn_tensor::testgen_recip!(); - burn_tensor::testgen_clamp!(); - burn_tensor::testgen_cos!(); - // burn_tensor::testgen_div!(); - burn_tensor::testgen_erf!(); - burn_tensor::testgen_exp!(); - burn_tensor::testgen_flatten!(); - burn_tensor::testgen_full!(); - burn_tensor::testgen_gather_scatter!(); - burn_tensor::testgen_init!(); - burn_tensor::testgen_log!(); - burn_tensor::testgen_log1p!(); - burn_tensor::testgen_map_comparison!(); - burn_tensor::testgen_mask!(); - burn_tensor::testgen_matmul!(); - burn_tensor::testgen_maxmin!(); - burn_tensor::testgen_mul!(); - burn_tensor::testgen_neg!(); - burn_tensor::testgen_powf!(); - burn_tensor::testgen_random!(); - // burn_tensor::testgen_repeat!(); - burn_tensor::testgen_reshape!(); - burn_tensor::testgen_select!(); - burn_tensor::testgen_sin!(); - // burn_tensor::testgen_slice!(); - burn_tensor::testgen_sqrt!(); - burn_tensor::testgen_abs!(); - burn_tensor::testgen_squeeze!(); - burn_tensor::testgen_sub!(); - burn_tensor::testgen_tanh!(); - burn_tensor::testgen_transpose!(); + // test ops + burn_tensor::testgen_add!(); + // burn_tensor::testgen_aggregation!(); + burn_tensor::testgen_arange!(); + burn_tensor::testgen_arange_step!(); + burn_tensor::testgen_arg!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_recip!(); + burn_tensor::testgen_clamp!(); + burn_tensor::testgen_cos!(); + // burn_tensor::testgen_div!(); + burn_tensor::testgen_erf!(); + burn_tensor::testgen_exp!(); + burn_tensor::testgen_flatten!(); + burn_tensor::testgen_full!(); + burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_init!(); + burn_tensor::testgen_log!(); + burn_tensor::testgen_log1p!(); + burn_tensor::testgen_map_comparison!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_matmul!(); + burn_tensor::testgen_maxmin!(); + burn_tensor::testgen_mul!(); + burn_tensor::testgen_neg!(); + burn_tensor::testgen_powf!(); + burn_tensor::testgen_random!(); + // burn_tensor::testgen_repeat!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_select!(); + burn_tensor::testgen_sin!(); + // burn_tensor::testgen_slice!(); + burn_tensor::testgen_sqrt!(); + burn_tensor::testgen_abs!(); + burn_tensor::testgen_squeeze!(); + burn_tensor::testgen_sub!(); + burn_tensor::testgen_tanh!(); + burn_tensor::testgen_transpose!(); - // test stats - burn_tensor::testgen_var!(); - burn_tensor::testgen_display!(); + // test stats + burn_tensor::testgen_var!(); + burn_tensor::testgen_display!(); - // Behavior - // burn_autodiff::testgen_ad_broadcast!(); + // Behavior + // burn_autodiff::testgen_ad_broadcast!(); - // Activation - burn_autodiff::testgen_ad_relu!(); - burn_autodiff::testgen_ad_gelu!(); + // Activation + burn_autodiff::testgen_ad_relu!(); + burn_autodiff::testgen_ad_gelu!(); - // Modules - // burn_autodiff::testgen_ad_conv1d!(); - // burn_autodiff::testgen_ad_conv2d!(); - // burn_autodiff::testgen_ad_conv_transpose1d!(); - // burn_autodiff::testgen_ad_conv_transpose2d!(); - // burn_autodiff::testgen_ad_max_pool1d!(); - // burn_autodiff::testgen_ad_max_pool2d!(); - // burn_autodiff::testgen_ad_avg_pool1d!(); - // burn_autodiff::testgen_ad_avg_pool2d!(); - // burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); - // burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); - burn_autodiff::testgen_module_backward!(); + // Modules + // burn_autodiff::testgen_ad_conv1d!(); + // burn_autodiff::testgen_ad_conv2d!(); + // burn_autodiff::testgen_ad_conv_transpose1d!(); + // burn_autodiff::testgen_ad_conv_transpose2d!(); + // burn_autodiff::testgen_ad_max_pool1d!(); + // burn_autodiff::testgen_ad_max_pool2d!(); + // burn_autodiff::testgen_ad_avg_pool1d!(); + // burn_autodiff::testgen_ad_avg_pool2d!(); + // burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); + // burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); + burn_autodiff::testgen_module_backward!(); - // Tensor - burn_autodiff::testgen_ad_complex!(); - burn_autodiff::testgen_ad_multithread!(); - burn_autodiff::testgen_ad_add!(); - burn_autodiff::testgen_ad_aggregation!(); - burn_autodiff::testgen_ad_maxmin!(); - // burn_autodiff::testgen_ad_cat!(); - burn_autodiff::testgen_ad_cos!(); - burn_autodiff::testgen_ad_cross_entropy_loss!(); - burn_autodiff::testgen_ad_div!(); - burn_autodiff::testgen_ad_erf!(); - burn_autodiff::testgen_ad_exp!(); - // burn_autodiff::testgen_ad_slice!(); - burn_autodiff::testgen_ad_gather_scatter!(); - burn_autodiff::testgen_ad_select!(); - burn_autodiff::testgen_ad_log!(); - burn_autodiff::testgen_ad_log1p!(); - burn_autodiff::testgen_ad_mask!(); - burn_autodiff::testgen_ad_matmul!(); - burn_autodiff::testgen_ad_mul!(); - burn_autodiff::testgen_ad_neg!(); - burn_autodiff::testgen_ad_powf!(); - burn_autodiff::testgen_ad_recip!(); - burn_autodiff::testgen_ad_reshape!(); - burn_autodiff::testgen_ad_sin!(); - burn_autodiff::testgen_ad_softmax!(); - burn_autodiff::testgen_ad_sqrt!(); - burn_autodiff::testgen_ad_abs!(); - burn_autodiff::testgen_ad_sub!(); - burn_autodiff::testgen_ad_tanh!(); - burn_autodiff::testgen_ad_transpose!(); + // Tensor + burn_autodiff::testgen_ad_complex!(); + burn_autodiff::testgen_ad_multithread!(); + burn_autodiff::testgen_ad_add!(); + burn_autodiff::testgen_ad_aggregation!(); + burn_autodiff::testgen_ad_maxmin!(); + // burn_autodiff::testgen_ad_cat!(); + burn_autodiff::testgen_ad_cos!(); + burn_autodiff::testgen_ad_cross_entropy_loss!(); + burn_autodiff::testgen_ad_div!(); + burn_autodiff::testgen_ad_erf!(); + burn_autodiff::testgen_ad_exp!(); + // burn_autodiff::testgen_ad_slice!(); + burn_autodiff::testgen_ad_gather_scatter!(); + burn_autodiff::testgen_ad_select!(); + burn_autodiff::testgen_ad_log!(); + burn_autodiff::testgen_ad_log1p!(); + burn_autodiff::testgen_ad_mask!(); + burn_autodiff::testgen_ad_matmul!(); + burn_autodiff::testgen_ad_mul!(); + burn_autodiff::testgen_ad_neg!(); + burn_autodiff::testgen_ad_powf!(); + burn_autodiff::testgen_ad_recip!(); + burn_autodiff::testgen_ad_reshape!(); + burn_autodiff::testgen_ad_sin!(); + burn_autodiff::testgen_ad_softmax!(); + burn_autodiff::testgen_ad_sqrt!(); + burn_autodiff::testgen_ad_abs!(); + burn_autodiff::testgen_ad_sub!(); + burn_autodiff::testgen_ad_tanh!(); + burn_autodiff::testgen_ad_transpose!(); } diff --git a/burn-candle/src/ops/activation.rs b/burn-candle/src/ops/activation.rs index 0cedb23aa5..fadeb94e7f 100644 --- a/burn-candle/src/ops/activation.rs +++ b/burn-candle/src/ops/activation.rs @@ -1,16 +1,16 @@ use burn_tensor::ops::{ActivationOps, FloatTensor}; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - tensor, Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + tensor, Candle, CandleTensor, }; impl ActivationOps for Candle { - fn gelu(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.gelu().unwrap()) - } + fn gelu(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.gelu().unwrap()) + } - fn relu(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.relu().unwrap()) - } + fn relu(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.relu().unwrap()) + } } diff --git a/burn-candle/src/ops/base.rs b/burn-candle/src/ops/base.rs index 643c27a73d..a241622fb6 100644 --- a/burn-candle/src/ops/base.rs +++ b/burn-candle/src/ops/base.rs @@ -3,88 +3,88 @@ use std::marker::PhantomData; use burn_tensor::{backend::Backend, Data, Reader, Shape}; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleDevice, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleDevice, CandleTensor, }; use super::tensor; pub fn cat( - tensors: Vec>, - dim: usize, + tensors: Vec>, + dim: usize, ) -> CandleTensor { - let tensors: Vec = tensors.into_iter().map(|t| t.tensor).collect(); - CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap()) + let tensors: Vec = tensors.into_iter().map(|t| t.tensor).collect(); + CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap()) } pub fn from_data( - data: Data, - device: &CandleDevice, + data: Data, + device: &CandleDevice, ) -> CandleTensor { - CandleTensor::from_data(data, *device) + CandleTensor::from_data(data, *device) } pub fn into_data(tensor: CandleTensor) -> Data { - Data::new( - tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(), - tensor.shape(), - ) + Data::new( + tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(), + tensor.shape(), + ) } pub fn to_device( - tensor: CandleTensor, - device: &CandleDevice, + tensor: CandleTensor, + device: &CandleDevice, ) -> CandleTensor { - CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap()) + CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap()) } pub fn empty( - shape: Shape, - device: &CandleDevice, + shape: Shape, + device: &CandleDevice, ) -> CandleTensor { - CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(*device).into()).unwrap()) + CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(*device).into()).unwrap()) } pub fn swap_dims( - mut tensor: CandleTensor, - dim1: usize, - dim2: usize, + mut tensor: CandleTensor, + dim1: usize, + dim2: usize, ) -> CandleTensor { - CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap()) + CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap()) } pub fn reshape( - tensor: CandleTensor, - shape: Shape, + tensor: CandleTensor, + shape: Shape, ) -> CandleTensor { - CandleTensor::new(tensor.tensor.reshape(&shape.dims).unwrap()) + CandleTensor::new(tensor.tensor.reshape(&shape.dims).unwrap()) } pub fn device(tensor: &CandleTensor) -> CandleDevice { - tensor.tensor.device().clone().into() + tensor.tensor.device().clone().into() } pub fn shape(tensor: &CandleTensor) -> Shape { - tensor.shape() + tensor.shape() } pub fn slice( - tensor: CandleTensor, - ranges: [std::ops::Range; D2], + tensor: CandleTensor, + ranges: [std::ops::Range; D2], ) -> CandleTensor { - let mut narrow_tensor = tensor.tensor; - for (i, range) in ranges.iter().enumerate().take(D2) { - narrow_tensor = narrow_tensor - .narrow(i, range.start, range.end - range.start) - .unwrap() - } - CandleTensor::new(narrow_tensor) + let mut narrow_tensor = tensor.tensor; + for (i, range) in ranges.iter().enumerate().take(D2) { + narrow_tensor = narrow_tensor + .narrow(i, range.start, range.end - range.start) + .unwrap() + } + CandleTensor::new(narrow_tensor) } pub fn slice_assign( - tensor: CandleTensor, - ranges: [std::ops::Range; D2], - value: CandleTensor, + tensor: CandleTensor, + ranges: [std::ops::Range; D2], + value: CandleTensor, ) -> CandleTensor { - panic!("slice_assign not supported by Candle") + panic!("slice_assign not supported by Candle") } diff --git a/burn-candle/src/ops/bool_tensor.rs b/burn-candle/src/ops/bool_tensor.rs index e5fd153f85..4cd6fa0da9 100644 --- a/burn-candle/src/ops/bool_tensor.rs +++ b/burn-candle/src/ops/bool_tensor.rs @@ -1,112 +1,113 @@ use burn_tensor::{ - ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor}, - Data, Device, Reader, Shape, + ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor}, + Data, Device, Reader, Shape, }; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleTensor, }; impl BoolTensorOps for Candle { - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - super::base::empty(shape, device) - } - - fn bool_shape(tensor: &BoolTensor) -> Shape { - super::base::shape(tensor) - } - - fn bool_into_data(tensor: BoolTensor) -> Reader> { - let x: Vec = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(); - let y = x.iter().map(|b| !matches!(b, 0)).collect(); - let data = Data::new(y, tensor.shape()); - - Reader::Concrete(data) - } - - fn bool_from_data( - data: Data, - device: &Device, - ) -> BoolTensor { - let data: Data = Data::new( - data.value - .into_iter() - .map(|c| match c { - true => 1, - false => 0, - }) - .collect(), - data.shape, - ); - super::base::from_data(data, device) - } - - fn bool_into_int(tensor: BoolTensor) -> IntTensor { - CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) - } - - fn bool_into_float(tensor: BoolTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) - } - - fn bool_device(tensor: &BoolTensor) -> Device { - super::base::device(tensor) - } - - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - super::base::to_device(tensor, device) - } - - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - super::base::reshape(tensor, shape) - } - - fn bool_slice( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - ) -> BoolTensor { - super::base::slice(tensor, ranges) - } - - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - value: BoolTensor, - ) -> BoolTensor { - super::base::slice_assign(tensor, ranges, value) - } - - fn bool_cat( - tensors: Vec>, - dim: usize, - ) -> BoolTensor { - super::base::cat(tensors, dim) - } - - fn bool_equal( - lhs: BoolTensor, - rhs: BoolTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) - } - - fn bool_not(tensor: BoolTensor) -> BoolTensor { - let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap()); - CandleTensor::new(tensor.tensor.eq(&x).unwrap()) - } - - fn bool_swap_dims( - tensor: as burn_tensor::backend::Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { - super::base::swap_dims(tensor, dim1, dim2) - } + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + super::base::empty(shape, device) + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + super::base::shape(tensor) + } + + fn bool_into_data(tensor: BoolTensor) -> Reader> { + let x: Vec = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(); + let y = x.iter().map(|b| !matches!(b, 0)).collect(); + let data = Data::new(y, tensor.shape()); + + Reader::Concrete(data) + } + + fn bool_from_data( + data: Data, + device: &Device, + ) -> BoolTensor { + let data: Data = Data::new( + data + .value + .into_iter() + .map(|c| match c { + true => 1, + false => 0, + }) + .collect(), + data.shape, + ); + super::base::from_data(data, device) + } + + fn bool_into_int(tensor: BoolTensor) -> IntTensor { + CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) + } + + fn bool_into_float(tensor: BoolTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) + } + + fn bool_device(tensor: &BoolTensor) -> Device { + super::base::device(tensor) + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + super::base::to_device(tensor, device) + } + + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + super::base::reshape(tensor, shape) + } + + fn bool_slice( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + ) -> BoolTensor { + super::base::slice(tensor, ranges) + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + value: BoolTensor, + ) -> BoolTensor { + super::base::slice_assign(tensor, ranges, value) + } + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> BoolTensor { + super::base::cat(tensors, dim) + } + + fn bool_equal( + lhs: BoolTensor, + rhs: BoolTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap()); + CandleTensor::new(tensor.tensor.eq(&x).unwrap()) + } + + fn bool_swap_dims( + tensor: as burn_tensor::backend::Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { + super::base::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-candle/src/ops/candle_utils.rs b/burn-candle/src/ops/candle_utils.rs index 2a9b92cb37..499b46cff3 100644 --- a/burn-candle/src/ops/candle_utils.rs +++ b/burn-candle/src/ops/candle_utils.rs @@ -3,23 +3,23 @@ use candle_core::{DType, Device, Shape, Tensor}; use crate::element::CandleElement; pub(crate) fn fill>( - value: E, - shape: S, - dtype: DType, - device: &Device, + value: E, + shape: S, + dtype: DType, + device: &Device, ) -> Tensor { - let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::()).unwrap(); - values.expand(shape).unwrap() + let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::()).unwrap(); + values.expand(shape).unwrap() } pub(crate) fn fill_like( - value: E, - reference_tensor: &Tensor, + value: E, + reference_tensor: &Tensor, ) -> Tensor { - fill( - value, - reference_tensor.shape(), - reference_tensor.dtype(), - reference_tensor.device(), - ) + fill( + value, + reference_tensor.shape(), + reference_tensor.dtype(), + reference_tensor.device(), + ) } diff --git a/burn-candle/src/ops/int_tensor.rs b/burn-candle/src/ops/int_tensor.rs index c1c502bf33..429d27029a 100644 --- a/burn-candle/src/ops/int_tensor.rs +++ b/burn-candle/src/ops/int_tensor.rs @@ -1,362 +1,365 @@ use burn_tensor::{ - ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, - Bool, Data, Device, Reader, Shape, + ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, + Bool, Data, Device, Reader, Shape, }; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleTensor, }; impl IntTensorOps for Candle { - fn int_empty(shape: Shape, device: &Device) -> IntTensor { - super::base::empty(shape, device) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - super::base::shape(tensor) - } - - fn int_into_data(tensor: IntTensor) -> Reader, D>> { - Reader::Concrete(super::base::into_data(tensor)) - } - - fn int_from_data( - data: Data, D>, - device: &Device, - ) -> IntTensor { - super::base::from_data(data, device) - } - - fn int_device(tensor: &IntTensor) -> Device { - super::base::device(tensor) - } - - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - super::base::to_device(tensor, device) - } - - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - super::base::reshape(tensor, shape) - } - - fn int_slice( - tensor: IntTensor, - indices: [std::ops::Range; D2], - ) -> IntTensor { - super::base::slice(tensor, indices) - } - - fn int_slice_assign( - tensor: IntTensor, - indices: [std::ops::Range; D2], - value: IntTensor, - ) -> IntTensor { - super::base::slice_assign(tensor, indices, value) - } - - fn int_into_float(tensor: IntTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) - } - - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - source: IntTensor, - ) -> IntTensor { - CandleTensor::new( - mask.tensor - .where_cond(&source.tensor, &tensor.tensor) - .unwrap(), + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + super::base::empty(shape, device) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + super::base::shape(tensor) + } + + fn int_into_data(tensor: IntTensor) -> Reader, D>> { + Reader::Concrete(super::base::into_data(tensor)) + } + + fn int_from_data( + data: Data, D>, + device: &Device, + ) -> IntTensor { + super::base::from_data(data, device) + } + + fn int_device(tensor: &IntTensor) -> Device { + super::base::device(tensor) + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + super::base::to_device(tensor, device) + } + + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + super::base::reshape(tensor, shape) + } + + fn int_slice( + tensor: IntTensor, + indices: [std::ops::Range; D2], + ) -> IntTensor { + super::base::slice(tensor, indices) + } + + fn int_slice_assign( + tensor: IntTensor, + indices: [std::ops::Range; D2], + value: IntTensor, + ) -> IntTensor { + super::base::slice_assign(tensor, indices, value) + } + + fn int_into_float(tensor: IntTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + source: IntTensor, + ) -> IntTensor { + CandleTensor::new( + mask + .tensor + .where_cond(&source.tensor, &tensor.tensor) + .unwrap(), + ) + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor { + CandleTensor::new( + mask + .tensor + .where_cond( + &super::candle_utils::fill_like::(value, &tensor.tensor), + &tensor.tensor, ) - } - - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor { - CandleTensor::new( - mask.tensor - .where_cond( - &super::candle_utils::fill_like::(value, &tensor.tensor), - &tensor.tensor, - ) - .unwrap(), - ) - } - - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .scatter_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) - } - - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .index_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - super::base::cat(tensors, dim) - } - - fn int_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) - } - - fn int_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_greater( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) - } - - fn int_greater_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_lower( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) - } - - fn int_lower_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) - } - - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_add( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) - } - - fn int_add_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) - } - - fn int_sub( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) - } - - fn int_sub_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) - } - - fn int_mul( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) - } - - fn int_mul_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) - } - - fn int_div( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) - } - - fn int_div_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. - panic!("Not supported by Candle") - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - CandleTensor::new( - candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(*device).into()).unwrap(), - ) - } - - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - CandleTensor::new( - candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(*device).into()).unwrap(), - ) - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); - CandleTensor::from_data( - Data::new([sum].into(), [1].into()), - Self::int_device(&tensor), - ) - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. - panic!("Not supported by Candle") - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmax_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmin_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn int_abs(tensor: IntTensor) -> IntTensor { - // Ugly type conversion here as Candle does not support unary ops on ints - CandleTensor::new( - tensor - .tensor - .to_dtype(F::DTYPE) - .unwrap() - .abs() - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn int_swap_dims( - tensor: as burn_tensor::backend::Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as burn_tensor::backend::Backend>::IntTensorPrimitive { - super::base::swap_dims(tensor, dim1, dim2) - } + .unwrap(), + ) + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .scatter_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .index_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + super::base::cat(tensors, dim) + } + + fn int_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) + } + + fn int_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_greater( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) + } + + fn int_greater_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_lower( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) + } + + fn int_lower_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) + } + + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_add( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) + } + + fn int_add_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) + } + + fn int_sub( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) + } + + fn int_sub_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) + } + + fn int_mul( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) + } + + fn int_mul_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) + } + + fn int_div( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) + } + + fn int_div_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. + panic!("Not supported by Candle") + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(*device).into()).unwrap()) + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + CandleTensor::new(candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(*device).into()).unwrap()) + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); + CandleTensor::from_data( + Data::new([sum].into(), [1].into()), + Self::int_device(&tensor), + ) + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. + panic!("Not supported by Candle") + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmax_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmin_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn int_abs(tensor: IntTensor) -> IntTensor { + // Ugly type conversion here as Candle does not support unary ops on ints + CandleTensor::new( + tensor + .tensor + .to_dtype(F::DTYPE) + .unwrap() + .abs() + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn int_swap_dims( + tensor: as burn_tensor::backend::Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as burn_tensor::backend::Backend>::IntTensorPrimitive { + super::base::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-candle/src/ops/module.rs b/burn-candle/src/ops/module.rs index 0a169277fa..f8c1ae719f 100644 --- a/burn-candle/src/ops/module.rs +++ b/burn-candle/src/ops/module.rs @@ -1,223 +1,220 @@ use burn_tensor::{ - ops::{ - ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool2dBackward, - MaxPool2dWithIndices, ModuleOps, UnfoldOptions, - }, - Shape, + ops::{ + ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool2dBackward, + MaxPool2dWithIndices, ModuleOps, UnfoldOptions, + }, + Shape, }; use candle_core::ToUsize2; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - ops::base::reshape, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + ops::base::reshape, + Candle, CandleTensor, }; impl ModuleOps for Candle { - fn conv1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> FloatTensor { - let conv = x - .tensor - .conv1d( - &weight.tensor, - options.padding[0], - options.stride[0], - options.dilation[0], - options.groups, - ) - .unwrap(); - CandleTensor::new(match bias { - Some(bias) => conv - .broadcast_add(&bias.tensor.unsqueeze(1).unwrap()) - .unwrap(), - None => conv, - }) - } + fn conv1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> FloatTensor { + let conv = x + .tensor + .conv1d( + &weight.tensor, + options.padding[0], + options.stride[0], + options.dilation[0], + options.groups, + ) + .unwrap(); + CandleTensor::new(match bias { + Some(bias) => conv + .broadcast_add(&bias.tensor.unsqueeze(1).unwrap()) + .unwrap(), + None => conv, + }) + } - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor { - assert!( - options.dilation[0] == options.dilation[1] - && options.padding[0] == options.padding[1] - && options.stride[0] == options.stride[1], - "Candle does not support per dimension options in convolutions" - ); - let conv = x + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor { + assert!( + options.dilation[0] == options.dilation[1] + && options.padding[0] == options.padding[1] + && options.stride[0] == options.stride[1], + "Candle does not support per dimension options in convolutions" + ); + let conv = x + .tensor + .conv2d( + &weight.tensor, + options.padding[0], + options.stride[0], + options.dilation[0], + options.groups, + ) + .unwrap(); + CandleTensor::new(match bias { + Some(bias) => conv + .broadcast_add( + &bias .tensor - .conv2d( - &weight.tensor, - options.padding[0], - options.stride[0], - options.dilation[0], - options.groups, - ) - .unwrap(); - CandleTensor::new(match bias { - Some(bias) => conv - .broadcast_add( - &bias - .tensor - .unsqueeze(0) - .unwrap() - .unsqueeze(2) - .unwrap() - .unsqueeze(3) - .unwrap(), - ) - .unwrap(), - None => conv, - }) - } + .unsqueeze(0) + .unwrap() + .unsqueeze(2) + .unwrap() + .unsqueeze(3) + .unwrap(), + ) + .unwrap(), + None => conv, + }) + } - fn conv_transpose1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> FloatTensor { - panic!("Candle does not support conv_transpose1d") - } + fn conv_transpose1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + panic!("Candle does not support conv_transpose1d") + } - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor { - assert!( - options.dilation[0] == options.dilation[1] - && options.padding[0] == options.padding[1] - && options.padding_out[0] == options.padding_out[1] - && options.stride[0] == options.stride[1], - "Candle does not support per dimension options in transposed convolutions" - ); - assert!( - options.groups == 1, - "Candle does not support groups in transposed convolutions" - ); - let conv_transpose = x + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + assert!( + options.dilation[0] == options.dilation[1] + && options.padding[0] == options.padding[1] + && options.padding_out[0] == options.padding_out[1] + && options.stride[0] == options.stride[1], + "Candle does not support per dimension options in transposed convolutions" + ); + assert!( + options.groups == 1, + "Candle does not support groups in transposed convolutions" + ); + let conv_transpose = x + .tensor + .conv_transpose2d( + &weight.tensor, + options.padding[0], + options.padding_out[0], + options.stride[0], + options.dilation[0], + ) + .unwrap(); + CandleTensor::new(match bias { + Some(bias) => conv_transpose + .broadcast_add( + &bias .tensor - .conv_transpose2d( - &weight.tensor, - options.padding[0], - options.padding_out[0], - options.stride[0], - options.dilation[0], - ) - .unwrap(); - CandleTensor::new(match bias { - Some(bias) => conv_transpose - .broadcast_add( - &bias - .tensor - .unsqueeze(0) - .unwrap() - .unsqueeze(2) - .unwrap() - .unsqueeze(3) - .unwrap(), - ) - .unwrap(), - None => conv_transpose, - }) - } - - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - assert!( - padding[0] == 0 && padding[1] == 0, - "Candle does not support padding in pooling" - ); - assert!( - count_include_pad, - "Candle does not support excluding pad count in pooling" - ); - CandleTensor::new( - x.tensor - .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) - .unwrap(), + .unsqueeze(0) + .unwrap() + .unsqueeze(2) + .unwrap() + .unsqueeze(3) + .unwrap(), ) - } + .unwrap(), + None => conv_transpose, + }) + } - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - panic!("avg_pool2d_backward is not supported by Candle") - } + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + assert!( + padding[0] == 0 && padding[1] == 0, + "Candle does not support padding in pooling" + ); + assert!( + count_include_pad, + "Candle does not support excluding pad count in pooling" + ); + CandleTensor::new( + x.tensor + .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) + .unwrap(), + ) + } - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor { - assert!( - padding[0] == 0 && padding[1] == 0, - "Candle does not support padding in pooling" - ); - assert!( - dilation[0] == 1 && dilation[1] == 1, - "Candle does not support dilation in pooling" - ); - CandleTensor::new( - x.tensor - .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) - .unwrap(), - ) - } + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + panic!("avg_pool2d_backward is not supported by Candle") + } + + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor { + assert!( + padding[0] == 0 && padding[1] == 0, + "Candle does not support padding in pooling" + ); + assert!( + dilation[0] == 1 && dilation[1] == 1, + "Candle does not support dilation in pooling" + ); + CandleTensor::new( + x.tensor + .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) + .unwrap(), + ) + } - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - panic!("max_pool2d_with_indices is not supported by Candle") - } + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + panic!("max_pool2d_with_indices is not supported by Candle") + } - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward> { - panic!("max_pool2d_with_indices_backward is not supported by Candle") - } + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward> { + panic!("max_pool2d_with_indices_backward is not supported by Candle") + } - fn adaptive_avg_pool2d( - x: FloatTensor, - output_size: [usize; 2], - ) -> FloatTensor { - panic!("adaptive_avg_pool2 is not supported by Candle") - } + fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { + panic!("adaptive_avg_pool2 is not supported by Candle") + } - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - panic!("adaptive_avg_pool2d_backward is not supported by Candle") - } + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + panic!("adaptive_avg_pool2d_backward is not supported by Candle") + } } diff --git a/burn-candle/src/ops/tensor.rs b/burn-candle/src/ops/tensor.rs index 3f096e92f4..5da1bf86cc 100644 --- a/burn-candle/src/ops/tensor.rs +++ b/burn-candle/src/ops/tensor.rs @@ -1,449 +1,456 @@ use std::borrow::Borrow; use burn_tensor::{ - ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, - Data, Device, Distribution, ElementConversion, Reader, Shape, + ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, + Data, Device, Distribution, ElementConversion, Reader, Shape, }; use candle_core::{backend::BackendStorage, shape, Tensor}; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleTensor, }; impl TensorOps for Candle { - fn from_data(data: Data, device: &Device) -> CandleTensor { - CandleTensor::from_data(data, *device) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: &Device, - ) -> FloatTensor { - let shape = &shape.dims; - let device = &(*device).into(); - match distribution { - Distribution::Default => CandleTensor::new( - candle_core::Tensor::rand(0., 1., shape, device) - .unwrap() - .to_dtype(F::DTYPE) - .unwrap(), - ), - Distribution::Bernoulli(prob) => CandleTensor::new( - candle_core::Tensor::rand(0., 1., shape, device) - .unwrap() - .to_dtype(F::DTYPE) - .unwrap() - .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device)) - .unwrap() - .to_dtype(F::DTYPE) - .unwrap(), - ), - Distribution::Uniform(from, to) => { - CandleTensor::new(candle_core::Tensor::rand(from, to, shape, device).unwrap()) - } - Distribution::Normal(mean, std) => { - CandleTensor::new(candle_core::Tensor::randn(mean, std, shape, device).unwrap()) - } - } - } - - fn shape(tensor: &CandleTensor) -> Shape { - super::base::shape(tensor) - } - - fn into_data(tensor: CandleTensor) -> Reader> { - Reader::Concrete(super::base::into_data(tensor)) - } - - fn device(tensor: &CandleTensor) -> Device { - super::base::device(tensor) - } - - fn to_device( - tensor: CandleTensor, - device: &Device, - ) -> CandleTensor { - super::base::to_device(tensor, device) - } - - fn into_int(tensor: CandleTensor) -> IntTensor { - CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) - } - - fn empty(shape: Shape, device: &Device) -> FloatTensor { - super::base::empty(shape, device) - } - - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) - } - - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor / rhs.elem::()).unwrap()) - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs.tensor).unwrap()) - } - - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - super::base::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - super::base::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - CandleTensor::new( - tensor - .tensor - .scatter_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - CandleTensor::new( - tensor - .tensor - .index_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn slice( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - ) -> FloatTensor { - super::base::slice(tensor, ranges) - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - value: FloatTensor, - ) -> FloatTensor { - super::base::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor { - CandleTensor::new( - mask.tensor - .where_cond(&value.tensor, &tensor.tensor) - .unwrap(), - ) - } - - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - CandleTensor::new( - mask.tensor - .where_cond( - &super::candle_utils::fill_like::(value, &tensor.tensor), - &tensor.tensor, - ) - .unwrap(), - ) - } - - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) - } - - fn equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), + fn from_data(data: Data, device: &Device) -> CandleTensor { + CandleTensor::from_data(data, *device) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: &Device, + ) -> FloatTensor { + let shape = &shape.dims; + let device = &(*device).into(); + match distribution { + Distribution::Default => CandleTensor::new( + candle_core::Tensor::rand(0., 1., shape, device) + .unwrap() + .to_dtype(F::DTYPE) + .unwrap(), + ), + Distribution::Bernoulli(prob) => CandleTensor::new( + candle_core::Tensor::rand(0., 1., shape, device) + .unwrap() + .to_dtype(F::DTYPE) + .unwrap() + .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device)) + .unwrap() + .to_dtype(F::DTYPE) + .unwrap(), + ), + Distribution::Uniform(from, to) => { + CandleTensor::new(candle_core::Tensor::rand(from, to, shape, device).unwrap()) + } + Distribution::Normal(mean, std) => { + CandleTensor::new(candle_core::Tensor::randn(mean, std, shape, device).unwrap()) + } + } + } + + fn shape(tensor: &CandleTensor) -> Shape { + super::base::shape(tensor) + } + + fn into_data(tensor: CandleTensor) -> Reader> { + Reader::Concrete(super::base::into_data(tensor)) + } + + fn device(tensor: &CandleTensor) -> Device { + super::base::device(tensor) + } + + fn to_device( + tensor: CandleTensor, + device: &Device, + ) -> CandleTensor { + super::base::to_device(tensor, device) + } + + fn into_int(tensor: CandleTensor) -> IntTensor { + CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + super::base::empty(shape, device) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor / rhs.elem::()).unwrap()) + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs.tensor).unwrap()) + } + + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + super::base::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + super::base::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + CandleTensor::new( + tensor + .tensor + .scatter_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + CandleTensor::new( + tensor + .tensor + .index_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn slice( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + ) -> FloatTensor { + super::base::slice(tensor, ranges) + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + value: FloatTensor, + ) -> FloatTensor { + super::base::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor { + CandleTensor::new( + mask + .tensor + .where_cond(&value.tensor, &tensor.tensor) + .unwrap(), + ) + } + + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + CandleTensor::new( + mask + .tensor + .where_cond( + &super::candle_utils::fill_like::(value, &tensor.tensor), + &tensor.tensor, ) - } - - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) - } - - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) - } - - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) - } - - fn lower_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) - } - - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs.tensor - .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn sum(tensor: FloatTensor) -> FloatTensor { - let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); - CandleTensor::from_data(Data::new([sum].into(), [1].into()), Self::device(&tensor)) - } - - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) - } - - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap()) - } - - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - CandleTensor::new(tensor.tensor.to_dtype(candle_core::DType::F32).unwrap()) - } - - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) - } - - fn exp(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.exp().unwrap()) - } - - fn log(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.log().unwrap()) - } - - fn log1p(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap()) - } - - fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { - CandleTensor::new(tensor.tensor.powf(value.elem::()).unwrap()) - } - - fn sqrt(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.sqrt().unwrap()) - } - - fn abs(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.abs().unwrap()) - } - - fn cos(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.cos().unwrap()) - } - - fn sin(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.sin().unwrap()) - } - - fn tanh(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.tanh().unwrap()) - } - - fn erf(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.erf().unwrap()) - } - - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - super::base::cat(tensors, dim) - } - - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmax_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmin_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.minimum(max).unwrap()) - } - - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.maximum(min).unwrap()) - } - - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.clamp(min, max).unwrap()) - } - - fn recip(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.recip().unwrap()) - } + .unwrap(), + ) + } + + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) + } + + fn equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) + } + + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) + } + + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) + } + + fn lower_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) + } + + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs + .tensor + .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn sum(tensor: FloatTensor) -> FloatTensor { + let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); + CandleTensor::from_data(Data::new([sum].into(), [1].into()), Self::device(&tensor)) + } + + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) + } + + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap()) + } + + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + CandleTensor::new(tensor.tensor.to_dtype(candle_core::DType::F32).unwrap()) + } + + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) + } + + fn exp(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.exp().unwrap()) + } + + fn log(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.log().unwrap()) + } + + fn log1p(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap()) + } + + fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { + CandleTensor::new(tensor.tensor.powf(value.elem::()).unwrap()) + } + + fn sqrt(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.sqrt().unwrap()) + } + + fn abs(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.abs().unwrap()) + } + + fn cos(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.cos().unwrap()) + } + + fn sin(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.sin().unwrap()) + } + + fn tanh(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.tanh().unwrap()) + } + + fn erf(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.erf().unwrap()) + } + + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + super::base::cat(tensors, dim) + } + + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmax_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmin_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn clamp_max( + tensor: FloatTensor, + max: FloatElem, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.minimum(max).unwrap()) + } + + fn clamp_min( + tensor: FloatTensor, + min: FloatElem, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.maximum(min).unwrap()) + } + + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.clamp(min, max).unwrap()) + } + + fn recip(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.recip().unwrap()) + } } diff --git a/burn-candle/src/tensor.rs b/burn-candle/src/tensor.rs index 5ade8cc366..7f98b8ee3b 100644 --- a/burn-candle/src/tensor.rs +++ b/burn-candle/src/tensor.rs @@ -7,38 +7,38 @@ use crate::{element::CandleElement, CandleDevice}; /// A tensor that uses the candle backend. #[derive(Debug, Clone)] pub struct CandleTensor { - pub(crate) tensor: candle_core::Tensor, - phantom: PhantomData, + pub(crate) tensor: candle_core::Tensor, + phantom: PhantomData, } impl CandleTensor { - /// Create a new tensor. - pub fn new(tensor: candle_core::Tensor) -> Self { - Self { - tensor, - phantom: PhantomData, - } + /// Create a new tensor. + pub fn new(tensor: candle_core::Tensor) -> Self { + Self { + tensor, + phantom: PhantomData, } + } - /// Creates a new tensor from data and a device. - /// - /// # Arguments - /// - /// * `data` - The tensor's data. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// A new tensor. - pub fn from_data(data: Data, device: CandleDevice) -> Self { - let candle_shape: candle_core::Shape = (&data.shape.dims).into(); - let tensor = - candle_core::Tensor::from_slice(data.value.as_slice(), candle_shape, &device.into()); - Self::new(tensor.unwrap()) - } + /// Creates a new tensor from data and a device. + /// + /// # Arguments + /// + /// * `data` - The tensor's data. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// A new tensor. + pub fn from_data(data: Data, device: CandleDevice) -> Self { + let candle_shape: candle_core::Shape = (&data.shape.dims).into(); + let tensor = + candle_core::Tensor::from_slice(data.value.as_slice(), candle_shape, &device.into()); + Self::new(tensor.unwrap()) + } - pub(crate) fn shape(&self) -> Shape { - let x: [usize; D] = self.tensor.dims().try_into().unwrap(); - Shape::from(x) - } + pub(crate) fn shape(&self) -> Shape { + let x: [usize; D] = self.tensor.dims().try_into().unwrap(); + Shape::from(x) + } } diff --git a/burn-common/src/benchmark.rs b/burn-common/src/benchmark.rs index a4abba2f3f..4de028cbf9 100644 --- a/burn-common/src/benchmark.rs +++ b/burn-common/src/benchmark.rs @@ -10,45 +10,45 @@ use std::time::Instant; /// Results of a benchmark run. #[derive(new, Debug)] pub struct BenchmarkResult { - durations: Vec, + durations: Vec, } impl BenchmarkResult { - /// Returns the median duration among all durations - pub fn median_duration(&self) -> Duration { - let mut sorted = self.durations.clone(); - sorted.sort(); - *sorted.get(sorted.len() / 2).unwrap() - } - pub(crate) fn mean_duration(&self) -> Duration { - self.durations.iter().sum::() / self.durations.len() as u32 - } + /// Returns the median duration among all durations + pub fn median_duration(&self) -> Duration { + let mut sorted = self.durations.clone(); + sorted.sort(); + *sorted.get(sorted.len() / 2).unwrap() + } + pub(crate) fn mean_duration(&self) -> Duration { + self.durations.iter().sum::() / self.durations.len() as u32 + } } impl Display for BenchmarkResult { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mean = self.mean_duration(); - let var = self - .durations - .iter() - .map(|duration| { - let tmp = duration.as_secs_f64() - mean.as_secs_f64(); - Duration::from_secs_f64(tmp * tmp) - }) - .sum::() - / self.durations.len() as u32; - - let mut sorted = self.durations.clone(); - sorted.sort(); - - let min = sorted.first().unwrap(); - let max = sorted.last().unwrap(); - let median = sorted.get(sorted.len() / 2).unwrap(); - let num_sample = self.durations.len(); - - f.write_str( - format!( - " + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mean = self.mean_duration(); + let var = self + .durations + .iter() + .map(|duration| { + let tmp = duration.as_secs_f64() - mean.as_secs_f64(); + Duration::from_secs_f64(tmp * tmp) + }) + .sum::() + / self.durations.len() as u32; + + let mut sorted = self.durations.clone(); + sorted.sort(); + + let min = sorted.first().unwrap(); + let max = sorted.last().unwrap(); + let median = sorted.get(sorted.len() / 2).unwrap(); + let num_sample = self.durations.len(); + + f.write_str( + format!( + " ―――――――― Result ――――――――― Samples {num_sample} Mean {mean:.3?} @@ -57,85 +57,85 @@ impl Display for BenchmarkResult { Min {min:.3?} Max {max:.3?} ―――――――――――――――――――――――――" - ) - .as_str(), - ) - } + ) + .as_str(), + ) + } } /// Benchmark trait. pub trait Benchmark { - /// Benchmark arguments. - type Args; - - /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't - /// count as included in the duration. - /// - /// # Notes - /// - /// This should not include warmup, the benchmark will be run at least one time without - /// measuring the execution time. - fn prepare(&self) -> Self::Args; - /// Execute the benchmark and returns the time it took to complete. - fn execute(&self, args: Self::Args); - /// Number of samples required to have a statistical significance. - fn num_samples(&self) -> usize { - 10 - } - /// Name of the benchmark. - fn name(&self) -> String; - /// Wait for computations to be over - fn sync(&self); - /// Run the benchmark a number of times. - fn run(&self) -> BenchmarkResult { - #[cfg(not(feature = "std"))] - panic!("Attempting to run benchmark in a no-std environment"); - - #[cfg(feature = "std")] - { - // Warmup - self.execute(self.prepare()); - self.sync(); - - let mut durations = Vec::with_capacity(self.num_samples()); - - for _ in 0..self.num_samples() { - // Prepare - let args = self.prepare(); - self.sync(); - - // Execute the benchmark - let start = Instant::now(); - self.execute(args); - self.sync(); - let end = Instant::now(); - - // Register the duration - durations.push(end - start); - } - - BenchmarkResult { durations } - } + /// Benchmark arguments. + type Args; + + /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't + /// count as included in the duration. + /// + /// # Notes + /// + /// This should not include warmup, the benchmark will be run at least one time without + /// measuring the execution time. + fn prepare(&self) -> Self::Args; + /// Execute the benchmark and returns the time it took to complete. + fn execute(&self, args: Self::Args); + /// Number of samples required to have a statistical significance. + fn num_samples(&self) -> usize { + 10 + } + /// Name of the benchmark. + fn name(&self) -> String; + /// Wait for computations to be over + fn sync(&self); + /// Run the benchmark a number of times. + fn run(&self) -> BenchmarkResult { + #[cfg(not(feature = "std"))] + panic!("Attempting to run benchmark in a no-std environment"); + + #[cfg(feature = "std")] + { + // Warmup + self.execute(self.prepare()); + self.sync(); + + let mut durations = Vec::with_capacity(self.num_samples()); + + for _ in 0..self.num_samples() { + // Prepare + let args = self.prepare(); + self.sync(); + + // Execute the benchmark + let start = Instant::now(); + self.execute(args); + self.sync(); + let end = Instant::now(); + + // Register the duration + durations.push(end - start); + } + + BenchmarkResult { durations } } + } } #[cfg(feature = "std")] /// Runs the given benchmark on the device and prints result and information. pub fn run_benchmark(benchmark: BM) where - BM: Benchmark, + BM: Benchmark, { - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(); - let output = std::process::Command::new("git") - .args(["rev-porse", "HEAD"]) - .output() - .unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - - println!("Timestamp: {}", timestamp); - println!("Git Hash: {}", str::trim(&git_hash)); - println!("Benchmarking - {}{}", benchmark.name(), benchmark.run()); + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis(); + let output = std::process::Command::new("git") + .args(["rev-porse", "HEAD"]) + .output() + .unwrap(); + let git_hash = String::from_utf8(output.stdout).unwrap(); + + println!("Timestamp: {}", timestamp); + println!("Git Hash: {}", str::trim(&git_hash)); + println!("Benchmarking - {}{}", benchmark.name(), benchmark.run()); } diff --git a/burn-common/src/id.rs b/burn-common/src/id.rs index 25c2161817..45dd4b0c9f 100644 --- a/burn-common/src/id.rs +++ b/burn-common/src/id.rs @@ -6,70 +6,70 @@ use uuid::{Builder, Bytes}; pub struct IdGenerator {} impl IdGenerator { - /// Generates a new ID in the form of a UUID. - pub fn generate() -> String { - let random_bytes: Bytes = gen_random(); + /// Generates a new ID in the form of a UUID. + pub fn generate() -> String { + let random_bytes: Bytes = gen_random(); - let uuid = Builder::from_random_bytes(random_bytes).into_uuid(); + let uuid = Builder::from_random_bytes(random_bytes).into_uuid(); - uuid.as_hyphenated().to_string() - } + uuid.as_hyphenated().to_string() + } } #[cfg(test)] mod tests { - use super::*; - - use alloc::{collections::BTreeSet, string::String}; + use super::*; - #[cfg(feature = "std")] - use dashmap::DashSet; //Concurrent HashMap - #[cfg(feature = "std")] - use std::{sync::Arc, thread}; + use alloc::{collections::BTreeSet, string::String}; - #[test] - fn not_empty_test() { - assert!(!IdGenerator::generate().is_empty()); - } + #[cfg(feature = "std")] + use dashmap::DashSet; //Concurrent HashMap + #[cfg(feature = "std")] + use std::{sync::Arc, thread}; - #[test] - fn uniqueness_test() { - const IDS_CNT: usize = 10_000; + #[test] + fn not_empty_test() { + assert!(!IdGenerator::generate().is_empty()); + } - let mut set: BTreeSet = BTreeSet::new(); + #[test] + fn uniqueness_test() { + const IDS_CNT: usize = 10_000; - for _i in 0..IDS_CNT { - assert!(set.insert(IdGenerator::generate())); - } + let mut set: BTreeSet = BTreeSet::new(); - assert_eq!(set.len(), IDS_CNT); + for _i in 0..IDS_CNT { + assert!(set.insert(IdGenerator::generate())); } - #[cfg(feature = "std")] - #[test] - fn thread_safety_test() { - const NUM_THREADS: usize = 10; - const NUM_REPEATS: usize = 1_000; - const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS; + assert_eq!(set.len(), IDS_CNT); + } - let set: Arc> = Arc::new(DashSet::new()); + #[cfg(feature = "std")] + #[test] + fn thread_safety_test() { + const NUM_THREADS: usize = 10; + const NUM_REPEATS: usize = 1_000; + const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS; - let mut handles = vec![]; + let set: Arc> = Arc::new(DashSet::new()); - for _ in 0..NUM_THREADS { - let set = set.clone(); + let mut handles = vec![]; - let handle = thread::spawn(move || { - for _i in 0..NUM_REPEATS { - assert!(set.insert(IdGenerator::generate())); - } - }); - handles.push(handle); - } + for _ in 0..NUM_THREADS { + let set = set.clone(); - for handle in handles { - handle.join().unwrap(); + let handle = thread::spawn(move || { + for _i in 0..NUM_REPEATS { + assert!(set.insert(IdGenerator::generate())); } - assert_eq!(set.len(), EXPECTED_TOTAL_IDS); + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); } + assert_eq!(set.len(), EXPECTED_TOTAL_IDS); + } } diff --git a/burn-common/src/rand.rs b/burn-common/src/rand.rs index c9198930c3..8e50e48c46 100644 --- a/burn-common/src/rand.rs +++ b/burn-common/src/rand.rs @@ -7,15 +7,15 @@ use rand::prelude::Distribution; #[cfg(feature = "std")] #[inline(always)] pub fn get_seeded_rng() -> StdRng { - StdRng::from_entropy() + StdRng::from_entropy() } /// Returns a seeded random number generator using a pre-generated seed. #[cfg(not(feature = "std"))] #[inline(always)] pub fn get_seeded_rng() -> StdRng { - const CONST_SEED: u64 = 42; - StdRng::seed_from_u64(CONST_SEED) + const CONST_SEED: u64 = 42; + StdRng::seed_from_u64(CONST_SEED) } /// Generates random data from a thread-local RNG. @@ -23,9 +23,9 @@ pub fn get_seeded_rng() -> StdRng { #[inline] pub fn gen_random() -> T where - Standard: Distribution, + Standard: Distribution, { - rand::thread_rng().gen() + rand::thread_rng().gen() } /// Generates random data from a mutex-protected RNG. @@ -33,13 +33,13 @@ where #[inline] pub fn gen_random() -> T where - Standard: Distribution, + Standard: Distribution, { - use crate::stub::Mutex; - static RNG: Mutex> = Mutex::new(None); - let mut rng = RNG.lock().unwrap(); - if rng.is_none() { - *rng = Some(get_seeded_rng()); - } - rng.as_mut().unwrap().gen() + use crate::stub::Mutex; + static RNG: Mutex> = Mutex::new(None); + let mut rng = RNG.lock().unwrap(); + if rng.is_none() { + *rng = Some(get_seeded_rng()); + } + rng.as_mut().unwrap().gen() } diff --git a/burn-common/src/reader.rs b/burn-common/src/reader.rs index 91f4492c1a..7a9d0b7af7 100644 --- a/burn-common/src/reader.rs +++ b/burn-common/src/reader.rs @@ -5,111 +5,111 @@ use core::marker::PhantomData; #[async_trait::async_trait] /// Allows to create async reader. pub trait AsyncReader: Send { - /// Read asynchronously. - async fn read(self: Box) -> T; + /// Read asynchronously. + async fn read(self: Box) -> T; } /// Define how data is read, sync or async. pub enum Reader { - /// Concrete variant. - Concrete(T), - /// Sync data variant. - Sync(Box>), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Async data variant. - Async(Box>), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Future data variant. - Future(core::pin::Pin + Send>>), + /// Concrete variant. + Concrete(T), + /// Sync data variant. + Sync(Box>), + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Async data variant. + Async(Box>), + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Future data variant. + Future(core::pin::Pin + Send>>), } /// Allows to create sync reader. pub trait SyncReader: Send { - /// Read synchronously. - fn read(self: Box) -> T; + /// Read synchronously. + fn read(self: Box) -> T; } #[derive(new)] struct MappedReader { - reader: Reader, - mapper: F, - _output: PhantomData, + reader: Reader, + mapper: F, + _output: PhantomData, } impl SyncReader for MappedReader where - I: Send, - O: Send, - F: Send + FnOnce(I) -> O, + I: Send, + O: Send, + F: Send + FnOnce(I) -> O, { - fn read(self: Box) -> O { - let input = self - .reader - .read_sync() - .expect("Only sync data supported in a sync reader."); + fn read(self: Box) -> O { + let input = self + .reader + .read_sync() + .expect("Only sync data supported in a sync reader."); - (self.mapper)(input) - } + (self.mapper)(input) + } } #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] #[async_trait::async_trait] impl AsyncReader for MappedReader where - I: Send, - O: Send, - F: Send + FnOnce(I) -> O, + I: Send, + O: Send, + F: Send + FnOnce(I) -> O, { - async fn read(self: Box) -> O { - let input = self.reader.read().await; - (self.mapper)(input) - } + async fn read(self: Box) -> O { + let input = self.reader.read().await; + (self.mapper)(input) + } } impl Reader { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Read the data. - pub async fn read(self) -> T { - match self { - Self::Concrete(data) => data, - Self::Sync(reader) => reader.read(), - Self::Async(func) => func.read().await, - Self::Future(future) => future.await, - } + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Read the data. + pub async fn read(self) -> T { + match self { + Self::Concrete(data) => data, + Self::Sync(reader) => reader.read(), + Self::Async(func) => func.read().await, + Self::Future(future) => future.await, } + } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Read the data. - pub fn read(self) -> T { - match self { - Self::Concrete(data) => data, - Self::Sync(reader) => reader.read(), - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Read the data. + pub fn read(self) -> T { + match self { + Self::Concrete(data) => data, + Self::Sync(reader) => reader.read(), } + } - /// Read the data only if sync, returns None if an async reader. - pub fn read_sync(self) -> Option { - match self { - Self::Concrete(data) => Some(data), - Self::Sync(reader) => Some(reader.read()), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - Self::Async(_func) => return None, - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - Self::Future(_future) => return None, - } + /// Read the data only if sync, returns None if an async reader. + pub fn read_sync(self) -> Option { + match self { + Self::Concrete(data) => Some(data), + Self::Sync(reader) => Some(reader.read()), + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + Self::Async(_func) => return None, + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + Self::Future(_future) => return None, } + } - /// Map the current reader to another type. - pub fn map O>(self, mapper: F) -> Reader - where - T: 'static + Send, - O: 'static + Send, - F: 'static + Send, - { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - return Reader::Async(Box::new(MappedReader::new(self, mapper))); + /// Map the current reader to another type. + pub fn map O>(self, mapper: F) -> Reader + where + T: 'static + Send, + O: 'static + Send, + F: 'static + Send, + { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + return Reader::Async(Box::new(MappedReader::new(self, mapper))); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - Reader::Sync(Box::new(MappedReader::new(self, mapper))) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + Reader::Sync(Box::new(MappedReader::new(self, mapper))) + } } diff --git a/burn-common/src/stub.rs b/burn-common/src/stub.rs index 93d0715f5f..ecec818581 100644 --- a/burn-common/src/stub.rs +++ b/burn-common/src/stub.rs @@ -1,5 +1,5 @@ use spin::{ - Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard, + Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard, }; /// A mutual exclusion primitive useful for protecting shared data @@ -10,23 +10,23 @@ use spin::{ /// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap #[derive(Debug)] pub struct Mutex { - inner: MutexImported, + inner: MutexImported, } impl Mutex { - /// Creates a new mutex in an unlocked state ready for use. - #[inline(always)] - pub const fn new(value: T) -> Self { - Self { - inner: MutexImported::new(value), - } + /// Creates a new mutex in an unlocked state ready for use. + #[inline(always)] + pub const fn new(value: T) -> Self { + Self { + inner: MutexImported::new(value), } + } - /// Locks the mutex blocking the current thread until it is able to do so. - #[inline(always)] - pub fn lock(&self) -> Result, alloc::string::String> { - Ok(self.inner.lock()) - } + /// Locks the mutex blocking the current thread until it is able to do so. + #[inline(always)] + pub fn lock(&self) -> Result, alloc::string::String> { + Ok(self.inner.lock()) + } } /// A reader-writer lock which is exclusively locked for writing or shared for reading. @@ -35,31 +35,31 @@ impl Mutex { /// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap #[derive(Debug)] pub struct RwLock { - inner: RwLockImported, + inner: RwLockImported, } impl RwLock { - /// Creates a new reader-writer lock in an unlocked state ready for use. - #[inline(always)] - pub const fn new(value: T) -> Self { - Self { - inner: RwLockImported::new(value), - } + /// Creates a new reader-writer lock in an unlocked state ready for use. + #[inline(always)] + pub const fn new(value: T) -> Self { + Self { + inner: RwLockImported::new(value), } + } - /// Locks this rwlock with shared read access, blocking the current thread - /// until it can be acquired. - #[inline(always)] - pub fn read(&self) -> Result, alloc::string::String> { - Ok(self.inner.read()) - } + /// Locks this rwlock with shared read access, blocking the current thread + /// until it can be acquired. + #[inline(always)] + pub fn read(&self) -> Result, alloc::string::String> { + Ok(self.inner.read()) + } - /// Locks this rwlock with exclusive write access, blocking the current thread - /// until it can be acquired. - #[inline(always)] - pub fn write(&self) -> Result, alloc::string::String> { - Ok(self.inner.write()) - } + /// Locks this rwlock with exclusive write access, blocking the current thread + /// until it can be acquired. + #[inline(always)] + pub fn write(&self) -> Result, alloc::string::String> { + Ok(self.inner.write()) + } } /// A unique identifier for a running thread. diff --git a/burn-compute/src/channel/base.rs b/burn-compute/src/channel/base.rs index 9b2c7e3db5..78259fdadb 100644 --- a/burn-compute/src/channel/base.rs +++ b/burn-compute/src/channel/base.rs @@ -5,18 +5,18 @@ use burn_common::reader::Reader; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety pub trait ComputeChannel: Clone + core::fmt::Debug { - /// Given a handle, returns owned resource as bytes - fn read(&self, handle: &Handle) -> Reader>; + /// Given a handle, returns owned resource as bytes + fn read(&self, handle: &Handle) -> Reader>; - /// Given a resource as bytes, stores it and returns the resource handle - fn create(&self, data: &[u8]) -> Handle; + /// Given a resource as bytes, stores it and returns the resource handle + fn create(&self, data: &[u8]) -> Handle; - /// Reserves `size` bytes in the storage, and returns a handle over them - fn empty(&self, size: usize) -> Handle; + /// Reserves `size` bytes in the storage, and returns a handle over them + fn empty(&self, size: usize) -> Handle; - /// Executes the `kernel` over the given `handles`. - fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]); + /// Executes the `kernel` over the given `handles`. + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]); - /// Wait for the completion of every task in the server. - fn sync(&self); + /// Wait for the completion of every task in the server. + fn sync(&self); } diff --git a/burn-compute/src/channel/cell.rs b/burn-compute/src/channel/cell.rs index 002b237271..cb110e4d77 100644 --- a/burn-compute/src/channel/cell.rs +++ b/burn-compute/src/channel/cell.rs @@ -15,52 +15,53 @@ use burn_common::reader::Reader; /// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels. #[derive(Debug)] pub struct RefCellComputeChannel { - server: Arc>, + server: Arc>, } impl Clone for RefCellComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - } + fn clone(&self) -> Self { + Self { + server: self.server.clone(), } + } } impl RefCellComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - /// Create a new cell compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(core::cell::RefCell::new(server)), - } + /// Create a new cell compute channel. + pub fn new(server: Server) -> Self { + Self { + server: Arc::new(core::cell::RefCell::new(server)), } + } } impl ComputeChannel for RefCellComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - fn read(&self, handle: &Handle) -> Reader> { - self.server.borrow_mut().read(handle) - } + fn read(&self, handle: &Handle) -> Reader> { + self.server.borrow_mut().read(handle) + } - fn create(&self, resource: &[u8]) -> Handle { - self.server.borrow_mut().create(resource) - } + fn create(&self, resource: &[u8]) -> Handle { + self.server.borrow_mut().create(resource) + } - fn empty(&self, size: usize) -> Handle { - self.server.borrow_mut().empty(size) - } + fn empty(&self, size: usize) -> Handle { + self.server.borrow_mut().empty(size) + } - fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle]) { - self.server - .borrow_mut() - .execute(kernel_description, handles) - } + fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle]) { + self + .server + .borrow_mut() + .execute(kernel_description, handles) + } - fn sync(&self) { - self.server.borrow_mut().sync() - } + fn sync(&self) { + self.server.borrow_mut().sync() + } } diff --git a/burn-compute/src/channel/mpsc.rs b/burn-compute/src/channel/mpsc.rs index e0f07ccca8..8bd6bc576c 100644 --- a/burn-compute/src/channel/mpsc.rs +++ b/burn-compute/src/channel/mpsc.rs @@ -1,6 +1,6 @@ use std::{ - sync::{mpsc, Arc}, - thread, + sync::{mpsc, Arc}, + thread, }; use burn_common::reader::Reader; @@ -13,146 +13,150 @@ use crate::server::{ComputeServer, Handle}; #[derive(Debug)] pub struct MpscComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - state: Arc>, + state: Arc>, } #[derive(Debug)] struct MpscComputeChannelState where - Server: ComputeServer, + Server: ComputeServer, { - _handle: thread::JoinHandle<()>, - sender: mpsc::SyncSender>, + _handle: thread::JoinHandle<()>, + sender: mpsc::SyncSender>, } type Callback = mpsc::SyncSender; enum Message where - Server: ComputeServer, + Server: ComputeServer, { - Read(Handle, Callback>>), - Create(Vec, Callback>), - Empty(usize, Callback>), - ExecuteKernel(Server::Kernel, Vec>), - Sync(Callback<()>), + Read(Handle, Callback>>), + Create(Vec, Callback>), + Empty(usize, Callback>), + ExecuteKernel(Server::Kernel, Vec>), + Sync(Callback<()>), } impl MpscComputeChannel where - Server: ComputeServer + 'static, + Server: ComputeServer + 'static, { - /// Create a new mpsc compute channel. - pub fn new(mut server: Server, bound: usize) -> Self { - let (sender, receiver) = mpsc::sync_channel(bound); - - let _handle = thread::spawn(move || { - while let Ok(message) = receiver.recv() { - match message { - Message::Read(handle, callback) => { - let data = server.read(&handle); - core::mem::drop(handle); - callback.send(data).unwrap(); - } - Message::Create(data, callback) => { - let handle = server.create(&data); - callback.send(handle).unwrap(); - } - Message::Empty(size, callback) => { - let handle = server.empty(size); - callback.send(handle).unwrap(); - } - Message::ExecuteKernel(kernel, handles) => { - server.execute(kernel, &handles.iter().collect::>()); - } - Message::Sync(callback) => { - server.sync(); - callback.send(()).unwrap(); - } - }; - } - }); - - let state = Arc::new(MpscComputeChannelState { sender, _handle }); - - Self { state } - } + /// Create a new mpsc compute channel. + pub fn new(mut server: Server, bound: usize) -> Self { + let (sender, receiver) = mpsc::sync_channel(bound); + + let _handle = thread::spawn(move || { + while let Ok(message) = receiver.recv() { + match message { + Message::Read(handle, callback) => { + let data = server.read(&handle); + core::mem::drop(handle); + callback.send(data).unwrap(); + } + Message::Create(data, callback) => { + let handle = server.create(&data); + callback.send(handle).unwrap(); + } + Message::Empty(size, callback) => { + let handle = server.empty(size); + callback.send(handle).unwrap(); + } + Message::ExecuteKernel(kernel, handles) => { + server.execute(kernel, &handles.iter().collect::>()); + } + Message::Sync(callback) => { + server.sync(); + callback.send(()).unwrap(); + } + }; + } + }); + + let state = Arc::new(MpscComputeChannelState { sender, _handle }); + + Self { state } + } } impl Clone for MpscComputeChannel { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } + fn clone(&self) -> Self { + Self { + state: self.state.clone(), } + } } impl ComputeChannel for MpscComputeChannel where - Server: ComputeServer + 'static, + Server: ComputeServer + 'static, { - fn read(&self, handle: &Handle) -> Reader> { - let (callback, response) = mpsc::sync_channel(1); - - self.state - .sender - .send(Message::Read(handle.clone(), callback)) - .unwrap(); - - self.response(response) - } - - fn create(&self, data: &[u8]) -> Handle { - let (callback, response) = mpsc::sync_channel(1); - - self.state - .sender - .send(Message::Create(data.to_vec(), callback)) - .unwrap(); - - self.response(response) - } - - fn empty(&self, size: usize) -> Handle { - let (callback, response) = mpsc::sync_channel(1); - - self.state - .sender - .send(Message::Empty(size, callback)) - .unwrap(); - - self.response(response) - } - - fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { - self.state - .sender - .send(Message::ExecuteKernel( - kernel, - handles - .iter() - .map(|h| (*h).clone()) - .collect::>>(), - )) - .unwrap() - } - - fn sync(&self) { - let (callback, response) = mpsc::sync_channel(1); - - self.state.sender.send(Message::Sync(callback)).unwrap(); - - self.response(response) - } + fn read(&self, handle: &Handle) -> Reader> { + let (callback, response) = mpsc::sync_channel(1); + + self + .state + .sender + .send(Message::Read(handle.clone(), callback)) + .unwrap(); + + self.response(response) + } + + fn create(&self, data: &[u8]) -> Handle { + let (callback, response) = mpsc::sync_channel(1); + + self + .state + .sender + .send(Message::Create(data.to_vec(), callback)) + .unwrap(); + + self.response(response) + } + + fn empty(&self, size: usize) -> Handle { + let (callback, response) = mpsc::sync_channel(1); + + self + .state + .sender + .send(Message::Empty(size, callback)) + .unwrap(); + + self.response(response) + } + + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self + .state + .sender + .send(Message::ExecuteKernel( + kernel, + handles + .iter() + .map(|h| (*h).clone()) + .collect::>>(), + )) + .unwrap() + } + + fn sync(&self) { + let (callback, response) = mpsc::sync_channel(1); + + self.state.sender.send(Message::Sync(callback)).unwrap(); + + self.response(response) + } } impl MpscComputeChannel { - fn response(&self, response: mpsc::Receiver) -> Response { - match response.recv() { - Ok(val) => val, - Err(err) => panic!("Can't connect to the server correctly {err:?}"), - } + fn response(&self, response: mpsc::Receiver) -> Response { + match response.recv() { + Ok(val) => val, + Err(err) => panic!("Can't connect to the server correctly {err:?}"), } + } } diff --git a/burn-compute/src/channel/mutex.rs b/burn-compute/src/channel/mutex.rs index 140b850eb0..731d7e1fb0 100644 --- a/burn-compute/src/channel/mutex.rs +++ b/burn-compute/src/channel/mutex.rs @@ -9,49 +9,49 @@ use spin::Mutex; /// on every operation #[derive(Debug)] pub struct MutexComputeChannel { - server: Arc>, + server: Arc>, } impl Clone for MutexComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - } + fn clone(&self) -> Self { + Self { + server: self.server.clone(), } + } } impl MutexComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - /// Create a new mutex compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(Mutex::new(server)), - } + /// Create a new mutex compute channel. + pub fn new(server: Server) -> Self { + Self { + server: Arc::new(Mutex::new(server)), } + } } impl ComputeChannel for MutexComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - fn read(&self, handle: &Handle) -> Reader> { - self.server.lock().read(handle) - } + fn read(&self, handle: &Handle) -> Reader> { + self.server.lock().read(handle) + } - fn create(&self, data: &[u8]) -> Handle { - self.server.lock().create(data) - } + fn create(&self, data: &[u8]) -> Handle { + self.server.lock().create(data) + } - fn empty(&self, size: usize) -> Handle { - self.server.lock().empty(size) - } + fn empty(&self, size: usize) -> Handle { + self.server.lock().empty(size) + } - fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { - self.server.lock().execute(kernel, handles) - } + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.server.lock().execute(kernel, handles) + } - fn sync(&self) { - self.server.lock().sync() - } + fn sync(&self) { + self.server.lock().sync() + } } diff --git a/burn-compute/src/client.rs b/burn-compute/src/client.rs index 3832652aeb..8c989e4ce4 100644 --- a/burn-compute/src/client.rs +++ b/burn-compute/src/client.rs @@ -1,7 +1,7 @@ use crate::{ - channel::ComputeChannel, - server::{ComputeServer, Handle}, - tune::{AutotuneOperationSet, Tuner}, + channel::ComputeChannel, + server::{ComputeServer, Handle}, + tune::{AutotuneOperationSet, Tuner}, }; use alloc::vec::Vec; use alloc::{boxed::Box, sync::Arc}; @@ -13,71 +13,72 @@ use spin::Mutex; /// It should be obtained for a specific device via the Compute struct. #[derive(Debug)] pub struct ComputeClient { - channel: Channel, - tuner: Arc>>, - _server: PhantomData, + channel: Channel, + tuner: Arc>>, + _server: PhantomData, } impl Clone for ComputeClient where - S: ComputeServer, - C: ComputeChannel, + S: ComputeServer, + C: ComputeChannel, { - fn clone(&self) -> Self { - Self { - channel: self.channel.clone(), - tuner: self.tuner.clone(), - _server: PhantomData, - } + fn clone(&self) -> Self { + Self { + channel: self.channel.clone(), + tuner: self.tuner.clone(), + _server: PhantomData, } + } } impl ComputeClient where - Server: ComputeServer, - Channel: ComputeChannel, + Server: ComputeServer, + Channel: ComputeChannel, { - /// Create a new client. - pub fn new(channel: Channel, tuner: Arc>>) -> Self { - Self { - channel, - tuner, - _server: PhantomData, - } + /// Create a new client. + pub fn new(channel: Channel, tuner: Arc>>) -> Self { + Self { + channel, + tuner, + _server: PhantomData, } + } - /// Given a handle, returns owned resource as bytes. - pub fn read(&self, handle: &Handle) -> Reader> { - self.channel.read(handle) - } + /// Given a handle, returns owned resource as bytes. + pub fn read(&self, handle: &Handle) -> Reader> { + self.channel.read(handle) + } - /// Given a resource, stores it and returns the resource handle. - pub fn create(&self, data: &[u8]) -> Handle { - self.channel.create(data) - } + /// Given a resource, stores it and returns the resource handle. + pub fn create(&self, data: &[u8]) -> Handle { + self.channel.create(data) + } - /// Reserves `size` bytes in the storage, and returns a handle over them. - pub fn empty(&self, size: usize) -> Handle { - self.channel.empty(size) - } + /// Reserves `size` bytes in the storage, and returns a handle over them. + pub fn empty(&self, size: usize) -> Handle { + self.channel.empty(size) + } - /// Executes the `kernel` over the given `handles`. - pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { - self.channel.execute(kernel, handles) - } + /// Executes the `kernel` over the given `handles`. + pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.channel.execute(kernel, handles) + } - /// Wait for the completion of every task in the server. - pub fn sync(&self) { - self.channel.sync() - } + /// Wait for the completion of every task in the server. + pub fn sync(&self) { + self.channel.sync() + } - /// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks - pub fn execute_autotune( - &self, - autotune_operation_set: Box>, - ) { - self.tuner - .lock() - .execute_autotune(autotune_operation_set, self); - } + /// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks + pub fn execute_autotune( + &self, + autotune_operation_set: Box>, + ) { + self + .tuner + .lock() + .execute_autotune(autotune_operation_set, self); + } } diff --git a/burn-compute/src/compute.rs b/burn-compute/src/compute.rs index 33f5d34337..a9dd96cbff 100644 --- a/burn-compute/src/compute.rs +++ b/burn-compute/src/compute.rs @@ -5,79 +5,79 @@ use hashbrown::HashMap; /// The compute type has the responsibility to retrieve the correct compute client based on the /// given device. pub struct Compute { - clients: spin::Mutex>>>, + clients: spin::Mutex>>>, } impl Compute where - Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, - Server: ComputeServer, - Channel: ComputeChannel, + Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, + Server: ComputeServer, + Channel: ComputeChannel, { - /// Create a new compute. - pub const fn new() -> Self { - Self { - clients: spin::Mutex::new(None), - } + /// Create a new compute. + pub const fn new() -> Self { + Self { + clients: spin::Mutex::new(None), } + } - /// Get the compute client for the given device. - /// - /// Provide the init function to create a new client if it isn't already initialized. - pub fn client(&self, device: &Device, init: Init) -> ComputeClient - where - Init: Fn() -> ComputeClient, - { - let mut clients = self.clients.lock(); + /// Get the compute client for the given device. + /// + /// Provide the init function to create a new client if it isn't already initialized. + pub fn client(&self, device: &Device, init: Init) -> ComputeClient + where + Init: Fn() -> ComputeClient, + { + let mut clients = self.clients.lock(); - if clients.is_none() { - Self::register_inner(device, init(), &mut clients); - } + if clients.is_none() { + Self::register_inner(device, init(), &mut clients); + } - match clients.deref_mut() { - Some(clients) => match clients.get(device) { - Some(client) => client.clone(), - None => { - let client = init(); - clients.insert(device.clone(), client.clone()); - client - } - }, - _ => unreachable!(), + match clients.deref_mut() { + Some(clients) => match clients.get(device) { + Some(client) => client.clone(), + None => { + let client = init(); + clients.insert(device.clone(), client.clone()); + client } + }, + _ => unreachable!(), } + } - /// Register the compute client for the given device. - /// - /// # Note - /// - /// This function is mostly useful when the creation of the compute client can't be done - /// synchronously and require special context. - /// - /// # Panics - /// - /// If a client is already registered for the given device. - pub fn register(&self, device: &Device, client: ComputeClient) { - let mut clients = self.clients.lock(); + /// Register the compute client for the given device. + /// + /// # Note + /// + /// This function is mostly useful when the creation of the compute client can't be done + /// synchronously and require special context. + /// + /// # Panics + /// + /// If a client is already registered for the given device. + pub fn register(&self, device: &Device, client: ComputeClient) { + let mut clients = self.clients.lock(); - Self::register_inner(device, client, &mut clients); - } + Self::register_inner(device, client, &mut clients); + } - fn register_inner( - device: &Device, - client: ComputeClient, - clients: &mut Option>>, - ) { - if clients.is_none() { - *clients = Some(HashMap::new()); - } + fn register_inner( + device: &Device, + client: ComputeClient, + clients: &mut Option>>, + ) { + if clients.is_none() { + *clients = Some(HashMap::new()); + } - if let Some(clients) = clients { - if clients.contains_key(device) { - panic!("Client already created for device {:?}", device); - } + if let Some(clients) = clients { + if clients.contains_key(device) { + panic!("Client already created for device {:?}", device); + } - clients.insert(device.clone(), client); - } + clients.insert(device.clone(), client); } + } } diff --git a/burn-compute/src/id.rs b/burn-compute/src/id.rs index 33ba53c044..90fb4e28c8 100644 --- a/burn-compute/src/id.rs +++ b/burn-compute/src/id.rs @@ -1,53 +1,53 @@ #[macro_export(local_inner_macros)] /// Create a new storage ID type. macro_rules! storage_id_type { - ($name:ident) => { - #[derive(Clone, Hash, PartialEq, Eq)] - /// Storage ID. - pub struct $name { - id: alloc::sync::Arc, - } + ($name:ident) => { + #[derive(Clone, Hash, PartialEq, Eq)] + /// Storage ID. + pub struct $name { + id: alloc::sync::Arc, + } - impl $name { - /// Create a new ID. - pub fn new() -> Self { - Self { - id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), - } - } + impl $name { + /// Create a new ID. + pub fn new() -> Self { + Self { + id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), } + } + } - impl Default for $name { - fn default() -> Self { - Self::new() - } - } - }; + impl Default for $name { + fn default() -> Self { + Self::new() + } + } + }; } #[macro_export(local_inner_macros)] /// Create a new memory ID type. macro_rules! memory_id_type { - ($name:ident) => { - #[derive(Clone, Hash, PartialEq, Eq, Debug)] - /// Memory ID. - pub struct $name { - id: alloc::sync::Arc, - } + ($name:ident) => { + #[derive(Clone, Hash, PartialEq, Eq, Debug)] + /// Memory ID. + pub struct $name { + id: alloc::sync::Arc, + } - impl $name { - /// Create a new ID. - pub(crate) fn new() -> Self { - Self { - id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), - } - } + impl $name { + /// Create a new ID. + pub(crate) fn new() -> Self { + Self { + id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), } + } + } - impl Default for $name { - fn default() -> Self { - Self::new() - } - } - }; + impl Default for $name { + fn default() -> Self { + Self::new() + } + } + }; } diff --git a/burn-compute/src/memory_management/base.rs b/burn-compute/src/memory_management/base.rs index 4a6310cf9e..1be75be6cc 100644 --- a/burn-compute/src/memory_management/base.rs +++ b/burn-compute/src/memory_management/base.rs @@ -6,8 +6,8 @@ use crate::storage::ComputeStorage; /// It is responsible for determining if the memory segment can be mutated, /// for instance by keeping track of a reference count pub trait MemoryHandle: Clone + Send + core::fmt::Debug { - /// Checks if the underlying memory can be safely mutated. - fn can_mut(&self) -> bool; + /// Checks if the underlying memory can be safely mutated. + fn can_mut(&self) -> bool; } /// The MemoryManagement trait encapsulates strategies for (de)allocating memory. @@ -16,38 +16,38 @@ pub trait MemoryHandle: Clone + Send + core::fmt::Debug { /// The MemoryManagement can only reserve memory space or get the resource located at a space. /// Modification of the resource data should be done directly on the resource. pub trait MemoryManagement: Send + core::fmt::Debug { - /// The associated type Handle must implement MemoryHandle - type Handle: MemoryHandle; + /// The associated type Handle must implement MemoryHandle + type Handle: MemoryHandle; - /// Returns the resource from the storage at the specified handle - fn get(&mut self, handle: &Self::Handle) -> Storage::Resource; + /// Returns the resource from the storage at the specified handle + fn get(&mut self, handle: &Self::Handle) -> Storage::Resource; - /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it - fn reserve(&mut self, size: usize) -> Self::Handle; + /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it + fn reserve(&mut self, size: usize) -> Self::Handle; - /// Bypass the memory allocation algorithm to allocate data directly. - /// - /// # Notes - /// - /// Can be useful for servers that want specific control over memory. - fn alloc(&mut self, size: usize) -> Self::Handle; + /// Bypass the memory allocation algorithm to allocate data directly. + /// + /// # Notes + /// + /// Can be useful for servers that want specific control over memory. + fn alloc(&mut self, size: usize) -> Self::Handle; - /// Bypass the memory allocation algorithm to deallocate data directly. - /// - /// # Notes - /// - /// Can be useful for servers that want specific control over memory. - fn dealloc(&mut self, handle: &Self::Handle); + /// Bypass the memory allocation algorithm to deallocate data directly. + /// + /// # Notes + /// + /// Can be useful for servers that want specific control over memory. + fn dealloc(&mut self, handle: &Self::Handle); - /// Fetch the storage used by the memory manager. - /// - /// # Notes - /// - /// The storage should probably not be used for allocations since the handles won't be - /// compatible with the ones provided by the current trait. Prefer using the - /// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions. - /// - /// This is useful if you need to time the deallocations based on async computation, or to - /// change the mode of storage for different reasons. - fn storage(&mut self) -> &mut Storage; + /// Fetch the storage used by the memory manager. + /// + /// # Notes + /// + /// The storage should probably not be used for allocations since the handles won't be + /// compatible with the ones provided by the current trait. Prefer using the + /// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions. + /// + /// This is useful if you need to time the deallocations based on async computation, or to + /// change the mode of storage for different reasons. + fn storage(&mut self) -> &mut Storage; } diff --git a/burn-compute/src/memory_management/simple.rs b/burn-compute/src/memory_management/simple.rs index e6bb4fb37d..1605ee0839 100644 --- a/burn-compute/src/memory_management/simple.rs +++ b/burn-compute/src/memory_management/simple.rs @@ -1,7 +1,7 @@ use super::{MemoryHandle, MemoryManagement}; use crate::{ - memory_id_type, - storage::{ComputeStorage, StorageHandle, StorageUtilization}, + memory_id_type, + storage::{ComputeStorage, StorageHandle, StorageUtilization}, }; use alloc::{sync::Arc, vec::Vec}; use hashbrown::HashMap; @@ -12,451 +12,451 @@ memory_id_type!(ChunkId); memory_id_type!(SliceId); impl ChunkId { - /// A chunk is free if it is only referred by the chunk hashmap. - fn is_free(&self) -> bool { - Arc::strong_count(&self.id) <= 1 - } + /// A chunk is free if it is only referred by the chunk hashmap. + fn is_free(&self) -> bool { + Arc::strong_count(&self.id) <= 1 + } } impl SliceId { - /// A slice is free if it is only referred by the slice hashmap and the chunk it is in. - fn is_free(&self) -> bool { - Arc::strong_count(&self.id) <= 2 - } + /// A slice is free if it is only referred by the slice hashmap and the chunk it is in. + fn is_free(&self) -> bool { + Arc::strong_count(&self.id) <= 2 + } } /// The SimpleHandle is a memory handle, referring to either a chunk or a slice. #[derive(Debug, Clone)] pub enum SimpleHandle { - /// A whole chunk of memory. - Chunk(ChunkId), - /// A slice of a chunk of memory. - Slice(SliceId), + /// A whole chunk of memory. + Chunk(ChunkId), + /// A slice of a chunk of memory. + Slice(SliceId), } /// The strategy defines the frequency at which deallocation of unused memory chunks should occur. #[derive(Debug)] pub enum DeallocStrategy { - /// Once every n calls to reserve. - PeriodTick { - /// Number of calls to be executed before triggering the deallocation. - period: usize, - /// Current state. Should start at zero. - state: usize, - }, - #[cfg(feature = "std")] - /// Once every period of time - PeriodTime { - /// Number of time before triggering the deallocation. - period: std::time::Duration, - /// Current state. Should start at now. - state: std::time::Instant, - }, - /// Never deallocate. - Never, + /// Once every n calls to reserve. + PeriodTick { + /// Number of calls to be executed before triggering the deallocation. + period: usize, + /// Current state. Should start at zero. + state: usize, + }, + #[cfg(feature = "std")] + /// Once every period of time + PeriodTime { + /// Number of time before triggering the deallocation. + period: std::time::Duration, + /// Current state. Should start at now. + state: std::time::Instant, + }, + /// Never deallocate. + Never, } /// The strategy defines when to reuse chunk with slices. #[derive(Debug)] pub enum SliceStrategy { - /// Never use slices. - Never, - /// Ratio needed before the chunk can be used as a slice. Between 0 and 1. - Ratio(f32), - /// When the reserved memory is at least {} bytes. - MinimumSize(usize), - /// When the reserved memory less than {} bytes. - MaximumSize(usize), + /// Never use slices. + Never, + /// Ratio needed before the chunk can be used as a slice. Between 0 and 1. + Ratio(f32), + /// When the reserved memory is at least {} bytes. + MinimumSize(usize), + /// When the reserved memory less than {} bytes. + MaximumSize(usize), } impl SliceStrategy { - /// If the chunk can be used with a slice. - pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool { - if chunk_size < reserved_size { - return false; - } + /// If the chunk can be used with a slice. + pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool { + if chunk_size < reserved_size { + return false; + } - match self { - SliceStrategy::Never => false, - SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio, - SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes, - SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes, - } + match self { + SliceStrategy::Never => false, + SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio, + SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes, + SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes, } + } } impl DeallocStrategy { - /// Create a new strategy with the given period. - pub fn new_period_tick(period: usize) -> Self { - DeallocStrategy::PeriodTick { period, state: 0 } - } - - fn should_dealloc(&mut self) -> bool { - match self { - DeallocStrategy::PeriodTick { period, state } => { - *state = (*state + 1) % *period; - *state == 0 - } - #[cfg(feature = "std")] - DeallocStrategy::PeriodTime { period, state } => { - if &state.elapsed() > period { - *state = std::time::Instant::now(); - true - } else { - false - } - } - DeallocStrategy::Never => false, + /// Create a new strategy with the given period. + pub fn new_period_tick(period: usize) -> Self { + DeallocStrategy::PeriodTick { period, state: 0 } + } + + fn should_dealloc(&mut self) -> bool { + match self { + DeallocStrategy::PeriodTick { period, state } => { + *state = (*state + 1) % *period; + *state == 0 + } + #[cfg(feature = "std")] + DeallocStrategy::PeriodTime { period, state } => { + if &state.elapsed() > period { + *state = std::time::Instant::now(); + true + } else { + false } + } + DeallocStrategy::Never => false, } + } } /// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. pub struct SimpleMemoryManagement { - chunks: HashMap)>, - slices: HashMap, - dealloc_strategy: DeallocStrategy, - slice_strategy: SliceStrategy, - storage: Storage, + chunks: HashMap)>, + slices: HashMap, + dealloc_strategy: DeallocStrategy, + slice_strategy: SliceStrategy, + storage: Storage, } impl core::fmt::Debug for SimpleMemoryManagement { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - alloc::format!( - "SimpleMemoryManagement {:?} - {:?}", - self.dealloc_strategy, - core::any::type_name::(), - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + alloc::format!( + "SimpleMemoryManagement {:?} - {:?}", + self.dealloc_strategy, + core::any::type_name::(), + ) + .as_str(), + ) + } } impl MemoryHandle for SimpleHandle { - /// Returns true if referenced by only one tensor, and only once by the - /// memory management hashmaps - fn can_mut(&self) -> bool { - // One reference in the chunk hashmap, another owned by one tensor. - const REFERENCE_LIMIT_CHUNK: usize = 2; - // One reference in the chunk hashmap (for the chunk on which this slice is built), - // another in the slice hashmap for this slice, and another owned by one tensor. - const REFERENCE_LIMIT_SLICE: usize = 3; - - match &self { - SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK, - SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE, - } + /// Returns true if referenced by only one tensor, and only once by the + /// memory management hashmaps + fn can_mut(&self) -> bool { + // One reference in the chunk hashmap, another owned by one tensor. + const REFERENCE_LIMIT_CHUNK: usize = 2; + // One reference in the chunk hashmap (for the chunk on which this slice is built), + // another in the slice hashmap for this slice, and another owned by one tensor. + const REFERENCE_LIMIT_SLICE: usize = 3; + + match &self { + SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK, + SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE, } + } } impl MemoryManagement for SimpleMemoryManagement { - type Handle = SimpleHandle; + type Handle = SimpleHandle; - /// Returns the resource from the storage, for the specified handle. - fn get(&mut self, handle: &Self::Handle) -> Storage::Resource { - let resource = match &handle { - SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0, - SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0, - }; + /// Returns the resource from the storage, for the specified handle. + fn get(&mut self, handle: &Self::Handle) -> Storage::Resource { + let resource = match &handle { + SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0, + SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0, + }; - self.storage.get(resource) - } + self.storage.get(resource) + } - /// Reserves memory of specified size using the reserve algorithm, and return - /// a handle to the reserved memory. - /// - /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy. - fn reserve(&mut self, size: usize) -> Self::Handle { - self.cleanup_slices(); + /// Reserves memory of specified size using the reserve algorithm, and return + /// a handle to the reserved memory. + /// + /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy. + fn reserve(&mut self, size: usize) -> Self::Handle { + self.cleanup_slices(); - let handle = self.reserve_algorithm(size); + let handle = self.reserve_algorithm(size); - if self.dealloc_strategy.should_dealloc() { - self.cleanup_chunks(); - } - - handle + if self.dealloc_strategy.should_dealloc() { + self.cleanup_chunks(); } - fn alloc(&mut self, size: usize) -> Self::Handle { - self.create_chunk(size) - } + handle + } - fn dealloc(&mut self, handle: &Self::Handle) { - match handle { - SimpleHandle::Chunk(id) => { - if let Some((handle, _slices)) = self.chunks.remove(id) { - self.storage.dealloc(handle.id); - } - } - SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"), + fn alloc(&mut self, size: usize) -> Self::Handle { + self.create_chunk(size) + } + + fn dealloc(&mut self, handle: &Self::Handle) { + match handle { + SimpleHandle::Chunk(id) => { + if let Some((handle, _slices)) = self.chunks.remove(id) { + self.storage.dealloc(handle.id); } + } + SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"), } + } - fn storage(&mut self) -> &mut Storage { - &mut self.storage - } + fn storage(&mut self) -> &mut Storage { + &mut self.storage + } } impl SimpleMemoryManagement { - /// Creates a new instance using the given storage, deallocation strategy and slice strategy. - pub fn new( - storage: Storage, - dealloc_strategy: DeallocStrategy, - slice_strategy: SliceStrategy, - ) -> Self { - Self { - chunks: HashMap::new(), - slices: HashMap::new(), - dealloc_strategy, - slice_strategy, - storage, - } + /// Creates a new instance using the given storage, deallocation strategy and slice strategy. + pub fn new( + storage: Storage, + dealloc_strategy: DeallocStrategy, + slice_strategy: SliceStrategy, + ) -> Self { + Self { + chunks: HashMap::new(), + slices: HashMap::new(), + dealloc_strategy, + slice_strategy, + storage, } + } - fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { - // Looks for a large enough, existing but unused chunk of memory. - let chunk = self.find_free_chunk(size); - - match chunk { - Some((chunk_id, chunk_size)) => { - if size == chunk_size { - // If there is one of exactly the same size, it reuses it. - SimpleHandle::Chunk(chunk_id.clone()) - } else { - // Otherwise creates a slice of the right size upon it, always starting at zero. - self.create_slice(size, chunk_id) - } - } - // If no chunk available, creates one of exactly the right size. - None => self.create_chunk(size), - } - } + fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { + // Looks for a large enough, existing but unused chunk of memory. + let chunk = self.find_free_chunk(size); - /// Finds the smallest of the free and large enough chunks to fit `size` - /// Returns the chunk's id and size. - fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> { - let mut size_diff_current = usize::MAX; - let mut current = None; - - for (chunk_id, (resource, slices)) in self.chunks.iter() { - // If chunk is already used, we do not choose it - if !slices.is_empty() || !chunk_id.is_free() { - continue; - } - - let resource_size = resource.size(); - - // If we find a chunk of exactly the right size, we stop searching altogether - if size == resource_size { - current = Some((chunk_id, resource)); - break; - } - - // Finds the smallest of the large enough chunks that can accept a slice - // of the given size - if self.slice_strategy.can_use_chunk(resource_size, size) { - let size_diff = resource_size - size; - - if size_diff < size_diff_current { - current = Some((chunk_id, resource)); - size_diff_current = size_diff; - } - } + match chunk { + Some((chunk_id, chunk_size)) => { + if size == chunk_size { + // If there is one of exactly the same size, it reuses it. + SimpleHandle::Chunk(chunk_id.clone()) + } else { + // Otherwise creates a slice of the right size upon it, always starting at zero. + self.create_slice(size, chunk_id) } - - current.map(|(id, handle)| (id.clone(), handle.size())) + } + // If no chunk available, creates one of exactly the right size. + None => self.create_chunk(size), } - - /// Creates a slice of size `size` upon the given chunk. - /// - /// For now slices must start at zero, therefore there can be only one per chunk - fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle { - let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap(); - let slice_id = SliceId::new(); - - let storage = StorageHandle { - id: handle.id.clone(), - utilization: StorageUtilization::Slice(0, size), - }; - - if slices.is_empty() { - self.slices.insert(slice_id.clone(), (storage, chunk_id)); - } else { - panic!("Can't have more than 1 slice yet."); + } + + /// Finds the smallest of the free and large enough chunks to fit `size` + /// Returns the chunk's id and size. + fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> { + let mut size_diff_current = usize::MAX; + let mut current = None; + + for (chunk_id, (resource, slices)) in self.chunks.iter() { + // If chunk is already used, we do not choose it + if !slices.is_empty() || !chunk_id.is_free() { + continue; + } + + let resource_size = resource.size(); + + // If we find a chunk of exactly the right size, we stop searching altogether + if size == resource_size { + current = Some((chunk_id, resource)); + break; + } + + // Finds the smallest of the large enough chunks that can accept a slice + // of the given size + if self.slice_strategy.can_use_chunk(resource_size, size) { + let size_diff = resource_size - size; + + if size_diff < size_diff_current { + current = Some((chunk_id, resource)); + size_diff_current = size_diff; } - - slices.push(slice_id.clone()); - - SimpleHandle::Slice(slice_id) + } } - /// Creates a chunk of given size by allocating on the storage. - fn create_chunk(&mut self, size: usize) -> SimpleHandle { - let resource = self.storage.alloc(size); - let chunk_id = ChunkId::new(); + current.map(|(id, handle)| (id.clone(), handle.size())) + } - self.chunks.insert(chunk_id.clone(), (resource, Vec::new())); + /// Creates a slice of size `size` upon the given chunk. + /// + /// For now slices must start at zero, therefore there can be only one per chunk + fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle { + let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap(); + let slice_id = SliceId::new(); - SimpleHandle::Chunk(chunk_id) - } + let storage = StorageHandle { + id: handle.id.clone(), + utilization: StorageUtilization::Slice(0, size), + }; - /// Deallocates free chunks and remove them from chunks map. - fn cleanup_chunks(&mut self) { - let mut ids_to_remove = Vec::new(); - - self.chunks.iter().for_each(|(chunk_id, _resource)| { - if chunk_id.is_free() { - ids_to_remove.push(chunk_id.clone()); - } - }); - - ids_to_remove - .iter() - .map(|chunk_id| self.chunks.remove(chunk_id).unwrap()) - .for_each(|(resource, _slices)| { - self.storage.dealloc(resource.id); - }); + if slices.is_empty() { + self.slices.insert(slice_id.clone(), (storage, chunk_id)); + } else { + panic!("Can't have more than 1 slice yet."); } - /// Removes free slices from slice map and corresponding chunks. - fn cleanup_slices(&mut self) { - let mut ids_to_remove = Vec::new(); - - self.slices.iter().for_each(|(slice_id, _resource)| { - if slice_id.is_free() { - ids_to_remove.push(slice_id.clone()); - } - }); - - ids_to_remove - .iter() - .map(|slice_id| { - let value = self.slices.remove(slice_id).unwrap(); - (slice_id, value.1) - }) - .for_each(|(slice_id, chunk_id)| { - let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap(); - slices.retain(|id| id != slice_id); - }); - } + slices.push(slice_id.clone()); + + SimpleHandle::Slice(slice_id) + } + + /// Creates a chunk of given size by allocating on the storage. + fn create_chunk(&mut self, size: usize) -> SimpleHandle { + let resource = self.storage.alloc(size); + let chunk_id = ChunkId::new(); + + self.chunks.insert(chunk_id.clone(), (resource, Vec::new())); + + SimpleHandle::Chunk(chunk_id) + } + + /// Deallocates free chunks and remove them from chunks map. + fn cleanup_chunks(&mut self) { + let mut ids_to_remove = Vec::new(); + + self.chunks.iter().for_each(|(chunk_id, _resource)| { + if chunk_id.is_free() { + ids_to_remove.push(chunk_id.clone()); + } + }); + + ids_to_remove + .iter() + .map(|chunk_id| self.chunks.remove(chunk_id).unwrap()) + .for_each(|(resource, _slices)| { + self.storage.dealloc(resource.id); + }); + } + + /// Removes free slices from slice map and corresponding chunks. + fn cleanup_slices(&mut self) { + let mut ids_to_remove = Vec::new(); + + self.slices.iter().for_each(|(slice_id, _resource)| { + if slice_id.is_free() { + ids_to_remove.push(slice_id.clone()); + } + }); + + ids_to_remove + .iter() + .map(|slice_id| { + let value = self.slices.remove(slice_id).unwrap(); + (slice_id, value.1) + }) + .for_each(|(slice_id, chunk_id)| { + let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap(); + slices.retain(|id| id != slice_id); + }); + } } #[cfg(test)] mod tests { - use crate::{ - memory_management::{MemoryHandle, MemoryManagement, SliceStrategy}, - storage::BytesStorage, - }; - - use super::{DeallocStrategy, SimpleMemoryManagement}; - - #[test] - fn can_mut_with_single_tensor_reference() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - - let chunk_size = 4; - let simple_handle = memory_management.create_chunk(chunk_size); - - let x = simple_handle.clone(); - core::mem::drop(simple_handle); - - assert!(x.can_mut()); - } - - #[test] - fn two_tensor_references_remove_mutability() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - - let chunk_size = 4; - let simple_handle = memory_management.create_chunk(chunk_size); - - let x = simple_handle.clone(); - - assert!(!simple_handle.can_mut()); - assert!(!x.can_mut()) + use crate::{ + memory_management::{MemoryHandle, MemoryManagement, SliceStrategy}, + storage::BytesStorage, + }; + + use super::{DeallocStrategy, SimpleMemoryManagement}; + + #[test] + fn can_mut_with_single_tensor_reference() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + + let chunk_size = 4; + let simple_handle = memory_management.create_chunk(chunk_size); + + let x = simple_handle.clone(); + core::mem::drop(simple_handle); + + assert!(x.can_mut()); + } + + #[test] + fn two_tensor_references_remove_mutability() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + + let chunk_size = 4; + let simple_handle = memory_management.create_chunk(chunk_size); + + let x = simple_handle.clone(); + + assert!(!simple_handle.can_mut()); + assert!(!x.can_mut()) + } + + #[test] + fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + let chunk_size = 4; + let _chunk_handle = memory_management.reserve(chunk_size); + let _new_handle = memory_management.reserve(chunk_size); + + assert_eq!(memory_management.chunks.len(), 2); + } + + #[test] + fn when_empty_chunk_is_cleaned_upexists_it_disappears() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + let chunk_size = 4; + let chunk_handle = memory_management.reserve(chunk_size); + drop(chunk_handle); + memory_management.cleanup_chunks(); + + assert_eq!(memory_management.chunks.len(), 0); + } + + #[test] + fn never_dealloc_strategy_never_deallocs() { + let mut never_dealloc = DeallocStrategy::Never; + for _ in 0..20 { + assert!(!never_dealloc.should_dealloc()) } - - #[test] - fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - let chunk_size = 4; - let _chunk_handle = memory_management.reserve(chunk_size); - let _new_handle = memory_management.reserve(chunk_size); - - assert_eq!(memory_management.chunks.len(), 2); + } + + #[test] + fn period_tick_dealloc_strategy_should_dealloc_after_period() { + let period = 3; + let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period); + + for _ in 0..3 { + for _ in 0..period - 1 { + assert!(!period_tick_dealloc.should_dealloc()); + } + assert!(period_tick_dealloc.should_dealloc()); } + } - #[test] - fn when_empty_chunk_is_cleaned_upexists_it_disappears() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - let chunk_size = 4; - let chunk_handle = memory_management.reserve(chunk_size); - drop(chunk_handle); - memory_management.cleanup_chunks(); - - assert_eq!(memory_management.chunks.len(), 0); - } + #[test] + fn slice_strategy_minimum_bytes() { + let strategy = SliceStrategy::MinimumSize(100); - #[test] - fn never_dealloc_strategy_never_deallocs() { - let mut never_dealloc = DeallocStrategy::Never; - for _ in 0..20 { - assert!(!never_dealloc.should_dealloc()) - } - } + assert!(strategy.can_use_chunk(200, 101)); + assert!(!strategy.can_use_chunk(200, 99)); + } - #[test] - fn period_tick_dealloc_strategy_should_dealloc_after_period() { - let period = 3; - let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period); + #[test] + fn slice_strategy_maximum_bytes() { + let strategy = SliceStrategy::MaximumSize(100); - for _ in 0..3 { - for _ in 0..period - 1 { - assert!(!period_tick_dealloc.should_dealloc()); - } - assert!(period_tick_dealloc.should_dealloc()); - } - } + assert!(strategy.can_use_chunk(200, 99)); + assert!(!strategy.can_use_chunk(200, 101)); + } - #[test] - fn slice_strategy_minimum_bytes() { - let strategy = SliceStrategy::MinimumSize(100); + #[test] + fn slice_strategy_ratio() { + let strategy = SliceStrategy::Ratio(0.9); - assert!(strategy.can_use_chunk(200, 101)); - assert!(!strategy.can_use_chunk(200, 99)); - } - - #[test] - fn slice_strategy_maximum_bytes() { - let strategy = SliceStrategy::MaximumSize(100); - - assert!(strategy.can_use_chunk(200, 99)); - assert!(!strategy.can_use_chunk(200, 101)); - } - - #[test] - fn slice_strategy_ratio() { - let strategy = SliceStrategy::Ratio(0.9); - - assert!(strategy.can_use_chunk(200, 180)); - assert!(!strategy.can_use_chunk(200, 179)); - } + assert!(strategy.can_use_chunk(200, 180)); + assert!(!strategy.can_use_chunk(200, 179)); + } } diff --git a/burn-compute/src/server.rs b/burn-compute/src/server.rs index 3682ed487a..e8b75413e6 100644 --- a/burn-compute/src/server.rs +++ b/burn-compute/src/server.rs @@ -1,9 +1,9 @@ use core::fmt::Debug; use crate::{ - memory_management::{MemoryHandle, MemoryManagement}, - storage::ComputeStorage, - tune::AutotuneKey, + memory_management::{MemoryHandle, MemoryManagement}, + storage::ComputeStorage, + tune::AutotuneKey, }; use alloc::vec::Vec; use burn_common::reader::Reader; @@ -14,54 +14,54 @@ use burn_common::reader::Reader; /// [compute channel](crate::channel::ComputeChannel) for thread safety. pub trait ComputeServer: Send + core::fmt::Debug where - Self: Sized, + Self: Sized, { - /// The kernel type defines the computation algorithms. - type Kernel: Send; - /// The [storage](ComputeStorage) type defines how data is stored and accessed. - type Storage: ComputeStorage; - /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. - type MemoryManagement: MemoryManagement; - /// The key used to cache operations used on specific inputs in autotune - type AutotuneKey: AutotuneKey; + /// The kernel type defines the computation algorithms. + type Kernel: Send; + /// The [storage](ComputeStorage) type defines how data is stored and accessed. + type Storage: ComputeStorage; + /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. + type MemoryManagement: MemoryManagement; + /// The key used to cache operations used on specific inputs in autotune + type AutotuneKey: AutotuneKey; - /// Given a handle, returns the owned resource as bytes. - fn read(&mut self, handle: &Handle) -> Reader>; + /// Given a handle, returns the owned resource as bytes. + fn read(&mut self, handle: &Handle) -> Reader>; - /// Given a resource as bytes, stores it and returns the memory handle. - fn create(&mut self, data: &[u8]) -> Handle; + /// Given a resource as bytes, stores it and returns the memory handle. + fn create(&mut self, data: &[u8]) -> Handle; - /// Reserves `size` bytes in the storage, and returns a handle over them. - fn empty(&mut self, size: usize) -> Handle; + /// Reserves `size` bytes in the storage, and returns a handle over them. + fn empty(&mut self, size: usize) -> Handle; - /// Executes the `kernel` over the given memory `handles`. - /// - /// Kernels have mutable access to every resource they are given - /// and are responsible of determining which should be read or written. - fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]); + /// Executes the `kernel` over the given memory `handles`. + /// + /// Kernels have mutable access to every resource they are given + /// and are responsible of determining which should be read or written. + fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]); - /// Wait for the completion of every task in the server. - fn sync(&mut self); + /// Wait for the completion of every task in the server. + fn sync(&mut self); } /// Server handle containing the [memory handle](MemoryManagement::Handle). #[derive(new, Debug)] pub struct Handle { - /// Handle for the memory in use. - pub memory: >::Handle, + /// Handle for the memory in use. + pub memory: >::Handle, } impl Handle { - /// If the tensor handle can be mut with an inplace operation. - pub fn can_mut(&self) -> bool { - self.memory.can_mut() - } + /// If the tensor handle can be mut with an inplace operation. + pub fn can_mut(&self) -> bool { + self.memory.can_mut() + } } impl Clone for Handle { - fn clone(&self) -> Self { - Self { - memory: self.memory.clone(), - } + fn clone(&self) -> Self { + Self { + memory: self.memory.clone(), } + } } diff --git a/burn-compute/src/storage/base.rs b/burn-compute/src/storage/base.rs index ce6be5bceb..37a293f408 100644 --- a/burn-compute/src/storage/base.rs +++ b/burn-compute/src/storage/base.rs @@ -6,43 +6,43 @@ storage_id_type!(StorageId); /// Defines if data uses a full memory chunk or a slice of it. #[derive(Clone)] pub enum StorageUtilization { - /// Full memory chunk of specified size - Full(usize), - /// Slice of memory chunk with start index and size. - Slice(usize, usize), + /// Full memory chunk of specified size + Full(usize), + /// Slice of memory chunk with start index and size. + Slice(usize, usize), } /// Contains the [storage id](StorageId) of a resource and the way it is used. #[derive(new)] pub struct StorageHandle { - /// Storage id. - pub id: StorageId, - /// How the storage is used. - pub utilization: StorageUtilization, + /// Storage id. + pub id: StorageId, + /// How the storage is used. + pub utilization: StorageUtilization, } impl StorageHandle { - /// Returns the size the handle is pointing to in memory. - pub fn size(&self) -> usize { - match self.utilization { - StorageUtilization::Full(size) => size, - StorageUtilization::Slice(_, size) => size, - } + /// Returns the size the handle is pointing to in memory. + pub fn size(&self) -> usize { + match self.utilization { + StorageUtilization::Full(size) => size, + StorageUtilization::Slice(_, size) => size, } + } } /// Storage types are responsible for allocating and deallocating memory. pub trait ComputeStorage: Send { - /// The resource associated type determines the way data is implemented and how - /// it can be accessed by kernels. - type Resource: Send; + /// The resource associated type determines the way data is implemented and how + /// it can be accessed by kernels. + type Resource: Send; - /// Returns the underlying resource for a specified storage handle - fn get(&mut self, handle: &StorageHandle) -> Self::Resource; + /// Returns the underlying resource for a specified storage handle + fn get(&mut self, handle: &StorageHandle) -> Self::Resource; - /// Allocates `size` units of memory and returns a handle to it - fn alloc(&mut self, size: usize) -> StorageHandle; + /// Allocates `size` units of memory and returns a handle to it + fn alloc(&mut self, size: usize) -> StorageHandle; - /// Deallocates the memory pointed by the given storage id. - fn dealloc(&mut self, id: StorageId); + /// Deallocates the memory pointed by the given storage id. + fn dealloc(&mut self, id: StorageId); } diff --git a/burn-compute/src/storage/bytes_cpu.rs b/burn-compute/src/storage/bytes_cpu.rs index bfaf07950e..7f54d180fe 100644 --- a/burn-compute/src/storage/bytes_cpu.rs +++ b/burn-compute/src/storage/bytes_cpu.rs @@ -5,13 +5,13 @@ use hashbrown::HashMap; /// The bytes storage maps ids to pointers of bytes in a contiguous layout. #[derive(Default)] pub struct BytesStorage { - memory: HashMap, + memory: HashMap, } impl core::fmt::Debug for BytesStorage { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str("BytesStorage") - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str("BytesStorage") + } } /// Can send to other threads. @@ -20,108 +20,108 @@ unsafe impl Send for BytesResource {} /// This struct is a pointer to a memory chunk or slice. pub struct BytesResource { - ptr: *mut u8, - utilization: StorageUtilization, + ptr: *mut u8, + utilization: StorageUtilization, } /// This struct refers to a specific (contiguous) layout of bytes. struct AllocatedBytes { - ptr: *mut u8, - layout: Layout, + ptr: *mut u8, + layout: Layout, } impl BytesResource { - fn get_exact_location_and_length(&self) -> (*mut u8, usize) { - match self.utilization { - StorageUtilization::Full(len) => (self.ptr, len), - StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) }, - } + fn get_exact_location_and_length(&self) -> (*mut u8, usize) { + match self.utilization { + StorageUtilization::Full(len) => (self.ptr, len), + StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) }, } + } - /// Returns the resource as a mutable slice of bytes. - pub fn write<'a>(&self) -> &'a mut [u8] { - let (ptr, len) = self.get_exact_location_and_length(); + /// Returns the resource as a mutable slice of bytes. + pub fn write<'a>(&self) -> &'a mut [u8] { + let (ptr, len) = self.get_exact_location_and_length(); - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - } + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + } - /// Returns the resource as an immutable slice of bytes. - pub fn read<'a>(&self) -> &'a [u8] { - let (ptr, len) = self.get_exact_location_and_length(); + /// Returns the resource as an immutable slice of bytes. + pub fn read<'a>(&self) -> &'a [u8] { + let (ptr, len) = self.get_exact_location_and_length(); - unsafe { core::slice::from_raw_parts(ptr, len) } - } + unsafe { core::slice::from_raw_parts(ptr, len) } + } } impl ComputeStorage for BytesStorage { - type Resource = BytesResource; + type Resource = BytesResource; - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let allocated_bytes = self.memory.get_mut(&handle.id).unwrap(); + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let allocated_bytes = self.memory.get_mut(&handle.id).unwrap(); - BytesResource { - ptr: allocated_bytes.ptr, - utilization: handle.utilization.clone(), - } + BytesResource { + ptr: allocated_bytes.ptr, + utilization: handle.utilization.clone(), } + } - fn alloc(&mut self, size: usize) -> StorageHandle { - let id = StorageId::new(); - let handle = StorageHandle { - id: id.clone(), - utilization: StorageUtilization::Full(size), - }; + fn alloc(&mut self, size: usize) -> StorageHandle { + let id = StorageId::new(); + let handle = StorageHandle { + id: id.clone(), + utilization: StorageUtilization::Full(size), + }; - unsafe { - let layout = Layout::array::(size).unwrap(); - let ptr = alloc(layout); - let memory = AllocatedBytes { ptr, layout }; + unsafe { + let layout = Layout::array::(size).unwrap(); + let ptr = alloc(layout); + let memory = AllocatedBytes { ptr, layout }; - self.memory.insert(id, memory); - } - - handle + self.memory.insert(id, memory); } - fn dealloc(&mut self, id: StorageId) { - if let Some(memory) = self.memory.remove(&id) { - unsafe { - dealloc(memory.ptr, memory.layout); - } - } + handle + } + + fn dealloc(&mut self, id: StorageId) { + if let Some(memory) = self.memory.remove(&id) { + unsafe { + dealloc(memory.ptr, memory.layout); + } } + } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_can_alloc_and_dealloc() { - let mut storage = BytesStorage::default(); - let handle_1 = storage.alloc(64); - - assert_eq!(handle_1.size(), 64); - storage.dealloc(handle_1.id); - } - - #[test] - fn test_slices() { - let mut storage = BytesStorage::default(); - let handle_1 = storage.alloc(64); - let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8)); - - storage - .get(&handle_1) - .write() - .iter_mut() - .enumerate() - .for_each(|(i, b)| { - *b = i as u8; - }); - - let bytes = storage.get(&handle_2).read().to_vec(); - storage.dealloc(handle_1.id); - assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]); - } + use super::*; + + #[test] + fn test_can_alloc_and_dealloc() { + let mut storage = BytesStorage::default(); + let handle_1 = storage.alloc(64); + + assert_eq!(handle_1.size(), 64); + storage.dealloc(handle_1.id); + } + + #[test] + fn test_slices() { + let mut storage = BytesStorage::default(); + let handle_1 = storage.alloc(64); + let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8)); + + storage + .get(&handle_1) + .write() + .iter_mut() + .enumerate() + .for_each(|(i, b)| { + *b = i as u8; + }); + + let bytes = storage.get(&handle_2).read().to_vec(); + storage.dealloc(handle_1.id); + assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]); + } } diff --git a/burn-compute/src/tune/operation.rs b/burn-compute/src/tune/operation.rs index 548b5f59f2..1dd1d19e60 100644 --- a/burn-compute/src/tune/operation.rs +++ b/burn-compute/src/tune/operation.rs @@ -6,30 +6,30 @@ use core::hash::Hash; /// Groups operations of the same type for autotune pub trait AutotuneOperationSet: Send { - /// The key used in the tune cache - fn key(&self) -> K; + /// The key used in the tune cache + fn key(&self) -> K; - /// All candidate operations for autotuning this operation type - /// Operations can run on toy tensors of relevant size - fn autotunables(&self) -> Vec>; + /// All candidate operations for autotuning this operation type + /// Operations can run on toy tensors of relevant size + fn autotunables(&self) -> Vec>; - /// Returns the operation for the given index, matching the order - /// returned by autotunables. Operation obtained here runs on original tensors - fn fastest(self: Box, fastest_index: usize) -> Box; + /// Returns the operation for the given index, matching the order + /// returned by autotunables. Operation obtained here runs on original tensors + fn fastest(self: Box, fastest_index: usize) -> Box; } /// Contains operation to run and inputs on which to run it pub trait AutotuneOperation { - /// Runs the operation - fn execute(self: Box); + /// Runs the operation + fn execute(self: Box); - /// The name of the operation. - fn name(&self) -> &str { - core::any::type_name::() - } + /// The name of the operation. + fn name(&self) -> &str { + core::any::type_name::() + } - /// Clones the operation and inputs - fn clone(&self) -> Box; + /// Clones the operation and inputs + fn clone(&self) -> Box; } /// Trait alias diff --git a/burn-compute/src/tune/tune_benchmark.rs b/burn-compute/src/tune/tune_benchmark.rs index dd0231340d..bf4051e77c 100644 --- a/burn-compute/src/tune/tune_benchmark.rs +++ b/burn-compute/src/tune/tune_benchmark.rs @@ -11,30 +11,30 @@ use alloc::string::{String, ToString}; /// A benchmark that runs on server handles #[derive(new)] pub struct TuneBenchmark { - operation: Box, - client: ComputeClient, + operation: Box, + client: ComputeClient, } impl> Benchmark for TuneBenchmark { - type Args = Box; + type Args = Box; - fn prepare(&self) -> Self::Args { - self.operation.clone() - } + fn prepare(&self) -> Self::Args { + self.operation.clone() + } - fn num_samples(&self) -> usize { - 10 - } + fn num_samples(&self) -> usize { + 10 + } - fn execute(&self, operation: Self::Args) { - AutotuneOperation::execute(operation); - } + fn execute(&self, operation: Self::Args) { + AutotuneOperation::execute(operation); + } - fn name(&self) -> String { - "Autotune".to_string() - } + fn name(&self) -> String { + "Autotune".to_string() + } - fn sync(&self) { - self.client.sync(); - } + fn sync(&self) { + self.client.sync(); + } } diff --git a/burn-compute/src/tune/tune_cache.rs b/burn-compute/src/tune/tune_cache.rs index d91b70ec12..30d63619da 100644 --- a/burn-compute/src/tune/tune_cache.rs +++ b/burn-compute/src/tune/tune_cache.rs @@ -7,37 +7,37 @@ use hashbrown::HashMap; /// Use to find and reuse the best kernel for some input #[derive(Debug, Default)] pub(crate) struct TuneCache { - cache: HashMap, + cache: HashMap, } /// Result of the cache try pub enum TuneCacheResult { - /// An operation is found and given - Hit(Box), - /// No operation is found and the set is given back for ownership - Miss(Box>), + /// An operation is found and given + Hit(Box), + /// No operation is found and the set is given back for ownership + Miss(Box>), } impl TuneCache { - pub(crate) fn new() -> Self { - TuneCache { - cache: HashMap::new(), - } + pub(crate) fn new() -> Self { + TuneCache { + cache: HashMap::new(), } + } - #[allow(clippy::borrowed_box)] - pub(crate) fn try_cache( - &self, - autotune_operation_set: Box>, - ) -> TuneCacheResult { - let index = self.cache.get(&autotune_operation_set.key()); - if let Some(&i) = index { - return TuneCacheResult::Hit(autotune_operation_set.fastest(i)); - } - TuneCacheResult::Miss(autotune_operation_set) + #[allow(clippy::borrowed_box)] + pub(crate) fn try_cache( + &self, + autotune_operation_set: Box>, + ) -> TuneCacheResult { + let index = self.cache.get(&autotune_operation_set.key()); + if let Some(&i) = index { + return TuneCacheResult::Hit(autotune_operation_set.fastest(i)); } + TuneCacheResult::Miss(autotune_operation_set) + } - pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) { - self.cache.insert(key, fastest_index); - } + pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) { + self.cache.insert(key, fastest_index); + } } diff --git a/burn-compute/src/tune/tuner.rs b/burn-compute/src/tune/tuner.rs index c9a9afeb8f..a2d8a3e7fe 100644 --- a/burn-compute/src/tune/tuner.rs +++ b/burn-compute/src/tune/tuner.rs @@ -14,87 +14,87 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa #[derive(Debug, Default)] /// Executes autotune benchmarking and caching pub struct Tuner { - tune_cache: TuneCache, - _channel: PhantomData, + tune_cache: TuneCache, + _channel: PhantomData, } impl> Tuner { - /// Returns a tuner with empty cache - pub fn new() -> Self { - Self { - tune_cache: TuneCache::new(), - _channel: PhantomData, - } + /// Returns a tuner with empty cache + pub fn new() -> Self { + Self { + tune_cache: TuneCache::new(), + _channel: PhantomData, } - - pub(crate) fn execute_autotune( - &mut self, - autotune_operation_set: Box>, - client: &ComputeClient, - ) { - let operation = match self.tune_cache.try_cache(autotune_operation_set) { - super::TuneCacheResult::Hit(ops) => ops, - super::TuneCacheResult::Miss(set) => self.autotuning(set, client), - }; - - AutotuneOperation::execute(operation); + } + + pub(crate) fn execute_autotune( + &mut self, + autotune_operation_set: Box>, + client: &ComputeClient, + ) { + let operation = match self.tune_cache.try_cache(autotune_operation_set) { + super::TuneCacheResult::Hit(ops) => ops, + super::TuneCacheResult::Miss(set) => self.autotuning(set, client), + }; + + AutotuneOperation::execute(operation); + } + + fn autotuning( + &mut self, + autotune_operation_set: Box>, + client: &ComputeClient, + ) -> Box { + let key = autotune_operation_set.key(); + let autotunables = autotune_operation_set.autotunables(); + let mut names = Vec::with_capacity(autotunables.len()); + + // Run all autotune benchmarks + let results: Vec = autotunables + .into_iter() + .map(|op| { + names.push(op.name().to_string()); + self.run_benchmark(op, client) + }) + .collect(); + + for (name, result) in names.iter().zip(results.iter()) { + log::info!("Benchmark result {name}-{key} => {result}"); } - fn autotuning( - &mut self, - autotune_operation_set: Box>, - client: &ComputeClient, - ) -> Box { - let key = autotune_operation_set.key(); - let autotunables = autotune_operation_set.autotunables(); - let mut names = Vec::with_capacity(autotunables.len()); - - // Run all autotune benchmarks - let results: Vec = autotunables - .into_iter() - .map(|op| { - names.push(op.name().to_string()); - self.run_benchmark(op, client) - }) - .collect(); - - for (name, result) in names.iter().zip(results.iter()) { - log::info!("Benchmark result {name}-{key} => {result}"); - } - - // Finds the fastest operation, stores it and returns it - let fastest_index = self.find_fastest(results); - let fastest_name = names.get(fastest_index).unwrap(); - log::info!("Fastest result {fastest_name}-{key}"); - - self.tune_cache.cache_insert(key, fastest_index); - match self.tune_cache.try_cache(autotune_operation_set) { - super::TuneCacheResult::Hit(ops) => ops, - super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"), - } - } + // Finds the fastest operation, stores it and returns it + let fastest_index = self.find_fastest(results); + let fastest_name = names.get(fastest_index).unwrap(); + log::info!("Fastest result {fastest_name}-{key}"); - fn run_benchmark( - &mut self, - operation: Box, - client: &ComputeClient, - ) -> BenchmarkResult { - TuneBenchmark::new(operation, client.clone()).run() + self.tune_cache.cache_insert(key, fastest_index); + match self.tune_cache.try_cache(autotune_operation_set) { + super::TuneCacheResult::Hit(ops) => ops, + super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"), } - - fn find_fastest(&self, results: Vec) -> usize { - let mut smallest_duration = Duration::MAX; - let mut fastest_tunable = None; - - for (i, result) in results.into_iter().enumerate() { - let duration = result.median_duration(); - - if duration < smallest_duration { - smallest_duration = duration; - fastest_tunable = Some(i); - } - } - - fastest_tunable.expect("At least one kernel needed. ") + } + + fn run_benchmark( + &mut self, + operation: Box, + client: &ComputeClient, + ) -> BenchmarkResult { + TuneBenchmark::new(operation, client.clone()).run() + } + + fn find_fastest(&self, results: Vec) -> usize { + let mut smallest_duration = Duration::MAX; + let mut fastest_tunable = None; + + for (i, result) in results.into_iter().enumerate() { + let duration = result.median_duration(); + + if duration < smallest_duration { + smallest_duration = duration; + fastest_tunable = Some(i); + } } + + fastest_tunable.expect("At least one kernel needed. ") + } } diff --git a/burn-compute/tests/dummy/compute.rs b/burn-compute/tests/dummy/compute.rs index 54081a1510..e651082565 100644 --- a/burn-compute/tests/dummy/compute.rs +++ b/burn-compute/tests/dummy/compute.rs @@ -19,14 +19,14 @@ pub type DummyClient = ComputeClient; static COMPUTE: Compute = Compute::new(); pub fn client(device: &DummyDevice) -> DummyClient { - COMPUTE.client(device, || { - let storage = BytesStorage::default(); - let memory_management = - SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never); - let server = DummyServer::new(memory_management); - let channel = MutexComputeChannel::new(server); - let tuner = Arc::new(Mutex::new(Tuner::new())); + COMPUTE.client(device, || { + let storage = BytesStorage::default(); + let memory_management = + SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never); + let server = DummyServer::new(memory_management); + let channel = MutexComputeChannel::new(server); + let tuner = Arc::new(Mutex::new(Tuner::new())); - ComputeClient::new(channel, tuner) - }) + ComputeClient::new(channel, tuner) + }) } diff --git a/burn-compute/tests/dummy/kernel.rs b/burn-compute/tests/dummy/kernel.rs index 30a67d5538..b2f8cf2668 100644 --- a/burn-compute/tests/dummy/kernel.rs +++ b/burn-compute/tests/dummy/kernel.rs @@ -2,24 +2,24 @@ use burn_compute::storage::BytesResource; /// The DummyKernel trait should be implemented for every supported operation pub trait DummyKernel: Sync + Send { - fn compute(&self, resources: &mut [BytesResource]); + fn compute(&self, resources: &mut [BytesResource]); } /// Contains the algorithm for element-wise addition pub struct DummyElementwiseAddition; impl DummyKernel for DummyElementwiseAddition { - fn compute(&self, inputs: &mut [BytesResource]) { - // Notice how the kernel is responsible for determining which inputs - // are read-only and which are writable. - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + // Notice how the kernel is responsible for determining which inputs + // are read-only and which are writable. + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - out[i] = lhs[i] + rhs[i]; - } + for i in 0..size { + out[i] = lhs[i] + rhs[i]; } + } } diff --git a/burn-compute/tests/dummy/server.rs b/burn-compute/tests/dummy/server.rs index 55d8f49c2e..5749a407ab 100644 --- a/burn-compute/tests/dummy/server.rs +++ b/burn-compute/tests/dummy/server.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use burn_common::reader::Reader; use burn_compute::{ - memory_management::{MemoryManagement, SimpleMemoryManagement}, - server::{ComputeServer, Handle}, - storage::BytesStorage, + memory_management::{MemoryManagement, SimpleMemoryManagement}, + server::{ComputeServer, Handle}, + storage::BytesStorage, }; use derive_new::new; @@ -14,51 +14,51 @@ use super::DummyKernel; /// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks. #[derive(new, Debug)] pub struct DummyServer> { - memory_management: MM, + memory_management: MM, } impl ComputeServer for DummyServer where - MM: MemoryManagement, + MM: MemoryManagement, { - type Kernel = Arc; - type Storage = BytesStorage; - type MemoryManagement = MM; - type AutotuneKey = String; + type Kernel = Arc; + type Storage = BytesStorage; + type MemoryManagement = MM; + type AutotuneKey = String; - fn read(&mut self, handle: &Handle) -> Reader> { - let bytes = self.memory_management.get(&handle.memory); + fn read(&mut self, handle: &Handle) -> Reader> { + let bytes = self.memory_management.get(&handle.memory); - Reader::Concrete(bytes.read().to_vec()) - } - - fn create(&mut self, data: &[u8]) -> Handle { - let handle = self.memory_management.reserve(data.len()); - let resource = self.memory_management.get(&handle); + Reader::Concrete(bytes.read().to_vec()) + } - let bytes = resource.write(); + fn create(&mut self, data: &[u8]) -> Handle { + let handle = self.memory_management.reserve(data.len()); + let resource = self.memory_management.get(&handle); - for (i, val) in data.iter().enumerate() { - bytes[i] = *val; - } + let bytes = resource.write(); - Handle::new(handle) + for (i, val) in data.iter().enumerate() { + bytes[i] = *val; } - fn empty(&mut self, size: usize) -> Handle { - Handle::new(self.memory_management.reserve(size)) - } + Handle::new(handle) + } - fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]) { - let mut resources = handles - .iter() - .map(|handle| self.memory_management.get(&handle.memory)) - .collect::>(); + fn empty(&mut self, size: usize) -> Handle { + Handle::new(self.memory_management.reserve(size)) + } - kernel.compute(&mut resources); - } + fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]) { + let mut resources = handles + .iter() + .map(|handle| self.memory_management.get(&handle.memory)) + .collect::>(); - fn sync(&mut self) { - // Nothing to do with dummy backend. - } + kernel.compute(&mut resources); + } + + fn sync(&mut self) { + // Nothing to do with dummy backend. + } } diff --git a/burn-compute/tests/dummy/tune/autotune_operations.rs b/burn-compute/tests/dummy/tune/autotune_operations.rs index 5af0eaa472..fec4ab029c 100644 --- a/burn-compute/tests/dummy/tune/autotune_operations.rs +++ b/burn-compute/tests/dummy/tune/autotune_operations.rs @@ -9,25 +9,25 @@ use crate::dummy::{DummyChannel, DummyKernel, DummyServer}; /// Extended kernel that accounts for additional parameters, i.e. needed /// information that does not count as an input/output. pub struct OneKernelAutotuneOperation { - kernel: Arc, - client: ComputeClient, - shapes: Vec>, - handles: Vec>, + kernel: Arc, + client: ComputeClient, + shapes: Vec>, + handles: Vec>, } impl AutotuneOperation for OneKernelAutotuneOperation { - /// Executes the operation on given handles and server, with the additional parameters - fn execute(self: Box) { - let handle_refs: &Vec<&Handle> = &self.handles.iter().collect(); - self.client.execute(self.kernel.clone(), handle_refs); - } + /// Executes the operation on given handles and server, with the additional parameters + fn execute(self: Box) { + let handle_refs: &Vec<&Handle> = &self.handles.iter().collect(); + self.client.execute(self.kernel.clone(), handle_refs); + } - fn clone(&self) -> Box { - Box::new(Self { - kernel: self.kernel.clone(), - client: self.client.clone(), - shapes: self.shapes.clone(), - handles: self.handles.clone(), - }) - } + fn clone(&self) -> Box { + Box::new(Self { + kernel: self.kernel.clone(), + client: self.client.clone(), + shapes: self.shapes.clone(), + handles: self.handles.clone(), + }) + } } diff --git a/burn-compute/tests/dummy/tune/kernels.rs b/burn-compute/tests/dummy/tune/kernels.rs index bd0058310b..ffce69802f 100644 --- a/burn-compute/tests/dummy/tune/kernels.rs +++ b/burn-compute/tests/dummy/tune/kernels.rs @@ -14,93 +14,93 @@ pub struct CacheTestSlowOn3; pub struct ParameteredKernel; impl DummyKernel for DummyElementwiseAdditionSlowWrong { - fn compute(&self, inputs: &mut [BytesResource]) { - // Slow and wrong on purpose, for tests - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + // Slow and wrong on purpose, for tests + let lhs = &inputs[0].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i] - } + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = lhs[i] } + } } impl DummyKernel for DummyElementwiseMultiplication { - fn compute(&self, inputs: &mut [BytesResource]) { - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - out[i] = lhs[i] * rhs[i]; - } + for i in 0..size { + out[i] = lhs[i] * rhs[i]; } + } } impl DummyKernel for DummyElementwiseMultiplicationSlowWrong { - fn compute(&self, inputs: &mut [BytesResource]) { - // Slow and wrong on purpose, for tests - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + // Slow and wrong on purpose, for tests + let lhs = &inputs[0].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i]; - } + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = lhs[i]; } + } } impl DummyKernel for CacheTestFastOn3 { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for testing cache only - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - if size == 3 { - out[..size].copy_from_slice(&lhs[..size]); - } else { - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i]; - } - } + fn compute(&self, inputs: &mut [BytesResource]) { + // This is an artificial kernel designed for testing cache only + let lhs = &inputs[0].read(); + let out = &mut inputs[2].write(); + + let size = lhs.len(); + if size == 3 { + out[..size].copy_from_slice(&lhs[..size]); + } else { + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = lhs[i]; + } } + } } impl DummyKernel for CacheTestSlowOn3 { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for testing cache only - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - if size == 3 { - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = rhs[i]; - } - } else { - out[..size].copy_from_slice(&rhs[..size]); - } + fn compute(&self, inputs: &mut [BytesResource]) { + // This is an artificial kernel designed for testing cache only + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); + + let size = lhs.len(); + if size == 3 { + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = rhs[i]; + } + } else { + out[..size].copy_from_slice(&rhs[..size]); } + } } impl DummyKernel for ParameteredKernel { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for info buffer - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - let info = &inputs[3].read(); - - for i in 0..lhs.len() { - out[i] = lhs[i] + rhs[i] + info[0]; - } + fn compute(&self, inputs: &mut [BytesResource]) { + // This is an artificial kernel designed for info buffer + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); + let info = &inputs[3].read(); + + for i in 0..lhs.len() { + out[i] = lhs[i] + rhs[i] + info[0]; } + } } diff --git a/burn-compute/tests/dummy/tune/operation_sets.rs b/burn-compute/tests/dummy/tune/operation_sets.rs index f5b30e0727..89a68f1aa6 100644 --- a/burn-compute/tests/dummy/tune/operation_sets.rs +++ b/burn-compute/tests/dummy/tune/operation_sets.rs @@ -1,170 +1,170 @@ use std::sync::Arc; use burn_compute::{ - server::Handle, - tune::{AutotuneOperation, AutotuneOperationSet}, + server::Handle, + tune::{AutotuneOperation, AutotuneOperationSet}, }; use crate::dummy::{ - CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition, - DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer, - OneKernelAutotuneOperation, + CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition, + DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer, + OneKernelAutotuneOperation, }; use super::DummyElementwiseAdditionSlowWrong; pub struct AdditionAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - handles: Vec>, + client: DummyClient, + key: String, + shapes: Vec>, + handles: Vec>, } impl AdditionAutotuneOperationSet { - pub fn new( - client: DummyClient, - shapes: Vec>, - handles: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "add", log_shape_input_key(&shapes)), - shapes, - handles, - } + pub fn new( + client: DummyClient, + shapes: Vec>, + handles: Vec>, + ) -> Self { + Self { + client, + key: format!("{}-{}", "add", log_shape_input_key(&shapes)), + shapes, + handles, } + } } impl AutotuneOperationSet for AdditionAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseAddition), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseAdditionSlowWrong), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } + fn key(&self) -> String { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + vec![ + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseAddition), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseAdditionSlowWrong), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + self.autotunables()[fastest_index].clone() + } } pub struct MultiplicationAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - handles: Vec>, + client: DummyClient, + key: String, + shapes: Vec>, + handles: Vec>, } impl MultiplicationAutotuneOperationSet { - pub fn new( - client: DummyClient, - shapes: Vec>, - handles: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "mul", log_shape_input_key(&shapes)), - shapes, - handles, - } + pub fn new( + client: DummyClient, + shapes: Vec>, + handles: Vec>, + ) -> Self { + Self { + client, + key: format!("{}-{}", "mul", log_shape_input_key(&shapes)), + shapes, + handles, } + } } impl AutotuneOperationSet for MultiplicationAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseMultiplicationSlowWrong), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseMultiplication), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } + fn key(&self) -> String { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + vec![ + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseMultiplicationSlowWrong), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseMultiplication), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + self.autotunables()[fastest_index].clone() + } } pub struct CacheTestAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - handles: Vec>, + client: DummyClient, + key: String, + shapes: Vec>, + handles: Vec>, } impl CacheTestAutotuneOperationSet { - pub fn new( - client: DummyClient, - shapes: Vec>, - handles: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)), - shapes, - handles, - } + pub fn new( + client: DummyClient, + shapes: Vec>, + handles: Vec>, + ) -> Self { + Self { + client, + key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)), + shapes, + handles, } + } } impl AutotuneOperationSet for CacheTestAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(CacheTestFastOn3), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(CacheTestSlowOn3), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } + fn key(&self) -> String { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + vec![ + Box::new(OneKernelAutotuneOperation::new( + Arc::new(CacheTestFastOn3), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + Box::new(OneKernelAutotuneOperation::new( + Arc::new(CacheTestSlowOn3), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + self.autotunables()[fastest_index].clone() + } } pub fn log_shape_input_key(shapes: &[Vec]) -> String { - let mut hash = String::new(); - let lhs = &shapes[0]; - for size in lhs { - let exp = f32::ceil(f32::log2(*size as f32)) as u32; - hash.push_str(2_u32.pow(exp).to_string().as_str()); - hash.push(','); - } - hash + let mut hash = String::new(); + let lhs = &shapes[0]; + for size in lhs { + let exp = f32::ceil(f32::log2(*size as f32)) as u32; + hash.push_str(2_u32.pow(exp).to_string().as_str()); + hash.push(','); + } + hash } diff --git a/burn-compute/tests/integration_test.rs b/burn-compute/tests/integration_test.rs index 79532f7f9a..c9db430423 100644 --- a/burn-compute/tests/integration_test.rs +++ b/burn-compute/tests/integration_test.rs @@ -8,141 +8,141 @@ use serial_test::serial; #[test] fn created_resource_is_the_same_when_read() { - let client = client(&DummyDevice); - let resource = Vec::from([0, 1, 2]); - let resource_description = client.create(&resource); + let client = client(&DummyDevice); + let resource = Vec::from([0, 1, 2]); + let resource_description = client.create(&resource); - let obtained_resource = client.read(&resource_description); + let obtained_resource = client.read(&resource_description); - assert_eq!(resource, obtained_resource.read()) + assert_eq!(resource, obtained_resource.read()) } #[test] fn empty_allocates_memory() { - let client = client(&DummyDevice); - let size = 4; - let resource_description = client.empty(size); - let empty_resource = client.read(&resource_description); + let client = client(&DummyDevice); + let size = 4; + let resource_description = client.empty(size); + let empty_resource = client.read(&resource_description); - assert_eq!(empty_resource.read().len(), 4); + assert_eq!(empty_resource.read().len(), 4); } #[test] fn execute_elementwise_addition() { - let client = client(&DummyDevice); - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); + let client = client(&DummyDevice); + let lhs = client.create(&[0, 1, 2]); + let rhs = client.create(&[4, 4, 4]); + let out = client.empty(3); - client.execute(Arc::new(DummyElementwiseAddition), &[&lhs, &rhs, &out]); + client.execute(Arc::new(DummyElementwiseAddition), &[&lhs, &rhs, &out]); - let obtained_resource = client.read(&out); + let obtained_resource = client.read(&out); - assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])) + assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])) } #[test] #[serial] #[cfg(feature = "std")] fn autotune_basic_addition_execution() { - let client = client(&DummyDevice); + let client = client(&DummyDevice); - let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - let handles = vec![lhs, rhs, out.clone()]; + let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs = client.create(&[0, 1, 2]); + let rhs = client.create(&[4, 4, 4]); + let out = client.empty(3); + let handles = vec![lhs, rhs, out.clone()]; - let addition_autotune_kernel = - dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles); - client.execute_autotune(Box::new(addition_autotune_kernel)); + let addition_autotune_kernel = + dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles); + client.execute_autotune(Box::new(addition_autotune_kernel)); - let obtained_resource = client.read(&out); + let obtained_resource = client.read(&out); - // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])); + // If slow kernel was selected it would output [0, 1, 2] + assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])); } #[test] #[serial] #[cfg(feature = "std")] fn autotune_basic_multiplication_execution() { - let client = client(&DummyDevice); + let client = client(&DummyDevice); - let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - let handles = vec![lhs, rhs, out.clone()]; + let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs = client.create(&[0, 1, 2]); + let rhs = client.create(&[4, 4, 4]); + let out = client.empty(3); + let handles = vec![lhs, rhs, out.clone()]; - let multiplication_autotune_kernel = - dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles); - client.execute_autotune(Box::new(multiplication_autotune_kernel)); + let multiplication_autotune_kernel = + dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles); + client.execute_autotune(Box::new(multiplication_autotune_kernel)); - let obtained_resource = client.read(&out); + let obtained_resource = client.read(&out); - // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource.read(), Vec::from([0, 4, 8])); + // If slow kernel was selected it would output [0, 1, 2] + assert_eq!(obtained_resource.read(), Vec::from([0, 4, 8])); } #[test] #[serial] #[cfg(feature = "std")] fn autotune_cache_hit_test() { - let client = client(&DummyDevice); - - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1, rhs_1, out_1]; - - let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]]; - let lhs_2 = client.create(&[0, 1, 2, 3]); - let rhs_2 = client.create(&[5, 6, 7, 8]); - let out_2 = client.empty(4); - let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; - - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - let cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); - client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); - - let obtained_resource = client.read(&out_2); - - // Cache should be hit, so CacheTestFastOn3 should be used, returning lhs - assert_eq!(obtained_resource.read(), Vec::from([0, 1, 2, 3])); + let client = client(&DummyDevice); + + let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs_1 = client.create(&[0, 1, 2]); + let rhs_1 = client.create(&[4, 4, 4]); + let out_1 = client.empty(3); + let handles_1 = vec![lhs_1, rhs_1, out_1]; + + let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]]; + let lhs_2 = client.create(&[0, 1, 2, 3]); + let rhs_2 = client.create(&[5, 6, 7, 8]); + let out_2 = client.empty(4); + let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; + + let cache_test_autotune_kernel_1 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); + let cache_test_autotune_kernel_2 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); + client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); + client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); + + let obtained_resource = client.read(&out_2); + + // Cache should be hit, so CacheTestFastOn3 should be used, returning lhs + assert_eq!(obtained_resource.read(), Vec::from([0, 1, 2, 3])); } #[test] #[serial] #[cfg(feature = "std")] fn autotune_cache_miss_test() { - let client = client(&DummyDevice); - - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1, rhs_1, out_1]; - - let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]]; - let lhs_2 = client.create(&[0, 1, 2, 3, 4]); - let rhs_2 = client.create(&[5, 6, 7, 8, 9]); - let out_2 = client.empty(5); - let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; - - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - let cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); - client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); - - let obtained_resource = client.read(&out_2); - - // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs - assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9])); + let client = client(&DummyDevice); + + let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs_1 = client.create(&[0, 1, 2]); + let rhs_1 = client.create(&[4, 4, 4]); + let out_1 = client.empty(3); + let handles_1 = vec![lhs_1, rhs_1, out_1]; + + let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]]; + let lhs_2 = client.create(&[0, 1, 2, 3, 4]); + let rhs_2 = client.create(&[5, 6, 7, 8, 9]); + let out_2 = client.empty(5); + let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; + + let cache_test_autotune_kernel_1 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); + let cache_test_autotune_kernel_2 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); + client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); + client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); + + let obtained_resource = client.read(&out_2); + + // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs + assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9])); } diff --git a/burn-core/src/config.rs b/burn-core/src/config.rs index 94166228da..ba5340a30b 100644 --- a/burn-core/src/config.rs +++ b/burn-core/src/config.rs @@ -4,28 +4,28 @@ pub use burn_derive::Config; /// Configuration IO error. #[derive(Debug)] pub enum ConfigError { - /// Invalid format. - InvalidFormat(String), + /// Invalid format. + InvalidFormat(String), - /// File not found. - FileNotFound(String), + /// File not found. + FileNotFound(String), } impl core::fmt::Display for ConfigError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mut message = "Config error => ".to_string(); + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut message = "Config error => ".to_string(); - match self { - Self::InvalidFormat(err) => { - message += format!("Invalid format: {err}").as_str(); - } - Self::FileNotFound(err) => { - message += format!("File not found: {err}").as_str(); - } - }; + match self { + Self::InvalidFormat(err) => { + message += format!("Invalid format: {err}").as_str(); + } + Self::FileNotFound(err) => { + message += format!("File not found: {err}").as_str(); + } + }; - f.write_str(message.as_str()) - } + f.write_str(message.as_str()) + } } // TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765) @@ -34,51 +34,50 @@ impl std::error::Error for ConfigError {} /// Configuration trait. pub trait Config: serde::Serialize + serde::de::DeserializeOwned { - /// Saves the configuration to a file. - /// - /// # Arguments - /// - /// * `file` - File to save the configuration to. - /// - /// # Returns - /// - /// The output of the save operation. - #[cfg(feature = "std")] - fn save>(&self, file: P) -> std::io::Result<()> { - std::fs::write(file, config_to_json(self)) - } + /// Saves the configuration to a file. + /// + /// # Arguments + /// + /// * `file` - File to save the configuration to. + /// + /// # Returns + /// + /// The output of the save operation. + #[cfg(feature = "std")] + fn save>(&self, file: P) -> std::io::Result<()> { + std::fs::write(file, config_to_json(self)) + } - /// Loads the configuration from a file. - /// - /// # Arguments - /// - /// * `file` - File to load the configuration from. - /// - /// # Returns - /// - /// The loaded configuration. - #[cfg(feature = "std")] - fn load>(file: P) -> Result { - let content = std::fs::read_to_string(file.as_ref()) - .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?; - config_from_str(&content) - } + /// Loads the configuration from a file. + /// + /// # Arguments + /// + /// * `file` - File to load the configuration from. + /// + /// # Returns + /// + /// The loaded configuration. + #[cfg(feature = "std")] + fn load>(file: P) -> Result { + let content = std::fs::read_to_string(file.as_ref()) + .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?; + config_from_str(&content) + } - /// Loads the configuration from a binary buffer. - /// - /// # Arguments - /// - /// * `data` - Binary buffer to load the configuration from. - /// - /// # Returns - /// - /// The loaded configuration. - fn load_binary(data: &[u8]) -> Result { - let content = core::str::from_utf8(data).map_err(|_| { - ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string()) - })?; - config_from_str(content) - } + /// Loads the configuration from a binary buffer. + /// + /// # Arguments + /// + /// * `data` - Binary buffer to load the configuration from. + /// + /// # Returns + /// + /// The loaded configuration. + fn load_binary(data: &[u8]) -> Result { + let content = core::str::from_utf8(data) + .map_err(|_| ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string()))?; + config_from_str(content) + } } /// Converts a configuration to a JSON string. @@ -91,9 +90,9 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned { /// /// The JSON string. pub fn config_to_json(config: &C) -> String { - serde_json::to_string_pretty(config).unwrap() + serde_json::to_string_pretty(config).unwrap() } fn config_from_str(content: &str) -> Result { - serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}"))) + serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}"))) } diff --git a/burn-core/src/data/dataloader/base.rs b/burn-core/src/data/dataloader/base.rs index 0222248ad3..7619f95b47 100644 --- a/burn-core/src/data/dataloader/base.rs +++ b/burn-core/src/data/dataloader/base.rs @@ -4,21 +4,21 @@ use core::iter::Iterator; /// A progress struct that can be used to track the progress of a data loader. #[derive(Clone, Debug)] pub struct Progress { - /// The number of items that have been processed. - pub items_processed: usize, + /// The number of items that have been processed. + pub items_processed: usize, - /// The total number of items that need to be processed. - pub items_total: usize, + /// The total number of items that need to be processed. + pub items_total: usize, } /// A data loader iterator that can be used to iterate over a data loader. pub trait DataLoaderIterator: Iterator { - /// Returns the progress of the data loader. - fn progress(&self) -> Progress; + /// Returns the progress of the data loader. + fn progress(&self) -> Progress; } /// A data loader that can be used to iterate over a dataset. pub trait DataLoader { - /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader. - fn iter<'a>(&'a self) -> Box + 'a>; + /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader. + fn iter<'a>(&'a self) -> Box + 'a>; } diff --git a/burn-core/src/data/dataloader/batch.rs b/burn-core/src/data/dataloader/batch.rs index 978127bad8..34b82c953d 100644 --- a/burn-core/src/data/dataloader/batch.rs +++ b/burn-core/src/data/dataloader/batch.rs @@ -1,254 +1,253 @@ use super::{ - batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, MultiThreadDataLoader, - Progress, + batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, MultiThreadDataLoader, Progress, }; use burn_dataset::{ - transform::{PartialDataset, ShuffledDataset}, - Dataset, + transform::{PartialDataset, ShuffledDataset}, + Dataset, }; use rand::{distributions::Standard, prelude::Distribution, rngs::StdRng, Rng, SeedableRng}; use std::sync::Arc; /// A data loader that can be used to iterate over a dataset in batches. pub struct BatchDataLoader { - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - rng: Option>, + strategy: Box>, + dataset: Arc>, + batcher: Arc>, + rng: Option>, } impl BatchDataLoader { - /// Creates a new batch data loader. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader - /// iterator is created. - /// - /// # Returns - /// - /// The batch data loader. - pub fn new( - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - rng: Option, - ) -> Self { - Self { - strategy, - dataset, - batcher, - rng: rng.map(spin::Mutex::new), - } + /// Creates a new batch data loader. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader + /// iterator is created. + /// + /// # Returns + /// + /// The batch data loader. + pub fn new( + strategy: Box>, + dataset: Arc>, + batcher: Arc>, + rng: Option, + ) -> Self { + Self { + strategy, + dataset, + batcher, + rng: rng.map(spin::Mutex::new), } + } } /// A data loader iterator that can be used to iterate over a data loader. struct BatchDataloaderIterator { - current_index: usize, - strategy: Box>, - dataset: Arc>, - batcher: Arc>, + current_index: usize, + strategy: Box>, + dataset: Arc>, + batcher: Arc>, } impl BatchDataLoader where - I: Send + Sync + Clone + 'static, - O: Send + Sync + Clone + 'static, + I: Send + Sync + Clone + 'static, + O: Send + Sync + Clone + 'static, { - /// Creates a new multi-threaded batch data loader. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// * `num_threads` - The number of threads. - /// - /// # Returns - /// - /// The multi-threaded batch data loader. - pub fn multi_thread( - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - num_threads: usize, - mut rng: Option, - ) -> MultiThreadDataLoader { - let datasets = PartialDataset::split(dataset, num_threads); - - let mut dataloaders: Vec + Send + Sync>> = - Vec::with_capacity(num_threads); - - // Create more rngs from the first one, one for each new dataloader. - let rngs = (0..num_threads).map(|_| { - rng.as_mut() - .map(|rng| StdRng::seed_from_u64(Distribution::sample(&Standard, rng))) - }); - - for (dataset, rng) in datasets.into_iter().zip(rngs) { - let strategy = strategy.new_like(); - let dataloader = - BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone(), rng); - let dataloader = Arc::new(dataloader); - dataloaders.push(dataloader); - } - MultiThreadDataLoader::new(dataloaders) + /// Creates a new multi-threaded batch data loader. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// * `num_threads` - The number of threads. + /// + /// # Returns + /// + /// The multi-threaded batch data loader. + pub fn multi_thread( + strategy: Box>, + dataset: Arc>, + batcher: Arc>, + num_threads: usize, + mut rng: Option, + ) -> MultiThreadDataLoader { + let datasets = PartialDataset::split(dataset, num_threads); + + let mut dataloaders: Vec + Send + Sync>> = + Vec::with_capacity(num_threads); + + // Create more rngs from the first one, one for each new dataloader. + let rngs = (0..num_threads).map(|_| { + rng + .as_mut() + .map(|rng| StdRng::seed_from_u64(Distribution::sample(&Standard, rng))) + }); + + for (dataset, rng) in datasets.into_iter().zip(rngs) { + let strategy = strategy.new_like(); + let dataloader = BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone(), rng); + let dataloader = Arc::new(dataloader); + dataloaders.push(dataloader); } + MultiThreadDataLoader::new(dataloaders) + } } impl DataLoader for BatchDataLoader { - fn iter<'a>(&'a self) -> Box + 'a> { - // When starting a new iteration, we first check if the dataloader was created with an rng, - // implying that we should shuffle the dataset beforehand, while advancing the current - // rng to ensure that each new iteration shuffles the dataset differently. - let dataset = match &self.rng { - Some(rng) => { - let mut rng = rng.lock(); - - Arc::new(ShuffledDataset::with_seed( - self.dataset.clone(), - rng.sample(Standard), - )) - } - None => self.dataset.clone(), - }; - Box::new(BatchDataloaderIterator::new( - self.strategy.new_like(), - dataset, - self.batcher.clone(), + fn iter<'a>(&'a self) -> Box + 'a> { + // When starting a new iteration, we first check if the dataloader was created with an rng, + // implying that we should shuffle the dataset beforehand, while advancing the current + // rng to ensure that each new iteration shuffles the dataset differently. + let dataset = match &self.rng { + Some(rng) => { + let mut rng = rng.lock(); + + Arc::new(ShuffledDataset::with_seed( + self.dataset.clone(), + rng.sample(Standard), )) - } + } + None => self.dataset.clone(), + }; + Box::new(BatchDataloaderIterator::new( + self.strategy.new_like(), + dataset, + self.batcher.clone(), + )) + } } impl BatchDataloaderIterator { - /// Creates a new batch data loader iterator. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The batch data loader iterator. - pub fn new( - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - ) -> Self { - BatchDataloaderIterator { - current_index: 0, - strategy, - dataset, - batcher, - } + /// Creates a new batch data loader iterator. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// + /// # Returns + /// + /// The batch data loader iterator. + pub fn new( + strategy: Box>, + dataset: Arc>, + batcher: Arc>, + ) -> Self { + BatchDataloaderIterator { + current_index: 0, + strategy, + dataset, + batcher, } + } } impl Iterator for BatchDataloaderIterator { - type Item = O; - - fn next(&mut self) -> Option { - while let Some(item) = self.dataset.get(self.current_index) { - self.current_index += 1; - self.strategy.add(item); + type Item = O; - if let Some(items) = self.strategy.batch(false) { - return Some(self.batcher.batch(items)); - } - } + fn next(&mut self) -> Option { + while let Some(item) = self.dataset.get(self.current_index) { + self.current_index += 1; + self.strategy.add(item); - if let Some(items) = self.strategy.batch(true) { - return Some(self.batcher.batch(items)); - } + if let Some(items) = self.strategy.batch(false) { + return Some(self.batcher.batch(items)); + } + } - None + if let Some(items) = self.strategy.batch(true) { + return Some(self.batcher.batch(items)); } + + None + } } impl DataLoaderIterator for BatchDataloaderIterator { - fn progress(&self) -> Progress { - Progress { - items_processed: self.current_index, - items_total: self.dataset.len(), - } + fn progress(&self) -> Progress { + Progress { + items_processed: self.current_index, + items_total: self.dataset.len(), } + } } #[cfg(test)] mod tests { - use std::collections::HashSet; - - use super::*; - use crate::data::dataloader::batcher::TestBatcher; - use crate::data::dataloader::FixBatchStrategy; - use crate::data::dataset::FakeDataset; - - #[test] - fn test_batch_dataloader() { - let batcher = Arc::new(TestBatcher::new()); - let dataset = Arc::new(FakeDataset::::new(27)); - let dataloader = BatchDataLoader::new( - Box::new(FixBatchStrategy::new(5)), - dataset.clone(), - batcher, - None, - ); - - let mut items_dataset = HashSet::new(); - let mut items_dataloader = HashSet::new(); - - for item in dataset.iter() { - items_dataset.insert(item); - } - - for items in dataloader.iter() { - for item in items { - items_dataloader.insert(item); - } - } - - assert_eq!(items_dataset, items_dataloader); + use std::collections::HashSet; + + use super::*; + use crate::data::dataloader::batcher::TestBatcher; + use crate::data::dataloader::FixBatchStrategy; + use crate::data::dataset::FakeDataset; + + #[test] + fn test_batch_dataloader() { + let batcher = Arc::new(TestBatcher::new()); + let dataset = Arc::new(FakeDataset::::new(27)); + let dataloader = BatchDataLoader::new( + Box::new(FixBatchStrategy::new(5)), + dataset.clone(), + batcher, + None, + ); + + let mut items_dataset = HashSet::new(); + let mut items_dataloader = HashSet::new(); + + for item in dataset.iter() { + items_dataset.insert(item); + } + + for items in dataloader.iter() { + for item in items { + items_dataloader.insert(item); + } } - #[test] - fn test_multi_thread_batch_dataloader() { - let batcher = Arc::new(TestBatcher::new()); - let dataset = Arc::new(FakeDataset::::new(27)); - let dataloader_single_thread = BatchDataLoader::new( - Box::new(FixBatchStrategy::new(5)), - dataset.clone(), - batcher.clone(), - None, - ); - let dataloader_multi_thread = BatchDataLoader::multi_thread( - Box::new(FixBatchStrategy::new(5)), - dataset, - batcher, - 4, - None, - ); - - let mut items_single_thread = HashSet::new(); - let mut items_multi_thread = HashSet::new(); - - for items in dataloader_single_thread.iter() { - for item in items { - items_single_thread.insert(item); - } - } - - for items in dataloader_multi_thread.iter() { - for item in items { - items_multi_thread.insert(item); - } - } - - assert_eq!(items_single_thread, items_multi_thread); + assert_eq!(items_dataset, items_dataloader); + } + + #[test] + fn test_multi_thread_batch_dataloader() { + let batcher = Arc::new(TestBatcher::new()); + let dataset = Arc::new(FakeDataset::::new(27)); + let dataloader_single_thread = BatchDataLoader::new( + Box::new(FixBatchStrategy::new(5)), + dataset.clone(), + batcher.clone(), + None, + ); + let dataloader_multi_thread = BatchDataLoader::multi_thread( + Box::new(FixBatchStrategy::new(5)), + dataset, + batcher, + 4, + None, + ); + + let mut items_single_thread = HashSet::new(); + let mut items_multi_thread = HashSet::new(); + + for items in dataloader_single_thread.iter() { + for item in items { + items_single_thread.insert(item); + } } + + for items in dataloader_multi_thread.iter() { + for item in items { + items_multi_thread.insert(item); + } + } + + assert_eq!(items_single_thread, items_multi_thread); + } } diff --git a/burn-core/src/data/dataloader/batcher.rs b/burn-core/src/data/dataloader/batcher.rs index 724a2e3a54..0e52444da1 100644 --- a/burn-core/src/data/dataloader/batcher.rs +++ b/burn-core/src/data/dataloader/batcher.rs @@ -1,15 +1,15 @@ /// A trait for batching items of type `I` into items of type `O`. pub trait Batcher: Send + Sync { - /// Batches the given items. - /// - /// # Arguments - /// - /// * `items` - The items to batch. - /// - /// # Returns - /// - /// The batched items. - fn batch(&self, items: Vec) -> O; + /// Batches the given items. + /// + /// # Arguments + /// + /// * `items` - The items to batch. + /// + /// # Returns + /// + /// The batched items. + fn batch(&self, items: Vec) -> O; } #[cfg(test)] @@ -17,7 +17,7 @@ pub trait Batcher: Send + Sync { pub struct TestBatcher; #[cfg(test)] impl Batcher> for TestBatcher { - fn batch(&self, items: Vec) -> Vec { - items - } + fn batch(&self, items: Vec) -> Vec { + items + } } diff --git a/burn-core/src/data/dataloader/builder.rs b/burn-core/src/data/dataloader/builder.rs index d6227ebc49..8c6d29154b 100644 --- a/burn-core/src/data/dataloader/builder.rs +++ b/burn-core/src/data/dataloader/builder.rs @@ -5,113 +5,113 @@ use std::sync::Arc; /// A builder for data loaders. pub struct DataLoaderBuilder { - strategy: Option>>, - batcher: Arc>, - num_threads: Option, - shuffle: Option, + strategy: Option>>, + batcher: Arc>, + num_threads: Option, + shuffle: Option, } impl DataLoaderBuilder where - I: Send + Sync + Clone + std::fmt::Debug + 'static, - O: Send + Sync + Clone + std::fmt::Debug + 'static, + I: Send + Sync + Clone + std::fmt::Debug + 'static, + O: Send + Sync + Clone + std::fmt::Debug + 'static, { - /// Creates a new data loader builder. - /// - /// # Arguments - /// - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The data loader builder. - pub fn new(batcher: B) -> Self - where - B: Batcher + 'static, - { - Self { - batcher: Arc::new(batcher), - strategy: None, - num_threads: None, - shuffle: None, - } + /// Creates a new data loader builder. + /// + /// # Arguments + /// + /// * `batcher` - The batcher. + /// + /// # Returns + /// + /// The data loader builder. + pub fn new(batcher: B) -> Self + where + B: Batcher + 'static, + { + Self { + batcher: Arc::new(batcher), + strategy: None, + num_threads: None, + shuffle: None, } + } - /// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy) - /// will be used. - /// - /// # Arguments - /// - /// * `batch_size` - The batch size. - /// - /// # Returns - /// - /// The data loader builder. - pub fn batch_size(mut self, batch_size: usize) -> Self { - self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size))); - self - } + /// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy) + /// will be used. + /// + /// # Arguments + /// + /// * `batch_size` - The batch size. + /// + /// # Returns + /// + /// The data loader builder. + pub fn batch_size(mut self, batch_size: usize) -> Self { + self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size))); + self + } - /// Sets the seed for shuffling. - /// - /// Each time the dataloader starts a new iteration, the dataset will be shuffled. - /// - /// # Arguments - /// - /// * `seed` - The seed. - /// - /// # Returns - /// - /// The data loader builder. - pub fn shuffle(mut self, seed: u64) -> Self { - self.shuffle = Some(seed); - self - } + /// Sets the seed for shuffling. + /// + /// Each time the dataloader starts a new iteration, the dataset will be shuffled. + /// + /// # Arguments + /// + /// * `seed` - The seed. + /// + /// # Returns + /// + /// The data loader builder. + pub fn shuffle(mut self, seed: u64) -> Self { + self.shuffle = Some(seed); + self + } - /// Sets the number of workers. - /// - /// # Arguments - /// - /// * `num_workers` - The number of workers. - /// - /// # Returns - /// - /// The data loader builder. - pub fn num_workers(mut self, num_workers: usize) -> Self { - self.num_threads = Some(num_workers); - self - } + /// Sets the number of workers. + /// + /// # Arguments + /// + /// * `num_workers` - The number of workers. + /// + /// # Returns + /// + /// The data loader builder. + pub fn num_workers(mut self, num_workers: usize) -> Self { + self.num_threads = Some(num_workers); + self + } - /// Builds the data loader. - /// - /// # Arguments - /// - /// * `dataset` - The dataset. - /// - /// # Returns - /// - /// The data loader. - pub fn build(self, dataset: D) -> Arc> - where - D: Dataset + 'static, - { - let dataset = Arc::new(dataset); + /// Builds the data loader. + /// + /// # Arguments + /// + /// * `dataset` - The dataset. + /// + /// # Returns + /// + /// The data loader. + pub fn build(self, dataset: D) -> Arc> + where + D: Dataset + 'static, + { + let dataset = Arc::new(dataset); - let rng = self.shuffle.map(StdRng::seed_from_u64); - let strategy = match self.strategy { - Some(strategy) => strategy, - None => Box::new(FixBatchStrategy::new(1)), - }; - if let Some(num_threads) = self.num_threads { - return Arc::new(BatchDataLoader::multi_thread( - strategy, - dataset, - self.batcher, - num_threads, - rng, - )); - } - - Arc::new(BatchDataLoader::new(strategy, dataset, self.batcher, rng)) + let rng = self.shuffle.map(StdRng::seed_from_u64); + let strategy = match self.strategy { + Some(strategy) => strategy, + None => Box::new(FixBatchStrategy::new(1)), + }; + if let Some(num_threads) = self.num_threads { + return Arc::new(BatchDataLoader::multi_thread( + strategy, + dataset, + self.batcher, + num_threads, + rng, + )); } + + Arc::new(BatchDataLoader::new(strategy, dataset, self.batcher, rng)) + } } diff --git a/burn-core/src/data/dataloader/multithread.rs b/burn-core/src/data/dataloader/multithread.rs index 00fabf2957..28bbb5e0b9 100644 --- a/burn-core/src/data/dataloader/multithread.rs +++ b/burn-core/src/data/dataloader/multithread.rs @@ -7,134 +7,134 @@ const MAX_QUEUED_ITEMS: usize = 100; /// A multi-threaded data loader that can be used to iterate over a dataset. pub struct MultiThreadDataLoader { - dataloaders: Vec + Send + Sync>>, + dataloaders: Vec + Send + Sync>>, } /// A message that can be sent between threads. #[derive(Debug)] pub enum Message { - /// A batch of items. - Batch(usize, O, Progress), + /// A batch of items. + Batch(usize, O, Progress), - /// The thread is done. - Done, + /// The thread is done. + Done, } struct MultiThreadsDataloaderIterator { - num_done: usize, - workers: Vec>, - receiver: mpsc::Receiver>, - progresses: HashMap, + num_done: usize, + workers: Vec>, + receiver: mpsc::Receiver>, + progresses: HashMap, } impl MultiThreadDataLoader { - /// Creates a new multi-threaded data loader. - /// - /// # Arguments - /// - /// * `dataloaders` - The data loaders. - /// - /// # Returns - /// - /// The multi-threaded data loader. - pub fn new(dataloaders: Vec + Send + Sync>>) -> Self { - Self { dataloaders } - } + /// Creates a new multi-threaded data loader. + /// + /// # Arguments + /// + /// * `dataloaders` - The data loaders. + /// + /// # Returns + /// + /// The multi-threaded data loader. + pub fn new(dataloaders: Vec + Send + Sync>>) -> Self { + Self { dataloaders } + } } impl DataLoader for MultiThreadDataLoader where - O: Send + 'static + std::fmt::Debug, + O: Send + 'static + std::fmt::Debug, { - fn iter<'a>(&'a self) -> Box + 'a> { - let (sender, receiver) = mpsc::sync_channel::>(MAX_QUEUED_ITEMS); - - let handlers: Vec<_> = self - .dataloaders - .clone() - .into_iter() - .enumerate() - .map(|(index, dataloader)| { - let dataloader_cloned = dataloader; - let sender_cloned = sender.clone(); - - thread::spawn(move || { - let mut iterator = dataloader_cloned.iter(); - while let Some(item) = iterator.next() { - let progress = iterator.progress(); - - match sender_cloned.send(Message::Batch(index, item, progress)) { - Ok(_) => {} - // The receiver is probably gone, no need to panic, just need to stop - // iterating. - Err(_) => return, - }; - } - // Same thing. - sender_cloned.send(Message::Done).ok(); - }) - }) - .collect(); - - Box::new(MultiThreadsDataloaderIterator::new(receiver, handlers)) - } + fn iter<'a>(&'a self) -> Box + 'a> { + let (sender, receiver) = mpsc::sync_channel::>(MAX_QUEUED_ITEMS); + + let handlers: Vec<_> = self + .dataloaders + .clone() + .into_iter() + .enumerate() + .map(|(index, dataloader)| { + let dataloader_cloned = dataloader; + let sender_cloned = sender.clone(); + + thread::spawn(move || { + let mut iterator = dataloader_cloned.iter(); + while let Some(item) = iterator.next() { + let progress = iterator.progress(); + + match sender_cloned.send(Message::Batch(index, item, progress)) { + Ok(_) => {} + // The receiver is probably gone, no need to panic, just need to stop + // iterating. + Err(_) => return, + }; + } + // Same thing. + sender_cloned.send(Message::Done).ok(); + }) + }) + .collect(); + + Box::new(MultiThreadsDataloaderIterator::new(receiver, handlers)) + } } impl MultiThreadsDataloaderIterator { - pub fn new(receiver: mpsc::Receiver>, workers: Vec>) -> Self { - MultiThreadsDataloaderIterator { - num_done: 0, - workers, - receiver, - progresses: HashMap::new(), - } + pub fn new(receiver: mpsc::Receiver>, workers: Vec>) -> Self { + MultiThreadsDataloaderIterator { + num_done: 0, + workers, + receiver, + progresses: HashMap::new(), } + } } impl DataLoaderIterator for MultiThreadsDataloaderIterator { - fn progress(&self) -> Progress { - let mut items_total = 0; - let mut items_processed = 0; + fn progress(&self) -> Progress { + let mut items_total = 0; + let mut items_processed = 0; - for progress in self.progresses.values() { - items_total += progress.items_total; - items_processed += progress.items_processed; - } + for progress in self.progresses.values() { + items_total += progress.items_total; + items_processed += progress.items_processed; + } - Progress { - items_processed, - items_total, - } + Progress { + items_processed, + items_total, } + } } impl Iterator for MultiThreadsDataloaderIterator { - type Item = O; + type Item = O; - fn next(&mut self) -> Option { - if self.workers.is_empty() { - return None; - } + fn next(&mut self) -> Option { + if self.workers.is_empty() { + return None; + } - loop { - let item = self.receiver.recv(); - let item = item.unwrap(); - - match item { - Message::Batch(index, item, progress) => { - self.progresses.insert(index, progress); - return Some(item); - } - Message::Done => { - self.num_done += 1; - } - }; + loop { + let item = self.receiver.recv(); + let item = item.unwrap(); + + match item { + Message::Batch(index, item, progress) => { + self.progresses.insert(index, progress); + return Some(item); + } + Message::Done => { + self.num_done += 1; + } + }; - if self.num_done == self.workers.len() { - while let Some(worker) = self.workers.pop() { - worker.join().unwrap(); - } - return None; - } + if self.num_done == self.workers.len() { + while let Some(worker) = self.workers.pop() { + worker.join().unwrap(); } + return None; + } } + } } diff --git a/burn-core/src/data/dataloader/strategy.rs b/burn-core/src/data/dataloader/strategy.rs index 9e09207edf..f2302947ba 100644 --- a/burn-core/src/data/dataloader/strategy.rs +++ b/burn-core/src/data/dataloader/strategy.rs @@ -1,76 +1,76 @@ /// A strategy to batch items. pub trait BatchStrategy: Send + Sync { - /// Adds an item to the strategy. - /// - /// # Arguments - /// - /// * `item` - The item to add. - fn add(&mut self, item: I); + /// Adds an item to the strategy. + /// + /// # Arguments + /// + /// * `item` - The item to add. + fn add(&mut self, item: I); - /// Batches the items. - /// - /// # Arguments - /// - /// * `force` - Whether to force batching. - /// - /// # Returns - /// - /// The batched items. - fn batch(&mut self, force: bool) -> Option>; + /// Batches the items. + /// + /// # Arguments + /// + /// * `force` - Whether to force batching. + /// + /// # Returns + /// + /// The batched items. + fn batch(&mut self, force: bool) -> Option>; - /// Creates a new strategy of the same type. - /// - /// # Returns - /// - /// The new strategy. - fn new_like(&self) -> Box>; + /// Creates a new strategy of the same type. + /// + /// # Returns + /// + /// The new strategy. + fn new_like(&self) -> Box>; } /// A strategy to batch items with a fixed batch size. pub struct FixBatchStrategy { - items: Vec, - batch_size: usize, + items: Vec, + batch_size: usize, } impl FixBatchStrategy { - /// Creates a new strategy to batch items with a fixed batch size. - /// - /// # Arguments - /// - /// * `batch_size` - The batch size. - /// - /// # Returns - /// - /// The strategy. - pub fn new(batch_size: usize) -> Self { - FixBatchStrategy { - items: Vec::with_capacity(batch_size), - batch_size, - } + /// Creates a new strategy to batch items with a fixed batch size. + /// + /// # Arguments + /// + /// * `batch_size` - The batch size. + /// + /// # Returns + /// + /// The strategy. + pub fn new(batch_size: usize) -> Self { + FixBatchStrategy { + items: Vec::with_capacity(batch_size), + batch_size, } + } } impl BatchStrategy for FixBatchStrategy { - fn add(&mut self, item: I) { - self.items.push(item); - } - - fn batch(&mut self, force: bool) -> Option> { - if self.items.len() < self.batch_size && !force { - return None; - } + fn add(&mut self, item: I) { + self.items.push(item); + } - let mut items = Vec::with_capacity(self.batch_size); - std::mem::swap(&mut items, &mut self.items); + fn batch(&mut self, force: bool) -> Option> { + if self.items.len() < self.batch_size && !force { + return None; + } - if items.is_empty() { - return None; - } + let mut items = Vec::with_capacity(self.batch_size); + std::mem::swap(&mut items, &mut self.items); - Some(items) + if items.is_empty() { + return None; } - fn new_like(&self) -> Box> { - Box::new(Self::new(self.batch_size)) - } + Some(items) + } + + fn new_like(&self) -> Box> { + Box::new(Self::new(self.batch_size)) + } } diff --git a/burn-core/src/data/mod.rs b/burn-core/src/data/mod.rs index 5489cae640..6e81a9a0a7 100644 --- a/burn-core/src/data/mod.rs +++ b/burn-core/src/data/mod.rs @@ -3,5 +3,5 @@ pub mod dataloader; /// Dataset module. pub mod dataset { - pub use burn_dataset::*; + pub use burn_dataset::*; } diff --git a/burn-core/src/grad_clipping/base.rs b/burn-core/src/grad_clipping/base.rs index 91a6be069d..5f38487841 100644 --- a/burn-core/src/grad_clipping/base.rs +++ b/burn-core/src/grad_clipping/base.rs @@ -6,138 +6,138 @@ use burn_tensor::backend::Backend; /// Gradient Clipping provides a way to mitigate exploding gradients #[derive(Config)] pub enum GradientClippingConfig { - /// Clip the gradient by value. - Value(f32), + /// Clip the gradient by value. + Value(f32), - /// Clip the gradient by norm. - Norm(f32), + /// Clip the gradient by norm. + Norm(f32), } impl GradientClippingConfig { - /// Initialize the gradient clipping. - /// - /// # Returns - /// - /// The gradient clipping. - pub fn init(&self) -> GradientClipping { - match self { - GradientClippingConfig::Value(val) => GradientClipping::Value(*val), - GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val), - } + /// Initialize the gradient clipping. + /// + /// # Returns + /// + /// The gradient clipping. + pub fn init(&self) -> GradientClipping { + match self { + GradientClippingConfig::Value(val) => GradientClipping::Value(*val), + GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val), } + } } /// Gradient Clipping provides a way to mitigate exploding gradients /// by clipping every component of the gradient by value or by norm during /// backpropagation. pub enum GradientClipping { - /// Clip the gradient by value. - Value(f32), + /// Clip the gradient by value. + Value(f32), - /// Clip the gradient by norm. - Norm(f32), + /// Clip the gradient by norm. + Norm(f32), } impl GradientClipping { - /// Clip the gradient. - /// - /// # Arguments - /// - /// * `grad` - The gradient to clip. - /// - /// # Returns - /// - /// The clipped gradient. - pub fn clip_gradient(&self, grad: Tensor) -> Tensor { - match self { - GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold), - GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm), - } + /// Clip the gradient. + /// + /// # Arguments + /// + /// * `grad` - The gradient to clip. + /// + /// # Returns + /// + /// The clipped gradient. + pub fn clip_gradient(&self, grad: Tensor) -> Tensor { + match self { + GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold), + GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm), } - - fn clip_by_value( - &self, - grad: Tensor, - threshold: f32, - ) -> Tensor { - let greater_mask = grad.clone().greater_elem(threshold); - let lower_mask = grad.clone().lower_elem(-threshold); - - let clipped_grad = grad.mask_fill(greater_mask, threshold); - - clipped_grad.mask_fill(lower_mask, -threshold) - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - fn clip_by_norm( - &self, - _grad: Tensor, - _threshold: f32, - ) -> Tensor { - todo!("Not yet supported on wasm"); - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn clip_by_norm( - &self, - grad: Tensor, - threshold: f32, - ) -> Tensor { - use burn_tensor::ElementConversion; - - let norm = Self::l2_norm(grad.clone()); - let norm_float = norm.into_scalar().elem::(); - - if norm_float > threshold { - let scale = threshold / norm_float; - grad.mul_scalar(scale) - } else { - grad - } + } + + fn clip_by_value( + &self, + grad: Tensor, + threshold: f32, + ) -> Tensor { + let greater_mask = grad.clone().greater_elem(threshold); + let lower_mask = grad.clone().lower_elem(-threshold); + + let clipped_grad = grad.mask_fill(greater_mask, threshold); + + clipped_grad.mask_fill(lower_mask, -threshold) + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + fn clip_by_norm( + &self, + _grad: Tensor, + _threshold: f32, + ) -> Tensor { + todo!("Not yet supported on wasm"); + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn clip_by_norm( + &self, + grad: Tensor, + threshold: f32, + ) -> Tensor { + use burn_tensor::ElementConversion; + + let norm = Self::l2_norm(grad.clone()); + let norm_float = norm.into_scalar().elem::(); + + if norm_float > threshold { + let scale = threshold / norm_float; + grad.mul_scalar(scale) + } else { + grad } + } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn l2_norm(tensor: Tensor) -> Tensor { - let squared = tensor.powf(2.0); - let sum = squared.sum(); + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn l2_norm(tensor: Tensor) -> Tensor { + let squared = tensor.powf(2.0); + let sum = squared.sum(); - sum.sqrt() - } + sum.sqrt() + } } #[cfg(test)] mod tests { - use super::*; - use crate::tensor::Tensor; - use crate::TestBackend; - - #[test] - fn test_clip_by_value() { - let gradient: Tensor = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]); - - let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient); - let clipped_gradient_data = clipped_gradient.into_data(); - - for value in clipped_gradient_data.value { - assert!(value <= 0.5); - } + use super::*; + use crate::tensor::Tensor; + use crate::TestBackend; + + #[test] + fn test_clip_by_value() { + let gradient: Tensor = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]); + + let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient); + let clipped_gradient_data = clipped_gradient.into_data(); + + for value in clipped_gradient_data.value { + assert!(value <= 0.5); } + } - #[test] - fn test_clip_by_norm() { - let gradient: Tensor = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]); + #[test] + fn test_clip_by_norm() { + let gradient: Tensor = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]); - let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient); - let clipped_gradient_data = clipped_gradient.into_data(); + let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient); + let clipped_gradient_data = clipped_gradient.into_data(); - for value in clipped_gradient_data.value { - assert!(value <= 0.88); - } + for value in clipped_gradient_data.value { + assert!(value <= 0.88); } + } } diff --git a/burn-core/src/lr_scheduler/base.rs b/burn-core/src/lr_scheduler/base.rs index 23c56f8a58..662a590d27 100644 --- a/burn-core/src/lr_scheduler/base.rs +++ b/burn-core/src/lr_scheduler/base.rs @@ -2,16 +2,16 @@ use crate::{record::Record, LearningRate}; /// Learning rate scheduler defines how the learning rate will evolve during training. pub trait LrScheduler: Send + Sync { - /// Scheduler associative type to be used when saving and loading the state. - type Record: Record; + /// Scheduler associative type to be used when saving and loading the state. + type Record: Record; - /// Perform the scheduler step, potentially updating its state, and returning the effective - /// learning rate. - fn step(&mut self) -> LearningRate; + /// Perform the scheduler step, potentially updating its state, and returning the effective + /// learning rate. + fn step(&mut self) -> LearningRate; - /// Get the current state of the scheduler as a [record](Record). - fn to_record(&self) -> Self::Record; + /// Get the current state of the scheduler as a [record](Record). + fn to_record(&self) -> Self::Record; - /// Load the state of the scheduler as a [record](Record). - fn load_record(self, record: Self::Record) -> Self; + /// Load the state of the scheduler as a [record](Record). + fn load_record(self, record: Self::Record) -> Self; } diff --git a/burn-core/src/lr_scheduler/constant.rs b/burn-core/src/lr_scheduler/constant.rs index eb41f1c108..36820bf33b 100644 --- a/burn-core/src/lr_scheduler/constant.rs +++ b/burn-core/src/lr_scheduler/constant.rs @@ -8,39 +8,39 @@ use crate::LearningRate; /// You can also use [learning rate](LearningRate) which the same effect. #[derive(new, Clone, Debug)] pub struct ConstantLr { - lr: LearningRate, + lr: LearningRate, } impl From for ConstantLr { - fn from(lr: LearningRate) -> Self { - Self { lr } - } + fn from(lr: LearningRate) -> Self { + Self { lr } + } } impl LrScheduler for ConstantLr { - type Record = (); + type Record = (); - fn step(&mut self) -> LearningRate { - self.lr - } + fn step(&mut self) -> LearningRate { + self.lr + } - fn to_record(&self) -> Self::Record {} + fn to_record(&self) -> Self::Record {} - fn load_record(self, _record: Self::Record) -> Self { - self - } + fn load_record(self, _record: Self::Record) -> Self { + self + } } impl LrScheduler for LearningRate { - type Record = (); + type Record = (); - fn step(&mut self) -> LearningRate { - *self - } + fn step(&mut self) -> LearningRate { + *self + } - fn to_record(&self) -> Self::Record {} + fn to_record(&self) -> Self::Record {} - fn load_record(self, _record: Self::Record) -> Self { - self - } + fn load_record(self, _record: Self::Record) -> Self { + self + } } diff --git a/burn-core/src/lr_scheduler/noam.rs b/burn-core/src/lr_scheduler/noam.rs index 622ee5c5b9..2cb415c535 100644 --- a/burn-core/src/lr_scheduler/noam.rs +++ b/burn-core/src/lr_scheduler/noam.rs @@ -6,87 +6,87 @@ use crate::{config::Config, LearningRate}; /// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler. #[derive(Config)] pub struct NoamLrSchedulerConfig { - /// The initial learning rate. - init_lr: LearningRate, - /// The number of steps before the exponential decay stats. - #[config(default = 4000)] - warmup_steps: usize, - /// The size of the model. - #[config(default = 512)] - model_size: usize, + /// The initial learning rate. + init_lr: LearningRate, + /// The number of steps before the exponential decay stats. + #[config(default = 4000)] + warmup_steps: usize, + /// The size of the model. + #[config(default = 512)] + model_size: usize, } /// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762). #[derive(Clone, Debug)] pub struct NoamLrScheduler { - warmup_steps: f64, - embedding_size: f64, - init_lr: LearningRate, - step: f64, + warmup_steps: f64, + embedding_size: f64, + init_lr: LearningRate, + step: f64, } impl NoamLrSchedulerConfig { - /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler. - pub fn init(&self) -> NoamLrScheduler { - NoamLrScheduler { - warmup_steps: self.warmup_steps as f64, - embedding_size: self.model_size as f64, - init_lr: self.init_lr, - step: 0.0, - } + /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler. + pub fn init(&self) -> NoamLrScheduler { + NoamLrScheduler { + warmup_steps: self.warmup_steps as f64, + embedding_size: self.model_size as f64, + init_lr: self.init_lr, + step: 0.0, } + } } impl LrScheduler for NoamLrScheduler { - type Record = usize; + type Record = usize; - fn step(&mut self) -> LearningRate { - self.step += 1.0; + fn step(&mut self) -> LearningRate { + self.step += 1.0; - let arg1 = self.step.powf(-0.5); - let arg2 = self.step * self.warmup_steps.powf(-1.5); + let arg1 = self.step.powf(-0.5); + let arg2 = self.step * self.warmup_steps.powf(-1.5); - self.init_lr * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2) - } + self.init_lr * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2) + } - fn to_record(&self) -> Self::Record { - self.step as usize - } + fn to_record(&self) -> Self::Record { + self.step as usize + } - fn load_record(mut self, record: Self::Record) -> Self { - self.step = record as f64; - self - } + fn load_record(mut self, record: Self::Record) -> Self { + self.step = record as f64; + self + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_function_increase_and_decrease() { - let warmup_steps = 100; - let mut scheduler = NoamLrSchedulerConfig::new(10.0) - .with_warmup_steps(warmup_steps) - .init(); - let mut lr_current = 0.0; + #[test] + fn test_function_increase_and_decrease() { + let warmup_steps = 100; + let mut scheduler = NoamLrSchedulerConfig::new(10.0) + .with_warmup_steps(warmup_steps) + .init(); + let mut lr_current = 0.0; - for _ in 0..warmup_steps { - let lr = scheduler.step(); - assert!( - lr > lr_current, - "Learning rate should increase before the warmup_steps is reached." - ); - lr_current = lr; - } + for _ in 0..warmup_steps { + let lr = scheduler.step(); + assert!( + lr > lr_current, + "Learning rate should increase before the warmup_steps is reached." + ); + lr_current = lr; + } - for _ in 0..warmup_steps { - let lr = scheduler.step(); - assert!( - lr < lr_current, - "Learning rate should decrease after the warmup_steps is reached." - ); - lr_current = lr; - } + for _ in 0..warmup_steps { + let lr = scheduler.step(); + assert!( + lr < lr_current, + "Learning rate should decrease after the warmup_steps is reached." + ); + lr_current = lr; } + } } diff --git a/burn-core/src/module/base.rs b/burn-core/src/module/base.rs index a54f00624d..f97f31692d 100644 --- a/burn-core/src/module/base.rs +++ b/burn-core/src/module/base.rs @@ -2,8 +2,8 @@ use alloc::vec::Vec; use super::ParamId; use crate::{ - record::Record, - tensor::backend::{AutodiffBackend, Backend}, + record::Record, + tensor::backend::{AutodiffBackend, Backend}, }; pub use burn_derive::Module; use burn_tensor::Tensor; @@ -11,54 +11,54 @@ use burn_tensor::Tensor; // At the moment, our plan is to continue experimenting with the macro internally and monitor its development. // We may consider making it public in the future. macro_rules! module { - (map=$module:ident, ops=$item:expr) => {{ - struct Mapper; - impl ModuleMapper for Mapper { - fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { - let func = $item; - func(tensor) - } - } - let mut mapper = Mapper; - $module.map(&mut mapper) - }}; - (map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{ - struct Mapper<'a, B: Backend> { - capture: &'a $ty, - backend: core::marker::PhantomData, - } - impl<'a, B: Backend> ModuleMapper for Mapper<'a, B> { - fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { - let func = $item; - func(tensor, self.capture) - } - } - let mut mapper = Mapper { - capture: $capture, - backend: core::marker::PhantomData, - }; - $module.map(&mut mapper) - }}; - (visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ - struct Visitor<'a, B: Backend> { - state: &'a mut $state_ty, - backend: core::marker::PhantomData, - } - impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { - fn visit(&mut self, _id: &ParamId, tensor: &Tensor) { - let func = $item; - func(tensor, &mut self.state) - } - } - #[allow(clippy::redundant_closure_call)] - let mut state = $init(); - let mut visitor = Visitor { - state: &mut state, - backend: core::marker::PhantomData, - }; - $module.visit(&mut visitor); - state - }}; + (map=$module:ident, ops=$item:expr) => {{ + struct Mapper; + impl ModuleMapper for Mapper { + fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + let func = $item; + func(tensor) + } + } + let mut mapper = Mapper; + $module.map(&mut mapper) + }}; + (map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{ + struct Mapper<'a, B: Backend> { + capture: &'a $ty, + backend: core::marker::PhantomData, + } + impl<'a, B: Backend> ModuleMapper for Mapper<'a, B> { + fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + let func = $item; + func(tensor, self.capture) + } + } + let mut mapper = Mapper { + capture: $capture, + backend: core::marker::PhantomData, + }; + $module.map(&mut mapper) + }}; + (visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ + struct Visitor<'a, B: Backend> { + state: &'a mut $state_ty, + backend: core::marker::PhantomData, + } + impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { + fn visit(&mut self, _id: &ParamId, tensor: &Tensor) { + let func = $item; + func(tensor, &mut self.state) + } + } + #[allow(clippy::redundant_closure_call)] + let mut state = $init(); + let mut visitor = Visitor { + state: &mut state, + backend: core::marker::PhantomData, + }; + $module.visit(&mut visitor); + state + }}; } /// Trait for all neural network modules. @@ -91,163 +91,163 @@ macro_rules! module { /// } /// ``` pub trait Module: Clone + Send + Sync + core::fmt::Debug { - /// Type to save and load the module. - type Record: Record; - - /// Get the device list of the module and all of its sub-modules. - fn devices(&self) -> Vec { - module!( - visit = self, - ops = |tensor: &Tensor, state: &mut Vec| { - let device = tensor.device(); - if !state.contains(&device) { - state.push(device); - } - }, - state = Vec, - init = Vec::new - ) - } - - /// Fork the module and all of its sub-modules to the given device. - /// - /// # Notes - /// - /// This is similar to [to_device](Module::to_device), but it ensures the module will - /// have its own autodiff graph. - fn fork(self, device: &B::Device) -> Self { - module!( - map = self, - ops = |tensor: Tensor, device: &B::Device| { - let is_require_grad = tensor.is_require_grad(); - let mut tensor = tensor.to_device(device).detach(); - - if is_require_grad { - tensor = tensor.require_grad(); - } - - tensor - }, - capture = { device: B::Device } - ) - } - - /// Move the module and all of its sub-modules to the given device. - /// - /// # Warnings - /// - /// The device operations will be registered in the autodiff graph. Therefore, be sure to call - /// backward only one time even if you have the same module on multiple devices. If you want to - /// call backward multiple times, look into using [fork](Module::fork) instead. - fn to_device(self, device: &B::Device) -> Self { - module!( - map = self, - ops = |tensor: Tensor, device: &B::Device| tensor.to_device(device), - capture = { device: B::Device } - ) - } - - /// Each tensor in the module tree will not require grad. - /// - /// # Warnings - /// - /// This should not be used for inference, use [valid](AutodiffModule::valid) when using - /// AD modules. This is mostly useful when performing partial finetuning, which is updating only - /// a small fraction of the parameters instead of finetuning all of them. - fn no_grad(self) -> Self { - module!( - map = self, - ops = |tensor: Tensor| tensor.set_require_grad(false) - ) - } - - /// Get the number of parameters the module has, including all of its sub-modules. - fn num_params(&self) -> usize { - module!( - visit = self, - ops = |tensor: &Tensor, state: &mut usize| { - *state += tensor.shape().num_elements(); - }, - state = usize, - init = || 0 - ) - } - /// Visit each tensor in the module with a [visitor](ModuleVisitor). - fn visit>(&self, visitor: &mut V); - - /// Map each tensor in the module with a [mapper](ModuleMapper). - fn map>(self, mapper: &mut M) -> Self; - - /// Load the module state from a record. - fn load_record(self, record: Self::Record) -> Self; - - /// Convert the module into a record containing the state. - fn into_record(self) -> Self::Record; - - #[cfg(feature = "std")] - /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder). - /// - /// List of supported file recorders: - /// - /// * [default](crate::record::DefaultFileRecorder) - /// * [bincode](crate::record::BinFileRecorder) - /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder) - /// * [json pretty](crate::record::PrettyJsonFileRecorder) - /// * [json compressed with gzip](crate::record::JsonGzFileRecorder) - /// * [named mpk](crate::record::NamedMpkFileRecorder) - /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder) - /// - /// ## Notes - /// - /// The file extension is automatically added depending on the file recorder provided, you - /// don't have to specify it. - fn save_file>( - self, - file_path: PB, - recorder: &FR, - ) -> Result<(), crate::record::RecorderError> { - let record = Self::into_record(self); - recorder.record(record, file_path.into()) - } + /// Type to save and load the module. + type Record: Record; + + /// Get the device list of the module and all of its sub-modules. + fn devices(&self) -> Vec { + module!( + visit = self, + ops = |tensor: &Tensor, state: &mut Vec| { + let device = tensor.device(); + if !state.contains(&device) { + state.push(device); + } + }, + state = Vec, + init = Vec::new + ) + } + + /// Fork the module and all of its sub-modules to the given device. + /// + /// # Notes + /// + /// This is similar to [to_device](Module::to_device), but it ensures the module will + /// have its own autodiff graph. + fn fork(self, device: &B::Device) -> Self { + module!( + map = self, + ops = |tensor: Tensor, device: &B::Device| { + let is_require_grad = tensor.is_require_grad(); + let mut tensor = tensor.to_device(device).detach(); + + if is_require_grad { + tensor = tensor.require_grad(); + } - #[cfg(feature = "std")] - /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder). - /// - /// The recorder should be the same as the one used to save the module, see - /// [save_file](Self::save_file). - /// - /// ## Notes - /// - /// The file extension is automatically added depending on the file recorder provided, you - /// don't have to specify it. - fn load_file>( - self, - file_path: PB, - recorder: &FR, - ) -> Result { - let record = recorder.load(file_path.into())?; - - Ok(self.load_record(record)) - } + tensor + }, + capture = { device: B::Device } + ) + } + + /// Move the module and all of its sub-modules to the given device. + /// + /// # Warnings + /// + /// The device operations will be registered in the autodiff graph. Therefore, be sure to call + /// backward only one time even if you have the same module on multiple devices. If you want to + /// call backward multiple times, look into using [fork](Module::fork) instead. + fn to_device(self, device: &B::Device) -> Self { + module!( + map = self, + ops = |tensor: Tensor, device: &B::Device| tensor.to_device(device), + capture = { device: B::Device } + ) + } + + /// Each tensor in the module tree will not require grad. + /// + /// # Warnings + /// + /// This should not be used for inference, use [valid](AutodiffModule::valid) when using + /// AD modules. This is mostly useful when performing partial finetuning, which is updating only + /// a small fraction of the parameters instead of finetuning all of them. + fn no_grad(self) -> Self { + module!( + map = self, + ops = |tensor: Tensor| tensor.set_require_grad(false) + ) + } + + /// Get the number of parameters the module has, including all of its sub-modules. + fn num_params(&self) -> usize { + module!( + visit = self, + ops = |tensor: &Tensor, state: &mut usize| { + *state += tensor.shape().num_elements(); + }, + state = usize, + init = || 0 + ) + } + /// Visit each tensor in the module with a [visitor](ModuleVisitor). + fn visit>(&self, visitor: &mut V); + + /// Map each tensor in the module with a [mapper](ModuleMapper). + fn map>(self, mapper: &mut M) -> Self; + + /// Load the module state from a record. + fn load_record(self, record: Self::Record) -> Self; + + /// Convert the module into a record containing the state. + fn into_record(self) -> Self::Record; + + #[cfg(feature = "std")] + /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder). + /// + /// List of supported file recorders: + /// + /// * [default](crate::record::DefaultFileRecorder) + /// * [bincode](crate::record::BinFileRecorder) + /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder) + /// * [json pretty](crate::record::PrettyJsonFileRecorder) + /// * [json compressed with gzip](crate::record::JsonGzFileRecorder) + /// * [named mpk](crate::record::NamedMpkFileRecorder) + /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder) + /// + /// ## Notes + /// + /// The file extension is automatically added depending on the file recorder provided, you + /// don't have to specify it. + fn save_file>( + self, + file_path: PB, + recorder: &FR, + ) -> Result<(), crate::record::RecorderError> { + let record = Self::into_record(self); + recorder.record(record, file_path.into()) + } + + #[cfg(feature = "std")] + /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder). + /// + /// The recorder should be the same as the one used to save the module, see + /// [save_file](Self::save_file). + /// + /// ## Notes + /// + /// The file extension is automatically added depending on the file recorder provided, you + /// don't have to specify it. + fn load_file>( + self, + file_path: PB, + recorder: &FR, + ) -> Result { + let record = recorder.load(file_path.into())?; + + Ok(self.load_record(record)) + } } /// Module visitor trait. pub trait ModuleVisitor { - /// Visit a tensor in the module. - fn visit(&mut self, id: &ParamId, tensor: &Tensor); + /// Visit a tensor in the module. + fn visit(&mut self, id: &ParamId, tensor: &Tensor); } /// Module mapper trait. pub trait ModuleMapper { - /// Map a tensor in the module. - fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor; + /// Map a tensor in the module. + fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor; } /// Module with auto-differentiation backend. pub trait AutodiffModule: Module + Send + Sync + core::fmt::Debug { - /// Inner module without auto-differentiation. - type InnerModule: Module; + /// Inner module without auto-differentiation. + type InnerModule: Module; - /// Get the same module, but on the inner backend without auto-differentiation. - fn valid(&self) -> Self::InnerModule; + /// Get the same module, but on the inner backend without auto-differentiation. + fn valid(&self) -> Self::InnerModule; } diff --git a/burn-core/src/module/param/base.rs b/burn-core/src/module/param/base.rs index 72174cba70..0e76a6ed50 100644 --- a/burn-core/src/module/param/base.rs +++ b/burn-core/src/module/param/base.rs @@ -4,37 +4,37 @@ use alloc::format; /// Define a parameter. #[derive(new, Debug, Clone)] pub struct Param { - pub(crate) id: ParamId, - pub(crate) value: T, + pub(crate) id: ParamId, + pub(crate) value: T, } impl core::fmt::Display for Param { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("Param: {}", self.id).as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(format!("Param: {}", self.id).as_str()) + } } impl Param { - /// Gets the parameter value. - /// - /// # Returns - /// - /// The parameter value. - pub fn val(&self) -> T { - self.value.clone() - } + /// Gets the parameter value. + /// + /// # Returns + /// + /// The parameter value. + pub fn val(&self) -> T { + self.value.clone() + } - /// Execute the given function on the inner value. - pub fn map T>(mut self, func: F) -> Self { - self.value = func(self.value); - self - } + /// Execute the given function on the inner value. + pub fn map T>(mut self, func: F) -> Self { + self.value = func(self.value); + self + } } impl core::ops::Deref for Param { - type Target = T; + type Target = T; - fn deref(&self) -> &Self::Target { - &self.value - } + fn deref(&self) -> &Self::Target { + &self.value + } } diff --git a/burn-core/src/module/param/constant.rs b/burn-core/src/module/param/constant.rs index 9c33d1409a..67f6e2bb52 100644 --- a/burn-core/src/module/param/constant.rs +++ b/burn-core/src/module/param/constant.rs @@ -1,14 +1,14 @@ use core::marker::PhantomData; use crate::{ - self as burn, - module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}, - record::Record, + self as burn, + module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}, + record::Record, }; use burn::record::PrecisionSettings; use burn_tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, + backend::{AutodiffBackend, Backend}, + Tensor, }; use super::ParamId; @@ -18,76 +18,76 @@ use super::ParamId; pub struct ConstantRecord; impl serde::Serialize for ConstantRecord { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - // nothing to serialize - S::serialize_none(serializer) - } + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + // nothing to serialize + S::serialize_none(serializer) + } } impl<'de> serde::Deserialize<'de> for ConstantRecord { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_option(serde::de::IgnoredAny).ok(); - Ok(ConstantRecord::new()) - } + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_option(serde::de::IgnoredAny).ok(); + Ok(ConstantRecord::new()) + } } impl Record for ConstantRecord { - type Item = ConstantRecord; + type Item = ConstantRecord; - fn into_item(self) -> Self::Item { - self - } + fn into_item(self) -> Self::Item { + self + } - fn from_item(item: Self::Item) -> Self { - item - } + fn from_item(item: Self::Item) -> Self { + item + } } /// Constant macro. #[macro_export] macro_rules! constant { - (module) => { - type Record = burn::module::ConstantRecord; - - fn visit>(&self, _visitor: &mut V) { - // Nothing to do - } - - fn map>(self, _mapper: &mut M) -> Self { - self - } - - fn load_record(self, _record: Self::Record) -> Self { - self - } - - fn into_record(self) -> Self::Record { - burn::module::ConstantRecord::new() - } - }; - - (ad_module, $type:ty) => { - type InnerModule = $type; - - fn valid(&self) -> Self::InnerModule { - self.clone() - } - }; - - ($type:ty) => { - impl burn::module::Module for $type { - constant!(module); - } - - impl burn::module::AutodiffModule for $type { - constant!(ad_module, $type); - } - }; + (module) => { + type Record = burn::module::ConstantRecord; + + fn visit>(&self, _visitor: &mut V) { + // Nothing to do + } + + fn map>(self, _mapper: &mut M) -> Self { + self + } + + fn load_record(self, _record: Self::Record) -> Self { + self + } + + fn into_record(self) -> Self::Record { + burn::module::ConstantRecord::new() + } + }; + + (ad_module, $type:ty) => { + type InnerModule = $type; + + fn valid(&self) -> Self::InnerModule { + self.clone() + } + }; + + ($type:ty) => { + impl burn::module::Module for $type { + constant!(module); + } + + impl burn::module::AutodiffModule for $type { + constant!(ad_module, $type); + } + }; } // General Types @@ -114,121 +114,121 @@ constant!(i16); constant!(i8); impl Module for Tensor { - type Record = ConstantRecord; - - fn visit>(&self, visitor: &mut V) { - // Important: - // We need to implement visit method for Tensor Module because - // to_device will be called during the visit method of the ModuleVisitor - - // We are using a dummy param id because the visit method requires a param id - let dummy_param_id = ParamId::new(); - visitor.visit(&dummy_param_id, self) - } - - fn map>(self, mapper: &mut M) -> Self { - // Important: - // We need to implement visit method for Tensor Module because - // to_device will be called during the visit method of the ModuleVisitor - - // We are using a dummy param id because the visit method requires a param id - let dummy_param_id = ParamId::new(); - mapper.map(&dummy_param_id, self) - } - - fn into_record(self) -> Self::Record { - ConstantRecord - } - - fn load_record(self, _record: Self::Record) -> Self { - self - } + type Record = ConstantRecord; + + fn visit>(&self, visitor: &mut V) { + // Important: + // We need to implement visit method for Tensor Module because + // to_device will be called during the visit method of the ModuleVisitor + + // We are using a dummy param id because the visit method requires a param id + let dummy_param_id = ParamId::new(); + visitor.visit(&dummy_param_id, self) + } + + fn map>(self, mapper: &mut M) -> Self { + // Important: + // We need to implement visit method for Tensor Module because + // to_device will be called during the visit method of the ModuleVisitor + + // We are using a dummy param id because the visit method requires a param id + let dummy_param_id = ParamId::new(); + mapper.map(&dummy_param_id, self) + } + + fn into_record(self) -> Self::Record { + ConstantRecord + } + + fn load_record(self, _record: Self::Record) -> Self { + self + } } impl AutodiffModule for Tensor { - type InnerModule = Tensor; + type InnerModule = Tensor; - fn valid(&self) -> Self::InnerModule { - self.clone().inner() - } + fn valid(&self) -> Self::InnerModule { + self.clone().inner() + } } impl Module for PhantomData { - type Record = ConstantRecord; + type Record = ConstantRecord; - fn visit>(&self, _visitor: &mut V) { - // Nothing to do - } + fn visit>(&self, _visitor: &mut V) { + // Nothing to do + } - fn map>(self, _mapper: &mut M) -> Self { - self - } + fn map>(self, _mapper: &mut M) -> Self { + self + } - fn load_record(self, _record: Self::Record) -> Self { - self - } + fn load_record(self, _record: Self::Record) -> Self { + self + } - fn into_record(self) -> Self::Record { - ConstantRecord::new() - } + fn into_record(self) -> Self::Record { + ConstantRecord::new() + } } impl AutodiffModule for PhantomData { - type InnerModule = PhantomData; + type InnerModule = PhantomData; - fn valid(&self) -> Self::InnerModule { - PhantomData - } + fn valid(&self) -> Self::InnerModule { + PhantomData + } } #[cfg(all(test, feature = "std"))] mod tests { - use core::marker::PhantomData; + use core::marker::PhantomData; - use burn_tensor::backend::Backend; - use burn_tensor::Tensor; + use burn_tensor::backend::Backend; + use burn_tensor::Tensor; - use crate::TestBackend; - use crate::{ - record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, - TestAutodiffBackend, - }; - use burn::module::Module; + use crate::TestBackend; + use crate::{ + record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, + TestAutodiffBackend, + }; + use burn::module::Module; - use crate as burn; + use crate as burn; - #[test] - fn tensor_load_record_setting() { - let tensor = Tensor::::ones([3, 3]); + #[test] + fn tensor_load_record_setting() { + let tensor = Tensor::::ones([3, 3]); - let byte_recorder = BinBytesRecorder::::default(); - let bytes = byte_recorder - .record(tensor.clone().into_record(), ()) - .unwrap(); + let byte_recorder = BinBytesRecorder::::default(); + let bytes = byte_recorder + .record(tensor.clone().into_record(), ()) + .unwrap(); - let no_grad_is_require_grad = tensor - .clone() - .no_grad() - .load_record(byte_recorder.load(bytes.clone()).unwrap()) - .is_require_grad(); + let no_grad_is_require_grad = tensor + .clone() + .no_grad() + .load_record(byte_recorder.load(bytes.clone()).unwrap()) + .is_require_grad(); - let with_default_is_require_grad = tensor - .load_record(byte_recorder.load(bytes).unwrap()) - .is_require_grad(); + let with_default_is_require_grad = tensor + .load_record(byte_recorder.load(bytes).unwrap()) + .is_require_grad(); - assert!(!no_grad_is_require_grad); - assert!(!with_default_is_require_grad); - } + assert!(!no_grad_is_require_grad); + assert!(!with_default_is_require_grad); + } - #[test] - fn empty_module_with_phantom() { - #[derive(Module, Debug, new)] - struct EmptyModule { - _phantom: PhantomData, - } + #[test] + fn empty_module_with_phantom() { + #[derive(Module, Debug, new)] + struct EmptyModule { + _phantom: PhantomData, + } - let _module = EmptyModule::::new(); + let _module = EmptyModule::::new(); - assert_eq!(core::mem::size_of::>(), 0); - } + assert_eq!(core::mem::size_of::>(), 0); + } } diff --git a/burn-core/src/module/param/id.rs b/burn-core/src/module/param/id.rs index 2828cf38c7..8ef7607013 100644 --- a/burn-core/src/module/param/id.rs +++ b/burn-core/src/module/param/id.rs @@ -4,45 +4,45 @@ use burn_common::id::IdGenerator; /// Parameter ID. #[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct ParamId { - value: String, + value: String, } impl From<&str> for ParamId { - fn from(val: &str) -> Self { - Self { - value: val.to_string(), - } + fn from(val: &str) -> Self { + Self { + value: val.to_string(), } + } } impl From for ParamId { - fn from(value: String) -> Self { - Self { value } - } + fn from(value: String) -> Self { + Self { value } + } } impl Default for ParamId { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl ParamId { - /// Create a new parameter ID. - pub fn new() -> Self { - Self { - value: IdGenerator::generate(), - } + /// Create a new parameter ID. + pub fn new() -> Self { + Self { + value: IdGenerator::generate(), } + } - /// Convert the parameter ID into a string. - pub fn into_string(self) -> String { - self.value - } + /// Convert the parameter ID into a string. + pub fn into_string(self) -> String { + self.value + } } impl core::fmt::Display for ParamId { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(self.value.as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.value.as_str()) + } } diff --git a/burn-core/src/module/param/primitive.rs b/burn-core/src/module/param/primitive.rs index dfa23e2d85..e306d17441 100644 --- a/burn-core/src/module/param/primitive.rs +++ b/burn-core/src/module/param/primitive.rs @@ -5,153 +5,156 @@ use core::fmt::Debug; impl Module for Option where - T: Module + Debug + Send + Sync + Clone, - B: Backend, + T: Module + Debug + Send + Sync + Clone, + B: Backend, { - type Record = Option; + type Record = Option; - fn visit>(&self, visitor: &mut V) { - if let Some(module) = self { - module.visit(visitor) - } + fn visit>(&self, visitor: &mut V) { + if let Some(module) = self { + module.visit(visitor) } + } - fn map>(self, mapper: &mut M) -> Self { - self.map(|module| module.map(mapper)) - } + fn map>(self, mapper: &mut M) -> Self { + self.map(|module| module.map(mapper)) + } - fn load_record(self, record: Self::Record) -> Self { - self.zip(record) - .map(|(module, record)| module.load_record(record)) - } + fn load_record(self, record: Self::Record) -> Self { + self + .zip(record) + .map(|(module, record)| module.load_record(record)) + } - fn into_record(self) -> Self::Record { - self.map(Module::into_record) - } + fn into_record(self) -> Self::Record { + self.map(Module::into_record) + } } impl AutodiffModule for Option where - T: AutodiffModule + Debug + Send + Sync + Clone, - B: AutodiffBackend, + T: AutodiffModule + Debug + Send + Sync + Clone, + B: AutodiffBackend, { - type InnerModule = Option; + type InnerModule = Option; - fn valid(&self) -> Self::InnerModule { - self.as_ref().map(|module| module.valid()) - } + fn valid(&self) -> Self::InnerModule { + self.as_ref().map(|module| module.valid()) + } } impl Module for Vec where - T: Module + Debug + Send + Sync + Clone, - B: Backend, + T: Module + Debug + Send + Sync + Clone, + B: Backend, { - type Record = Vec; - - fn num_params(&self) -> usize { - let mut num_params = 0; - for module in self.iter() { - num_params += module.num_params(); - } - - num_params - } - - fn visit>(&self, visitor: &mut V) { - self.iter().for_each(|module| { - module.visit(visitor); - }); - } - - fn map>(self, mapper: &mut M) -> Self { - self.into_iter().map(|module| module.map(mapper)).collect() - } - - fn into_record(self) -> Self::Record { - self.into_iter().map(Module::into_record).collect() - } - - fn load_record(self, record: Self::Record) -> Self { - self.into_iter() - .zip(record) - .map(|(module, record)| module.load_record(record)) - .collect() - } + type Record = Vec; + + fn num_params(&self) -> usize { + let mut num_params = 0; + for module in self.iter() { + num_params += module.num_params(); + } + + num_params + } + + fn visit>(&self, visitor: &mut V) { + self.iter().for_each(|module| { + module.visit(visitor); + }); + } + + fn map>(self, mapper: &mut M) -> Self { + self.into_iter().map(|module| module.map(mapper)).collect() + } + + fn into_record(self) -> Self::Record { + self.into_iter().map(Module::into_record).collect() + } + + fn load_record(self, record: Self::Record) -> Self { + self + .into_iter() + .zip(record) + .map(|(module, record)| module.load_record(record)) + .collect() + } } impl AutodiffModule for Vec where - T: AutodiffModule + Debug + Send + Sync + Clone, - B: AutodiffBackend, + T: AutodiffModule + Debug + Send + Sync + Clone, + B: AutodiffBackend, { - type InnerModule = Vec; + type InnerModule = Vec; - fn valid(&self) -> Self::InnerModule { - self.iter().map(|module| module.valid()).collect() - } + fn valid(&self) -> Self::InnerModule { + self.iter().map(|module| module.valid()).collect() + } } impl Module for [T; N] where - T: Module + Debug + Send + Sync + Clone + Copy, - T::Record: Debug, - B: Backend, + T: Module + Debug + Send + Sync + Clone + Copy, + T::Record: Debug, + B: Backend, { - type Record = [T::Record; N]; - - fn devices(&self) -> Vec<::Device> { - let mut devices = Vec::new(); - for module in self.iter() { - devices.append(&mut module.devices()); - } - devices - } - - fn num_params(&self) -> usize { - let mut num_params = 0; - for module in self.iter() { - num_params += module.num_params(); - } - - num_params - } - - fn visit>(&self, visitor: &mut V) { - self.iter().for_each(|module| { - module.visit(visitor); - }); - } - - fn map>(self, mapper: &mut M) -> Self { - self.map(|module| module.map(mapper)) - } - - fn load_record(self, record: Self::Record) -> Self { - self.into_iter() - .zip(record) - .map(|(module, record)| module.load_record(record)) - .collect::>() - .try_into() - .unwrap() - } - - fn into_record(self) -> Self::Record { - self.map(Module::into_record) - } + type Record = [T::Record; N]; + + fn devices(&self) -> Vec<::Device> { + let mut devices = Vec::new(); + for module in self.iter() { + devices.append(&mut module.devices()); + } + devices + } + + fn num_params(&self) -> usize { + let mut num_params = 0; + for module in self.iter() { + num_params += module.num_params(); + } + + num_params + } + + fn visit>(&self, visitor: &mut V) { + self.iter().for_each(|module| { + module.visit(visitor); + }); + } + + fn map>(self, mapper: &mut M) -> Self { + self.map(|module| module.map(mapper)) + } + + fn load_record(self, record: Self::Record) -> Self { + self + .into_iter() + .zip(record) + .map(|(module, record)| module.load_record(record)) + .collect::>() + .try_into() + .unwrap() + } + + fn into_record(self) -> Self::Record { + self.map(Module::into_record) + } } impl AutodiffModule for [T; N] where - T: AutodiffModule + Debug + Send + Sync + Clone + Copy, - T::InnerModule: Copy + Debug, - >::Record: Debug, - >::Record: Debug, - B: AutodiffBackend, + T: AutodiffModule + Debug + Send + Sync + Clone + Copy, + T::InnerModule: Copy + Debug, + >::Record: Debug, + >::Record: Debug, + B: AutodiffBackend, { - type InnerModule = [T::InnerModule; N]; + type InnerModule = [T::InnerModule; N]; - fn valid(&self) -> Self::InnerModule { - self.map(|module| module.valid()) - } + fn valid(&self) -> Self::InnerModule { + self.map(|module| module.valid()) + } } diff --git a/burn-core/src/module/param/running.rs b/burn-core/src/module/param/running.rs index ab17f181f0..952adc2180 100644 --- a/burn-core/src/module/param/running.rs +++ b/burn-core/src/module/param/running.rs @@ -3,31 +3,31 @@ use alloc::sync::Arc; use super::ParamId; use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param}; use burn_tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, + backend::{AutodiffBackend, Backend}, + Tensor, }; #[cfg(feature = "std")] mod threading { - pub(super) use std::collections::HashMap; - pub(super) use std::sync::{Mutex, RwLock}; - pub(super) use std::thread::ThreadId; - - #[inline(always)] - pub(super) fn get_thread_current_id() -> ThreadId { - std::thread::current().id() - } + pub(super) use std::collections::HashMap; + pub(super) use std::sync::{Mutex, RwLock}; + pub(super) use std::thread::ThreadId; + + #[inline(always)] + pub(super) fn get_thread_current_id() -> ThreadId { + std::thread::current().id() + } } #[cfg(not(feature = "std"))] mod threading { - pub(super) use burn_common::stub::{Mutex, RwLock, ThreadId}; - pub(super) use hashbrown::HashMap; + pub(super) use burn_common::stub::{Mutex, RwLock, ThreadId}; + pub(super) use hashbrown::HashMap; - #[inline(always)] - pub(super) fn get_thread_current_id() -> ThreadId { - panic!("Current thread id is not available") - } + #[inline(always)] + pub(super) fn get_thread_current_id() -> ThreadId { + panic!("Current thread id is not available") + } } // Re-export items from the disabled/enabled blocks @@ -40,152 +40,152 @@ use threading::*; /// The state value is the average of all updates on all threads. #[derive(Clone, Debug)] pub struct RunningState { - id: ParamId, - values: Arc>>, - value: Arc>, + id: ParamId, + values: Arc>>, + value: Arc>, } impl Module for RunningState> { - type Record = Param>; + type Record = Param>; - fn visit>(&self, visitor: &mut V) { - let tensor = self.value.read().unwrap(); + fn visit>(&self, visitor: &mut V) { + let tensor = self.value.read().unwrap(); - visitor.visit(&self.id, &tensor) - } + visitor.visit(&self.id, &tensor) + } - fn map>(self, mapper: &mut M) -> Self { - let mut tensor = self.value.write().unwrap(); - let tensor_out = mapper.map(&self.id, tensor.clone()); + fn map>(self, mapper: &mut M) -> Self { + let mut tensor = self.value.write().unwrap(); + let tensor_out = mapper.map(&self.id, tensor.clone()); - *tensor = tensor_out; - core::mem::drop(tensor); + *tensor = tensor_out; + core::mem::drop(tensor); - self - } + self + } - fn into_record(self) -> Self::Record { - self.sync(); - let tensor = self.value.read().unwrap(); + fn into_record(self) -> Self::Record { + self.sync(); + let tensor = self.value.read().unwrap(); - Param::new(self.id, tensor.clone()) - } + Param::new(self.id, tensor.clone()) + } - fn load_record(mut self, record: Self::Record) -> Self { - let mut tensor = self.value.write().unwrap(); - *tensor = record.value.to_device(&tensor.device()); - self.id = record.id; + fn load_record(mut self, record: Self::Record) -> Self { + let mut tensor = self.value.write().unwrap(); + *tensor = record.value.to_device(&tensor.device()); + self.id = record.id; - core::mem::drop(tensor); + core::mem::drop(tensor); - self - } + self + } } impl RunningState> { - /// Create a new running state. - pub fn new(value: Tensor) -> Self { - Self { - id: ParamId::new(), - values: Arc::new(Mutex::new(HashMap::new())), - value: Arc::new(RwLock::new(value)), - } + /// Create a new running state. + pub fn new(value: Tensor) -> Self { + Self { + id: ParamId::new(), + values: Arc::new(Mutex::new(HashMap::new())), + value: Arc::new(RwLock::new(value)), } - - /// Create a new running state. - pub fn with_id(id: ParamId, value: Tensor) -> Self { - Self { - id, - values: Arc::new(Mutex::new(HashMap::new())), - value: Arc::new(RwLock::new(value)), - } + } + + /// Create a new running state. + pub fn with_id(id: ParamId, value: Tensor) -> Self { + Self { + id, + values: Arc::new(Mutex::new(HashMap::new())), + value: Arc::new(RwLock::new(value)), } - - /// Create a new running state from a record. - pub fn from_record(record: Param>) -> Self { - Self { - id: record.id, - values: Arc::new(Mutex::new(HashMap::new())), - value: Arc::new(RwLock::new(record.value)), - } + } + + /// Create a new running state from a record. + pub fn from_record(record: Param>) -> Self { + Self { + id: record.id, + values: Arc::new(Mutex::new(HashMap::new())), + value: Arc::new(RwLock::new(record.value)), } + } - /// Update the value on the current thread. - pub fn update(&self, value: Tensor) { - let thread_id = get_thread_current_id(); - let mut map = self.values.lock().unwrap(); - - if map.contains_key(&thread_id) { - self.update_value(&mut map); - } + /// Update the value on the current thread. + pub fn update(&self, value: Tensor) { + let thread_id = get_thread_current_id(); + let mut map = self.values.lock().unwrap(); - map.insert(thread_id, value); + if map.contains_key(&thread_id) { + self.update_value(&mut map); } - /// Get the current value, - /// - /// # Note - /// - /// The current value might be outdated by one update. - pub fn value(&self) -> Tensor { - let value = self.value.read().unwrap(); - value.clone() + map.insert(thread_id, value); + } + + /// Get the current value, + /// + /// # Note + /// + /// The current value might be outdated by one update. + pub fn value(&self) -> Tensor { + let value = self.value.read().unwrap(); + value.clone() + } + + /// Get the current value and make sure it is sync. + /// + /// # Note + /// + /// Don't use this function after an update on the same thread where other threads might have to + /// register their update before the actual synchronization needs to happen. + pub fn value_sync(&self) -> Tensor { + let thread_id = get_thread_current_id(); + let mut map = self.values.lock().unwrap(); + + if map.contains_key(&thread_id) { + self.update_value(&mut map); } - /// Get the current value and make sure it is sync. - /// - /// # Note - /// - /// Don't use this function after an update on the same thread where other threads might have to - /// register their update before the actual synchronization needs to happen. - pub fn value_sync(&self) -> Tensor { - let thread_id = get_thread_current_id(); - let mut map = self.values.lock().unwrap(); - - if map.contains_key(&thread_id) { - self.update_value(&mut map); - } - - let value = self.value.read().unwrap(); - value.clone() - } + let value = self.value.read().unwrap(); + value.clone() + } - fn sync(&self) { - let mut map = self.values.lock().unwrap(); + fn sync(&self) { + let mut map = self.values.lock().unwrap(); - if !map.is_empty() { - self.update_value(&mut map); - } + if !map.is_empty() { + self.update_value(&mut map); } + } - fn update_value(&self, map: &mut HashMap>) { - let mut value_updated = None; - let mut counter = 0; + fn update_value(&self, map: &mut HashMap>) { + let mut value_updated = None; + let mut counter = 0; - for (_key, tensor) in map.drain() { - counter += 1; + for (_key, tensor) in map.drain() { + counter += 1; - value_updated = match value_updated { - Some(current) => Some(tensor.add(current)), - None => Some(tensor), - }; - } + value_updated = match value_updated { + Some(current) => Some(tensor.add(current)), + None => Some(tensor), + }; + } - if let Some(value) = value_updated { - let value = value.div_scalar(counter); - let mut value_old = self.value.write().unwrap(); - *value_old = value; - } + if let Some(value) = value_updated { + let value = value.div_scalar(counter); + let mut value_old = self.value.write().unwrap(); + *value_old = value; } + } } impl AutodiffModule for RunningState> { - type InnerModule = RunningState>; + type InnerModule = RunningState>; - fn valid(&self) -> Self::InnerModule { - self.sync(); - let value = self.value(); + fn valid(&self) -> Self::InnerModule { + self.sync(); + let value = self.value(); - RunningState::with_id(self.id.clone(), value.inner()) - } + RunningState::with_id(self.id.clone(), value.inner()) + } } diff --git a/burn-core/src/module/param/tensor.rs b/burn-core/src/module/param/tensor.rs index 3bbe519dc4..c9936e0190 100644 --- a/burn-core/src/module/param/tensor.rs +++ b/burn-core/src/module/param/tensor.rs @@ -1,104 +1,104 @@ use super::{Param, ParamId}; use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}; use crate::tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, + backend::{AutodiffBackend, Backend}, + Tensor, }; impl From> for Param> { - fn from(value: Tensor) -> Self { - Param::new(ParamId::new(), value.require_grad()) - } + fn from(value: Tensor) -> Self { + Param::new(ParamId::new(), value.require_grad()) + } } impl Module for Param> { - type Record = Param>; + type Record = Param>; - fn visit>(&self, visitor: &mut V) { - visitor.visit(&self.id, &self.value) - } + fn visit>(&self, visitor: &mut V) { + visitor.visit(&self.id, &self.value) + } - fn map>(self, mapper: &mut M) -> Self { - let value = mapper.map(&self.id, self.value); - Self::new(self.id, value) - } + fn map>(self, mapper: &mut M) -> Self { + let value = mapper.map(&self.id, self.value); + Self::new(self.id, value) + } - fn into_record(self) -> Self::Record { - self - } + fn into_record(self) -> Self::Record { + self + } - fn load_record(self, record: Self::Record) -> Self { - let mut tensor = record.value.detach(); - let device = self.device(); + fn load_record(self, record: Self::Record) -> Self { + let mut tensor = record.value.detach(); + let device = self.device(); - // Make sure we load the record into the same module device. - if tensor.device() != device { - tensor = tensor.to_device(&device).detach(); - } + // Make sure we load the record into the same module device. + if tensor.device() != device { + tensor = tensor.to_device(&device).detach(); + } - // Make sure we load the record with the same autodiff setting. - tensor = tensor.set_require_grad(self.is_require_grad()); + // Make sure we load the record with the same autodiff setting. + tensor = tensor.set_require_grad(self.is_require_grad()); - Self::new(record.id, tensor) - } + Self::new(record.id, tensor) + } } impl AutodiffModule for Param> { - type InnerModule = Param>; - - fn valid(&self) -> Self::InnerModule { - Param::new( - self.id.clone(), - self.value.clone().inner().set_require_grad(false), - ) - } + type InnerModule = Param>; + + fn valid(&self) -> Self::InnerModule { + Param::new( + self.id.clone(), + self.value.clone().inner().set_require_grad(false), + ) + } } #[cfg(all(test, feature = "std"))] mod tests { - use super::*; - use crate::{ - module::Module, - nn::LinearConfig, - record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, - TestAutodiffBackend, - }; - - #[test] - fn test_load_record_setting() { - let tensor = Tensor::::ones([3, 3]); - - let byte_recorder = BinBytesRecorder::::default(); - let bytes = byte_recorder - .record(Param::from(tensor.clone()).into_record(), ()) - .unwrap(); - - let no_grad_is_require_grad = Param::from(tensor.clone()) - .no_grad() - .load_record(byte_recorder.load(bytes.clone()).unwrap()) - .value - .is_require_grad(); - - let with_default_is_require_grad = Param::from(tensor) - .load_record(byte_recorder.load(bytes).unwrap()) - .value - .is_require_grad(); - - assert!(!no_grad_is_require_grad); - assert!(with_default_is_require_grad); - } - - #[test] - fn test_init_with_record_setting() { - let config = LinearConfig::new(32, 32); - let module_init = config.init::(); - - let record = module_init.clone().into_record(); - let module_init_with = config.init_with::(record); - - assert_eq!( - module_init.weight.is_require_grad(), - module_init_with.weight.is_require_grad() - ); - } + use super::*; + use crate::{ + module::Module, + nn::LinearConfig, + record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, + TestAutodiffBackend, + }; + + #[test] + fn test_load_record_setting() { + let tensor = Tensor::::ones([3, 3]); + + let byte_recorder = BinBytesRecorder::::default(); + let bytes = byte_recorder + .record(Param::from(tensor.clone()).into_record(), ()) + .unwrap(); + + let no_grad_is_require_grad = Param::from(tensor.clone()) + .no_grad() + .load_record(byte_recorder.load(bytes.clone()).unwrap()) + .value + .is_require_grad(); + + let with_default_is_require_grad = Param::from(tensor) + .load_record(byte_recorder.load(bytes).unwrap()) + .value + .is_require_grad(); + + assert!(!no_grad_is_require_grad); + assert!(with_default_is_require_grad); + } + + #[test] + fn test_init_with_record_setting() { + let config = LinearConfig::new(32, 32); + let module_init = config.init::(); + + let record = module_init.clone().into_record(); + let module_init_with = config.init_with::(record); + + assert_eq!( + module_init.weight.is_require_grad(), + module_init_with.weight.is_require_grad() + ); + } } diff --git a/burn-core/src/module/param/visitor.rs b/burn-core/src/module/param/visitor.rs index 9e27e3b6d3..95bffe173c 100644 --- a/burn-core/src/module/param/visitor.rs +++ b/burn-core/src/module/param/visitor.rs @@ -5,28 +5,28 @@ use burn_tensor::{backend::Backend, Tensor}; use core::marker::PhantomData; struct ParamIdCollector<'a, M> { - param_ids: &'a mut Vec, - phantom: PhantomData, + param_ids: &'a mut Vec, + phantom: PhantomData, } impl<'a, B, M> ModuleVisitor for ParamIdCollector<'a, M> where - B: Backend, - M: Module, + B: Backend, + M: Module, { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { - self.param_ids.push(id.clone()); - } + fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + self.param_ids.push(id.clone()); + } } /// List all the parameter ids in a module. pub fn list_param_ids, B: Backend>(module: &M) -> Vec { - let mut params_ids = Vec::new(); - let mut visitor = ParamIdCollector { - param_ids: &mut params_ids, - phantom: PhantomData::, - }; - module.visit(&mut visitor); + let mut params_ids = Vec::new(); + let mut visitor = ParamIdCollector { + param_ids: &mut params_ids, + phantom: PhantomData::, + }; + module.visit(&mut visitor); - params_ids + params_ids } diff --git a/burn-core/src/nn/attention/mask.rs b/burn-core/src/nn/attention/mask.rs index b59f3a4f0a..fb7c6f3f91 100644 --- a/burn-core/src/nn/attention/mask.rs +++ b/burn-core/src/nn/attention/mask.rs @@ -6,136 +6,136 @@ use burn_tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Shape, T /// /// The mask can be used in Transformer modules to train models to generate tensors sequentially. pub fn generate_autoregressive_mask( - batch_size: usize, - seq_length: usize, - device: &B::Device, + batch_size: usize, + seq_length: usize, + device: &B::Device, ) -> Tensor { - let mut mask = Tensor::::zeros([1, seq_length, seq_length]); + let mut mask = Tensor::::zeros([1, seq_length, seq_length]); - for i in 0..(seq_length - 1) { - let values = Tensor::::ones([1, 1, seq_length - (i + 1)]); - mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values); - } + for i in 0..(seq_length - 1) { + let values = Tensor::::ones([1, 1, seq_length - (i + 1)]); + mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values); + } - mask = mask.to_device(device).repeat(0, batch_size); + mask = mask.to_device(device).repeat(0, batch_size); - mask.equal_elem(1_i64.elem::()) + mask.equal_elem(1_i64.elem::()) } /// Generate a padding attention mask. pub struct GeneratePaddingMask { - /// The generated tensor. - pub tensor: Tensor, + /// The generated tensor. + pub tensor: Tensor, - /// The generated mask. - pub mask: Tensor, + /// The generated mask. + pub mask: Tensor, } /// Generation padding attention mask. pub fn generate_padding_mask( - pad_token: usize, - tokens_list: Vec>, - max_seq_length: Option, - device: &B::Device, + pad_token: usize, + tokens_list: Vec>, + max_seq_length: Option, + device: &B::Device, ) -> GeneratePaddingMask { - let mut max_size = 0; - let batch_size = tokens_list.len(); - - for tokens in tokens_list.iter() { - if tokens.len() > max_size { - max_size = tokens.len(); - } - - if let Some(max_seq_length) = max_seq_length { - if tokens.len() >= max_seq_length { - max_size = max_seq_length; - break; - } - } + let mut max_size = 0; + let batch_size = tokens_list.len(); + + for tokens in tokens_list.iter() { + if tokens.len() > max_size { + max_size = tokens.len(); } - let mut tensor = Tensor::zeros([batch_size, max_size]); - tensor = tensor.add_scalar(pad_token as i64); - - for (index, tokens) in tokens_list.into_iter().enumerate() { - let mut seq_length = tokens.len(); - let mut tokens = tokens; - - if let Some(max_seq_length) = max_seq_length { - if seq_length > max_seq_length { - seq_length = max_seq_length; - let _ = tokens.split_off(seq_length); - } - } - - tensor = tensor.slice_assign( - [index..index + 1, 0..tokens.len()], - Tensor::from_data(Data::new( - tokens.into_iter().map(|e| (e as i64).elem()).collect(), - Shape::new([1, seq_length]), - )), - ); + if let Some(max_seq_length) = max_seq_length { + if tokens.len() >= max_seq_length { + max_size = max_seq_length; + break; + } } + } + + let mut tensor = Tensor::zeros([batch_size, max_size]); + tensor = tensor.add_scalar(pad_token as i64); + + for (index, tokens) in tokens_list.into_iter().enumerate() { + let mut seq_length = tokens.len(); + let mut tokens = tokens; - let mask = tensor - .clone() - .equal_elem(pad_token as i64) - .to_device(device); - let tensor = tensor.to_device(device); + if let Some(max_seq_length) = max_seq_length { + if seq_length > max_seq_length { + seq_length = max_seq_length; + let _ = tokens.split_off(seq_length); + } + } - GeneratePaddingMask { tensor, mask } + tensor = tensor.slice_assign( + [index..index + 1, 0..tokens.len()], + Tensor::from_data(Data::new( + tokens.into_iter().map(|e| (e as i64).elem()).collect(), + Shape::new([1, seq_length]), + )), + ); + } + + let mask = tensor + .clone() + .equal_elem(pad_token as i64) + .to_device(device); + let tensor = tensor.to_device(device); + + GeneratePaddingMask { tensor, mask } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use alloc::vec; - use burn_tensor::Data; - - #[test] - fn test_generate_autoregressive_mask() { - let device = ::Device::default(); - - let mask = generate_autoregressive_mask::(2, 3, &device); - - assert_eq!( - mask.into_data(), - Data::from([ - [ - [false, true, true], - [false, false, true], - [false, false, false], - ], - [ - [false, true, true], - [false, false, true], - [false, false, false], - ] - ]) - ); - } - - #[test] - fn test_generate_padding_mask() { - let device = ::Device::default(); - let tokens = vec![ - vec![3, 3, 3], - vec![3, 3, 3], - vec![3, 3, 3, 4], - vec![3, 3, 3, 4, 10, 15], - ]; - - let mask = generate_padding_mask::(0, tokens, None, &device); - - assert_eq!( - mask.mask.into_data(), - Data::from([ - [false, false, false, true, true, true], - [false, false, false, true, true, true], - [false, false, false, false, true, true], - [false, false, false, false, false, false], - ]) - ); - } + use super::*; + use crate::TestBackend; + use alloc::vec; + use burn_tensor::Data; + + #[test] + fn test_generate_autoregressive_mask() { + let device = ::Device::default(); + + let mask = generate_autoregressive_mask::(2, 3, &device); + + assert_eq!( + mask.into_data(), + Data::from([ + [ + [false, true, true], + [false, false, true], + [false, false, false], + ], + [ + [false, true, true], + [false, false, true], + [false, false, false], + ] + ]) + ); + } + + #[test] + fn test_generate_padding_mask() { + let device = ::Device::default(); + let tokens = vec![ + vec![3, 3, 3], + vec![3, 3, 3], + vec![3, 3, 3, 4], + vec![3, 3, 3, 4, 10, 15], + ]; + + let mask = generate_padding_mask::(0, tokens, None, &device); + + assert_eq!( + mask.mask.into_data(), + Data::from([ + [false, false, false, true, true, true], + [false, false, false, true, true, true], + [false, false, false, false, true, true], + [false, false, false, false, false, false], + ]) + ); + } } diff --git a/burn-core/src/nn/attention/mha.rs b/burn-core/src/nn/attention/mha.rs index 516166d656..b46e793e89 100644 --- a/burn-core/src/nn/attention/mha.rs +++ b/burn-core/src/nn/attention/mha.rs @@ -3,33 +3,39 @@ use crate as burn; use crate::nn::cache::TensorCache; use crate::nn::Initializer; use crate::{ - config::Config, - module::Module, - nn, - tensor::{activation, backend::Backend, Bool, Tensor}, + config::Config, + module::Module, + nn, + tensor::{activation, backend::Backend, Bool, Tensor}, }; use libm::sqrtf; /// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer. #[derive(Config)] pub struct MultiHeadAttentionConfig { - /// The size of the each linear layer. - d_model: usize, - /// The number of heads. - n_heads: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - dropout: f64, - /// The minimum value a float can take. Default: -1.0e4 - /// This is used to mask attention scores before calculating attention weights. - /// A value too low might result in NaN. - #[config(default = -1.0e4)] - min_float: f64, - /// The type of function used to initialize neural network parameters - #[config( - default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" - )] - pub initializer: Initializer, + /// The size of the each linear layer. + d_model: usize, + /// The number of heads. + n_heads: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + dropout: f64, + /// The minimum value a float can take. Default: -1.0e4 + /// This is used to mask attention scores before calculating attention weights. + /// A value too low might result in NaN. + #[config(default = -1.0e4)] + min_float: f64, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + quiet_softmax: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] + pub initializer: Initializer, } /// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -42,412 +48,415 @@ pub struct MultiHeadAttentionConfig { /// - output: [Linear](nn::Linear) layer with `d_model` input and output features. #[derive(Module, Debug)] pub struct MultiHeadAttention { - query: nn::Linear, - key: nn::Linear, - value: nn::Linear, - output: nn::Linear, - dropout: nn::Dropout, - activation: nn::GELU, - n_heads: usize, - d_k: usize, - min_float: f64, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + output: nn::Linear, + dropout: nn::Dropout, + activation: nn::GELU, + n_heads: usize, + d_k: usize, + min_float: f64, + quiet_softmax: bool, } /// [Multihead attention](MultiHeadAttention) forward pass input argument. #[derive(Debug, Clone)] pub struct MhaInput { - query: Tensor, - key: Tensor, - value: Tensor, - mask_pad: Option>, - mask_attn: Option>, + query: Tensor, + key: Tensor, + value: Tensor, + mask_pad: Option>, + mask_attn: Option>, } impl MultiHeadAttentionConfig { - /// Initialize a new [multihead attention](MultiHeadAttention) module. - pub fn init(&self) -> MultiHeadAttention { - let linear = |config: &Self| { - nn::LinearConfig::new(config.d_model, config.d_model) - .with_initializer(self.initializer.clone()) - .init() - }; - - MultiHeadAttention { - query: linear(self), - key: linear(self), - value: linear(self), - output: linear(self), - dropout: nn::DropoutConfig::new(self.dropout).init(), - activation: nn::GELU::new(), - n_heads: self.n_heads, - d_k: self.d_model / self.n_heads, - min_float: self.min_float, - } + /// Initialize a new [multihead attention](MultiHeadAttention) module. + pub fn init(&self) -> MultiHeadAttention { + let linear = |config: &Self| { + nn::LinearConfig::new(config.d_model, config.d_model) + .with_initializer(self.initializer.clone()) + .init() + }; + + MultiHeadAttention { + query: linear(self), + key: linear(self), + value: linear(self), + output: linear(self), + dropout: nn::DropoutConfig::new(self.dropout).init(), + activation: nn::GELU::new(), + n_heads: self.n_heads, + d_k: self.d_model / self.n_heads, + min_float: self.min_float, + quiet_softmax: self.quiet_softmax, } - - /// Initialize a new [multihead attention](MultiHeadAttention) module with a - /// [record](MultiHeadAttentionRecord). - pub fn init_with( - &self, - record: MultiHeadAttentionRecord, - ) -> MultiHeadAttention { - let linear = |config: &Self, record| { - nn::LinearConfig::new(config.d_model, config.d_model).init_with(record) - }; - - MultiHeadAttention { - query: linear(self, record.query), - key: linear(self, record.key), - value: linear(self, record.value), - output: linear(self, record.output), - dropout: nn::DropoutConfig::new(self.dropout).init(), - activation: nn::GELU::new(), - n_heads: self.n_heads, - d_k: self.d_model / self.n_heads, - min_float: self.min_float, - } + } + + /// Initialize a new [multihead attention](MultiHeadAttention) module with a + /// [record](MultiHeadAttentionRecord). + pub fn init_with( + &self, + record: MultiHeadAttentionRecord, + ) -> MultiHeadAttention { + let linear = |config: &Self, record| { + nn::LinearConfig::new(config.d_model, config.d_model).init_with(record) + }; + + MultiHeadAttention { + query: linear(self, record.query), + key: linear(self, record.key), + value: linear(self, record.value), + output: linear(self, record.output), + dropout: nn::DropoutConfig::new(self.dropout).init(), + activation: nn::GELU::new(), + n_heads: self.n_heads, + d_k: self.d_model / self.n_heads, + min_float: self.min_float, + quiet_softmax: self.quiet_softmax, } + } } impl MhaInput { - /// Create a [multihead attention](MultiHeadAttention) input argument - /// by setting the query, key and value to the given tensor. - pub fn self_attn(tensor: Tensor) -> Self { - Self { - query: tensor.clone(), - key: tensor.clone(), - value: tensor, - mask_pad: None, - mask_attn: None, - } - } - - /// Create a [multihead attention](MultiHeadAttention) input argument. - pub fn new(query: Tensor, key: Tensor, value: Tensor) -> Self { - Self { - query, - key, - value, - mask_pad: None, - mask_attn: None, - } + /// Create a [multihead attention](MultiHeadAttention) input argument + /// by setting the query, key and value to the given tensor. + pub fn self_attn(tensor: Tensor) -> Self { + Self { + query: tensor.clone(), + key: tensor.clone(), + value: tensor, + mask_pad: None, + mask_attn: None, } - - /// Register the padding mask. - pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { - self.mask_pad = Some(mask_pad); - self - } - - /// Register the attention mask. - pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { - self.mask_attn = Some(mask_attn); - self + } + + /// Create a [multihead attention](MultiHeadAttention) input argument. + pub fn new(query: Tensor, key: Tensor, value: Tensor) -> Self { + Self { + query, + key, + value, + mask_pad: None, + mask_attn: None, } + } + + /// Register the padding mask. + pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { + self.mask_pad = Some(mask_pad); + self + } + + /// Register the attention mask. + pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { + self.mask_attn = Some(mask_attn); + self + } } /// [Multihead attention](MultiHeadAttention) outputs. #[derive(Debug, Clone)] pub struct MhaOutput { - /// The attention weights [batch_size, seq_length_1, seq_length_2]. - pub weights: Tensor, - /// The context tensor [batch_size, seq_length_1, d_model]. - pub context: Tensor, + /// The attention weights [batch_size, seq_length_1, seq_length_2]. + pub weights: Tensor, + /// The context tensor [batch_size, seq_length_1, d_model]. + pub context: Tensor, } impl MultiHeadAttention { - /// Applies the forward pass on the input tensors. - /// - /// # Shapes - /// - /// - query: `[batch_size, seq_length_1, d_model]` - /// - key: `[batch_size, seq_length_2, d_model]` - /// - value: `[batch_size, seq_length_2, d_model]` - /// - output: `[batch_size, seq_length_1, d_model]` - pub fn forward(&self, input: MhaInput) -> MhaOutput { - let [batch_size, seq_length_1, d_model] = input.query.dims(); - - let query = self.attention_linear(input.query, &self.query); - let key = self.attention_linear(input.key, &self.key); - let value = self.attention_linear(input.value, &self.value); - - let attn_scores = self.attn_scores(query, key); - let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); - - let context = weights.clone().matmul(value); - let context = context - .swap_dims(1, 2) - .reshape([batch_size, seq_length_1, d_model]); - let context = self.output.forward(context); - - MhaOutput { weights, context } - } - - /// Applies the forward pass using a cache. - /// - /// # Shapes - /// - /// - query: `[batch_size, seq_length_1, d_model]` - /// - key: `[batch_size, seq_length_2, d_model]` - /// - value: `[batch_size, seq_length_2, d_model]` - /// - output: `[batch_size, seq_length_1, d_model]` - pub fn forward_cache(&self, input: MhaInput, cache: &mut MhaCache) -> MhaOutput { - let [batch_size, seq_length_1, d_model] = input.query.dims(); - - let query = cache - .query - .forward(input.query, |t| self.attention_linear(t, &self.query)); - let key = cache - .key - .forward(input.key, |t| self.attention_linear(t, &self.key)); - let value = cache - .value - .forward(input.value, |t| self.attention_linear(t, &self.value)); - - let attn_scores = self.attn_scores(query, key); - let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); - - let context = weights.clone().matmul(value); - let context = context - .swap_dims(1, 2) - .reshape([batch_size, seq_length_1, d_model]); - - let context = cache.output.forward(context, |t| self.output.forward(t)); - - MhaOutput { weights, context } + /// Applies the forward pass on the input tensors. + /// + /// # Shapes + /// + /// - query: `[batch_size, seq_length_1, d_model]` + /// - key: `[batch_size, seq_length_2, d_model]` + /// - value: `[batch_size, seq_length_2, d_model]` + /// - output: `[batch_size, seq_length_1, d_model]` + pub fn forward(&self, input: MhaInput) -> MhaOutput { + let [batch_size, seq_length_1, d_model] = input.query.dims(); + + let query = self.attention_linear(input.query, &self.query); + let key = self.attention_linear(input.key, &self.key); + let value = self.attention_linear(input.value, &self.value); + + let attn_scores = self.attn_scores(query, key); + let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); + + let context = weights.clone().matmul(value); + let context = context + .swap_dims(1, 2) + .reshape([batch_size, seq_length_1, d_model]); + let context = self.output.forward(context); + + MhaOutput { weights, context } + } + + /// Applies the forward pass using a cache. + /// + /// # Shapes + /// + /// - query: `[batch_size, seq_length_1, d_model]` + /// - key: `[batch_size, seq_length_2, d_model]` + /// - value: `[batch_size, seq_length_2, d_model]` + /// - output: `[batch_size, seq_length_1, d_model]` + pub fn forward_cache(&self, input: MhaInput, cache: &mut MhaCache) -> MhaOutput { + let [batch_size, seq_length_1, d_model] = input.query.dims(); + + let query = cache + .query + .forward(input.query, |t| self.attention_linear(t, &self.query)); + let key = cache + .key + .forward(input.key, |t| self.attention_linear(t, &self.key)); + let value = cache + .value + .forward(input.value, |t| self.attention_linear(t, &self.value)); + + let attn_scores = self.attn_scores(query, key); + let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); + + let context = weights.clone().matmul(value); + let context = context + .swap_dims(1, 2) + .reshape([batch_size, seq_length_1, d_model]); + + let context = cache.output.forward(context, |t| self.output.forward(t)); + + MhaOutput { weights, context } + } + + fn attn_scores(&self, query: Tensor, key: Tensor) -> Tensor { + let attn_scores = query + .matmul(key.transpose()) + .div_scalar(sqrtf(self.d_k as f32)); + + self.dropout.forward(attn_scores) + } + + fn attn_weights( + &self, + mut attn_scores: Tensor, + mask_pad: Option>, + mask_attn: Option>, + ) -> Tensor { + if let Some(mask_pad) = mask_pad { + let [batch_size, seq_length] = mask_pad.dims(); + + attn_scores = attn_scores.mask_fill( + mask_pad.reshape([batch_size, 1, 1, seq_length]), + self.min_float, + ); } - fn attn_scores(&self, query: Tensor, key: Tensor) -> Tensor { - let attn_scores = query - .matmul(key.transpose()) - .div_scalar(sqrtf(self.d_k as f32)); + if let Some(mask_attn) = mask_attn { + let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims(); - self.dropout.forward(attn_scores) + attn_scores = attn_scores.mask_fill( + mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]), + self.min_float, + ); } - fn attn_weights( - &self, - mut attn_scores: Tensor, - mask_pad: Option>, - mask_attn: Option>, - ) -> Tensor { - if let Some(mask_pad) = mask_pad { - let [batch_size, seq_length] = mask_pad.dims(); - - attn_scores = attn_scores.mask_fill( - mask_pad.reshape([batch_size, 1, 1, seq_length]), - self.min_float, - ); - } - - if let Some(mask_attn) = mask_attn { - let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims(); - - attn_scores = attn_scores.mask_fill( - mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]), - self.min_float, - ); - } - - activation::softmax(attn_scores, 3) - } - - fn attention_linear(&self, x: Tensor, linear: &nn::Linear) -> Tensor { - let [batch_size, seq_length, _d_model] = x.dims(); - linear - .forward(x) - .reshape([batch_size, seq_length, self.n_heads, self.d_k]) - .swap_dims(1, 2) + if self.quiet_softmax { + activation::quiet_softmax(attn_scores, 3) + } else { + activation::softmax(attn_scores, 3) } + } + + fn attention_linear(&self, x: Tensor, linear: &nn::Linear) -> Tensor { + let [batch_size, seq_length, _d_model] = x.dims(); + linear + .forward(x) + .reshape([batch_size, seq_length, self.n_heads, self.d_k]) + .swap_dims(1, 2) + } } /// Cache for the [Multi Head Attention](MultiHeadAttention) layer. /// /// To be used during inference when decoding tokens. pub struct MhaCache { - query: MhaLinearCache, - key: MhaLinearCache, - value: MhaLinearCache, - output: MhaLinearCache, + query: MhaLinearCache, + key: MhaLinearCache, + value: MhaLinearCache, + output: MhaLinearCache, } enum MhaLinearCache { - Autoregressive(TensorCache, usize), - Full(TensorCache), + Autoregressive(TensorCache, usize), + Full(TensorCache), } impl MhaCache { - /// Initialize a cache for autoregressive inference. - pub fn autoregressive() -> Self { - Self { - query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), - } + /// Initialize a cache for autoregressive inference. + pub fn autoregressive() -> Self { + Self { + query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), } - - /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and - /// values (cross-attention). - pub fn autoregressive_cross_attention() -> Self { - Self { - query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - key: MhaLinearCache::Full(TensorCache::empty()), - value: MhaLinearCache::Full(TensorCache::empty()), - output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), - } + } + + /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and + /// values (cross-attention). + pub fn autoregressive_cross_attention() -> Self { + Self { + query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + key: MhaLinearCache::Full(TensorCache::empty()), + value: MhaLinearCache::Full(TensorCache::empty()), + output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), } + } } impl MhaLinearCache { - pub fn forward) -> Tensor>( - &mut self, - tensor: Tensor, - func: F, - ) -> Tensor { - match self { - MhaLinearCache::Autoregressive(cache, dim) => { - cache.forward_autoregressive(tensor, *dim, func) - } - MhaLinearCache::Full(cache) => cache.forward_full(tensor, func), - } + pub fn forward) -> Tensor>( + &mut self, + tensor: Tensor, + func: F, + ) -> Tensor { + match self { + MhaLinearCache::Autoregressive(cache, dim) => { + cache.forward_autoregressive(tensor, *dim, func) + } + MhaLinearCache::Full(cache) => cache.forward_full(tensor, func), } + } } #[cfg(test)] mod tests { - use super::*; - use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; - use alloc::vec::Vec; - use burn::tensor::{Distribution, Shape}; - use burn_tensor::Int; - - #[test] - fn test_self_attention_shapes() { - let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - let input = MhaInput::self_attn(Tensor::random( - [batch_size, seq_length, d_model], - Distribution::Default, - )); - - let output = mha.forward(input); - - assert_eq!( - output.context.shape(), - Shape::new([batch_size, seq_length, d_model]), - "Context should have the correct shape", - ); - assert_eq!( - output.weights.shape(), - Shape::new([batch_size, n_heads, seq_length, seq_length]), - "Weights should have the correct shape", - ); + use super::*; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use alloc::vec::Vec; + use burn::tensor::{Distribution, Shape}; + use burn_tensor::Int; + + #[test] + fn test_self_attention_shapes() { + let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + let input = MhaInput::self_attn(Tensor::random( + [batch_size, seq_length, d_model], + Distribution::Default, + )); + + let output = mha.forward(input); + + assert_eq!( + output.context.shape(), + Shape::new([batch_size, seq_length, d_model]), + "Context should have the correct shape", + ); + assert_eq!( + output.weights.shape(), + Shape::new([batch_size, n_heads, seq_length, seq_length]), + "Weights should have the correct shape", + ); + } + + #[test] + fn test_generic_mha_shapes() { + let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + let input = MhaInput::new( + Tensor::random([batch_size, seq_length_1, d_model], Distribution::Default), + Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), + Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), + ); + + let output = mha.forward(input); + + assert_eq!( + output.context.shape(), + Shape::new([batch_size, seq_length_1, d_model]), + "Context should have the correct shape", + ); + assert_eq!( + output.weights.shape(), + Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]), + "Weights should have the correct shape", + ); + } + + #[test] + fn test_self_attention_mask_pad() { + let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + + // Create a padding mask + let mask_pad: Tensor = Tensor::zeros([batch_size, seq_length]); + let mask_pad = mask_pad.slice_assign( + [0..batch_size, seq_length - num_padded..seq_length], + Tensor::ones([batch_size, num_padded]), + ); + let mask_pad = mask_pad.equal_elem(1); + + let tensor_1 = + Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); + // Change the end of the tensor + let tensor_2 = tensor_1.clone().slice_assign( + [ + 0..batch_size, + seq_length - num_padded..seq_length, + 0..d_model, + ], + Tensor::random([batch_size, num_padded, d_model], Distribution::Default), + ); + + let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone()); + let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad); + + let output_1 = mha.forward(input_1); + let output_2 = mha.forward(input_2); + + // Check that the beginning of each tensor is the same + output_1 + .context + .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) + .into_data() + .assert_approx_eq( + &output_2 + .context + .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) + .into_data(), + 3, + ); + } + + #[test] + fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() { + let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + + let tensor = + Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); + let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn); + + let output_1 = mha.forward(input); + let mut output_2 = Vec::new(); + let mut cache = MhaCache::autoregressive(); + + for i in 1..seq_length + 1 { + let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); + let input = MhaInput::self_attn(tensor); + let next_tok = + mha + .forward_cache(input, &mut cache) + .context + .slice([0..batch_size, i - 1..i, 0..d_model]); + output_2.push(next_tok); } - #[test] - fn test_generic_mha_shapes() { - let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - let input = MhaInput::new( - Tensor::random([batch_size, seq_length_1, d_model], Distribution::Default), - Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), - Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), - ); - - let output = mha.forward(input); - - assert_eq!( - output.context.shape(), - Shape::new([batch_size, seq_length_1, d_model]), - "Context should have the correct shape", - ); - assert_eq!( - output.weights.shape(), - Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]), - "Weights should have the correct shape", - ); - } - - #[test] - fn test_self_attention_mask_pad() { - let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - - // Create a padding mask - let mask_pad: Tensor = Tensor::zeros([batch_size, seq_length]); - let mask_pad = mask_pad.slice_assign( - [0..batch_size, seq_length - num_padded..seq_length], - Tensor::ones([batch_size, num_padded]), - ); - let mask_pad = mask_pad.equal_elem(1); - - let tensor_1 = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - ); - // Change the end of the tensor - let tensor_2 = tensor_1.clone().slice_assign( - [ - 0..batch_size, - seq_length - num_padded..seq_length, - 0..d_model, - ], - Tensor::random([batch_size, num_padded, d_model], Distribution::Default), - ); - - let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone()); - let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad); - - let output_1 = mha.forward(input_1); - let output_2 = mha.forward(input_2); - - // Check that the beginning of each tensor is the same - output_1 - .context - .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) - .into_data() - .assert_approx_eq( - &output_2 - .context - .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) - .into_data(), - 3, - ); - } + let output_2 = Tensor::cat(output_2, 1); - #[test] - fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() { - let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - - let tensor = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - ); - let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); - let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn); - - let output_1 = mha.forward(input); - let mut output_2 = Vec::new(); - let mut cache = MhaCache::autoregressive(); - - for i in 1..seq_length + 1 { - let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); - let input = MhaInput::self_attn(tensor); - let next_tok = mha.forward_cache(input, &mut cache).context.slice([ - 0..batch_size, - i - 1..i, - 0..d_model, - ]); - output_2.push(next_tok); - } - - let output_2 = Tensor::cat(output_2, 1); - - output_1 - .context - .into_data() - .assert_approx_eq(&output_2.into_data(), 3); - } + output_1 + .context + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/cache/autoregressive.rs b/burn-core/src/nn/cache/autoregressive.rs index 8d1f1b5bb6..1cabecd6cb 100644 --- a/burn-core/src/nn/cache/autoregressive.rs +++ b/burn-core/src/nn/cache/autoregressive.rs @@ -5,47 +5,47 @@ use crate::tensor::backend::Backend; use crate::tensor::Tensor; impl TensorCache { - pub(crate) fn forward_autoregressive( - &mut self, - tensor: Tensor, - dim_cat: usize, - func: F, - ) -> Tensor - where - F: Fn(Tensor) -> Tensor, - { - let mut tensor_old = CacheState::Empty; - core::mem::swap(&mut self.state, &mut tensor_old); - - let tensor_new = match tensor_old { - CacheState::Value(tensor_old) => { - let [batch_size, seq_length, d_model] = tensor.dims(); - let next_seq_token = - tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]); - let next_seq_token = func(next_seq_token); - - Tensor::cat(vec![tensor_old, next_seq_token], dim_cat) - } - _ => func(tensor), - }; - - self.state = CacheState::Value(tensor_new.clone()); - tensor_new - } - - pub(crate) fn forward_full(&mut self, tensor: Tensor, func: F) -> Tensor - where - F: Fn(Tensor) -> Tensor, - { - let mut tensor_old = CacheState::Empty; - core::mem::swap(&mut self.state, &mut tensor_old); - - let tensor_new = match tensor_old { - CacheState::Value(tensor_old) => tensor_old, - _ => func(tensor), - }; - - self.state = CacheState::Value(tensor_new.clone()); - tensor_new - } + pub(crate) fn forward_autoregressive( + &mut self, + tensor: Tensor, + dim_cat: usize, + func: F, + ) -> Tensor + where + F: Fn(Tensor) -> Tensor, + { + let mut tensor_old = CacheState::Empty; + core::mem::swap(&mut self.state, &mut tensor_old); + + let tensor_new = match tensor_old { + CacheState::Value(tensor_old) => { + let [batch_size, seq_length, d_model] = tensor.dims(); + let next_seq_token = + tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]); + let next_seq_token = func(next_seq_token); + + Tensor::cat(vec![tensor_old, next_seq_token], dim_cat) + } + _ => func(tensor), + }; + + self.state = CacheState::Value(tensor_new.clone()); + tensor_new + } + + pub(crate) fn forward_full(&mut self, tensor: Tensor, func: F) -> Tensor + where + F: Fn(Tensor) -> Tensor, + { + let mut tensor_old = CacheState::Empty; + core::mem::swap(&mut self.state, &mut tensor_old); + + let tensor_new = match tensor_old { + CacheState::Value(tensor_old) => tensor_old, + _ => func(tensor), + }; + + self.state = CacheState::Value(tensor_new.clone()); + tensor_new + } } diff --git a/burn-core/src/nn/cache/base.rs b/burn-core/src/nn/cache/base.rs index 322c65c810..baa85bd414 100644 --- a/burn-core/src/nn/cache/base.rs +++ b/burn-core/src/nn/cache/base.rs @@ -2,24 +2,24 @@ use crate::tensor::backend::Backend; use crate::tensor::Tensor; pub(crate) enum CacheState { - Value(T), - Empty, + Value(T), + Empty, } /// A cache for a tensor. pub struct TensorCache { - pub(crate) state: CacheState>, + pub(crate) state: CacheState>, } impl TensorCache { - /// Creates a new empty cache. - /// - /// # Returns - /// - /// The empty cache. - pub fn empty() -> Self { - Self { - state: CacheState::Empty, - } + /// Creates a new empty cache. + /// + /// # Returns + /// + /// The empty cache. + pub fn empty() -> Self { + Self { + state: CacheState::Empty, } + } } diff --git a/burn-core/src/nn/conv/checks.rs b/burn-core/src/nn/conv/checks.rs index ca470f7b29..bf2253c8d6 100644 --- a/burn-core/src/nn/conv/checks.rs +++ b/burn-core/src/nn/conv/checks.rs @@ -1,8 +1,8 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize, groups: usize) { - let channels_in_div_by_group = channels_in % groups == 0; - let channels_out_div_by_group = channels_out % groups == 0; + let channels_in_div_by_group = channels_in % groups == 0; + let channels_out_div_by_group = channels_out % groups == 0; - if !channels_in_div_by_group && !channels_out_div_by_group { - panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}"); - } + if !channels_in_div_by_group && !channels_out_div_by_group { + panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}"); + } } diff --git a/burn-core/src/nn/conv/conv1d.rs b/burn-core/src/nn/conv/conv1d.rs index 8a5b79bb7b..2ceacebad7 100644 --- a/burn-core/src/nn/conv/conv1d.rs +++ b/burn-core/src/nn/conv/conv1d.rs @@ -15,30 +15,30 @@ use super::checks; /// Configuration to create an [1D convolution](Conv1d) layer. #[derive(Config, Debug)] pub struct Conv1dConfig { - /// The number of input channels. - pub channels_in: usize, - /// The number of output channels. - pub channels_out: usize, - /// The size of the kernel. - pub kernel_size: usize, - /// The stride of the convolution. - #[config(default = "1")] - pub stride: usize, - /// Spacing between kernel elements. - #[config(default = "1")] - pub dilation: usize, - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "PaddingConfig1d::Valid")] - pub padding: PaddingConfig1d, - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of input channels. + pub channels_in: usize, + /// The number of output channels. + pub channels_out: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The stride of the convolution. + #[config(default = "1")] + pub stride: usize, + /// Spacing between kernel elements. + #[config(default = "1")] + pub dilation: usize, + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "PaddingConfig1d::Valid")] + pub padding: PaddingConfig1d, + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 1D convolution over input tensors. @@ -50,111 +50,113 @@ pub struct Conv1dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct Conv1d { - weight: Param>, - bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: PaddingConfig1d, + weight: Param>, + bias: Option>>, + stride: usize, + kernel_size: usize, + dilation: usize, + groups: usize, + padding: PaddingConfig1d, } impl Conv1dConfig { - /// Initialize a new [conv1d](Conv1d) module. - pub fn init(&self) -> Conv1d { - checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); - - let shape = [ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size, - ]; - - let fan_in: usize = self.channels_in / self.groups * self.kernel_size; - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self.initializer - .init_with([self.channels_out], Some(fan_in), None), - ); - } - - Conv1d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, - groups: self.groups, - } + /// Initialize a new [conv1d](Conv1d) module. + pub fn init(&self) -> Conv1d { + checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); + + let shape = [ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size, + ]; + + let fan_in: usize = self.channels_in / self.groups * self.kernel_size; + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self + .initializer + .init_with([self.channels_out], Some(fan_in), None), + ); } - /// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord). - pub fn init_with(&self, record: Conv1dRecord) -> Conv1d { - Conv1d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, - groups: self.groups, - } + + Conv1d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, + groups: self.groups, + } + } + /// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord). + pub fn init_with(&self, record: Conv1dRecord) -> Conv1d { + Conv1d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, + groups: self.groups, } + } } impl Conv1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, length_in], - /// - output: [batch_size, channels_out, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); - - conv1d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvOptions::new([self.stride], [padding], [self.dilation], self.groups), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, length_in], + /// - output: [batch_size, channels_out, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); + + conv1d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new([self.stride], [padding], [self.dilation], self.groups), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = Conv1dConfig::new(5, 5, 5); - let k = (config.channels_in * config.kernel_size) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv.weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = Conv1dConfig::new(5, 5, 5); + let k = (config.channels_in * config.kernel_size) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv + .weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/conv/conv2d.rs b/burn-core/src/nn/conv/conv2d.rs index ed27f3a8a8..c114b02d10 100644 --- a/burn-core/src/nn/conv/conv2d.rs +++ b/burn-core/src/nn/conv/conv2d.rs @@ -16,28 +16,28 @@ use super::checks; /// Configuration to create an [2D convolution](Conv2d) layer. #[derive(Config, Debug)] pub struct Conv2dConfig { - /// The number of channels. - pub channels: [usize; 2], - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The stride of the convolution. - #[config(default = "[1, 1]")] - pub stride: [usize; 2], - /// Spacing between kernel elements. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "PaddingConfig2d::Valid")] - pub padding: PaddingConfig2d, - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of channels. + pub channels: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + #[config(default = "[1, 1]")] + pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "PaddingConfig2d::Valid")] + pub padding: PaddingConfig2d, + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 2D convolution over input tensors. @@ -49,112 +49,115 @@ pub struct Conv2dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct Conv2d { - weight: Param>, - bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: PaddingConfig2d, + weight: Param>, + bias: Option>>, + stride: [usize; 2], + kernel_size: [usize; 2], + dilation: [usize; 2], + groups: usize, + padding: PaddingConfig2d, } impl Conv2dConfig { - /// Initialize a new [conv2d](Conv2d) module. - pub fn init(&self) -> Conv2d { - checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - - let shape = [ - self.channels[1], - self.channels[0] / self.groups, - self.kernel_size[0], - self.kernel_size[1], - ]; - - let fan_in = self.channels[0] / self.groups * self.kernel_size.iter().product::(); - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self.initializer - .init_with([self.channels[1]], Some(fan_in), None), - ); - } - - Conv2d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - dilation: self.dilation, - padding: self.padding.clone(), - groups: self.groups, - } + /// Initialize a new [conv2d](Conv2d) module. + pub fn init(&self) -> Conv2d { + checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + + let shape = [ + self.channels[1], + self.channels[0] / self.groups, + self.kernel_size[0], + self.kernel_size[1], + ]; + + let fan_in = self.channels[0] / self.groups * self.kernel_size.iter().product::(); + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self + .initializer + .init_with([self.channels[1]], Some(fan_in), None), + ); } - /// Initialize a new [conv2d](Conv2d) module with a [record](Conv2dRecord). - pub fn init_with(&self, record: Conv2dRecord) -> Conv2d { - Conv2d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - dilation: self.dilation, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - groups: self.groups, - } + Conv2d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + dilation: self.dilation, + padding: self.padding.clone(), + groups: self.groups, } + } + + /// Initialize a new [conv2d](Conv2d) module with a [record](Conv2dRecord). + pub fn init_with(&self, record: Conv2dRecord) -> Conv2d { + Conv2d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + dilation: self.dilation, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + groups: self.groups, + } + } } impl Conv2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, height_in, width_in], - /// - output: [batch_size, channels_out, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self.padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - conv2d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvOptions::new(self.stride, padding, self.dilation, self.groups), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, height_in, width_in], + /// - output: [batch_size, channels_out, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = + self + .padding + .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); + conv2d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new(self.stride, padding, self.dilation, self.groups), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = Conv2dConfig::new([5, 1], [5, 5]); - let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv.weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = Conv2dConfig::new([5, 1], [5, 5]); + let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv + .weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/conv/conv_transpose1d.rs b/burn-core/src/nn/conv/conv_transpose1d.rs index fb25d6f3ee..4309738e68 100644 --- a/burn-core/src/nn/conv/conv_transpose1d.rs +++ b/burn-core/src/nn/conv/conv_transpose1d.rs @@ -15,31 +15,31 @@ use super::checks; /// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer. #[derive(Config, Debug)] pub struct ConvTranspose1dConfig { - /// The number of channels. - pub channels: [usize; 2], - /// The size of the kernel. - pub kernel_size: usize, - /// The stride of the convolution. - #[config(default = "1")] - pub stride: usize, - /// Spacing between kernel elements. - #[config(default = "1")] - pub dilation: usize, - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "0")] - pub padding: usize, - /// The padding output configuration. - #[config(default = "0")] - pub padding_out: usize, - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of channels. + pub channels: [usize; 2], + /// The size of the kernel. + pub kernel_size: usize, + /// The stride of the convolution. + #[config(default = "1")] + pub stride: usize, + /// Spacing between kernel elements. + #[config(default = "1")] + pub dilation: usize, + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "0")] + pub padding: usize, + /// The padding output configuration. + #[config(default = "0")] + pub padding_out: usize, + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 1D transposed convolution over input tensors. @@ -51,116 +51,118 @@ pub struct ConvTranspose1dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct ConvTranspose1d { - weight: Param>, - bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: usize, - padding_out: usize, + weight: Param>, + bias: Option>>, + stride: usize, + kernel_size: usize, + dilation: usize, + groups: usize, + padding: usize, + padding_out: usize, } impl ConvTranspose1dConfig { - /// Initialize a new [conv transpose 1d](ConvTranspose1d) module. - pub fn init(&self) -> ConvTranspose1d { - checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - - let shape = [ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size, - ]; - - let fan_in = self.channels[1] / self.groups * self.kernel_size; - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self.initializer - .init_with([self.channels[1]], Some(fan_in), None), - ); - } - - ConvTranspose1d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - dilation: self.dilation, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, - } + /// Initialize a new [conv transpose 1d](ConvTranspose1d) module. + pub fn init(&self) -> ConvTranspose1d { + checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + + let shape = [ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size, + ]; + + let fan_in = self.channels[1] / self.groups * self.kernel_size; + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self + .initializer + .init_with([self.channels[1]], Some(fan_in), None), + ); } - /// Initialize a new [conv transpose 1d](ConvTranspose1d) module with a [record](ConvTranspose1dRecord). - pub fn init_with(&self, record: ConvTranspose1dRecord) -> ConvTranspose1d { - ConvTranspose1d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - dilation: self.dilation, - kernel_size: self.kernel_size, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, - } + ConvTranspose1d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + dilation: self.dilation, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, } + } + + /// Initialize a new [conv transpose 1d](ConvTranspose1d) module with a [record](ConvTranspose1dRecord). + pub fn init_with(&self, record: ConvTranspose1dRecord) -> ConvTranspose1d { + ConvTranspose1d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + dilation: self.dilation, + kernel_size: self.kernel_size, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, + } + } } impl ConvTranspose1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, length_in], - /// - output: [batch_size, channels_out, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - conv_transpose1d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvTransposeOptions::new( - [self.stride], - [self.padding], - [self.padding_out], - [self.dilation], - self.groups, - ), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, length_in], + /// - output: [batch_size, channels_out, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + conv_transpose1d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvTransposeOptions::new( + [self.stride], + [self.padding], + [self.padding_out], + [self.dilation], + self.groups, + ), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = ConvTranspose1dConfig::new([5, 1], 5); - let k = (config.channels[1] * config.kernel_size) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv.weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = ConvTranspose1dConfig::new([5, 1], 5); + let k = (config.channels[1] * config.kernel_size) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv + .weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/conv/conv_transpose2d.rs b/burn-core/src/nn/conv/conv_transpose2d.rs index af4a249e62..7b6fa0ac67 100644 --- a/burn-core/src/nn/conv/conv_transpose2d.rs +++ b/burn-core/src/nn/conv/conv_transpose2d.rs @@ -15,31 +15,31 @@ use super::checks; /// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer. #[derive(Config, Debug)] pub struct ConvTranspose2dConfig { - /// The number of channels. - pub channels: [usize; 2], - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The stride of the convolution. - #[config(default = "[1, 1]")] - pub stride: [usize; 2], - /// Spacing between kernel elements. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "[0, 0]")] - pub padding: [usize; 2], - /// The padding output configuration. - #[config(default = "[0, 0]")] - pub padding_out: [usize; 2], - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of channels. + pub channels: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + #[config(default = "[1, 1]")] + pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "[0, 0]")] + pub padding: [usize; 2], + /// The padding output configuration. + #[config(default = "[0, 0]")] + pub padding_out: [usize; 2], + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 2D transposed convolution over input tensors. @@ -51,118 +51,119 @@ pub struct ConvTranspose2dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct ConvTranspose2d { - weight: Param>, - bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: [usize; 2], - padding_out: [usize; 2], + weight: Param>, + bias: Option>>, + stride: [usize; 2], + kernel_size: [usize; 2], + dilation: [usize; 2], + groups: usize, + padding: [usize; 2], + padding_out: [usize; 2], } impl ConvTranspose2dConfig { - /// Initialize a new [conv transpose 2d](ConvTranspose2d) module. - pub fn init(&self) -> ConvTranspose2d { - checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - - let shape = [ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size[0], - self.kernel_size[1], - ]; - - let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::(); - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self.initializer - .init_with([self.channels[1]], Some(fan_in), None), - ); - } - - ConvTranspose2d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - dilation: self.dilation, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, - } + /// Initialize a new [conv transpose 2d](ConvTranspose2d) module. + pub fn init(&self) -> ConvTranspose2d { + checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + + let shape = [ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size[0], + self.kernel_size[1], + ]; + + let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::(); + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self + .initializer + .init_with([self.channels[1]], Some(fan_in), None), + ); } - /// Initialize a new [conv transpose 2d](ConvTranspose2d) module with a [record](ConvTranspose2dRecord). - pub fn init_with(&self, record: ConvTranspose2dRecord) -> ConvTranspose2d { - ConvTranspose2d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - dilation: self.dilation, - kernel_size: self.kernel_size, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, - } + ConvTranspose2d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + dilation: self.dilation, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, } + } + + /// Initialize a new [conv transpose 2d](ConvTranspose2d) module with a [record](ConvTranspose2dRecord). + pub fn init_with(&self, record: ConvTranspose2dRecord) -> ConvTranspose2d { + ConvTranspose2d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + dilation: self.dilation, + kernel_size: self.kernel_size, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, + } + } } impl ConvTranspose2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, height_in, width_in], - /// - output: [batch_size, channels_out, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - conv_transpose2d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvTransposeOptions::new( - self.stride, - self.padding, - self.padding_out, - self.dilation, - self.groups, - ), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, height_in, width_in], + /// - output: [batch_size, channels_out, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + conv_transpose2d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvTransposeOptions::new( + self.stride, + self.padding, + self.padding_out, + self.dilation, + self.groups, + ), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = ConvTranspose2dConfig::new([5, 1], [5, 5]); - let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = - ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv.weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = ConvTranspose2dConfig::new([5, 1], [5, 5]); + let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv + .weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/dropout.rs b/burn-core/src/nn/dropout.rs index 109040e56d..51e4bd14ef 100644 --- a/burn-core/src/nn/dropout.rs +++ b/burn-core/src/nn/dropout.rs @@ -8,8 +8,8 @@ use crate::tensor::{Distribution, Tensor}; /// Configuration to create a [Dropout](Dropout) layer. #[derive(Config, Debug)] pub struct DropoutConfig { - /// The probability of randomly zeroes some elements of the input tensor during training. - pub prob: f64, + /// The probability of randomly zeroes some elements of the input tensor during training. + pub prob: f64, } /// Set at random some elements of the input tensor to zero during training. @@ -20,65 +20,65 @@ pub struct DropoutConfig { /// The input is also scaled during training to `1 / (1 - prob_keep)`. #[derive(Module, Clone, Debug)] pub struct Dropout { - prob: f64, + prob: f64, } impl DropoutConfig { - /// Initialize a new [dropout](Dropout) module. - pub fn init(&self) -> Dropout { - Dropout { prob: self.prob } - } + /// Initialize a new [dropout](Dropout) module. + pub fn init(&self) -> Dropout { + Dropout { prob: self.prob } + } } impl Dropout { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any]` - /// - output: `[..., any]` - pub fn forward(&self, input: Tensor) -> Tensor { - if !B::ad_enabled() || self.prob == 0.0 { - return input; - } - - let prob_keep = 1.0 - self.prob; - let random = input.random_like(Distribution::Bernoulli(prob_keep)); - let x = input * random; - - x * (1.0 / prob_keep) + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any]` + /// - output: `[..., any]` + pub fn forward(&self, input: Tensor) -> Tensor { + if !B::ad_enabled() || self.prob == 0.0 { + return input; } + + let prob_keep = 1.0 - self.prob; + let random = input.random_like(Distribution::Bernoulli(prob_keep)); + let x = input * random; + + x * (1.0 / prob_keep) + } } #[cfg(test)] mod tests { - use super::*; - use crate::tensor::Shape; + use super::*; + use crate::tensor::Shape; - #[cfg(feature = "std")] - use crate::{TestAutodiffBackend, TestBackend}; + #[cfg(feature = "std")] + use crate::{TestAutodiffBackend, TestBackend}; - #[cfg(not(feature = "std"))] - use crate::TestBackend; + #[cfg(not(feature = "std"))] + use crate::TestBackend; - #[cfg(feature = "std")] - #[test] - fn with_ad_backend_should_mark_input() { - let tensor = Tensor::::ones(Shape::new([100, 100])); - let dropout = DropoutConfig::new(0.5).init(); + #[cfg(feature = "std")] + #[test] + fn with_ad_backend_should_mark_input() { + let tensor = Tensor::::ones(Shape::new([100, 100])); + let dropout = DropoutConfig::new(0.5).init(); - let output = dropout.forward(tensor.clone()); + let output = dropout.forward(tensor.clone()); - assert_ne!(tensor.to_data(), output.to_data()); - } + assert_ne!(tensor.to_data(), output.to_data()); + } - #[test] - fn without_ad_backend_should_not_change_input() { - let tensor = Tensor::::ones(Shape::new([100, 100])); - let dropout = DropoutConfig::new(0.5).init(); + #[test] + fn without_ad_backend_should_not_change_input() { + let tensor = Tensor::::ones(Shape::new([100, 100])); + let dropout = DropoutConfig::new(0.5).init(); - let output = dropout.forward(tensor.clone()); + let output = dropout.forward(tensor.clone()); - assert_eq!(tensor.to_data(), output.to_data()); - } + assert_eq!(tensor.to_data(), output.to_data()); + } } diff --git a/burn-core/src/nn/embedding.rs b/burn-core/src/nn/embedding.rs index 49f26eb727..f3deb03a82 100644 --- a/burn-core/src/nn/embedding.rs +++ b/burn-core/src/nn/embedding.rs @@ -11,13 +11,13 @@ use burn_tensor::Int; /// Configuration to create an [Embedding](Embedding) layer. #[derive(Config)] pub struct EmbeddingConfig { - /// The number of embedding vectors. - n_embedding: usize, - /// The size of each vector. - d_model: usize, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")] - pub initializer: Initializer, + /// The number of embedding vectors. + n_embedding: usize, + /// The size of each vector. + d_model: usize, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")] + pub initializer: Initializer, } /// Lookup table to store a fix number of vectors. @@ -28,80 +28,80 @@ pub struct EmbeddingConfig { /// `N(0, 1)` #[derive(Module, Debug)] pub struct Embedding { - weight: Param>, + weight: Param>, } impl EmbeddingConfig { - /// Initialize a new [embedding](Embedding) module. - pub fn init(&self) -> Embedding { - let weight = self - .initializer - .init([self.n_embedding, self.d_model]) - .require_grad(); + /// Initialize a new [embedding](Embedding) module. + pub fn init(&self) -> Embedding { + let weight = self + .initializer + .init([self.n_embedding, self.d_model]) + .require_grad(); - Embedding { - weight: Param::from(weight), - } + Embedding { + weight: Param::from(weight), } - /// Initialize a new [embedding](Embedding) module with a [record](EmbeddingRecord). - pub fn init_with(&self, record: EmbeddingRecord) -> Embedding { - Embedding { - weight: record.weight, - } + } + /// Initialize a new [embedding](Embedding) module with a [record](EmbeddingRecord). + pub fn init_with(&self, record: EmbeddingRecord) -> Embedding { + Embedding { + weight: record.weight, } + } } impl Embedding { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, seq_length] - /// - output: [batch_size, d_model] - pub fn forward(&self, input: Tensor) -> Tensor { - burn_tensor::module::embedding(self.weight.val(), input) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, seq_length] + /// - output: [batch_size, d_model] + pub fn forward(&self, input: Tensor) -> Tensor { + burn_tensor::module::embedding(self.weight.val(), input) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; + use super::*; + use crate::TestBackend; + use burn_tensor::Data; - #[test] - fn initializer_default() { - TestBackend::seed(0); + #[test] + fn initializer_default() { + TestBackend::seed(0); - let config = EmbeddingConfig::new(100, 10); - let embed = config.init::(); - let weights = embed.weight.val().reshape([1000]); - let (var_act, mean_act) = weights.var_mean(0); + let config = EmbeddingConfig::new(100, 10); + let embed = config.init::(); + let weights = embed.weight.val().reshape([1000]); + let (var_act, mean_act) = weights.var_mean(0); - assert_eq!( - config.initializer, - Initializer::Normal { - mean: 0.0, - std: 1.0 - } - ); - var_act.to_data().assert_approx_eq(&Data::from([1.0f32]), 0); - mean_act - .to_data() - .assert_approx_eq(&Data::from([0.0f32]), 0); - } + assert_eq!( + config.initializer, + Initializer::Normal { + mean: 0.0, + std: 1.0 + } + ); + var_act.to_data().assert_approx_eq(&Data::from([1.0f32]), 0); + mean_act + .to_data() + .assert_approx_eq(&Data::from([0.0f32]), 0); + } - #[test] - fn initializer_zeros() { - TestBackend::seed(0); + #[test] + fn initializer_zeros() { + TestBackend::seed(0); - let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros); - let embed = config.init::(); + let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros); + let embed = config.init::(); - assert_eq!(config.initializer, Initializer::Zeros); - embed - .weight - .to_data() - .assert_approx_eq(&Data::zeros(embed.weight.shape()), 3); - } + assert_eq!(config.initializer, Initializer::Zeros); + embed + .weight + .to_data() + .assert_approx_eq(&Data::zeros(embed.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/gelu.rs b/burn-core/src/nn/gelu.rs index 020b6e5ee0..f0c392c4c4 100644 --- a/burn-core/src/nn/gelu.rs +++ b/burn-core/src/nn/gelu.rs @@ -9,18 +9,18 @@ use crate::tensor::Tensor; pub struct GELU {} impl GELU { - /// Create the module. - pub fn new() -> Self { - Self {} - } + /// Create the module. + pub fn new() -> Self { + Self {} + } - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any]` - /// - output: `[..., any]` - pub fn forward(&self, input: Tensor) -> Tensor { - crate::tensor::activation::gelu(input) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any]` + /// - output: `[..., any]` + pub fn forward(&self, input: Tensor) -> Tensor { + crate::tensor::activation::gelu(input) + } } diff --git a/burn-core/src/nn/initializer.rs b/burn-core/src/nn/initializer.rs index 51575b682d..7a97403cef 100644 --- a/burn-core/src/nn/initializer.rs +++ b/burn-core/src/nn/initializer.rs @@ -10,363 +10,351 @@ use crate as burn; /// Enum specifying with what values a tensor should be initialized #[derive(Config, Debug, PartialEq)] pub enum Initializer { - /// Fills tensor with specified value everywhere - Constant { - /// The value to fill the tensor with - value: f64, - }, - /// Fills tensor with 1s everywhere - Ones, - /// Fills tensor with 0s everywhere - Zeros, - /// Fills tensor with values drawn uniformly between specified values - Uniform { - /// The minimum value to draw from - min: f64, - - /// The maximum value to draw from - max: f64, - }, - /// Fills tensor with values drawn from normal distribution with specified mean and std - Normal { - /// The mean of the normal distribution - mean: f64, - - /// The standard deviation of the normal distribution - std: f64, - }, - /// Fills tensor with values according to the uniform version of Kaiming initialization - KaimingUniform { - /// The gain to use in initialization formula - gain: f64, - - /// Whether to use fan out only in initialization formula - fan_out_only: bool, - }, - /// Fills tensor with values according to the uniform version of Kaiming initialization - KaimingNormal { - /// The gain to use in initialization formula - gain: f64, - - /// Whether to use fan out only in initialization formula - fan_out_only: bool, - }, - /// Fills tensor with values according to the uniform version of Xavier Glorot initialization - /// described in [Understanding the difficulty of training deep feedforward neural networks - /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) - XavierUniform { - /// The gain to use in initialization formula - gain: f64, - }, - /// Fills tensor with values according to the normal version of Xavier Glorot initialization - /// described in [Understanding the difficulty of training deep feedforward neural networks - /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) - XavierNormal { - /// The gain to use in initialization formula - gain: f64, - }, + /// Fills tensor with specified value everywhere + Constant { + /// The value to fill the tensor with + value: f64, + }, + /// Fills tensor with 1s everywhere + Ones, + /// Fills tensor with 0s everywhere + Zeros, + /// Fills tensor with values drawn uniformly between specified values + Uniform { + /// The minimum value to draw from + min: f64, + + /// The maximum value to draw from + max: f64, + }, + /// Fills tensor with values drawn from normal distribution with specified mean and std + Normal { + /// The mean of the normal distribution + mean: f64, + + /// The standard deviation of the normal distribution + std: f64, + }, + /// Fills tensor with values according to the uniform version of Kaiming initialization + KaimingUniform { + /// The gain to use in initialization formula + gain: f64, + + /// Whether to use fan out only in initialization formula + fan_out_only: bool, + }, + /// Fills tensor with values according to the uniform version of Kaiming initialization + KaimingNormal { + /// The gain to use in initialization formula + gain: f64, + + /// Whether to use fan out only in initialization formula + fan_out_only: bool, + }, + /// Fills tensor with values according to the uniform version of Xavier Glorot initialization + /// described in [Understanding the difficulty of training deep feedforward neural networks + /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) + XavierUniform { + /// The gain to use in initialization formula + gain: f64, + }, + /// Fills tensor with values according to the normal version of Xavier Glorot initialization + /// described in [Understanding the difficulty of training deep feedforward neural networks + /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) + XavierNormal { + /// The gain to use in initialization formula + gain: f64, + }, } impl Initializer { - /// Inits a tensor of given shape with values depending on initializer kind. - /// - /// # Params - /// - /// - shape: Shape of the initiated tensor. - pub fn init>>(&self, shape: S) -> Tensor { - self.init_with(shape, None, None) - } - - /// Inits a tensor of given shape with values depending on initializer kind, with the possibility - /// of specifying fan in and fan out - /// - /// # Params - /// - /// - shape: Shape of the initiated tensor. - /// - fan_in: `Option`, the fan in to use in initialization formula, if needed - /// - fan_out: `Option`, the fan out to use in initialization formula, if needed - pub fn init_with>>( - &self, - shape: S, - fan_in: Option, - fan_out: Option, - ) -> Tensor { - let shape = shape.into(); - match self { - Initializer::Constant { value } => Tensor::::full(shape, *value), - Initializer::Ones => Tensor::::ones(shape), - Initializer::Zeros => Tensor::::zeros(shape), - Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max), - Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std), - Initializer::KaimingUniform { gain, fan_out_only } => { - let a = sqrt(3.0) * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); - uniform_draw(shape, -a, a) - } - Initializer::KaimingNormal { gain, fan_out_only } => { - let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); - normal_draw(shape, 0.0, std) - } - Initializer::XavierUniform { gain } => { - let a = sqrt(3.0) * *gain * self.xavier_std(fan_in, fan_out); - uniform_draw(shape, -a, a) - } - Initializer::XavierNormal { gain } => { - let std = *gain * self.xavier_std(fan_in, fan_out); - normal_draw(shape, 0.0, std) - } - } + /// Inits a tensor of given shape with values depending on initializer kind. + /// + /// # Params + /// + /// - shape: Shape of the initiated tensor. + pub fn init>>(&self, shape: S) -> Tensor { + self.init_with(shape, None, None) + } + + /// Inits a tensor of given shape with values depending on initializer kind, with the possibility + /// of specifying fan in and fan out + /// + /// # Params + /// + /// - shape: Shape of the initiated tensor. + /// - fan_in: `Option`, the fan in to use in initialization formula, if needed + /// - fan_out: `Option`, the fan out to use in initialization formula, if needed + pub fn init_with>>( + &self, + shape: S, + fan_in: Option, + fan_out: Option, + ) -> Tensor { + let shape = shape.into(); + match self { + Initializer::Constant { value } => Tensor::::full(shape, *value), + Initializer::Ones => Tensor::::ones(shape), + Initializer::Zeros => Tensor::::zeros(shape), + Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max), + Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std), + Initializer::KaimingUniform { gain, fan_out_only } => { + let a = sqrt(3.0) * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); + uniform_draw(shape, -a, a) + } + Initializer::KaimingNormal { gain, fan_out_only } => { + let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); + normal_draw(shape, 0.0, std) + } + Initializer::XavierUniform { gain } => { + let a = sqrt(3.0) * *gain * self.xavier_std(fan_in, fan_out); + uniform_draw(shape, -a, a) + } + Initializer::XavierNormal { gain } => { + let std = *gain * self.xavier_std(fan_in, fan_out); + normal_draw(shape, 0.0, std) + } } + } - fn kaiming_std( - &self, - fan_out_only: bool, - fan_in: Option, - fan_out: Option, - ) -> f64 { - let fan = if fan_out_only { fan_out } else { fan_in }; - let fan = fan.expect( - "Can't use Kaiming initialization without specifying fan. Use init_with method.", - ); + fn kaiming_std(&self, fan_out_only: bool, fan_in: Option, fan_out: Option) -> f64 { + let fan = if fan_out_only { fan_out } else { fan_in }; + let fan = + fan.expect("Can't use Kaiming initialization without specifying fan. Use init_with method."); - 1.0 / sqrt(fan as f64) - } + 1.0 / sqrt(fan as f64) + } - fn xavier_std(&self, fan_in: Option, fan_out: Option) -> f64 { - let fan_in = fan_in.expect( + fn xavier_std(&self, fan_in: Option, fan_out: Option) -> f64 { + let fan_in = fan_in.expect( "Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in.", ); - let fan_out = fan_out.expect( + let fan_out = fan_out.expect( "Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out.", ); - sqrt(2.0 / (fan_in + fan_out) as f64) - } + sqrt(2.0 / (fan_in + fan_out) as f64) + } } fn uniform_draw>>( - shape: S, - low: f64, - high: f64, + shape: S, + low: f64, + high: f64, ) -> Tensor { - let distribution = - Distribution::Uniform(low.elem::(), high.elem::()); - Tensor::::random(shape, distribution) + let distribution = Distribution::Uniform(low.elem::(), high.elem::()); + Tensor::::random(shape, distribution) } fn normal_draw>>( - shape: S, - mean: f64, - std: f64, + shape: S, + mean: f64, + std: f64, ) -> Tensor { - let distribution = Distribution::Normal(mean, std); - Tensor::::random(shape, distribution) + let distribution = Distribution::Normal(mean, std); + Tensor::::random(shape, distribution) } #[cfg(test)] mod tests { - use super::*; - - use burn_tensor::Data; - - pub type TB = burn_ndarray::NdArray; - - fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor) { - let (actual_vars, actual_means) = tensor.clone().var_mean(0); - - for i in 0..tensor.shape().dims[0] { - let actual_var = actual_vars.to_data().value[i] as f64; - let actual_mean = actual_means.to_data().value[i] as f64; - - assert!( - (expected_var - actual_var).abs() <= 0.1, - "Expected variance to be between {expected_var} += 0.1, but got {actual_var}" - ); - assert!( - (expected_mean - actual_mean).abs() <= 0.1, - "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}" - ); - } - } - - #[test] - fn initializer_uniform_init() { - TB::seed(0); - - let (min, max) = (0.0, 1.0); - let uniform = Initializer::Uniform { min, max }; - let tensor: Tensor = uniform.init([2, 2, 2, 2]); - - tensor.into_data().assert_within_range(min..max); - } + use super::*; - #[test] - fn initializer_normal_init() { - // seed random generator - TB::seed(0); - let (mean, std) = (0.0, 1.0); - let normal: Tensor = Initializer::Normal { mean, std }.init([1000]); - let (var_act, mean_act) = normal.var_mean(0); + use burn_tensor::Data; - let var_act: f32 = var_act.into_scalar().elem(); - let mean_act: f32 = mean_act.into_scalar().elem(); + pub type TB = burn_ndarray::NdArray; - assert!( - var_act > 0.9 && var_act < 1.1, - "Expected variance to be between 1.0 += 0.1, but got {var_act}" - ); - assert!( - mean_act > -0.1 && mean_act < 0.1, - "Expected mean to be between 0.0 += 0.1, but got {mean_act}" - ); - } + fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor) { + let (actual_vars, actual_means) = tensor.clone().var_mean(0); - #[test] - fn initializer_constant_init() { - let value = 5.0; - let constants: Tensor = Initializer::Constant { value }.init([2, 2, 2, 2]); - constants - .sum() - .to_data() - .assert_approx_eq(&Data::from([value as f32 * 16.0]), 3); - } - - #[test] - fn initializer_zeros_init() { - let zeros: Tensor = Initializer::Zeros.init([2, 2, 2, 2]); - zeros - .sum() - .to_data() - .assert_approx_eq(&Data::from([0.0]), 3); - } + for i in 0..tensor.shape().dims[0] { + let actual_var = actual_vars.to_data().value[i] as f64; + let actual_mean = actual_means.to_data().value[i] as f64; - #[test] - fn initializer_ones_init() { - let ones: Tensor = Initializer::Ones.init([2, 2, 2, 2]); - ones.sum() - .to_data() - .assert_approx_eq(&Data::from([16.0]), 3); + assert!( + (expected_var - actual_var).abs() <= 0.1, + "Expected variance to be between {expected_var} += 0.1, but got {actual_var}" + ); + assert!( + (expected_mean - actual_mean).abs() <= 0.1, + "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}" + ); } - - #[test] - fn initializer_kaiming_uniform_init() { - TB::seed(0); - - let gain = 2_f64; - let (fan_in, fan_out) = (5, 6); - let k = gain * sqrt(3.0 / fan_in as f64); - - let tensor: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: false, - } - .init_with([fan_out, fan_in], Some(fan_in), None); - tensor.into_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_uniform_init() { + TB::seed(0); + + let (min, max) = (0.0, 1.0); + let uniform = Initializer::Uniform { min, max }; + let tensor: Tensor = uniform.init([2, 2, 2, 2]); + + tensor.into_data().assert_within_range(min..max); + } + + #[test] + fn initializer_normal_init() { + // seed random generator + TB::seed(0); + let (mean, std) = (0.0, 1.0); + let normal: Tensor = Initializer::Normal { mean, std }.init([1000]); + let (var_act, mean_act) = normal.var_mean(0); + + let var_act: f32 = var_act.into_scalar().elem(); + let mean_act: f32 = mean_act.into_scalar().elem(); + + assert!( + var_act > 0.9 && var_act < 1.1, + "Expected variance to be between 1.0 += 0.1, but got {var_act}" + ); + assert!( + mean_act > -0.1 && mean_act < 0.1, + "Expected mean to be between 0.0 += 0.1, but got {mean_act}" + ); + } + + #[test] + fn initializer_constant_init() { + let value = 5.0; + let constants: Tensor = Initializer::Constant { value }.init([2, 2, 2, 2]); + constants + .sum() + .to_data() + .assert_approx_eq(&Data::from([value as f32 * 16.0]), 3); + } + + #[test] + fn initializer_zeros_init() { + let zeros: Tensor = Initializer::Zeros.init([2, 2, 2, 2]); + zeros + .sum() + .to_data() + .assert_approx_eq(&Data::from([0.0]), 3); + } + + #[test] + fn initializer_ones_init() { + let ones: Tensor = Initializer::Ones.init([2, 2, 2, 2]); + ones + .sum() + .to_data() + .assert_approx_eq(&Data::from([16.0]), 3); + } + + #[test] + fn initializer_kaiming_uniform_init() { + TB::seed(0); + + let gain = 2_f64; + let (fan_in, fan_out) = (5, 6); + let k = gain * sqrt(3.0 / fan_in as f64); + + let tensor: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: false, } - - #[test] - fn initializer_kaiming_normal_init() { - TB::seed(0); - - let gain = 2.; - let (fan_in, fan_out) = (1000, 10); - let expected_mean = 0_f64; - - let expected_var = (gain * sqrt(1. / (fan_in as f64))).powf(2.); - let tensor: Tensor = Initializer::KaimingNormal { - gain, - fan_out_only: false, - } - .init_with([fan_out, fan_in], Some(fan_in), None); - assert_normal_init(expected_mean, expected_var, &tensor) + .init_with([fan_out, fan_in], Some(fan_in), None); + tensor.into_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_kaiming_normal_init() { + TB::seed(0); + + let gain = 2.; + let (fan_in, fan_out) = (1000, 10); + let expected_mean = 0_f64; + + let expected_var = (gain * sqrt(1. / (fan_in as f64))).powf(2.); + let tensor: Tensor = Initializer::KaimingNormal { + gain, + fan_out_only: false, } - - #[test] - fn initializer_kaiming_uniform_init_bias() { - TB::seed(0); - - let gain = 2_f64; - let shape = [3]; - let fan_in = 5; - let k = gain * sqrt(3.0 / fan_in as f64); - - let tensor: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: false, - } - .init_with(shape, Some(fan_in), None); - tensor.into_data().assert_within_range(-k..k); + .init_with([fan_out, fan_in], Some(fan_in), None); + assert_normal_init(expected_mean, expected_var, &tensor) + } + + #[test] + fn initializer_kaiming_uniform_init_bias() { + TB::seed(0); + + let gain = 2_f64; + let shape = [3]; + let fan_in = 5; + let k = gain * sqrt(3.0 / fan_in as f64); + + let tensor: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: false, } + .init_with(shape, Some(fan_in), None); + tensor.into_data().assert_within_range(-k..k); + } - #[test] - fn initializer_kaiming_uniform_init_fan_out() { - TB::seed(0); + #[test] + fn initializer_kaiming_uniform_init_fan_out() { + TB::seed(0); - let gain = 2_f64; - let (fan_in, fan_out) = (5, 6); - let k = gain * sqrt(3.0 / fan_out as f64); + let gain = 2_f64; + let (fan_in, fan_out) = (5, 6); + let k = gain * sqrt(3.0 / fan_out as f64); - let tensor: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: true, - } - .init_with([fan_out, fan_in], None, Some(fan_out)); - tensor.into_data().assert_within_range(-k..k); + let tensor: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: true, } + .init_with([fan_out, fan_in], None, Some(fan_out)); + tensor.into_data().assert_within_range(-k..k); + } - #[test] - #[should_panic] - fn initializer_kaiming_uniform_no_fan() { - TB::seed(0); - - let gain = 2_f64; - let (fan_in, fan_out) = (5, 6); - - let _: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: false, - } - .init([fan_out, fan_in]); - } - - #[test] - fn initializer_xavier_uniform_init() { - TB::seed(0); - - let gain = 2.; - let (fan_in, fan_out) = (5, 6); - let bound = gain * sqrt(6. / (fan_in + fan_out) as f64); - let tensor: Tensor = Initializer::XavierUniform { gain }.init_with( - [fan_out, fan_in], - Some(fan_in), - Some(fan_out), - ); - - tensor.into_data().assert_within_range(-bound..bound); - } - - #[test] - fn initializer_xavier_normal_init() { - TB::seed(0); - - let gain = 2.; - let (fan_in, fan_out) = (1000, 10); - let expected_mean = 0_f64; - - let expected_var = (gain * sqrt(2. / (fan_in as f64 + fan_out as f64))).powf(2.); - let tensor: Tensor = Initializer::XavierNormal { gain }.init_with( - [fan_out, fan_in], - Some(fan_in), - Some(fan_out), - ); - assert_normal_init(expected_mean, expected_var, &tensor) - } + #[test] + #[should_panic] + fn initializer_kaiming_uniform_no_fan() { + TB::seed(0); - #[test] - #[should_panic] - fn initializer_xavier_uniform_no_fan() { - TB::seed(0); + let gain = 2_f64; + let (fan_in, fan_out) = (5, 6); - let gain = 2.; - let (fan_in, fan_out) = (5, 6); - let _: Tensor = Initializer::XavierUniform { gain }.init([fan_out, fan_in]); + let _: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: false, } + .init([fan_out, fan_in]); + } + + #[test] + fn initializer_xavier_uniform_init() { + TB::seed(0); + + let gain = 2.; + let (fan_in, fan_out) = (5, 6); + let bound = gain * sqrt(6. / (fan_in + fan_out) as f64); + let tensor: Tensor = + Initializer::XavierUniform { gain }.init_with([fan_out, fan_in], Some(fan_in), Some(fan_out)); + + tensor.into_data().assert_within_range(-bound..bound); + } + + #[test] + fn initializer_xavier_normal_init() { + TB::seed(0); + + let gain = 2.; + let (fan_in, fan_out) = (1000, 10); + let expected_mean = 0_f64; + + let expected_var = (gain * sqrt(2. / (fan_in as f64 + fan_out as f64))).powf(2.); + let tensor: Tensor = + Initializer::XavierNormal { gain }.init_with([fan_out, fan_in], Some(fan_in), Some(fan_out)); + assert_normal_init(expected_mean, expected_var, &tensor) + } + + #[test] + #[should_panic] + fn initializer_xavier_uniform_no_fan() { + TB::seed(0); + + let gain = 2.; + let (fan_in, fan_out) = (5, 6); + let _: Tensor = Initializer::XavierUniform { gain }.init([fan_out, fan_in]); + } } diff --git a/burn-core/src/nn/linear.rs b/burn-core/src/nn/linear.rs index 0b3ef20db2..266bc8eea7 100644 --- a/burn-core/src/nn/linear.rs +++ b/burn-core/src/nn/linear.rs @@ -11,16 +11,16 @@ use super::Initializer; /// Configuration to create a [Linear](Linear) layer. #[derive(Config, Debug)] pub struct LinearConfig { - /// The size of the input features. - pub d_input: usize, - /// The size of the output features. - pub d_output: usize, - /// If a bias should be applied during the linear transformation. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0), fan_out_only:false}")] - pub initializer: Initializer, + /// The size of the input features. + pub d_input: usize, + /// The size of the output features. + pub d_output: usize, + /// If a bias should be applied during the linear transformation. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0), fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a linear transformation to the input tensor: @@ -28,131 +28,131 @@ pub struct LinearConfig { /// `O = IW + b` #[derive(Module, Debug)] pub struct Linear { - /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution: - /// `U(-k, k)`, where `k = sqrt(1 / d_input)` - pub weight: Param>, - /// Vector of size `d_output` initialized from a uniform distribution: - /// `U(-k, k)`, where `k = sqrt(1 / d_input)` - pub bias: Option>>, + /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution: + /// `U(-k, k)`, where `k = sqrt(1 / d_input)` + pub weight: Param>, + /// Vector of size `d_output` initialized from a uniform distribution: + /// `U(-k, k)`, where `k = sqrt(1 / d_input)` + pub bias: Option>>, } impl LinearConfig { - /// Initialize a new [linear](Linear) module. - pub fn init(&self) -> Linear { - let shape = [self.d_input, self.d_output]; - let weight = self - .initializer - .init_with(shape, Some(self.d_input), Some(self.d_output)); - let bias = if self.bias { - Some(self.initializer.init_with( - [self.d_output], - Some(self.d_input), - Some(self.d_output), - )) - } else { - None - }; - - Linear { - weight: Param::from(weight), - bias: bias.map(Param::from), - } + /// Initialize a new [linear](Linear) module. + pub fn init(&self) -> Linear { + let shape = [self.d_input, self.d_output]; + let weight = self + .initializer + .init_with(shape, Some(self.d_input), Some(self.d_output)); + let bias = if self.bias { + Some( + self + .initializer + .init_with([self.d_output], Some(self.d_input), Some(self.d_output)), + ) + } else { + None + }; + + Linear { + weight: Param::from(weight), + bias: bias.map(Param::from), } + } - /// Initialize a new [linear](Linear) module with a [record](LinearRecord). - pub fn init_with(&self, record: LinearRecord) -> Linear { - Linear { - weight: record.weight, - bias: record.bias, - } + /// Initialize a new [linear](Linear) module with a [record](LinearRecord). + pub fn init_with(&self, record: LinearRecord) -> Linear { + Linear { + weight: record.weight, + bias: record.bias, } + } } impl Linear { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any, d_input]` - /// - output: `[..., any, d_output]` - pub fn forward(&self, input: Tensor) -> Tensor { - let output = input.matmul(self.weight.val().unsqueeze()); - - match &self.bias { - Some(bias) => output + bias.val().unsqueeze(), - None => output, - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any, d_input]` + /// - output: `[..., any, d_output]` + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.matmul(self.weight.val().unsqueeze()); + + match &self.bias { + Some(bias) => output + bias.val().unsqueeze(), + None => output, } + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::{Data, Shape}; - use libm::sqrt; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = LinearConfig::new(5, 5); - let k = sqrt(1.0 / config.d_input as f64) as f32; - let linear = config.init::(); - - assert_eq!( - config.initializer, - Initializer::KaimingUniform { - gain: 1.0 / sqrt(3.0), - fan_out_only: false - } - ); - linear.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros); - let linear = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - linear - .weight - .to_data() - .assert_approx_eq(&Data::zeros(linear.weight.shape()), 3); - } - - #[test] - fn test_linear_forward_no_bias() { - TestBackend::seed(0); - - let value = 2.; - let config = LinearConfig::new(2, 3) - .with_initializer(Initializer::Constant { value }) - .with_bias(false); - let linear = config.init(); - - let input = Tensor::::ones(Shape::new([1, 2])); - let result = linear.forward(input); - let expected_result = Tensor::::from_data([[4., 4., 4.]]); - - assert_eq!(result.into_data(), expected_result.into_data()); - } - - #[test] - fn test_linear_forward_with_bias() { - TestBackend::seed(0); - - let value = 2.; - let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); - let linear = config.init(); - - let input = Tensor::::ones(Shape::new([1, 2])); - let result = linear.forward(input); - let expected_result = Tensor::::from_data([[6., 6., 6.]]); - - assert_eq!(result.into_data(), expected_result.into_data()); - } + use super::*; + use crate::TestBackend; + use burn_tensor::{Data, Shape}; + use libm::sqrt; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = LinearConfig::new(5, 5); + let k = sqrt(1.0 / config.d_input as f64) as f32; + let linear = config.init::(); + + assert_eq!( + config.initializer, + Initializer::KaimingUniform { + gain: 1.0 / sqrt(3.0), + fan_out_only: false + } + ); + linear.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros); + let linear = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + linear + .weight + .to_data() + .assert_approx_eq(&Data::zeros(linear.weight.shape()), 3); + } + + #[test] + fn test_linear_forward_no_bias() { + TestBackend::seed(0); + + let value = 2.; + let config = LinearConfig::new(2, 3) + .with_initializer(Initializer::Constant { value }) + .with_bias(false); + let linear = config.init(); + + let input = Tensor::::ones(Shape::new([1, 2])); + let result = linear.forward(input); + let expected_result = Tensor::::from_data([[4., 4., 4.]]); + + assert_eq!(result.into_data(), expected_result.into_data()); + } + + #[test] + fn test_linear_forward_with_bias() { + TestBackend::seed(0); + + let value = 2.; + let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); + let linear = config.init(); + + let input = Tensor::::ones(Shape::new([1, 2])); + let result = linear.forward(input); + let expected_result = Tensor::::from_data([[6., 6., 6.]]); + + assert_eq!(result.into_data(), expected_result.into_data()); + } } diff --git a/burn-core/src/nn/loss/binary_cross_entropy.rs b/burn-core/src/nn/loss/binary_cross_entropy.rs index 506e34ab71..dcca0373b5 100644 --- a/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -7,170 +7,170 @@ use burn_tensor::{backend::Backend, Int, Tensor}; /// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss). #[derive(Config, Debug)] pub struct BinaryCrossEntropyLossConfig { - /// Create weighted binary cross-entropy. - /// - /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, - /// - /// # Pre-conditions - /// - The order of the weight vector should correspond to the label integer assignment. - /// - Targets assigned negative Int's will not be allowed. - pub weights: Option<[f32; 2]>, - - /// Create binary cross-entropy with label smoothing. - /// - /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. - /// Alpha = 0 would be the same as default. - smoothing: Option, - - /// Create binary cross-entropy with probabilities as input instead of logits. - /// - #[config(default = true)] - logits: bool, + /// Create weighted binary cross-entropy. + /// + /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, + /// + /// # Pre-conditions + /// - The order of the weight vector should correspond to the label integer assignment. + /// - Targets assigned negative Int's will not be allowed. + pub weights: Option<[f32; 2]>, + + /// Create binary cross-entropy with label smoothing. + /// + /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. + /// Alpha = 0 would be the same as default. + smoothing: Option, + + /// Create binary cross-entropy with probabilities as input instead of logits. + /// + #[config(default = true)] + logits: bool, } impl BinaryCrossEntropyLossConfig { - /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss). - pub fn init(&self) -> BinaryCrossEntropyLoss { - self.assertions(); - BinaryCrossEntropyLoss { - weights: self - .weights - .as_ref() - .map(|e| Tensor::::from_floats(e.as_slice())), - smoothing: self.smoothing, - logits: self.logits, - } + /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss). + pub fn init(&self) -> BinaryCrossEntropyLoss { + self.assertions(); + BinaryCrossEntropyLoss { + weights: self + .weights + .as_ref() + .map(|e| Tensor::::from_floats(e.as_slice())), + smoothing: self.smoothing, + logits: self.logits, } - - fn assertions(&self) { - if let Some(alpha) = self.smoothing { - assert!( - (0.0..=1.).contains(&alpha), - "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", - alpha - ); - }; - if let Some(weights) = self.weights.as_ref() { - assert!( - weights.iter().all(|e| e > &0.), - "Weights of cross-entropy have to be positive." - ); - } + } + + fn assertions(&self) { + if let Some(alpha) = self.smoothing { + assert!( + (0.0..=1.).contains(&alpha), + "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", + alpha + ); + }; + if let Some(weights) = self.weights.as_ref() { + assert!( + weights.iter().all(|e| e > &0.), + "Weights of cross-entropy have to be positive." + ); } + } } /// Calculate the cross entropy loss from the input logits and the targets. #[derive(Module, Debug)] pub struct BinaryCrossEntropyLoss { - /// Weights for cross-entropy. - pub weights: Option>, - smoothing: Option, - logits: bool, + /// Weights for cross-entropy. + pub weights: Option>, + smoothing: Option, + logits: bool, } impl Default for BinaryCrossEntropyLoss { - fn default() -> Self { - BinaryCrossEntropyLossConfig::new().init() - } + fn default() -> Self { + BinaryCrossEntropyLossConfig::new().init() + } } impl BinaryCrossEntropyLoss { - /// Compute the criterion on the input tensor. - /// - /// # Shapes - /// - /// - logits: `[batch_size]` - /// - targets: `[batch_size]` - pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { - Self::assertions(logits.clone(), targets.clone()); - let mut targets_float = targets.clone().float(); - if let Some(alpha) = self.smoothing { - targets_float = targets_float * (1. - alpha) + alpha / 2.; - } - let logits = if self.logits { sigmoid(logits) } else { logits }; - let loss = targets_float.clone() * logits.clone().log() - + (targets_float.clone().neg() + 1.) * (logits.neg() + 1.).log(); - - match &self.weights { - Some(weights) => { - let weights = weights.clone().gather(0, targets); - let loss = loss * weights.clone(); - loss.neg().sum() / weights.sum() - } - None => loss.mean().neg(), - } + /// Compute the criterion on the input tensor. + /// + /// # Shapes + /// + /// - logits: `[batch_size]` + /// - targets: `[batch_size]` + pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { + Self::assertions(logits.clone(), targets.clone()); + let mut targets_float = targets.clone().float(); + if let Some(alpha) = self.smoothing { + targets_float = targets_float * (1. - alpha) + alpha / 2.; } - - fn assertions(logits: Tensor, targets: Tensor) { - let [logits_height] = logits.dims(); - let [targets_height] = targets.dims(); - assert!( - logits_height == targets_height, - "Shape of targets ({}) should correspond to outer shape of logits ({}).", - targets_height, - logits_height - ); + let logits = if self.logits { sigmoid(logits) } else { logits }; + let loss = targets_float.clone() * logits.clone().log() + + (targets_float.clone().neg() + 1.) * (logits.neg() + 1.).log(); + + match &self.weights { + Some(weights) => { + let weights = weights.clone().gather(0, targets); + let loss = loss * weights.clone(); + loss.neg().sum() / weights.sum() + } + None => loss.mean().neg(), } + } + + fn assertions(logits: Tensor, targets: Tensor) { + let [logits_height] = logits.dims(); + let [targets_height] = targets.dims(); + assert!( + logits_height == targets_height, + "Shape of targets ({}) should correspond to outer shape of logits ({}).", + targets_height, + logits_height + ); + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::{activation::sigmoid, Data, Distribution}; - - #[test] - fn test_binary_cross_entropy() { - let [batch_size] = [4]; - let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); - - let loss_1 = BinaryCrossEntropyLossConfig::new() - .init() - .forward(logits.clone(), targets.clone()); - let logits = sigmoid(logits); - let loss_2 = targets.clone().float() * logits.clone().log() - + (-targets.float() + 1) * (-logits + 1).log(); - let loss_2 = loss_2.mean().neg(); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_binary_cross_entropy_with_weights() { - let [batch_size] = [4]; - let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); - let weights = [3., 7.]; - - let loss_1 = BinaryCrossEntropyLossConfig::new() - .with_weights(Some(weights)) - .init() - .forward(logits.clone(), targets.clone()); - let logits = sigmoid(logits); - let loss_2 = targets.clone().float() * logits.clone().log() - + (-targets.float() + 1) * (-logits + 1).log(); - - let loss_2 = loss_2 * Tensor::from_floats([3., 7., 3., 7.]); - let loss_2 = loss_2.neg().sum() / (3. + 3. + 7. + 7.); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_binary_cross_entropy_with_smoothing() { - let [batch_size] = [4]; - let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); - - let loss_1 = BinaryCrossEntropyLossConfig::new() - .with_smoothing(Some(0.1)) - .init() - .forward(logits.clone(), targets.clone()); - - let logits = sigmoid(logits); - let targets = targets.float() * (1. - 0.1) + 0.1 / 2.; - let loss_2 = targets.clone() * logits.clone().log() + (-targets + 1) * (-logits + 1).log(); - let loss_2 = loss_2.mean().neg(); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::{activation::sigmoid, Data, Distribution}; + + #[test] + fn test_binary_cross_entropy() { + let [batch_size] = [4]; + let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); + + let loss_1 = BinaryCrossEntropyLossConfig::new() + .init() + .forward(logits.clone(), targets.clone()); + let logits = sigmoid(logits); + let loss_2 = + targets.clone().float() * logits.clone().log() + (-targets.float() + 1) * (-logits + 1).log(); + let loss_2 = loss_2.mean().neg(); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_binary_cross_entropy_with_weights() { + let [batch_size] = [4]; + let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); + let weights = [3., 7.]; + + let loss_1 = BinaryCrossEntropyLossConfig::new() + .with_weights(Some(weights)) + .init() + .forward(logits.clone(), targets.clone()); + let logits = sigmoid(logits); + let loss_2 = + targets.clone().float() * logits.clone().log() + (-targets.float() + 1) * (-logits + 1).log(); + + let loss_2 = loss_2 * Tensor::from_floats([3., 7., 3., 7.]); + let loss_2 = loss_2.neg().sum() / (3. + 3. + 7. + 7.); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_binary_cross_entropy_with_smoothing() { + let [batch_size] = [4]; + let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); + + let loss_1 = BinaryCrossEntropyLossConfig::new() + .with_smoothing(Some(0.1)) + .init() + .forward(logits.clone(), targets.clone()); + + let logits = sigmoid(logits); + let targets = targets.float() * (1. - 0.1) + 0.1 / 2.; + let loss_2 = targets.clone() * logits.clone().log() + (-targets + 1) * (-logits + 1).log(); + let loss_2 = loss_2.mean().neg(); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/loss/cross_entropy.rs b/burn-core/src/nn/loss/cross_entropy.rs index 99e6fd5cc6..f493984965 100644 --- a/burn-core/src/nn/loss/cross_entropy.rs +++ b/burn-core/src/nn/loss/cross_entropy.rs @@ -9,382 +9,377 @@ use burn_tensor::{backend::Backend, Bool, Int, Tensor}; /// Configuration to create a [Cross-entropy loss](CrossEntropyLoss). #[derive(Config, Debug)] pub struct CrossEntropyLossConfig { - /// Create padded cross entropy. - /// - /// Prevents pad tokens from impacting loss calculation. - pad_tokens: Option>, - - /// Create weighted cross-entropy. - /// - /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, - /// - /// # Pre-conditions - /// - The order of the weight vector should correspond to the label integer assignment. - /// - Targets assigned negative Int's will not be allowed. - weights: Option>, - - /// Create cross-entropy with label smoothing. - /// - /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. - /// Alpha = 0 would be the same as default. - smoothing: Option, - - /// Create cross-entropy with probabilities as input instead of logits. - /// - #[config(default = true)] - logits: bool, + /// Create padded cross entropy. + /// + /// Prevents pad tokens from impacting loss calculation. + pad_tokens: Option>, + + /// Create weighted cross-entropy. + /// + /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, + /// + /// # Pre-conditions + /// - The order of the weight vector should correspond to the label integer assignment. + /// - Targets assigned negative Int's will not be allowed. + weights: Option>, + + /// Create cross-entropy with label smoothing. + /// + /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. + /// Alpha = 0 would be the same as default. + smoothing: Option, + + /// Create cross-entropy with probabilities as input instead of logits. + /// + #[config(default = true)] + logits: bool, } impl CrossEntropyLossConfig { - /// Initialize [Cross-entropy loss](CrossEntropyLoss). - pub fn init(&self) -> CrossEntropyLoss { - self.assertions(); - CrossEntropyLoss { - pad_tokens: self.pad_tokens.clone(), - weights: self - .weights - .as_ref() - .map(|e| Tensor::::from_floats(e.as_slice())), - smoothing: self.smoothing, - logits: self.logits, - } + /// Initialize [Cross-entropy loss](CrossEntropyLoss). + pub fn init(&self) -> CrossEntropyLoss { + self.assertions(); + CrossEntropyLoss { + pad_tokens: self.pad_tokens.clone(), + weights: self + .weights + .as_ref() + .map(|e| Tensor::::from_floats(e.as_slice())), + smoothing: self.smoothing, + logits: self.logits, } - - fn assertions(&self) { - if let Some(alpha) = self.smoothing { - assert!( - (0.0..=1.).contains(&alpha), - "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", - alpha - ); - }; - if let Some(weights) = self.weights.as_ref() { - assert!( - weights.iter().all(|e| e > &0.), - "Weights of cross-entropy have to be positive." - ); - } + } + + fn assertions(&self) { + if let Some(alpha) = self.smoothing { + assert!( + (0.0..=1.).contains(&alpha), + "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", + alpha + ); + }; + if let Some(weights) = self.weights.as_ref() { + assert!( + weights.iter().all(|e| e > &0.), + "Weights of cross-entropy have to be positive." + ); } + } } /// Calculate the cross entropy loss from the input logits and the targets. #[derive(Module, Debug)] pub struct CrossEntropyLoss { - pad_tokens: Option>, - /// Weights for cross-entropy. - pub weights: Option>, - smoothing: Option, - logits: bool, + pad_tokens: Option>, + /// Weights for cross-entropy. + pub weights: Option>, + smoothing: Option, + logits: bool, } impl Default for CrossEntropyLoss { - fn default() -> Self { - CrossEntropyLossConfig::new().init() - } + fn default() -> Self { + CrossEntropyLossConfig::new().init() + } } impl CrossEntropyLoss { - /// For backward compatibility. - pub fn new(pad_index: Option) -> Self { - CrossEntropyLossConfig::new() - .with_pad_tokens(pad_index.map(|e| vec![e])) - .init() - } - - /// Compute the criterion on the input tensor. - /// - /// # Shapes - /// - /// - logits: `[batch_size, num_targets]` - /// - targets: `[batch_size]` - pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { - Self::assertions(logits.clone(), targets.clone()); - match self.smoothing { - Some(alpha) => self.forward_smoothed(logits, targets, alpha), - _ => self.forward_default(logits, targets), - } + /// For backward compatibility. + pub fn new(pad_index: Option) -> Self { + CrossEntropyLossConfig::new() + .with_pad_tokens(pad_index.map(|e| vec![e])) + .init() + } + + /// Compute the criterion on the input tensor. + /// + /// # Shapes + /// + /// - logits: `[batch_size, num_targets]` + /// - targets: `[batch_size]` + pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { + Self::assertions(logits.clone(), targets.clone()); + match self.smoothing { + Some(alpha) => self.forward_smoothed(logits, targets, alpha), + _ => self.forward_default(logits, targets), } - - fn forward_smoothed( - &self, - logits: Tensor, - targets: Tensor, - alpha: f32, - ) -> Tensor { - let mask = self.padding_mask(&targets); - let tensor = if self.logits { - log_softmax(logits, 1) - } else { - logits.log() - }; - let [batch_size, nr_classes] = tensor.dims(); + } + + fn forward_smoothed( + &self, + logits: Tensor, + targets: Tensor, + alpha: f32, + ) -> Tensor { + let mask = self.padding_mask(&targets); + let tensor = if self.logits { + log_softmax(logits, 1) + } else { + logits.log() + }; + let [batch_size, nr_classes] = tensor.dims(); + let tensor = + tensor * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha); + + match &self.weights { + Some(weights) => { let tensor = tensor - * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha); - - match &self.weights { - Some(weights) => { - let tensor = tensor - * weights - .clone() - .reshape([1, nr_classes]) - .repeat(0, batch_size); - let weights = weights.clone().gather(0, targets); - let tensor = Self::apply_mask_2d(tensor, mask); - tensor.sum().neg() / weights.sum() - } - None => { - let tensor = Self::apply_mask_2d(tensor, mask); - tensor.sum_dim(1).mean().neg() - } - } - } - - fn forward_default(&self, logits: Tensor, targets: Tensor) -> Tensor { - let [batch_size] = targets.dims(); - - let mask = self.padding_mask(&targets); - let tensor = log_softmax(logits, 1); - let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1])); - - match &self.weights { - Some(weights) => { - let weights = weights.clone().gather(0, targets); - let tensor = tensor.reshape([batch_size]) * weights.clone(); - let tensor = Self::apply_mask_1d(tensor, mask); - tensor.sum().neg() / weights.sum() - } - None => { - let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask); - tensor.mean().neg() - } - } + * weights + .clone() + .reshape([1, nr_classes]) + .repeat(0, batch_size); + let weights = weights.clone().gather(0, targets); + let tensor = Self::apply_mask_2d(tensor, mask); + tensor.sum().neg() / weights.sum() + } + None => { + let tensor = Self::apply_mask_2d(tensor, mask); + tensor.sum_dim(1).mean().neg() + } } - - fn compute_smoothed_targets( - shape: [usize; 2], - targets: Tensor, - alpha: f32, - ) -> Tensor { - let [batch_size, nr_classes] = shape; - let device = &targets.device(); - let targets_matrix = Tensor::::zeros_device(shape, device).scatter( - 1, - targets.reshape([batch_size, 1]), - Tensor::ones_device([batch_size, 1], device), - ); - targets_matrix * (1. - alpha) + alpha / nr_classes as f32 + } + + fn forward_default(&self, logits: Tensor, targets: Tensor) -> Tensor { + let [batch_size] = targets.dims(); + + let mask = self.padding_mask(&targets); + let tensor = log_softmax(logits, 1); + let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1])); + + match &self.weights { + Some(weights) => { + let weights = weights.clone().gather(0, targets); + let tensor = tensor.reshape([batch_size]) * weights.clone(); + let tensor = Self::apply_mask_1d(tensor, mask); + tensor.sum().neg() / weights.sum() + } + None => { + let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask); + tensor.mean().neg() + } } - - fn padding_mask(&self, targets: &Tensor) -> Option> { - let mut mask = None; - if let Some(pad_tokens) = &self.pad_tokens { - let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int(); - for x in pad_tokens { - res = res + targets.clone().equal_elem(*x as i64).int(); - } - mask = Some(res.greater_elem(0)); - } - - mask + } + + fn compute_smoothed_targets( + shape: [usize; 2], + targets: Tensor, + alpha: f32, + ) -> Tensor { + let [batch_size, nr_classes] = shape; + let device = &targets.device(); + let targets_matrix = Tensor::::zeros_device(shape, device).scatter( + 1, + targets.reshape([batch_size, 1]), + Tensor::ones_device([batch_size, 1], device), + ); + targets_matrix * (1. - alpha) + alpha / nr_classes as f32 + } + + fn padding_mask(&self, targets: &Tensor) -> Option> { + let mut mask = None; + if let Some(pad_tokens) = &self.pad_tokens { + let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int(); + for x in pad_tokens { + res = res + targets.clone().equal_elem(*x as i64).int(); + } + mask = Some(res.greater_elem(0)); } - fn apply_mask_1d(mut tensor: Tensor, mask: Option>) -> Tensor { - if let Some(mask) = mask { - tensor = tensor.mask_fill(mask, 0); - } + mask + } - tensor + fn apply_mask_1d(mut tensor: Tensor, mask: Option>) -> Tensor { + if let Some(mask) = mask { + tensor = tensor.mask_fill(mask, 0); } - fn apply_mask_2d(mut tensor: Tensor, mask: Option>) -> Tensor { - if let Some(mask) = mask { - let [batch_size, nr_classes] = tensor.dims(); - tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0); - } + tensor + } - tensor + fn apply_mask_2d(mut tensor: Tensor, mask: Option>) -> Tensor { + if let Some(mask) = mask { + let [batch_size, nr_classes] = tensor.dims(); + tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0); } - fn assertions(logits: Tensor, targets: Tensor) { - let [logits_height, _] = logits.dims(); - let [targets_height] = targets.dims(); - assert!( - logits_height == targets_height, - "Shape of targets ({}) should correspond to outer shape of logits ({}).", - targets_height, - logits_height - ); - } + tensor + } + + fn assertions(logits: Tensor, targets: Tensor) { + let [logits_height, _] = logits.dims(); + let [targets_height] = targets.dims(); + assert!( + logits_height == targets_height, + "Shape of targets ({}) should correspond to outer shape of logits ({}).", + targets_height, + logits_height + ); + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::{loss::cross_entropy_with_logits, Data, Distribution}; - - macro_rules! setup { - () => {{ - let [batch_size, num_targets] = [4, 5]; - let logits = Tensor::::random( - [batch_size, num_targets], - Distribution::Normal(0., 1.0), - ); - let targets = Tensor::::from_data(Data::from([2, 0, 4, 1])); - let targets_logits = Tensor::::from_data(Data::from([ - [0.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - ])); - (logits, targets, targets_logits) - }}; - } - - macro_rules! setup_padded { - () => {{ - let [batch_size, num_targets, pad_index] = [4, 5, 1]; - let logits = Tensor::::random( - [batch_size, num_targets], - Distribution::Normal(0., 1.0), - ); - let targets = Tensor::::from_data( - Data::::from([2, 0, 4, pad_index as i64]).convert(), - ); - let targets_logits = Tensor::::from_data(Data::from([ - [0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ])); - (logits, targets, targets_logits) - }}; - } - - #[test] - fn test_cross_entropy_loss_with_weights() { - let (logits, targets, targets_logits) = setup!(); - let weights = vec![1.0, 2., 3., 4., 5.]; - let loss_1 = CrossEntropyLossConfig::new() - .with_weights(Some(weights.clone())) - .init() - .forward(logits.clone(), targets); - let tensor = log_softmax(logits, 1); - let loss_2 = tensor - * targets_logits - * Tensor::::from_floats(weights.as_slice()) - .unsqueeze() - .repeat(0, 4); - let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_with_weights_and_alpha_zero() { - let (logits, targets, _) = setup!(); - let weights = vec![1.0, 2., 3., 4., 5.]; - let loss_1 = CrossEntropyLossConfig::new() - .with_weights(Some(weights.clone())) - .init() - .forward(logits.clone(), targets.clone()); - let loss_2 = CrossEntropyLossConfig::new() - .with_weights(Some(weights.clone())) - .with_smoothing(Some(0.)) - .init() - .forward(logits.clone(), targets); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_cross_entropy_loss() { - let (logits, targets, targets_logits) = setup!(); - let loss_1 = CrossEntropyLossConfig::new() - .init() - .forward(logits.clone(), targets); - let loss_2 = cross_entropy_with_logits(logits, targets_logits); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_alpha_equal_zero() { - let (logits, targets, _) = setup!(); - let loss_1 = CrossEntropyLossConfig::new() - .init() - .forward(logits.clone(), targets.clone()); - let loss_2 = CrossEntropyLossConfig::new() - .with_smoothing(Some(0.)) - .init() - .forward(logits, targets); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_cross_entropy_loss_with_pad_token() { - let (logits, targets, targets_logits) = setup_padded!(); - let pad_index = 1; - - let loss_1 = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![pad_index, 2])) - .init() - .forward(logits.clone(), targets); - let loss_2 = cross_entropy_with_logits(logits, targets_logits); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_with_zero_alpha_and_pad_token() { - let (logits, targets, _) = setup_padded!(); - let pad_index = 1; - - let loss_1 = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![pad_index, 2])) - .init() - .forward(logits.clone(), targets.clone()); - let loss_2 = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![pad_index, 2])) - .with_smoothing(Some(0.)) - .init() - .forward(logits.clone(), targets); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_target_conversion() { - let (logits, targets, _) = setup!(); - let smoothed_targets = - CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05); - let targets_logits = Tensor::::from_data(Data::from([ - [0.01, 0.01, 0.96, 0.01, 0.01], - [0.96, 0.01, 0.01, 0.01, 0.01], - [0.01, 0.01, 0.01, 0.01, 0.96], - [0.01, 0.96, 0.01, 0.01, 0.01], - ])); - smoothed_targets - .into_data() - .assert_approx_eq(&targets_logits.into_data(), 3); - } - - #[test] - fn test_label_smoothing() { - let (logits, targets, _) = setup!(); - let loss_1 = CrossEntropyLossConfig::new() - .with_smoothing(Some(0.05)) - .init() - .forward(logits.clone(), targets); - let targets_logits = Tensor::::from_data(Data::from([ - [0.01, 0.01, 0.96, 0.01, 0.01], - [0.96, 0.01, 0.01, 0.01, 0.01], - [0.01, 0.01, 0.01, 0.01, 0.96], - [0.01, 0.96, 0.01, 0.01, 0.01], - ])); - - let x = log_softmax(logits, 1); - let loss_2 = (x * targets_logits).sum_dim(1).mean().neg(); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::{loss::cross_entropy_with_logits, Data, Distribution}; + + macro_rules! setup { + () => {{ + let [batch_size, num_targets] = [4, 5]; + let logits = + Tensor::::random([batch_size, num_targets], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data(Data::from([2, 0, 4, 1])); + let targets_logits = Tensor::::from_data(Data::from([ + [0.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ])); + (logits, targets, targets_logits) + }}; + } + + macro_rules! setup_padded { + () => {{ + let [batch_size, num_targets, pad_index] = [4, 5, 1]; + let logits = + Tensor::::random([batch_size, num_targets], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data( + Data::::from([2, 0, 4, pad_index as i64]).convert(), + ); + let targets_logits = Tensor::::from_data(Data::from([ + [0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ])); + (logits, targets, targets_logits) + }}; + } + + #[test] + fn test_cross_entropy_loss_with_weights() { + let (logits, targets, targets_logits) = setup!(); + let weights = vec![1.0, 2., 3., 4., 5.]; + let loss_1 = CrossEntropyLossConfig::new() + .with_weights(Some(weights.clone())) + .init() + .forward(logits.clone(), targets); + let tensor = log_softmax(logits, 1); + let loss_2 = tensor + * targets_logits + * Tensor::::from_floats(weights.as_slice()) + .unsqueeze() + .repeat(0, 4); + let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_with_weights_and_alpha_zero() { + let (logits, targets, _) = setup!(); + let weights = vec![1.0, 2., 3., 4., 5.]; + let loss_1 = CrossEntropyLossConfig::new() + .with_weights(Some(weights.clone())) + .init() + .forward(logits.clone(), targets.clone()); + let loss_2 = CrossEntropyLossConfig::new() + .with_weights(Some(weights.clone())) + .with_smoothing(Some(0.)) + .init() + .forward(logits.clone(), targets); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_cross_entropy_loss() { + let (logits, targets, targets_logits) = setup!(); + let loss_1 = CrossEntropyLossConfig::new() + .init() + .forward(logits.clone(), targets); + let loss_2 = cross_entropy_with_logits(logits, targets_logits); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_alpha_equal_zero() { + let (logits, targets, _) = setup!(); + let loss_1 = CrossEntropyLossConfig::new() + .init() + .forward(logits.clone(), targets.clone()); + let loss_2 = CrossEntropyLossConfig::new() + .with_smoothing(Some(0.)) + .init() + .forward(logits, targets); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_cross_entropy_loss_with_pad_token() { + let (logits, targets, targets_logits) = setup_padded!(); + let pad_index = 1; + + let loss_1 = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![pad_index, 2])) + .init() + .forward(logits.clone(), targets); + let loss_2 = cross_entropy_with_logits(logits, targets_logits); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_with_zero_alpha_and_pad_token() { + let (logits, targets, _) = setup_padded!(); + let pad_index = 1; + + let loss_1 = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![pad_index, 2])) + .init() + .forward(logits.clone(), targets.clone()); + let loss_2 = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![pad_index, 2])) + .with_smoothing(Some(0.)) + .init() + .forward(logits.clone(), targets); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_target_conversion() { + let (logits, targets, _) = setup!(); + let smoothed_targets = CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05); + let targets_logits = Tensor::::from_data(Data::from([ + [0.01, 0.01, 0.96, 0.01, 0.01], + [0.96, 0.01, 0.01, 0.01, 0.01], + [0.01, 0.01, 0.01, 0.01, 0.96], + [0.01, 0.96, 0.01, 0.01, 0.01], + ])); + smoothed_targets + .into_data() + .assert_approx_eq(&targets_logits.into_data(), 3); + } + + #[test] + fn test_label_smoothing() { + let (logits, targets, _) = setup!(); + let loss_1 = CrossEntropyLossConfig::new() + .with_smoothing(Some(0.05)) + .init() + .forward(logits.clone(), targets); + let targets_logits = Tensor::::from_data(Data::from([ + [0.01, 0.01, 0.96, 0.01, 0.01], + [0.96, 0.01, 0.01, 0.01, 0.01], + [0.01, 0.01, 0.01, 0.01, 0.96], + [0.01, 0.96, 0.01, 0.01, 0.01], + ])); + + let x = log_softmax(logits, 1); + let loss_2 = (x * targets_logits).sum_dim(1).mean().neg(); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/loss/mse.rs b/burn-core/src/nn/loss/mse.rs index 4ab95cd7d3..277bea3c36 100644 --- a/burn-core/src/nn/loss/mse.rs +++ b/burn-core/src/nn/loss/mse.rs @@ -6,74 +6,74 @@ use burn_tensor::{backend::Backend, Tensor}; /// Calculate the mean squared error loss from the input logits and the targets. #[derive(Clone, Debug)] pub struct MSELoss { - backend: PhantomData, + backend: PhantomData, } impl Default for MSELoss { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl MSELoss { - /// Create the criterion. - pub fn new() -> Self { - Self { - backend: PhantomData, - } + /// Create the criterion. + pub fn new() -> Self { + Self { + backend: PhantomData, } + } - /// Compute the criterion on the input tensor. - /// - /// # Shapes - /// - /// - logits: [batch_size, num_targets] - /// - targets: [batch_size, num_targets] - pub fn forward( - &self, - logits: Tensor, - targets: Tensor, - reduction: Reduction, - ) -> Tensor { - let tensor = self.forward_no_reduction(logits, targets); - match reduction { - Reduction::Mean | Reduction::Auto => tensor.mean(), - Reduction::Sum => tensor.sum(), - } + /// Compute the criterion on the input tensor. + /// + /// # Shapes + /// + /// - logits: [batch_size, num_targets] + /// - targets: [batch_size, num_targets] + pub fn forward( + &self, + logits: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let tensor = self.forward_no_reduction(logits, targets); + match reduction { + Reduction::Mean | Reduction::Auto => tensor.mean(), + Reduction::Sum => tensor.sum(), } + } - /// Compute the criterion on the input tensor without reducing. - pub fn forward_no_reduction( - &self, - logits: Tensor, - targets: Tensor, - ) -> Tensor { - logits.sub(targets).powf(2.0) - } + /// Compute the criterion on the input tensor without reducing. + pub fn forward_no_reduction( + &self, + logits: Tensor, + targets: Tensor, + ) -> Tensor { + logits.sub(targets).powf(2.0) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; + use super::*; + use crate::TestBackend; + use burn_tensor::Data; - #[test] - fn test_mse_loss() { - let logits = Tensor::::from_data(Data::from([[1.0, 2.0], [3.0, 4.0]])); + #[test] + fn test_mse_loss() { + let logits = Tensor::::from_data(Data::from([[1.0, 2.0], [3.0, 4.0]])); - let targets = Tensor::::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]])); + let targets = Tensor::::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]])); - let mse = MSELoss::new(); - let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone()); - let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto); - let loss_sum = mse.forward(logits, targets, Reduction::Sum); + let mse = MSELoss::new(); + let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone()); + let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto); + let loss_sum = mse.forward(logits, targets, Reduction::Sum); - assert_eq!( - loss_no_reduction.into_data(), - Data::from([[1.0, 1.0], [0.0, 4.0]]) - ); - assert_eq!(loss.into_data(), Data::from([1.5])); - assert_eq!(loss_sum.into_data(), Data::from([6.0])); - } + assert_eq!( + loss_no_reduction.into_data(), + Data::from([[1.0, 1.0], [0.0, 4.0]]) + ); + assert_eq!(loss.into_data(), Data::from([1.5])); + assert_eq!(loss_sum.into_data(), Data::from([6.0])); + } } diff --git a/burn-core/src/nn/loss/reduction.rs b/burn-core/src/nn/loss/reduction.rs index 499b171537..32107afd26 100644 --- a/burn-core/src/nn/loss/reduction.rs +++ b/burn-core/src/nn/loss/reduction.rs @@ -1,11 +1,11 @@ /// The reduction type for the loss. pub enum Reduction { - /// The mean of the losses will be returned. - Mean, + /// The mean of the losses will be returned. + Mean, - /// The sum of the losses will be returned. - Sum, + /// The sum of the losses will be returned. + Sum, - /// The mean of the losses will be returned. - Auto, + /// The mean of the losses will be returned. + Auto, } diff --git a/burn-core/src/nn/norm/batch.rs b/burn-core/src/nn/norm/batch.rs index d48ff76bff..bbc37b30f2 100644 --- a/burn-core/src/nn/norm/batch.rs +++ b/burn-core/src/nn/norm/batch.rs @@ -1,22 +1,22 @@ use crate as burn; use crate::{ - config::Config, - module::{Module, Param, RunningState}, - tensor::{backend::Backend, Tensor}, + config::Config, + module::{Module, Param, RunningState}, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [BatchNorm](BatchNorm) layer. #[derive(Config, Debug)] pub struct BatchNormConfig { - /// The number of features. - pub num_features: usize, - /// A value required for numerical stability. Default: 1e-5 - #[config(default = 1e-5)] - pub epsilon: f64, - /// Momentum used to update the metrics. Default: 0.1 - #[config(default = 0.1)] - pub momentum: f64, + /// The number of features. + pub num_features: usize, + /// A value required for numerical stability. Default: 1e-5 + #[config(default = 1e-5)] + pub epsilon: f64, + /// Momentum used to update the metrics. Default: 0.1 + #[config(default = 0.1)] + pub momentum: f64, } /// Applies Batch Normalization over a tensor as described in the paper [Batch Normalization](https://arxiv.org/abs/1502.03167) @@ -24,359 +24,361 @@ pub struct BatchNormConfig { /// `Y = norm(X) * γ + β` #[derive(Module, Debug)] pub struct BatchNorm { - gamma: Param>, - beta: Param>, - running_mean: RunningState>, - running_var: RunningState>, - momentum: f64, - epsilon: f64, + gamma: Param>, + beta: Param>, + running_mean: RunningState>, + running_var: RunningState>, + momentum: f64, + epsilon: f64, } impl BatchNormConfig { - /// Initialize a new [batch norm](BatchNorm) module. - pub fn init(&self) -> BatchNorm { - let gamma = Tensor::ones([self.num_features]); - let beta = Tensor::zeros([self.num_features]); - - let running_mean = Tensor::zeros([self.num_features]); - let running_var = Tensor::ones([self.num_features]); - - BatchNorm { - gamma: Param::from(gamma), - beta: Param::from(beta), - running_mean: RunningState::new(running_mean), - running_var: RunningState::new(running_var), - momentum: self.momentum, - epsilon: self.epsilon, - } + /// Initialize a new [batch norm](BatchNorm) module. + pub fn init(&self) -> BatchNorm { + let gamma = Tensor::ones([self.num_features]); + let beta = Tensor::zeros([self.num_features]); + + let running_mean = Tensor::zeros([self.num_features]); + let running_var = Tensor::ones([self.num_features]); + + BatchNorm { + gamma: Param::from(gamma), + beta: Param::from(beta), + running_mean: RunningState::new(running_mean), + running_var: RunningState::new(running_var), + momentum: self.momentum, + epsilon: self.epsilon, } - - /// Initialize a new [batch norm](BatchNorm) module with a [record](BatchNormRecord). - pub fn init_with( - &self, - record: BatchNormRecord, - ) -> BatchNorm { - BatchNorm { - gamma: record.gamma, - beta: record.beta, - running_mean: RunningState::from_record(record.running_mean), - running_var: RunningState::from_record(record.running_var), - momentum: self.momentum, - epsilon: self.epsilon, - } + } + + /// Initialize a new [batch norm](BatchNorm) module with a [record](BatchNormRecord). + pub fn init_with( + &self, + record: BatchNormRecord, + ) -> BatchNorm { + BatchNorm { + gamma: record.gamma, + beta: record.beta, + running_mean: RunningState::from_record(record.running_mean), + running_var: RunningState::from_record(record.running_var), + momentum: self.momentum, + epsilon: self.epsilon, } + } } impl BatchNorm { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[batch_size, channels, ...]` - /// - output: `[batch_size, channels, ...]` - pub fn forward(&self, input: Tensor) -> Tensor { - // Should be move to a compilation error when const generic support that kind of - // validation. https://github.com/rust-lang/rust/issues/76560 - if D + 2 != DI { - panic!("BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor", D, D+2, DI); - } - - match B::ad_enabled() { - true => self.forward_train(input), - false => self.forward_inference(input), - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[batch_size, channels, ...]` + /// - output: `[batch_size, channels, ...]` + pub fn forward(&self, input: Tensor) -> Tensor { + // Should be move to a compilation error when const generic support that kind of + // validation. https://github.com/rust-lang/rust/issues/76560 + if D + 2 != DI { + panic!("BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor", D, D+2, DI); } - fn forward_inference(&self, input: Tensor) -> Tensor { - let channels = input.dims()[1]; - let mean = self.running_mean.value(); - let var = self.running_var.value(); - - let mut shape = [1; DI]; - shape[1] = channels; - - self.forward_shared(input, mean.reshape(shape), var.reshape(shape)) + match B::ad_enabled() { + true => self.forward_train(input), + false => self.forward_inference(input), } + } - fn forward_train(&self, input: Tensor) -> Tensor { - let dims = input.dims(); - let batch_size = dims[0]; - let channels = dims[1]; - - let mut shape_unsqueeze = [1; DI]; - let mut flatten_size = batch_size; - shape_unsqueeze[1] = channels; - - for dim in dims.iter().take(DI).skip(2) { - flatten_size *= dim; - } - - let mean = input - .clone() - .swap_dims(0, 1) - .reshape([channels, flatten_size]) - .mean_dim(1) - .reshape(shape_unsqueeze); - - let var = input - .clone() - .sub(mean.clone()) - .powf(2.0) - .swap_dims(0, 1) - .reshape([channels, flatten_size]) - .mean_dim(1) - .reshape(shape_unsqueeze); - - let running_mean = self.running_mean.value_sync(); - let running_var = self.running_var.value_sync(); - - let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add( - mean.clone() - .detach() - .mul_scalar(self.momentum) - .reshape([channels]), - ); - let running_var = running_var.mul_scalar(1.0 - self.momentum).add( - var.clone() - .detach() - .mul_scalar(self.momentum) - .reshape([channels]), - ); - - self.running_mean.update(running_mean.detach()); - self.running_var.update(running_var.detach()); - - self.forward_shared(input, mean, var) - } + fn forward_inference(&self, input: Tensor) -> Tensor { + let channels = input.dims()[1]; + let mean = self.running_mean.value(); + let var = self.running_var.value(); - fn forward_shared( - &self, - x: Tensor, - mean: Tensor, - var: Tensor, - ) -> Tensor { - let channels = x.dims()[1]; - let mut shape = [1; DI]; - shape[1] = channels; + let mut shape = [1; DI]; + shape[1] = channels; - let std = var.add_scalar(self.epsilon).sqrt(); + self.forward_shared(input, mean.reshape(shape), var.reshape(shape)) + } - let x = x.sub(mean); - let x = x.div(std); + fn forward_train(&self, input: Tensor) -> Tensor { + let dims = input.dims(); + let batch_size = dims[0]; + let channels = dims[1]; - let x = x.mul(self.gamma.val().reshape(shape)); + let mut shape_unsqueeze = [1; DI]; + let mut flatten_size = batch_size; + shape_unsqueeze[1] = channels; - x.add(self.beta.val().reshape(shape)) + for dim in dims.iter().take(DI).skip(2) { + flatten_size *= dim; } + + let mean = input + .clone() + .swap_dims(0, 1) + .reshape([channels, flatten_size]) + .mean_dim(1) + .reshape(shape_unsqueeze); + + let var = input + .clone() + .sub(mean.clone()) + .powf(2.0) + .swap_dims(0, 1) + .reshape([channels, flatten_size]) + .mean_dim(1) + .reshape(shape_unsqueeze); + + let running_mean = self.running_mean.value_sync(); + let running_var = self.running_var.value_sync(); + + let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add( + mean + .clone() + .detach() + .mul_scalar(self.momentum) + .reshape([channels]), + ); + let running_var = running_var.mul_scalar(1.0 - self.momentum).add( + var + .clone() + .detach() + .mul_scalar(self.momentum) + .reshape([channels]), + ); + + self.running_mean.update(running_mean.detach()); + self.running_var.update(running_var.detach()); + + self.forward_shared(input, mean, var) + } + + fn forward_shared( + &self, + x: Tensor, + mean: Tensor, + var: Tensor, + ) -> Tensor { + let channels = x.dims()[1]; + let mut shape = [1; DI]; + shape[1] = channels; + + let std = var.add_scalar(self.epsilon).sqrt(); + + let x = x.sub(mean); + let x = x.div(std); + + let x = x.mul(self.gamma.val().reshape(shape)); + + x.add(self.beta.val().reshape(shape)) + } } #[cfg(feature = "std")] #[cfg(test)] mod tests_1d { - use super::*; - use crate::{module::AutodiffModule, TestAutodiffBackend}; - use burn_tensor::Data; - - #[test] - fn batch_norm_forward_train() { - let module = BatchNormConfig::new(3).init::(); - - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [ - [1.1483e+00, 3.7521e-01], - [1.6272e-03, 7.5067e-01], - [1.6204e+00, -4.5168e-02], - ], - [ - [6.8856e-02, -1.5923e+00], - [-1.6318e+00, 8.7949e-01], - [-5.3368e-01, -1.0416e+00], - ], - ]), - 2, - ); - } - - #[test] - fn batch_norm_forward_inference() { - let module = BatchNormConfig::new(3).init::(); - - module.forward(input_tensor()); - let module = module.valid(); - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]], - [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]], - ]), - 2, - ); - } - - fn input_tensor() -> Tensor { - Tensor::::from_floats([ - [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]], - [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]], - ]) - } + use super::*; + use crate::{module::AutodiffModule, TestAutodiffBackend}; + use burn_tensor::Data; + + #[test] + fn batch_norm_forward_train() { + let module = BatchNormConfig::new(3).init::(); + + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [1.1483e+00, 3.7521e-01], + [1.6272e-03, 7.5067e-01], + [1.6204e+00, -4.5168e-02], + ], + [ + [6.8856e-02, -1.5923e+00], + [-1.6318e+00, 8.7949e-01], + [-5.3368e-01, -1.0416e+00], + ], + ]), + 2, + ); + } + + #[test] + fn batch_norm_forward_inference() { + let module = BatchNormConfig::new(3).init::(); + + module.forward(input_tensor()); + let module = module.valid(); + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]], + [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]], + ]), + 2, + ); + } + + fn input_tensor() -> Tensor { + Tensor::::from_floats([ + [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]], + [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]], + ]) + } } #[cfg(feature = "std")] #[cfg(test)] mod tests_2d { - use super::*; - use crate::{module::AutodiffModule, TestAutodiffBackend}; - use burn_tensor::Data; - - #[test] - fn batch_norm_forward_train() { - let module = BatchNormConfig::new(3).init::(); - - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [ - [[1.5136, 0.7506], [-1.2216, 0.1477]], - [[0.3135, 1.2252], [-0.4150, 0.6130]], - [[1.4186, 0.3372], [-1.5183, 1.5262]], - ], - [ - [[0.4483, -1.1914], [-1.2010, 0.7537]], - [[-1.6752, 1.3822], [-0.5058, -0.9381]], - [[0.0200, -0.3097], [-0.5715, -0.9026]], - ], - ]), - 2, - ); - } - - #[test] - fn batch_norm_forward_inference() { - let module = BatchNormConfig::new(3).init::(); - - module.forward(input_tensor()); - let module = module.valid(); - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [ - [[0.9538, 0.7103], [0.0808, 0.5179]], - [[0.6015, 0.8910], [0.3703, 0.6966]], - [[0.9171, 0.6912], [0.3037, 0.9395]], - ], - [ - [[0.6138, 0.0904], [0.0874, 0.7113]], - [[-0.0297, 0.9408], [0.3415, 0.2042]], - [[0.6250, 0.5561], [0.5013, 0.4323]], - ], - ]), - 2, - ); - } - - #[test] - fn batch_norm_running_mean() { - let module = BatchNormConfig::new(3).init::(); - - let _output = module.forward(input_tensor()); - - let running_mean = module.running_mean.value_sync(); - - running_mean - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([0.0499, 0.0532, 0.0656]), 2); - } - - #[test] - fn batch_norm_running_var() { - let module = BatchNormConfig::new(3).init::(); - - let _output = module.forward(input_tensor()); - - let running_var = module.running_var.value_sync(); - - running_var - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([0.9106, 0.9105, 0.9045]), 2); - } - - #[test] - fn batch_norm_running_mean_inner_module() { - let module = BatchNormConfig::new(3).init::(); - - let _output = module.forward(input_tensor()); - - let module_valid = module.valid(); - let running_mean = module_valid.running_mean.value(); - let running_mean_after = module.running_mean.value(); - - running_mean_after - .into_data() - .assert_approx_eq(&running_mean.into_data(), 3); - } - - #[test] - fn batch_norm_grads() { - let module = BatchNormConfig::new(3).init::(); - let input = input_tensor().require_grad(); - - let output = module.forward(input.clone()); - - let grads = output.backward(); - - module - .gamma - .grad(&grads) - .unwrap() - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([0.0000e+00, -5.9035e-07, -6.0011e-07]), 3); - - module - .beta - .grad(&grads) - .unwrap() - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([8., 8., 8.]), 3); - - input.grad(&grads).unwrap().into_data().assert_approx_eq( - &Data::from([ - [ - [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], - [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]], - [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]], - ], - [ - [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], - [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]], - [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]], - ], - ]), - 4, - ); - } - - fn input_tensor() -> Tensor { - Tensor::::from_floats([ - [ - [[0.9601, 0.7277], [0.1270, 0.5441]], - [[0.6272, 0.9034], [0.4066, 0.7179]], - [[0.9378, 0.7230], [0.3544, 0.9591]], - ], - [ - [[0.6356, 0.1362], [0.1333, 0.7287]], - [[0.0249, 0.9509], [0.3791, 0.2481]], - [[0.6600, 0.5945], [0.5424, 0.4767]], - ], - ]) - } + use super::*; + use crate::{module::AutodiffModule, TestAutodiffBackend}; + use burn_tensor::Data; + + #[test] + fn batch_norm_forward_train() { + let module = BatchNormConfig::new(3).init::(); + + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [[1.5136, 0.7506], [-1.2216, 0.1477]], + [[0.3135, 1.2252], [-0.4150, 0.6130]], + [[1.4186, 0.3372], [-1.5183, 1.5262]], + ], + [ + [[0.4483, -1.1914], [-1.2010, 0.7537]], + [[-1.6752, 1.3822], [-0.5058, -0.9381]], + [[0.0200, -0.3097], [-0.5715, -0.9026]], + ], + ]), + 2, + ); + } + + #[test] + fn batch_norm_forward_inference() { + let module = BatchNormConfig::new(3).init::(); + + module.forward(input_tensor()); + let module = module.valid(); + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [[0.9538, 0.7103], [0.0808, 0.5179]], + [[0.6015, 0.8910], [0.3703, 0.6966]], + [[0.9171, 0.6912], [0.3037, 0.9395]], + ], + [ + [[0.6138, 0.0904], [0.0874, 0.7113]], + [[-0.0297, 0.9408], [0.3415, 0.2042]], + [[0.6250, 0.5561], [0.5013, 0.4323]], + ], + ]), + 2, + ); + } + + #[test] + fn batch_norm_running_mean() { + let module = BatchNormConfig::new(3).init::(); + + let _output = module.forward(input_tensor()); + + let running_mean = module.running_mean.value_sync(); + + running_mean + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([0.0499, 0.0532, 0.0656]), 2); + } + + #[test] + fn batch_norm_running_var() { + let module = BatchNormConfig::new(3).init::(); + + let _output = module.forward(input_tensor()); + + let running_var = module.running_var.value_sync(); + + running_var + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([0.9106, 0.9105, 0.9045]), 2); + } + + #[test] + fn batch_norm_running_mean_inner_module() { + let module = BatchNormConfig::new(3).init::(); + + let _output = module.forward(input_tensor()); + + let module_valid = module.valid(); + let running_mean = module_valid.running_mean.value(); + let running_mean_after = module.running_mean.value(); + + running_mean_after + .into_data() + .assert_approx_eq(&running_mean.into_data(), 3); + } + + #[test] + fn batch_norm_grads() { + let module = BatchNormConfig::new(3).init::(); + let input = input_tensor().require_grad(); + + let output = module.forward(input.clone()); + + let grads = output.backward(); + + module + .gamma + .grad(&grads) + .unwrap() + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([0.0000e+00, -5.9035e-07, -6.0011e-07]), 3); + + module + .beta + .grad(&grads) + .unwrap() + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([8., 8., 8.]), 3); + + input.grad(&grads).unwrap().into_data().assert_approx_eq( + &Data::from([ + [ + [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], + [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]], + [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]], + ], + [ + [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], + [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]], + [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]], + ], + ]), + 4, + ); + } + + fn input_tensor() -> Tensor { + Tensor::::from_floats([ + [ + [[0.9601, 0.7277], [0.1270, 0.5441]], + [[0.6272, 0.9034], [0.4066, 0.7179]], + [[0.9378, 0.7230], [0.3544, 0.9591]], + ], + [ + [[0.6356, 0.1362], [0.1333, 0.7287]], + [[0.0249, 0.9509], [0.3791, 0.2481]], + [[0.6600, 0.5945], [0.5424, 0.4767]], + ], + ]) + } } diff --git a/burn-core/src/nn/norm/layer.rs b/burn-core/src/nn/norm/layer.rs index 54bed1a7ff..f8a757abf8 100644 --- a/burn-core/src/nn/norm/layer.rs +++ b/burn-core/src/nn/norm/layer.rs @@ -9,11 +9,11 @@ use crate::tensor::Tensor; /// Configuration to create a [LayerNorm](LayerNorm) layer. #[derive(Config)] pub struct LayerNormConfig { - /// The size of the input features. - pub d_model: usize, - /// A value required for numerical stability. Default: 1e-5 - #[config(default = 1e-5)] - pub epsilon: f64, + /// The size of the input features. + pub d_model: usize, + /// A value required for numerical stability. Default: 1e-5 + #[config(default = 1e-5)] + pub epsilon: f64, } /// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450). @@ -21,112 +21,112 @@ pub struct LayerNormConfig { /// `Y = norm(X) * γ + β` #[derive(Module, Debug)] pub struct LayerNorm { - gamma: Param>, - beta: Param>, - epsilon: f64, + gamma: Param>, + beta: Param>, + epsilon: f64, } impl LayerNormConfig { - /// Initialize a new [layer norm](LayerNorm) module. - pub fn init(&self) -> LayerNorm { - let gamma = Tensor::ones([self.d_model]); - let beta = Tensor::zeros([self.d_model]); - - LayerNorm { - gamma: Param::from(gamma), - beta: Param::from(beta), - epsilon: self.epsilon, - } + /// Initialize a new [layer norm](LayerNorm) module. + pub fn init(&self) -> LayerNorm { + let gamma = Tensor::ones([self.d_model]); + let beta = Tensor::zeros([self.d_model]); + + LayerNorm { + gamma: Param::from(gamma), + beta: Param::from(beta), + epsilon: self.epsilon, } - - /// Initialize a new [layer norm](LayerNorm) module with a [record](LayerNormRecord). - pub fn init_with(&self, record: LayerNormRecord) -> LayerNorm { - LayerNorm { - gamma: record.gamma, - beta: record.beta, - epsilon: self.epsilon, - } + } + + /// Initialize a new [layer norm](LayerNorm) module with a [record](LayerNormRecord). + pub fn init_with(&self, record: LayerNormRecord) -> LayerNorm { + LayerNorm { + gamma: record.gamma, + beta: record.beta, + epsilon: self.epsilon, } + } } impl LayerNorm { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any, d_model]` - /// - output: `[..., any, d_model]` - pub fn forward(&self, input: Tensor) -> Tensor { - let (var, mean) = input.clone().var_mean_bias(D - 1); - - let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon)); - - input_normalized - .mul(self.gamma.val().unsqueeze()) - .add(self.beta.val().unsqueeze()) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any, d_model]` + /// - output: `[..., any, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let (var, mean) = input.clone().var_mean_bias(D - 1); + + let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon)); + + input_normalized + .mul(self.gamma.val().unsqueeze()) + .add(self.beta.val().unsqueeze()) + } } #[cfg(test)] mod tests { - use super::*; - use burn_tensor::Data; - - #[cfg(feature = "std")] - use crate::{TestAutodiffBackend, TestBackend}; - - #[cfg(not(feature = "std"))] - use crate::TestBackend; - - #[test] - fn layer_norm_forward() { - let module = LayerNormConfig::new(10).init::(); - let input = Tensor::from_data(Data::from([[ - -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728, - ]])); - - let output = module.forward(input); - - output.to_data().assert_approx_eq( - &Data::from([[ - -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915, - ]]), - 3, - ); - } - - #[cfg(feature = "std")] - #[test] - fn layer_norm_backward() { - let module = LayerNormConfig::new(2).init::(); - let tensor_1 = - Tensor::::from_data(Data::from([[0.0, 1.0], [3.0, 4.0]])) - .require_grad(); - let tensor_2 = - Tensor::::from_data(Data::from([[6.0, 7.0], [9.0, 10.0]])) - .require_grad(); - - let x = tensor_1.clone().matmul(tensor_2.clone()); - - let output = module.forward(x); - let grads = output.backward(); - - let tensor_1_grad = tensor_1.grad(&grads).unwrap(); - let tensor_2_grad = tensor_2.grad(&grads).unwrap(); - let gamma_grad = module.gamma.grad(&grads).unwrap(); - let beta_grad = module.beta.grad(&grads).unwrap(); - - gamma_grad - .to_data() - .assert_approx_eq(&Data::from([-2.0, 2.0]), 3); - beta_grad - .to_data() - .assert_approx_eq(&Data::from([2.0, 2.0]), 3); - tensor_1_grad - .to_data() - .assert_approx_eq(&Data::zeros(tensor_1_grad.shape()), 3); - tensor_2_grad - .to_data() - .assert_approx_eq(&Data::zeros(tensor_2_grad.shape()), 3); - } + use super::*; + use burn_tensor::Data; + + #[cfg(feature = "std")] + use crate::{TestAutodiffBackend, TestBackend}; + + #[cfg(not(feature = "std"))] + use crate::TestBackend; + + #[test] + fn layer_norm_forward() { + let module = LayerNormConfig::new(10).init::(); + let input = Tensor::from_data(Data::from([[ + -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728, + ]])); + + let output = module.forward(input); + + output.to_data().assert_approx_eq( + &Data::from([[ + -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915, + ]]), + 3, + ); + } + + #[cfg(feature = "std")] + #[test] + fn layer_norm_backward() { + let module = LayerNormConfig::new(2).init::(); + let tensor_1 = + Tensor::::from_data(Data::from([[0.0, 1.0], [3.0, 4.0]])) + .require_grad(); + let tensor_2 = + Tensor::::from_data(Data::from([[6.0, 7.0], [9.0, 10.0]])) + .require_grad(); + + let x = tensor_1.clone().matmul(tensor_2.clone()); + + let output = module.forward(x); + let grads = output.backward(); + + let tensor_1_grad = tensor_1.grad(&grads).unwrap(); + let tensor_2_grad = tensor_2.grad(&grads).unwrap(); + let gamma_grad = module.gamma.grad(&grads).unwrap(); + let beta_grad = module.beta.grad(&grads).unwrap(); + + gamma_grad + .to_data() + .assert_approx_eq(&Data::from([-2.0, 2.0]), 3); + beta_grad + .to_data() + .assert_approx_eq(&Data::from([2.0, 2.0]), 3); + tensor_1_grad + .to_data() + .assert_approx_eq(&Data::zeros(tensor_1_grad.shape()), 3); + tensor_2_grad + .to_data() + .assert_approx_eq(&Data::zeros(tensor_2_grad.shape()), 3); + } } diff --git a/burn-core/src/nn/padding.rs b/burn-core/src/nn/padding.rs index db27bf4486..0c7decb9c8 100644 --- a/burn-core/src/nn/padding.rs +++ b/burn-core/src/nn/padding.rs @@ -8,62 +8,62 @@ use crate::module::Module; /// Padding configuration for 1D operators. #[derive(Module, Config, Debug, PartialEq)] pub enum PaddingConfig1d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. - Same, - /// Same as no padding. - Valid, - /// Applies the specified amount of padding to all inputs. - Explicit(usize), + /// Dynamically calculate the amount of padding necessary to ensure that the output size will be + /// the same as the input. + Same, + /// Same as no padding. + Valid, + /// Applies the specified amount of padding to all inputs. + Explicit(usize), } impl PaddingConfig1d { - pub(crate) fn calculate_padding_1d( - &self, - length: usize, - kernel_size: usize, - stride: usize, - ) -> usize { - let same_padding = || calculate_conv_padding(kernel_size, stride, length, length); - match self { - Self::Valid => 0, - Self::Same => same_padding(), - Self::Explicit(value) => *value, - } + pub(crate) fn calculate_padding_1d( + &self, + length: usize, + kernel_size: usize, + stride: usize, + ) -> usize { + let same_padding = || calculate_conv_padding(kernel_size, stride, length, length); + match self { + Self::Valid => 0, + Self::Same => same_padding(), + Self::Explicit(value) => *value, } + } } /// Padding configuration for 2D operators. #[derive(Module, Config, Debug, PartialEq)] pub enum PaddingConfig2d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. - Same, - /// Same as no padding. - Valid, - /// Applies the specified amount of padding to all inputs. - Explicit(usize, usize), + /// Dynamically calculate the amount of padding necessary to ensure that the output size will be + /// the same as the input. + Same, + /// Same as no padding. + Valid, + /// Applies the specified amount of padding to all inputs. + Explicit(usize, usize), } impl PaddingConfig2d { - pub(crate) fn calculate_padding_2d( - &self, - height: usize, - width: usize, - kernel_size: &[usize; 2], - stride: &[usize; 2], - ) -> [usize; 2] { - let same_padding = || { - let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height); - let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width); + pub(crate) fn calculate_padding_2d( + &self, + height: usize, + width: usize, + kernel_size: &[usize; 2], + stride: &[usize; 2], + ) -> [usize; 2] { + let same_padding = || { + let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height); + let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width); - [p1, p2] - }; + [p1, p2] + }; - match self { - Self::Same => same_padding(), - Self::Valid => [0, 0], - Self::Explicit(v1, v2) => [*v1, *v2], - } + match self { + Self::Same => same_padding(), + Self::Valid => [0, 0], + Self::Explicit(v1, v2) => [*v1, *v2], } + } } diff --git a/burn-core/src/nn/pool/adaptive_avg_pool1d.rs b/burn-core/src/nn/pool/adaptive_avg_pool1d.rs index 2bd321f575..547c2d40ed 100644 --- a/burn-core/src/nn/pool/adaptive_avg_pool1d.rs +++ b/burn-core/src/nn/pool/adaptive_avg_pool1d.rs @@ -9,33 +9,33 @@ use burn_tensor::module::adaptive_avg_pool1d; /// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer. #[derive(Config)] pub struct AdaptiveAvgPool1dConfig { - /// The size of the output. - pub output_size: usize, + /// The size of the output. + pub output_size: usize, } /// Applies a 1D adaptive avg pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct AdaptiveAvgPool1d { - output_size: usize, + output_size: usize, } impl AdaptiveAvgPool1dConfig { - /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module. - pub fn init(&self) -> AdaptiveAvgPool1d { - AdaptiveAvgPool1d { - output_size: self.output_size, - } + /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module. + pub fn init(&self) -> AdaptiveAvgPool1d { + AdaptiveAvgPool1d { + output_size: self.output_size, } + } } impl AdaptiveAvgPool1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, length], - /// - output: [batch_size, channels, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - adaptive_avg_pool1d(input, self.output_size) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, length], + /// - output: [batch_size, channels, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + adaptive_avg_pool1d(input, self.output_size) + } } diff --git a/burn-core/src/nn/pool/adaptive_avg_pool2d.rs b/burn-core/src/nn/pool/adaptive_avg_pool2d.rs index 1aba65648d..b11cf6f362 100644 --- a/burn-core/src/nn/pool/adaptive_avg_pool2d.rs +++ b/burn-core/src/nn/pool/adaptive_avg_pool2d.rs @@ -9,33 +9,33 @@ use burn_tensor::module::adaptive_avg_pool2d; /// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer. #[derive(Config)] pub struct AdaptiveAvgPool2dConfig { - /// The size of the output. - pub output_size: [usize; 2], + /// The size of the output. + pub output_size: [usize; 2], } /// Applies a 2D adaptive avg pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct AdaptiveAvgPool2d { - output_size: [usize; 2], + output_size: [usize; 2], } impl AdaptiveAvgPool2dConfig { - /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module. - pub fn init(&self) -> AdaptiveAvgPool2d { - AdaptiveAvgPool2d { - output_size: self.output_size, - } + /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module. + pub fn init(&self) -> AdaptiveAvgPool2d { + AdaptiveAvgPool2d { + output_size: self.output_size, } + } } impl AdaptiveAvgPool2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, height_in, width_in], - /// - output: [batch_size, channels, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - adaptive_avg_pool2d(input, self.output_size) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, height_in, width_in], + /// - output: [batch_size, channels, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + adaptive_avg_pool2d(input, self.output_size) + } } diff --git a/burn-core/src/nn/pool/avg_pool1d.rs b/burn-core/src/nn/pool/avg_pool1d.rs index 50596b5ba3..d4e912d49d 100644 --- a/burn-core/src/nn/pool/avg_pool1d.rs +++ b/burn-core/src/nn/pool/avg_pool1d.rs @@ -10,17 +10,17 @@ use burn_tensor::module::avg_pool1d; /// Configuration to create a [1D avg pooling](AvgPool1d) layer. #[derive(Config)] pub struct AvgPool1dConfig { - /// The size of the kernel. - pub kernel_size: usize, - /// The stride. - #[config(default = "1")] - pub stride: usize, - /// The padding configuration. - #[config(default = "PaddingConfig1d::Valid")] - pub padding: PaddingConfig1d, - /// If the padding is counted in the denominator when computing the average. - #[config(default = "true")] - count_include_pad: bool, + /// The size of the kernel. + pub kernel_size: usize, + /// The stride. + #[config(default = "1")] + pub stride: usize, + /// The padding configuration. + #[config(default = "PaddingConfig1d::Valid")] + pub padding: PaddingConfig1d, + /// If the padding is counted in the denominator when computing the average. + #[config(default = "true")] + count_include_pad: bool, } /// Applies a 1D avg pooling over input tensors. @@ -40,43 +40,43 @@ pub struct AvgPool1dConfig { #[derive(Module, Debug, Clone)] pub struct AvgPool1d { - stride: usize, - kernel_size: usize, - padding: PaddingConfig1d, - count_include_pad: bool, + stride: usize, + kernel_size: usize, + padding: PaddingConfig1d, + count_include_pad: bool, } impl AvgPool1dConfig { - /// Initialize a new [avg pool 1d](AvgPool1d) module. - pub fn init(&self) -> AvgPool1d { - AvgPool1d { - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - count_include_pad: self.count_include_pad, - } + /// Initialize a new [avg pool 1d](AvgPool1d) module. + pub fn init(&self) -> AvgPool1d { + AvgPool1d { + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + count_include_pad: self.count_include_pad, } + } } impl AvgPool1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, length_in], - /// - output: [batch_size, channels, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, length_in], + /// - output: [batch_size, channels, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); - avg_pool1d( - input, - self.kernel_size, - self.stride, - padding, - self.count_include_pad, - ) - } + avg_pool1d( + input, + self.kernel_size, + self.stride, + padding, + self.count_include_pad, + ) + } } diff --git a/burn-core/src/nn/pool/avg_pool2d.rs b/burn-core/src/nn/pool/avg_pool2d.rs index f3bb2b60ec..a8be01e92b 100644 --- a/burn-core/src/nn/pool/avg_pool2d.rs +++ b/burn-core/src/nn/pool/avg_pool2d.rs @@ -10,17 +10,17 @@ use burn_tensor::module::avg_pool2d; /// Configuration to create a [2D avg pooling](AvgPool2d) layer. #[derive(Config, Debug)] pub struct AvgPool2dConfig { - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The strides. - #[config(default = "[1, 1]")] - pub strides: [usize; 2], - /// The padding configuration. - #[config(default = "PaddingConfig2d::Valid")] - pub padding: PaddingConfig2d, - /// If the padding is counted in the denominator when computing the average. - #[config(default = "true")] - count_include_pad: bool, + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The strides. + #[config(default = "[1, 1]")] + pub strides: [usize; 2], + /// The padding configuration. + #[config(default = "PaddingConfig2d::Valid")] + pub padding: PaddingConfig2d, + /// If the padding is counted in the denominator when computing the average. + #[config(default = "true")] + count_include_pad: bool, } /// Applies a 2D avg pooling over input tensors. @@ -39,43 +39,44 @@ pub struct AvgPool2dConfig { /// [Issue 636](https://github.com/burn-rs/burn/issues/636) #[derive(Module, Debug, Clone)] pub struct AvgPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: PaddingConfig2d, - count_include_pad: bool, + stride: [usize; 2], + kernel_size: [usize; 2], + padding: PaddingConfig2d, + count_include_pad: bool, } impl AvgPool2dConfig { - /// Initialize a new [avg pool 2d](AvgPool2d) module. - pub fn init(&self) -> AvgPool2d { - AvgPool2d { - stride: self.strides, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - count_include_pad: self.count_include_pad, - } + /// Initialize a new [avg pool 2d](AvgPool2d) module. + pub fn init(&self) -> AvgPool2d { + AvgPool2d { + stride: self.strides, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + count_include_pad: self.count_include_pad, } + } } impl AvgPool2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, height_in, width_in], - /// - output: [batch_size, channels, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self.padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, height_in, width_in], + /// - output: [batch_size, channels, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = + self + .padding + .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - avg_pool2d( - input, - self.kernel_size, - self.stride, - padding, - self.count_include_pad, - ) - } + avg_pool2d( + input, + self.kernel_size, + self.stride, + padding, + self.count_include_pad, + ) + } } diff --git a/burn-core/src/nn/pool/max_pool1d.rs b/burn-core/src/nn/pool/max_pool1d.rs index ca7a2bf01c..7a7ef99918 100644 --- a/burn-core/src/nn/pool/max_pool1d.rs +++ b/burn-core/src/nn/pool/max_pool1d.rs @@ -10,53 +10,53 @@ use burn_tensor::module::max_pool1d; /// Configuration to create a [1D max pooling](MaxPool1d) layer. #[derive(Config)] pub struct MaxPool1dConfig { - /// The size of the kernel. - pub kernel_size: usize, - /// The stride. - #[config(default = "1")] - pub stride: usize, - /// The padding configuration. - #[config(default = "PaddingConfig1d::Valid")] - pub padding: PaddingConfig1d, - /// The dilation. - #[config(default = "1")] - pub dilation: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The stride. + #[config(default = "1")] + pub stride: usize, + /// The padding configuration. + #[config(default = "PaddingConfig1d::Valid")] + pub padding: PaddingConfig1d, + /// The dilation. + #[config(default = "1")] + pub dilation: usize, } /// Applies a 1D max pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct MaxPool1d { - stride: usize, - kernel_size: usize, - padding: PaddingConfig1d, - dilation: usize, + stride: usize, + kernel_size: usize, + padding: PaddingConfig1d, + dilation: usize, } impl MaxPool1dConfig { - /// Initialize a new [max pool 1d](MaxPool1d) module. - pub fn init(&self) -> MaxPool1d { - MaxPool1d { - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, - } + /// Initialize a new [max pool 1d](MaxPool1d) module. + pub fn init(&self) -> MaxPool1d { + MaxPool1d { + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, } + } } impl MaxPool1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, length_in], - /// - output: [batch_size, channels, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, length_in], + /// - output: [batch_size, channels, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); - max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation) - } + max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation) + } } diff --git a/burn-core/src/nn/pool/max_pool2d.rs b/burn-core/src/nn/pool/max_pool2d.rs index f1aebc19ed..33ccd24ba7 100644 --- a/burn-core/src/nn/pool/max_pool2d.rs +++ b/burn-core/src/nn/pool/max_pool2d.rs @@ -10,53 +10,54 @@ use burn_tensor::module::max_pool2d; /// Configuration to create an [2D max pooling](MaxPool2d) layer. #[derive(Debug, Config)] pub struct MaxPool2dConfig { - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The strides. - #[config(default = "[1, 1]")] - pub strides: [usize; 2], - /// The padding configuration. - #[config(default = "PaddingConfig2d::Valid")] - pub padding: PaddingConfig2d, - /// The dilation. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The strides. + #[config(default = "[1, 1]")] + pub strides: [usize; 2], + /// The padding configuration. + #[config(default = "PaddingConfig2d::Valid")] + pub padding: PaddingConfig2d, + /// The dilation. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], } /// Applies a 2D max pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct MaxPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: PaddingConfig2d, - dilation: [usize; 2], + stride: [usize; 2], + kernel_size: [usize; 2], + padding: PaddingConfig2d, + dilation: [usize; 2], } impl MaxPool2dConfig { - /// Initialize a new [max pool 2d](MaxPool2d) module. - pub fn init(&self) -> MaxPool2d { - MaxPool2d { - stride: self.strides, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, - } + /// Initialize a new [max pool 2d](MaxPool2d) module. + pub fn init(&self) -> MaxPool2d { + MaxPool2d { + stride: self.strides, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, } + } } impl MaxPool2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, height_in, width_in], - /// - output: [batch_size, channels, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self.padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, height_in, width_in], + /// - output: [batch_size, channels, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = + self + .padding + .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation) - } + max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation) + } } diff --git a/burn-core/src/nn/pos_encoding.rs b/burn-core/src/nn/pos_encoding.rs index e4a935ac96..45d7b2900a 100644 --- a/burn-core/src/nn/pos_encoding.rs +++ b/burn-core/src/nn/pos_encoding.rs @@ -12,16 +12,16 @@ use libm::{cosf, expf, logf, sinf}; /// Configuration to create an [PositionalEncoding](PositionalEncoding) layer. #[derive(Config)] pub struct PositionalEncodingConfig { - /// Maximum sequence size to use. - #[config(default = "5_000")] - max_sequence_size: usize, + /// Maximum sequence size to use. + #[config(default = "5_000")] + max_sequence_size: usize, - /// The size of each vector. - d_model: usize, + /// The size of each vector. + d_model: usize, - /// Max time scale to use. - #[config(default = "10_000")] - max_timescale: usize, + /// Max time scale to use. + #[config(default = "10_000")] + max_timescale: usize, } /// Positional encoding layer for transformer models. @@ -38,55 +38,55 @@ pub struct PositionalEncodingConfig { /// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) #[derive(Module, Debug)] pub struct PositionalEncoding { - sinusoids: Tensor, + sinusoids: Tensor, } impl PositionalEncodingConfig { - /// Initialize a new [PositionalEncoding](PositionalEncoding) module. - pub fn init(&self) -> PositionalEncoding { - let sinusoids = - generate_sinusoids::(self.max_sequence_size, self.d_model, self.max_timescale) - .unsqueeze::<3>(); - - PositionalEncoding { sinusoids } - } + /// Initialize a new [PositionalEncoding](PositionalEncoding) module. + pub fn init(&self) -> PositionalEncoding { + let sinusoids = + generate_sinusoids::(self.max_sequence_size, self.d_model, self.max_timescale) + .unsqueeze::<3>(); + + PositionalEncoding { sinusoids } + } } impl PositionalEncoding { - /// Applies the forward pass on the input tensor by adding the sinusoids to the input. - /// - /// # Shapes - /// - /// * input: [batch_size, seq_length, d_model] - /// * output: [batch_size, seq_length, d_model] - /// - /// - /// # Panics - /// - /// * Panics if the input sequence length is greater than the maximum sequence size. - /// * Panics if the input d_model is not equal to the d_model of the sinusoids. - pub fn forward(&self, input: Tensor) -> Tensor { - let [_, seq_length, d_model_input] = input.dims(); - - let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims(); - - assert!( - max_sequence_size >= seq_length, - "max_sequence_size({}) must be greater or equal than length({seq_length})", - max_sequence_size, - ); - - assert!( - d_model_input == d_model, - "d_model({}) of the input must be equal to d_model of encoding({})", - d_model_input, - d_model, - ); - - let slices = [0..batch_size, 0..seq_length, 0..d_model]; - - input.add(self.sinusoids.clone().slice(slices)) - } + /// Applies the forward pass on the input tensor by adding the sinusoids to the input. + /// + /// # Shapes + /// + /// * input: [batch_size, seq_length, d_model] + /// * output: [batch_size, seq_length, d_model] + /// + /// + /// # Panics + /// + /// * Panics if the input sequence length is greater than the maximum sequence size. + /// * Panics if the input d_model is not equal to the d_model of the sinusoids. + pub fn forward(&self, input: Tensor) -> Tensor { + let [_, seq_length, d_model_input] = input.dims(); + + let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims(); + + assert!( + max_sequence_size >= seq_length, + "max_sequence_size({}) must be greater or equal than length({seq_length})", + max_sequence_size, + ); + + assert!( + d_model_input == d_model, + "d_model({}) of the input must be equal to d_model of encoding({})", + d_model_input, + d_model, + ); + + let slices = [0..batch_size, 0..seq_length, 0..d_model]; + + input.add(self.sinusoids.clone().slice(slices)) + } } /// Returns sinusoids for positional embedding introduced in @@ -106,124 +106,124 @@ impl PositionalEncoding { /// /// A tensor of shape [length, d_model] containing the sinusoids. pub fn generate_sinusoids( - length: usize, - d_model: usize, - max_timescale: usize, + length: usize, + d_model: usize, + max_timescale: usize, ) -> Tensor { - assert!(d_model % 2 == 0, "d_model must be even"); - assert!( - max_timescale >= length, - "max_timescale must be greater than length" - ); - - // Calculate the increment for the logarithmic timescale - let log_timescale_increment = -logf(max_timescale as f32) / d_model as f32; - - // Create a vector to hold the sinusoids - let mut scaled_time_sin_cos = Vec::with_capacity(length); - - // Loop over each position in the sequence - for i in 0..length { - // Create a vector to hold the sinusoids for this position - let mut row = Vec::with_capacity(d_model / 2); - // Loop over each dimension of the sinusoids - for k in (0..d_model).step_by(2) { - // Calculate the division term for this dimension - let div_term = expf(k as f32 * log_timescale_increment); - // Calculate the sine and cosine values for this dimension and position - row.push(sinf(div_term * i as f32)); - row.push(cosf(div_term * i as f32)); - } - - // Add the sinusoids for this position to the vector - scaled_time_sin_cos.push(row); + assert!(d_model % 2 == 0, "d_model must be even"); + assert!( + max_timescale >= length, + "max_timescale must be greater than length" + ); + + // Calculate the increment for the logarithmic timescale + let log_timescale_increment = -logf(max_timescale as f32) / d_model as f32; + + // Create a vector to hold the sinusoids + let mut scaled_time_sin_cos = Vec::with_capacity(length); + + // Loop over each position in the sequence + for i in 0..length { + // Create a vector to hold the sinusoids for this position + let mut row = Vec::with_capacity(d_model / 2); + // Loop over each dimension of the sinusoids + for k in (0..d_model).step_by(2) { + // Calculate the division term for this dimension + let div_term = expf(k as f32 * log_timescale_increment); + // Calculate the sine and cosine values for this dimension and position + row.push(sinf(div_term * i as f32)); + row.push(cosf(div_term * i as f32)); } - // Convert the sinusoids to a tensor and return it - let data = Data::new( - scaled_time_sin_cos.into_iter().flatten().collect(), - [length, d_model].into(), - ); + // Add the sinusoids for this position to the vector + scaled_time_sin_cos.push(row); + } + + // Convert the sinusoids to a tensor and return it + let data = Data::new( + scaled_time_sin_cos.into_iter().flatten().collect(), + [length, d_model].into(), + ); - Tensor::::from_data(data.convert()) + Tensor::::from_data(data.convert()) } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - - #[test] - fn test_module() { - let d_model = 6; - let length = 3; - - // expected to broadcast - let batch_size = 2; - - let pe = PositionalEncodingConfig::new(d_model).init::(); - - // Use a tensor of zeros as input for easy verification of the output - // The output should be the sinusoids broadcasted to the input shape - let tensor = Tensor::zeros([batch_size, length, d_model]); - - let output = pe.forward(tensor); - - assert_eq!(output.shape().dims, [batch_size, length, d_model]); - - let expected = Tensor::::from_floats([ - [ - [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], - [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], - [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], - ], - [ - [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], - [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], - [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], - ], - ]); - - output.to_data().assert_approx_eq(&expected.to_data(), 5); - } - - #[test] - fn test_generate_sinusoids() { - let sinusoids = generate_sinusoids::(12, 6, 10_000); - - // The values are taken from the pytorch reference implementation - let expected = Tensor::::from_floats([ - [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], - [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], - [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], - [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998], - [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996], - [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994], - [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992], - [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989], - [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985], - [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981], - [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977], - [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972], - ]); - sinusoids.to_data().assert_approx_eq(&expected.to_data(), 5); - } - - #[test] - #[should_panic] - fn d_model_input_should_match() { - let d_model = 8; - let pe = PositionalEncodingConfig::new(d_model).init::(); - let input = Tensor::zeros([1, 5, 10]); - let _output = pe.forward(input); - } - - #[test] - #[should_panic] - fn input_length_should_be_less_than_max_len() { - let d_model = 8; - let pe = PositionalEncodingConfig::new(d_model).init::(); - let input = Tensor::zeros([1, 6_000, d_model]); - let _output = pe.forward(input); - } + use super::*; + use crate::TestBackend; + + #[test] + fn test_module() { + let d_model = 6; + let length = 3; + + // expected to broadcast + let batch_size = 2; + + let pe = PositionalEncodingConfig::new(d_model).init::(); + + // Use a tensor of zeros as input for easy verification of the output + // The output should be the sinusoids broadcasted to the input shape + let tensor = Tensor::zeros([batch_size, length, d_model]); + + let output = pe.forward(tensor); + + assert_eq!(output.shape().dims, [batch_size, length, d_model]); + + let expected = Tensor::::from_floats([ + [ + [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], + [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], + [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], + ], + [ + [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], + [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], + [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], + ], + ]); + + output.to_data().assert_approx_eq(&expected.to_data(), 5); + } + + #[test] + fn test_generate_sinusoids() { + let sinusoids = generate_sinusoids::(12, 6, 10_000); + + // The values are taken from the pytorch reference implementation + let expected = Tensor::::from_floats([ + [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], + [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], + [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], + [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998], + [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996], + [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994], + [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992], + [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989], + [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985], + [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981], + [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977], + [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972], + ]); + sinusoids.to_data().assert_approx_eq(&expected.to_data(), 5); + } + + #[test] + #[should_panic] + fn d_model_input_should_match() { + let d_model = 8; + let pe = PositionalEncodingConfig::new(d_model).init::(); + let input = Tensor::zeros([1, 5, 10]); + let _output = pe.forward(input); + } + + #[test] + #[should_panic] + fn input_length_should_be_less_than_max_len() { + let d_model = 8; + let pe = PositionalEncodingConfig::new(d_model).init::(); + let input = Tensor::zeros([1, 6_000, d_model]); + let _output = pe.forward(input); + } } diff --git a/burn-core/src/nn/relu.rs b/burn-core/src/nn/relu.rs index 92e260c7ee..a84d7431e3 100644 --- a/burn-core/src/nn/relu.rs +++ b/burn-core/src/nn/relu.rs @@ -11,17 +11,17 @@ use crate::tensor::Tensor; pub struct ReLU {} impl ReLU { - /// Create the module. - pub fn new() -> Self { - Self {} - } - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any]` - /// - output: `[..., any]` - pub fn forward(&self, input: Tensor) -> Tensor { - crate::tensor::activation::relu(input) - } + /// Create the module. + pub fn new() -> Self { + Self {} + } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any]` + /// - output: `[..., any]` + pub fn forward(&self, input: Tensor) -> Tensor { + crate::tensor::activation::relu(input) + } } diff --git a/burn-core/src/nn/rnn/gate_controller.rs b/burn-core/src/nn/rnn/gate_controller.rs index c063301ac8..6a07c2b80b 100644 --- a/burn-core/src/nn/rnn/gate_controller.rs +++ b/burn-core/src/nn/rnn/gate_controller.rs @@ -15,73 +15,73 @@ use burn_tensor::backend::Backend; /// the gate's output. #[derive(Module, Debug)] pub struct GateController { - /// Represents the affine transformation applied to input vector - pub(crate) input_transform: Linear, - /// Represents the affine transformation applied to the hidden state - pub(crate) hidden_transform: Linear, + /// Represents the affine transformation applied to input vector + pub(crate) input_transform: Linear, + /// Represents the affine transformation applied to the hidden state + pub(crate) hidden_transform: Linear, } impl GateController { - /// Initialize a new [gate_controller](GateController) module. - pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self { - Self { - input_transform: LinearConfig { - d_input, - d_output, - bias, - initializer: initializer.clone(), - } - .init(), - hidden_transform: LinearConfig { - d_input: d_output, - d_output, - bias, - initializer, - } - .init(), - } + /// Initialize a new [gate_controller](GateController) module. + pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self { + Self { + input_transform: LinearConfig { + d_input, + d_output, + bias, + initializer: initializer.clone(), + } + .init(), + hidden_transform: LinearConfig { + d_input: d_output, + d_output, + bias, + initializer, + } + .init(), } + } - /// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord). - pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord) -> Self { - let l1 = LinearConfig::init_with(linear_config, record.input_transform); - let l2 = LinearConfig::init_with(linear_config, record.hidden_transform); + /// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord). + pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord) -> Self { + let l1 = LinearConfig::init_with(linear_config, record.input_transform); + let l2 = LinearConfig::init_with(linear_config, record.hidden_transform); - Self { - input_transform: l1, - hidden_transform: l2, - } + Self { + input_transform: l1, + hidden_transform: l2, } + } - /// Used to initialize a gate controller with known weight layers, - /// allowing for predictable behavior. Used only for testing in - /// lstm. - #[cfg(test)] - pub fn create_with_weights( - d_input: usize, - d_output: usize, - bias: bool, - initializer: Initializer, - input_record: crate::nn::LinearRecord, - hidden_record: crate::nn::LinearRecord, - ) -> Self { - let l1 = LinearConfig { - d_input, - d_output, - bias, - initializer: initializer.clone(), - } - .init_with(input_record); - let l2 = LinearConfig { - d_input, - d_output, - bias, - initializer, - } - .init_with(hidden_record); - Self { - input_transform: l1, - hidden_transform: l2, - } + /// Used to initialize a gate controller with known weight layers, + /// allowing for predictable behavior. Used only for testing in + /// lstm. + #[cfg(test)] + pub fn create_with_weights( + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + input_record: crate::nn::LinearRecord, + hidden_record: crate::nn::LinearRecord, + ) -> Self { + let l1 = LinearConfig { + d_input, + d_output, + bias, + initializer: initializer.clone(), } + .init_with(input_record); + let l2 = LinearConfig { + d_input, + d_output, + bias, + initializer, + } + .init_with(hidden_record); + Self { + input_transform: l1, + hidden_transform: l2, + } + } } diff --git a/burn-core/src/nn/rnn/gru.rs b/burn-core/src/nn/rnn/gru.rs index bd2c4649f0..fa05c77229 100644 --- a/burn-core/src/nn/rnn/gru.rs +++ b/burn-core/src/nn/rnn/gru.rs @@ -14,266 +14,256 @@ use super::gate_controller::GateController; /// The configuration for a [gru](Gru) module. #[derive(Config)] pub struct GruConfig { - /// The size of the input features. - pub d_input: usize, - /// The size of the hidden state. - pub d_hidden: usize, - /// If a bias should be applied during the Gru transformation. - pub bias: bool, - /// Gru initializer - #[config(default = "Initializer::XavierNormal{gain:1.0}")] - pub initializer: Initializer, + /// The size of the input features. + pub d_input: usize, + /// The size of the hidden state. + pub d_hidden: usize, + /// If a bias should be applied during the Gru transformation. + pub bias: bool, + /// Gru initializer + #[config(default = "Initializer::XavierNormal{gain:1.0}")] + pub initializer: Initializer, } /// The Gru module. This implementation is for a unidirectional, stateless, Gru. #[derive(Module, Debug)] pub struct Gru { - update_gate: GateController, - reset_gate: GateController, - new_gate: GateController, - d_hidden: usize, + update_gate: GateController, + reset_gate: GateController, + new_gate: GateController, + d_hidden: usize, } impl GruConfig { - /// Initialize a new [gru](Gru) module. - pub fn init(&self) -> Gru { - let d_output = self.d_hidden; + /// Initialize a new [gru](Gru) module. + pub fn init(&self) -> Gru { + let d_output = self.d_hidden; - let update_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let reset_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let new_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); + let update_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let reset_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let new_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); - Gru { - update_gate, - reset_gate, - new_gate, - d_hidden: self.d_hidden, - } + Gru { + update_gate, + reset_gate, + new_gate, + d_hidden: self.d_hidden, } + } - /// Initialize a new [gru](Gru) module. - pub fn init_with(self, record: GruRecord) -> Gru { - let linear_config = LinearConfig { - d_input: self.d_input, - d_output: self.d_hidden, - bias: self.bias, - initializer: self.initializer.clone(), - }; + /// Initialize a new [gru](Gru) module. + pub fn init_with(self, record: GruRecord) -> Gru { + let linear_config = LinearConfig { + d_input: self.d_input, + d_output: self.d_hidden, + bias: self.bias, + initializer: self.initializer.clone(), + }; - Gru { - update_gate: gate_controller::GateController::new_with( - &linear_config, - record.update_gate, - ), - reset_gate: gate_controller::GateController::new_with( - &linear_config, - record.reset_gate, - ), - new_gate: gate_controller::GateController::new_with(&linear_config, record.new_gate), - d_hidden: self.d_hidden, - } + Gru { + update_gate: gate_controller::GateController::new_with(&linear_config, record.update_gate), + reset_gate: gate_controller::GateController::new_with(&linear_config, record.reset_gate), + new_gate: gate_controller::GateController::new_with(&linear_config, record.new_gate), + d_hidden: self.d_hidden, } + } } impl Gru { - /// Applies the forward pass on the input tensor. This GRU implementation - /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. - /// - /// Parameters: - /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. - /// state: An optional tensor representing an initial cell state with the same dimensions - /// as batched_input. If none is provided, one will be generated. - /// - /// Returns: - /// The resulting state tensor, with shape [batch_size, sequence_length, hidden_size]. - pub fn forward( - &self, - batched_input: Tensor, - state: Option>, - ) -> Tensor { - let [batch_size, seq_length, _] = batched_input.shape().dims; + /// Applies the forward pass on the input tensor. This GRU implementation + /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. + /// + /// Parameters: + /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. + /// state: An optional tensor representing an initial cell state with the same dimensions + /// as batched_input. If none is provided, one will be generated. + /// + /// Returns: + /// The resulting state tensor, with shape [batch_size, sequence_length, hidden_size]. + pub fn forward(&self, batched_input: Tensor, state: Option>) -> Tensor { + let [batch_size, seq_length, _] = batched_input.shape().dims; - let mut hidden_state = match state { - Some(state) => state, - None => Tensor::zeros([batch_size, seq_length, self.d_hidden]), - }; + let mut hidden_state = match state { + Some(state) => state, + None => Tensor::zeros([batch_size, seq_length, self.d_hidden]), + }; - for (t, (input_t, hidden_t)) in batched_input - .iter_dim(1) - .zip(hidden_state.clone().iter_dim(1)) - .enumerate() - { - let input_t = input_t.squeeze(1); - let hidden_t = hidden_t.squeeze(1); - // u(pdate)g(ate) tensors - let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); - let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) + for (t, (input_t, hidden_t)) in batched_input + .iter_dim(1) + .zip(hidden_state.clone().iter_dim(1)) + .enumerate() + { + let input_t = input_t.squeeze(1); + let hidden_t = hidden_t.squeeze(1); + // u(pdate)g(ate) tensors + let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); + let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) - // r(eset)g(ate) tensors - let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); - let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) - let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate + // r(eset)g(ate) tensors + let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); + let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) + let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate - // n(ew)g(ate) tensor - let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); - let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) + // n(ew)g(ate) tensor + let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); + let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) - // calculate linear interpolation between previous hidden state and candidate state: - // g(t) * (1 - z(t)) + z(t) * hidden_t - let state_vector = candidate_state + // calculate linear interpolation between previous hidden state and candidate state: + // g(t) * (1 - z(t)) + z(t) * hidden_t + let state_vector = candidate_state .clone() .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1) + update_values.clone().mul(hidden_t); - let current_shape = state_vector.shape().dims; - let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; - let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); - hidden_state = hidden_state.slice_assign( - [0..batch_size, t..(t + 1), 0..self.d_hidden], - reshaped_state_vector, - ); - } - - hidden_state + let current_shape = state_vector.shape().dims; + let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; + let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); + hidden_state = hidden_state.slice_assign( + [0..batch_size, t..(t + 1), 0..self.d_hidden], + reshaped_state_vector, + ); } - /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. - /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: - /// Wx = weight matrix for the connection to input vector X - /// Wh = weight matrix for the connection to hidden state H - /// X = input vector - /// H = hidden state - /// b = bias terms - fn gate_product( - &self, - input: &Tensor, - hidden: &Tensor, - gate: &GateController, - ) -> Tensor { - let input_product = input.clone().matmul(gate.input_transform.weight.val()); - let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); + hidden_state + } + + /// Helper function for performing weighted matrix product for a gate and adds + /// bias, if any. + /// + /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Wx = weight matrix for the connection to input vector X + /// Wh = weight matrix for the connection to hidden state H + /// X = input vector + /// H = hidden state + /// b = bias terms + fn gate_product( + &self, + input: &Tensor, + hidden: &Tensor, + gate: &GateController, + ) -> Tensor { + let input_product = input.clone().matmul(gate.input_transform.weight.val()); + let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); - let input_bias = gate - .input_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - let hidden_bias = gate - .hidden_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); + let input_bias = gate + .input_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + let hidden_bias = gate + .hidden_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { - input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() - } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, - } + match (input_bias, hidden_bias) { + (Some(input_bias), Some(hidden_bias)) => { + input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() + } + (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, + (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), + (None, None) => input_product + hidden_product, } + } } #[cfg(test)] mod tests { - use super::*; - use crate::{module::Param, nn::LinearRecord, TestBackend}; - use burn_tensor::{Data, Distribution}; + use super::*; + use crate::{module::Param, nn::LinearRecord, TestBackend}; + use burn_tensor::{Data, Distribution}; - /// Test forward pass with simple input vector. - /// - /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 - /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 - /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 - /// - /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 - #[test] - fn tests_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = GruConfig::new(1, 1, false); - let mut gru = config.init::(); + /// Test forward pass with simple input vector. + /// + /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 + /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 + /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 + /// + /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 + #[test] + fn tests_forward_single_input_single_feature() { + TestBackend::seed(0); + let config = GruConfig::new(1, 1, false); + let mut gru = config.init::(); - fn create_gate_controller( - weights: f32, - biases: f32, - d_input: usize, - d_output: usize, - bias: bool, - initializer: Initializer, - ) -> GateController { - let record = LinearRecord { - weight: Param::from(Tensor::from_data(Data::from([[weights]]))), - bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), - }; - gate_controller::GateController::create_with_weights( - d_input, - d_output, - bias, - initializer, - record.clone(), - record, - ) - } + fn create_gate_controller( + weights: f32, + biases: f32, + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + ) -> GateController { + let record = LinearRecord { + weight: Param::from(Tensor::from_data(Data::from([[weights]]))), + bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), + }; + gate_controller::GateController::create_with_weights( + d_input, + d_output, + bias, + initializer, + record.clone(), + record, + ) + } - gru.update_gate = create_gate_controller( - 0.5, - 0.0, - 1, - 1, - false, - Initializer::XavierNormal { gain: 1.0 }, - ); - gru.reset_gate = create_gate_controller( - 0.6, - 0.0, - 1, - 1, - false, - Initializer::XavierNormal { gain: 1.0 }, - ); - gru.new_gate = create_gate_controller( - 0.7, - 0.0, - 1, - 1, - false, - Initializer::XavierNormal { gain: 1.0 }, - ); + gru.update_gate = create_gate_controller( + 0.5, + 0.0, + 1, + 1, + false, + Initializer::XavierNormal { gain: 1.0 }, + ); + gru.reset_gate = create_gate_controller( + 0.6, + 0.0, + 1, + 1, + false, + Initializer::XavierNormal { gain: 1.0 }, + ); + gru.new_gate = create_gate_controller( + 0.7, + 0.0, + 1, + 1, + false, + Initializer::XavierNormal { gain: 1.0 }, + ); - let input = Tensor::::from_data(Data::from([[[0.1]]])); + let input = Tensor::::from_data(Data::from([[[0.1]]])); - let state = gru.forward(input, None); + let state = gru.forward(input, None); - let output = state.select(0, Tensor::arange(0..1)).squeeze(0); + let output = state.select(0, Tensor::arange(0..1)).squeeze(0); - output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3); - } + output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3); + } - #[test] - fn test_batched_forward_pass() { - let gru = GruConfig::new(64, 1024, true).init::(); - let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); + #[test] + fn test_batched_forward_pass() { + let gru = GruConfig::new(64, 1024, true).init::(); + let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); - let hidden_state = gru.forward(batched_input, None); + let hidden_state = gru.forward(batched_input, None); - assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); - } + assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); + } } diff --git a/burn-core/src/nn/rnn/lstm.rs b/burn-core/src/nn/rnn/lstm.rs index b49db35642..00b4b51e6f 100644 --- a/burn-core/src/nn/rnn/lstm.rs +++ b/burn-core/src/nn/rnn/lstm.rs @@ -14,323 +14,314 @@ use super::gate_controller::GateController; /// The configuration for a [lstm](Lstm) module. #[derive(Config)] pub struct LstmConfig { - /// The size of the input features. - pub d_input: usize, - /// The size of the hidden state. - pub d_hidden: usize, - /// If a bias should be applied during the Lstm transformation. - pub bias: bool, - /// Lstm initializer - #[config(default = "Initializer::XavierNormal{gain:1.0}")] - pub initializer: Initializer, + /// The size of the input features. + pub d_input: usize, + /// The size of the hidden state. + pub d_hidden: usize, + /// If a bias should be applied during the Lstm transformation. + pub bias: bool, + /// Lstm initializer + #[config(default = "Initializer::XavierNormal{gain:1.0}")] + pub initializer: Initializer, } /// The Lstm module. This implementation is for a unidirectional, stateless, Lstm. #[derive(Module, Debug)] pub struct Lstm { - input_gate: GateController, - forget_gate: GateController, - output_gate: GateController, - cell_gate: GateController, - d_hidden: usize, + input_gate: GateController, + forget_gate: GateController, + output_gate: GateController, + cell_gate: GateController, + d_hidden: usize, } impl LstmConfig { - /// Initialize a new [lstm](Lstm) module. - pub fn init(&self) -> Lstm { - let d_output = self.d_hidden; - - let input_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let forget_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let output_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let cell_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - - Lstm { - input_gate, - forget_gate, - output_gate, - cell_gate, - d_hidden: self.d_hidden, - } + /// Initialize a new [lstm](Lstm) module. + pub fn init(&self) -> Lstm { + let d_output = self.d_hidden; + + let input_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let forget_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let output_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let cell_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + + Lstm { + input_gate, + forget_gate, + output_gate, + cell_gate, + d_hidden: self.d_hidden, } - - /// Initialize a new [lstm](Lstm) module with a [record](LstmRecord). - pub fn init_with(&self, record: LstmRecord) -> Lstm { - let linear_config = LinearConfig { - d_input: self.d_input, - d_output: self.d_hidden, - bias: self.bias, - initializer: self.initializer.clone(), - }; - - Lstm { - input_gate: gate_controller::GateController::new_with( - &linear_config, - record.input_gate, - ), - forget_gate: gate_controller::GateController::new_with( - &linear_config, - record.forget_gate, - ), - output_gate: gate_controller::GateController::new_with( - &linear_config, - record.output_gate, - ), - cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate), - d_hidden: self.d_hidden, - } + } + + /// Initialize a new [lstm](Lstm) module with a [record](LstmRecord). + pub fn init_with(&self, record: LstmRecord) -> Lstm { + let linear_config = LinearConfig { + d_input: self.d_input, + d_output: self.d_hidden, + bias: self.bias, + initializer: self.initializer.clone(), + }; + + Lstm { + input_gate: gate_controller::GateController::new_with(&linear_config, record.input_gate), + forget_gate: gate_controller::GateController::new_with(&linear_config, record.forget_gate), + output_gate: gate_controller::GateController::new_with(&linear_config, record.output_gate), + cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate), + d_hidden: self.d_hidden, } + } } impl Lstm { - /// Applies the forward pass on the input tensor. This LSTM implementation - /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), - /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size]. - /// - /// Parameters: - /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. - /// state: An optional tuple of tensors representing the initial cell state and hidden state. - /// Each state tensor has shape [batch_size, hidden_size]. - /// If no initial state is provided, these tensors are initialized to zeros. - /// - /// Returns: - /// A tuple of tensors, where the first tensor represents the cell states and - /// the second tensor represents the hidden states for each sequence element. - /// Both output tensors have the shape [batch_size, sequence_length, hidden_size]. - pub fn forward( - &self, - batched_input: Tensor, - state: Option<(Tensor, Tensor)>, - ) -> (Tensor, Tensor) { - let [batch_size, seq_length, _] = batched_input.shape().dims; - let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); - let mut batched_hidden_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); - - let (mut cell_state, mut hidden_state) = match state { - Some((cell_state, hidden_state)) => (cell_state, hidden_state), - None => ( - Tensor::zeros([batch_size, self.d_hidden]), - Tensor::zeros([batch_size, self.d_hidden]), - ), - }; - - for (t, input_t) in batched_input.iter_dim(1).enumerate() { - let input_t = input_t.squeeze(1); - // f(orget)g(ate) tensors - let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); - let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state - - // i(nput)g(ate) tensors - let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate); - let add_values = activation::sigmoid(biased_ig_input_sum); - - // o(output)g(ate) tensors - let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate); - let output_values = activation::sigmoid(biased_og_input_sum); - - // c(ell)g(ate) tensors - let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate); - let candidate_cell_values = biased_cg_input_sum.tanh(); - - cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; - hidden_state = output_values * cell_state.clone().tanh(); - - let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]]; - - let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape); - let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape); - - // store the state for this timestep - batched_cell_state = batched_cell_state.slice_assign( - [0..batch_size, t..(t + 1), 0..self.d_hidden], - unsqueezed_cell_state.clone(), - ); - batched_hidden_state = batched_hidden_state.slice_assign( - [0..batch_size, t..(t + 1), 0..self.d_hidden], - unsqueezed_hidden_state.clone(), - ); - } - - (batched_cell_state, batched_hidden_state) + /// Applies the forward pass on the input tensor. This LSTM implementation + /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), + /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size]. + /// + /// Parameters: + /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. + /// state: An optional tuple of tensors representing the initial cell state and hidden state. + /// Each state tensor has shape [batch_size, hidden_size]. + /// If no initial state is provided, these tensors are initialized to zeros. + /// + /// Returns: + /// A tuple of tensors, where the first tensor represents the cell states and + /// the second tensor represents the hidden states for each sequence element. + /// Both output tensors have the shape [batch_size, sequence_length, hidden_size]. + pub fn forward( + &self, + batched_input: Tensor, + state: Option<(Tensor, Tensor)>, + ) -> (Tensor, Tensor) { + let [batch_size, seq_length, _] = batched_input.shape().dims; + let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); + let mut batched_hidden_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); + + let (mut cell_state, mut hidden_state) = match state { + Some((cell_state, hidden_state)) => (cell_state, hidden_state), + None => ( + Tensor::zeros([batch_size, self.d_hidden]), + Tensor::zeros([batch_size, self.d_hidden]), + ), + }; + + for (t, input_t) in batched_input.iter_dim(1).enumerate() { + let input_t = input_t.squeeze(1); + // f(orget)g(ate) tensors + let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); + let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state + + // i(nput)g(ate) tensors + let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate); + let add_values = activation::sigmoid(biased_ig_input_sum); + + // o(output)g(ate) tensors + let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate); + let output_values = activation::sigmoid(biased_og_input_sum); + + // c(ell)g(ate) tensors + let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate); + let candidate_cell_values = biased_cg_input_sum.tanh(); + + cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; + hidden_state = output_values * cell_state.clone().tanh(); + + let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]]; + + let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape); + let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape); + + // store the state for this timestep + batched_cell_state = batched_cell_state.slice_assign( + [0..batch_size, t..(t + 1), 0..self.d_hidden], + unsqueezed_cell_state.clone(), + ); + batched_hidden_state = batched_hidden_state.slice_assign( + [0..batch_size, t..(t + 1), 0..self.d_hidden], + unsqueezed_hidden_state.clone(), + ); } - /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. - /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: - /// Wx = weight matrix for the connection to input vector X - /// Wh = weight matrix for the connection to hidden state H - /// X = input vector - /// H = hidden state - /// b = bias terms - fn gate_product( - &self, - input: &Tensor, - hidden: &Tensor, - gate: &GateController, - ) -> Tensor { - let input_product = input.clone().matmul(gate.input_transform.weight.val()); - let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); - - let input_bias = gate - .input_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - let hidden_bias = gate - .hidden_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { - input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() - } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, - } + (batched_cell_state, batched_hidden_state) + } + + /// Helper function for performing weighted matrix product for a gate and adds + /// bias, if any. + /// + /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Wx = weight matrix for the connection to input vector X + /// Wh = weight matrix for the connection to hidden state H + /// X = input vector + /// H = hidden state + /// b = bias terms + fn gate_product( + &self, + input: &Tensor, + hidden: &Tensor, + gate: &GateController, + ) -> Tensor { + let input_product = input.clone().matmul(gate.input_transform.weight.val()); + let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); + + let input_bias = gate + .input_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + let hidden_bias = gate + .hidden_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + + match (input_bias, hidden_bias) { + (Some(input_bias), Some(hidden_bias)) => { + input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() + } + (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, + (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), + (None, None) => input_product + hidden_product, } + } } #[cfg(test)] mod tests { - use super::*; - use crate::{module::Param, nn::LinearRecord, TestBackend}; - use burn_tensor::{Data, Distribution}; - - #[test] - fn test_with_uniform_initializer() { - TestBackend::seed(0); - - let config = LstmConfig::new(5, 5, false) - .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 }); - let lstm = config.init::(); - - let gate_to_data = - |gate: GateController| gate.input_transform.weight.val().to_data(); - - gate_to_data(lstm.input_gate).assert_within_range(0..1); - gate_to_data(lstm.forget_gate).assert_within_range(0..1); - gate_to_data(lstm.output_gate).assert_within_range(0..1); - gate_to_data(lstm.cell_gate).assert_within_range(0..1); + use super::*; + use crate::{module::Param, nn::LinearRecord, TestBackend}; + use burn_tensor::{Data, Distribution}; + + #[test] + fn test_with_uniform_initializer() { + TestBackend::seed(0); + + let config = + LstmConfig::new(5, 5, false).with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 }); + let lstm = config.init::(); + + let gate_to_data = + |gate: GateController| gate.input_transform.weight.val().to_data(); + + gate_to_data(lstm.input_gate).assert_within_range(0..1); + gate_to_data(lstm.forget_gate).assert_within_range(0..1); + gate_to_data(lstm.output_gate).assert_within_range(0..1); + gate_to_data(lstm.cell_gate).assert_within_range(0..1); + } + + /// Test forward pass with simple input vector. + /// + /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928 + /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725 + /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723 + /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937 + + /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 + /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 + #[test] + fn test_forward_single_input_single_feature() { + TestBackend::seed(0); + let config = LstmConfig::new(1, 1, false); + let mut lstm = config.init::(); + + fn create_gate_controller( + weights: f32, + biases: f32, + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + ) -> GateController { + let record = LinearRecord { + weight: Param::from(Tensor::from_data(Data::from([[weights]]))), + bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), + }; + gate_controller::GateController::create_with_weights( + d_input, + d_output, + bias, + initializer, + record.clone(), + record, + ) } - /// Test forward pass with simple input vector. - /// - /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928 - /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725 - /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723 - /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937 - - /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 - /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 - #[test] - fn test_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = LstmConfig::new(1, 1, false); - let mut lstm = config.init::(); - - fn create_gate_controller( - weights: f32, - biases: f32, - d_input: usize, - d_output: usize, - bias: bool, - initializer: Initializer, - ) -> GateController { - let record = LinearRecord { - weight: Param::from(Tensor::from_data(Data::from([[weights]]))), - bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), - }; - gate_controller::GateController::create_with_weights( - d_input, - d_output, - bias, - initializer, - record.clone(), - record, - ) - } - - lstm.input_gate = create_gate_controller( - 0.5, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - lstm.forget_gate = create_gate_controller( - 0.7, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - lstm.cell_gate = create_gate_controller( - 0.9, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - lstm.output_gate = create_gate_controller( - 1.1, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - - // single timestep with single feature - let input = Tensor::::from_data(Data::from([[[0.1]]])); - - let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); - let cell_state = cell_state_batch.select(0, Tensor::arange(0..1)).squeeze(0); - let hidden_state = hidden_state_batch - .select(0, Tensor::arange(0..1)) - .squeeze(0); - cell_state - .to_data() - .assert_approx_eq(&Data::from([[0.046]]), 3); - hidden_state - .to_data() - .assert_approx_eq(&Data::from([[0.024]]), 3) - } - - #[test] - fn test_batched_forward_pass() { - let lstm = LstmConfig::new(64, 1024, true).init::(); - let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); - - let (cell_state, hidden_state) = lstm.forward(batched_input, None); - - assert_eq!(cell_state.shape().dims, [8, 10, 1024]); - assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); - } + lstm.input_gate = create_gate_controller( + 0.5, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + lstm.forget_gate = create_gate_controller( + 0.7, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + lstm.cell_gate = create_gate_controller( + 0.9, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + lstm.output_gate = create_gate_controller( + 1.1, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + + // single timestep with single feature + let input = Tensor::::from_data(Data::from([[[0.1]]])); + + let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); + let cell_state = cell_state_batch.select(0, Tensor::arange(0..1)).squeeze(0); + let hidden_state = hidden_state_batch + .select(0, Tensor::arange(0..1)) + .squeeze(0); + cell_state + .to_data() + .assert_approx_eq(&Data::from([[0.046]]), 3); + hidden_state + .to_data() + .assert_approx_eq(&Data::from([[0.024]]), 3) + } + + #[test] + fn test_batched_forward_pass() { + let lstm = LstmConfig::new(64, 1024, true).init::(); + let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); + + let (cell_state, hidden_state) = lstm.forward(batched_input, None); + + assert_eq!(cell_state.shape().dims, [8, 10, 1024]); + assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); + } } diff --git a/burn-core/src/nn/transformer/decoder.rs b/burn-core/src/nn/transformer/decoder.rs index db5afd5b74..3030fecf75 100644 --- a/burn-core/src/nn/transformer/decoder.rs +++ b/burn-core/src/nn/transformer/decoder.rs @@ -2,43 +2,49 @@ use alloc::vec::Vec; use burn_tensor::Bool; use crate::{ - self as burn, - nn::{attention::MhaCache, cache::TensorCache, Initializer}, + self as burn, + nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ - config::Config, - module::Module, - nn::{ - attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, - Dropout, DropoutConfig, LayerNorm, LayerNormConfig, - }, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{ + attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, + Dropout, DropoutConfig, LayerNorm, LayerNormConfig, + }, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [Transformer Decoder](TransformerDecoder) layer. #[derive(Config)] pub struct TransformerDecoderConfig { - /// The size of the model. - pub d_model: usize, - /// The size of the position-wise feed-forward network. - pub d_ff: usize, - /// The number of attention heads. - pub n_heads: usize, - /// The number of layers. - pub n_layers: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - pub dropout: f64, - /// Layer norm will be applied first instead of after the other modules. - #[config(default = false)] - pub norm_first: bool, - /// The type of function used to initialize neural network parameters - #[config( - default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" - )] - pub initializer: Initializer, + /// The size of the model. + pub d_model: usize, + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + /// The number of attention heads. + pub n_heads: usize, + /// The number of layers. + pub n_layers: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + pub dropout: f64, + /// Layer norm will be applied first instead of after the other modules. + #[config(default = false)] + pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] + pub initializer: Initializer, } /// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -48,404 +54,400 @@ pub struct TransformerDecoderConfig { /// - layers: transformer decoder layers with `d_model` input and output features. #[derive(Module, Debug)] pub struct TransformerDecoder { - layers: Vec>, + layers: Vec>, } impl TransformerDecoderConfig { - /// Initialize a new [Transformer Decoder](TransformerDecoder) module. - pub fn init(&self) -> TransformerDecoder { - let layers = (0..self.n_layers) - .map(|_| TransformerDecoderLayer::new(self)) - .collect::>(); - - TransformerDecoder { layers } - } - - /// Initialize a new [Transformer Decoder](TransformerDecoder) module with a record. - /// - /// # Params - /// - /// - record: the record to initialize the module with. - pub fn init_with( - &self, - record: TransformerDecoderRecord, - ) -> TransformerDecoder { - TransformerDecoder { - layers: record - .layers - .into_iter() - .map(|record| TransformerDecoderLayer::new_with(self, record)) - .collect(), - } + /// Initialize a new [Transformer Decoder](TransformerDecoder) module. + pub fn init(&self) -> TransformerDecoder { + let layers = (0..self.n_layers) + .map(|_| TransformerDecoderLayer::new(self)) + .collect::>(); + + TransformerDecoder { layers } + } + + /// Initialize a new [Transformer Decoder](TransformerDecoder) module with a record. + /// + /// # Params + /// + /// - record: the record to initialize the module with. + pub fn init_with( + &self, + record: TransformerDecoderRecord, + ) -> TransformerDecoder { + TransformerDecoder { + layers: record + .layers + .into_iter() + .map(|record| TransformerDecoderLayer::new_with(self, record)) + .collect(), } + } } /// [Transformer Decoder](TransformerDecoder) forward pass input argument. #[derive(Debug)] pub struct TransformerDecoderInput { - target: Tensor, - target_mask_pad: Option>, - target_mask_attn: Option>, - memory: Tensor, - memory_mask_pad: Option>, - memory_mask_attn: Option>, + target: Tensor, + target_mask_pad: Option>, + target_mask_attn: Option>, + memory: Tensor, + memory_mask_pad: Option>, + memory_mask_attn: Option>, } impl TransformerDecoderInput { - /// Create a [transformer decoder](TransformerDecoder) input argument. - pub fn new(target: Tensor, memory: Tensor) -> Self { - Self { - target, - target_mask_pad: None, - target_mask_attn: None, - memory, - memory_mask_pad: None, - memory_mask_attn: None, - } - } - - /// Register the memory padding mask. - pub fn memory_mask_pad(mut self, mask_pad: Tensor) -> Self { - self.memory_mask_pad = Some(mask_pad); - self - } - - /// Register the memory attention mask. - pub fn memory_mask_attn(mut self, mask_attn: Tensor) -> Self { - self.memory_mask_attn = Some(mask_attn); - self - } - - /// Register the target padding mask. - pub fn target_mask_pad(mut self, mask_pad: Tensor) -> Self { - self.target_mask_pad = Some(mask_pad); - self - } - - /// Register the target attention mask. - pub fn target_mask_attn(mut self, mask_attn: Tensor) -> Self { - self.target_mask_attn = Some(mask_attn); - self + /// Create a [transformer decoder](TransformerDecoder) input argument. + pub fn new(target: Tensor, memory: Tensor) -> Self { + Self { + target, + target_mask_pad: None, + target_mask_attn: None, + memory, + memory_mask_pad: None, + memory_mask_attn: None, } + } + + /// Register the memory padding mask. + pub fn memory_mask_pad(mut self, mask_pad: Tensor) -> Self { + self.memory_mask_pad = Some(mask_pad); + self + } + + /// Register the memory attention mask. + pub fn memory_mask_attn(mut self, mask_attn: Tensor) -> Self { + self.memory_mask_attn = Some(mask_attn); + self + } + + /// Register the target padding mask. + pub fn target_mask_pad(mut self, mask_pad: Tensor) -> Self { + self.target_mask_pad = Some(mask_pad); + self + } + + /// Register the target attention mask. + pub fn target_mask_attn(mut self, mask_attn: Tensor) -> Self { + self.target_mask_attn = Some(mask_attn); + self + } } /// [Transformer Decoder](TransformerDecoder) layer module. #[derive(Module, Debug)] pub struct TransformerDecoderLayer { - cross_attn: MultiHeadAttention, - self_attn: MultiHeadAttention, - pwff: PositionWiseFeedForward, - norm_1: LayerNorm, - norm_2: LayerNorm, - norm_3: LayerNorm, - dropout: Dropout, - norm_first: bool, + cross_attn: MultiHeadAttention, + self_attn: MultiHeadAttention, + pwff: PositionWiseFeedForward, + norm_1: LayerNorm, + norm_2: LayerNorm, + norm_3: LayerNorm, + dropout: Dropout, + norm_first: bool, } struct TransformerDecoderLayerAutoregressiveCache { - cross_attn: MhaCache, - self_attn: MhaCache, - pwff: TensorCache, - norm_1: TensorCache, - norm_2: TensorCache, - norm_3: TensorCache, + cross_attn: MhaCache, + self_attn: MhaCache, + pwff: TensorCache, + norm_1: TensorCache, + norm_2: TensorCache, + norm_3: TensorCache, } impl TransformerDecoderLayerAutoregressiveCache { - fn empty() -> Self { - Self { - cross_attn: MhaCache::autoregressive_cross_attention(), - self_attn: MhaCache::autoregressive(), - pwff: TensorCache::empty(), - norm_1: TensorCache::empty(), - norm_2: TensorCache::empty(), - norm_3: TensorCache::empty(), - } + fn empty() -> Self { + Self { + cross_attn: MhaCache::autoregressive_cross_attention(), + self_attn: MhaCache::autoregressive(), + pwff: TensorCache::empty(), + norm_1: TensorCache::empty(), + norm_2: TensorCache::empty(), + norm_3: TensorCache::empty(), } + } } /// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer. /// /// To be used during inference when decoding tokens. pub struct TransformerDecoderAutoregressiveCache { - layers: Vec>, + layers: Vec>, } impl TransformerDecoderAutoregressiveCache { - fn empty(num_layers: usize) -> Self { - Self { - layers: (0..num_layers) - .map(|_| TransformerDecoderLayerAutoregressiveCache::empty()) - .collect(), - } + fn empty(num_layers: usize) -> Self { + Self { + layers: (0..num_layers) + .map(|_| TransformerDecoderLayerAutoregressiveCache::empty()) + .collect(), } + } } impl TransformerDecoderLayer { - fn new(config: &TransformerDecoderConfig) -> Self { - let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init(); - - let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init(); - let norm_1 = LayerNormConfig::new(config.d_model).init(); - let norm_2 = LayerNormConfig::new(config.d_model).init(); - let norm_3 = LayerNormConfig::new(config.d_model).init(); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_dropout(config.dropout) - .init(); - - Self { - cross_attn, - self_attn, - norm_1, - norm_2, - norm_3, - pwff, - dropout, - norm_first: config.norm_first, - } + fn new(config: &TransformerDecoderConfig) -> Self { + let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init(); + + let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init(); + let norm_1 = LayerNormConfig::new(config.d_model).init(); + let norm_2 = LayerNormConfig::new(config.d_model).init(); + let norm_3 = LayerNormConfig::new(config.d_model).init(); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_dropout(config.dropout) + .init(); + + Self { + cross_attn, + self_attn, + norm_1, + norm_2, + norm_3, + pwff, + dropout, + norm_first: config.norm_first, + } + } + + fn new_with(config: &TransformerDecoderConfig, record: TransformerDecoderLayerRecord) -> Self { + let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init_with(record.self_attn); + let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init_with(record.cross_attn); + let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); + let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); + let norm_3 = LayerNormConfig::new(config.d_model).init_with(record.norm_3); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_dropout(config.dropout) + .init_with(record.pwff); + + Self { + cross_attn, + self_attn, + norm_1, + norm_2, + norm_3, + pwff, + dropout, + norm_first: config.norm_first, } + } - fn new_with( - config: &TransformerDecoderConfig, - record: TransformerDecoderLayerRecord, - ) -> Self { - let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init_with(record.self_attn); - let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init_with(record.cross_attn); - let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); - let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); - let norm_3 = LayerNormConfig::new(config.d_model).init_with(record.norm_3); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_dropout(config.dropout) - .init_with(record.pwff); - - Self { - cross_attn, - self_attn, - norm_1, - norm_2, - norm_3, - pwff, - dropout, - norm_first: config.norm_first, - } + fn forward(&self, mut input: TransformerDecoderInput) -> TransformerDecoderInput { + let mut x_0 = input.target; + + if self.norm_first { + x_0 = self.norm_3.forward(x_0); } - fn forward(&self, mut input: TransformerDecoderInput) -> TransformerDecoderInput { - let mut x_0 = input.target; - - if self.norm_first { - x_0 = self.norm_3.forward(x_0); - } - - let mut self_attn_input = MhaInput::self_attn(x_0.clone()); - if let Some(mask_pad) = &input.target_mask_pad { - self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); - } - if let Some(mask_attn) = &input.target_mask_attn { - self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); - } - - let x_1 = self.self_attn.forward(self_attn_input); - let x_1 = self.dropout.forward(x_1.context) + x_0; - let x_1 = self.norm_1.forward(x_1); - - let mut cross_attn_input = - MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); - if let Some(mask_pad) = &input.memory_mask_pad { - cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone()); - } - if let Some(mask_attn) = &input.memory_mask_attn { - cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone()); - } - - let x_2 = self.cross_attn.forward(cross_attn_input); - let x_2 = self.dropout.forward(x_2.context) + x_1; - let x_2 = self.norm_2.forward(x_2); - - let x_3 = self.pwff.forward(x_2.clone()); - let mut x_3 = self.dropout.forward(x_3) + x_2; - - if !self.norm_first { - x_3 = self.norm_3.forward(x_3) - } - - input.target = x_3; - input + let mut self_attn_input = MhaInput::self_attn(x_0.clone()); + if let Some(mask_pad) = &input.target_mask_pad { + self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); + } + if let Some(mask_attn) = &input.target_mask_attn { + self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); } - fn forward_autoregressive_inference( - &self, - mut input: TransformerDecoderInput, - cache: &mut TransformerDecoderLayerAutoregressiveCache, - ) -> TransformerDecoderInput { - let mut x_0 = input.target; - - if self.norm_first { - x_0 = cache - .norm_3 - .forward_autoregressive(x_0, 1, |x| self.norm_3.forward(x)); - } - - let mut self_attn_input = MhaInput::self_attn(x_0.clone()); - if let Some(mask_pad) = &input.target_mask_pad { - self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); - } - if let Some(mask_attn) = &input.target_mask_attn { - self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); - } - - let x_1 = self - .self_attn - .forward_cache(self_attn_input, &mut cache.self_attn); - let x_1 = self.dropout.forward(x_1.context) + x_0; - let x_1 = cache - .norm_1 - .forward_autoregressive(x_1, 1, |x| self.norm_1.forward(x)); - - let mut mha_input = MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); - if let Some(mask_pad) = &input.memory_mask_pad { - mha_input = mha_input.mask_pad(mask_pad.clone()); - } - if let Some(mask_attn) = &input.memory_mask_attn { - mha_input = mha_input.mask_attn(mask_attn.clone()); - } - - let x_2 = self - .cross_attn - .forward_cache(mha_input, &mut cache.cross_attn); - let x_2 = self.dropout.forward(x_2.context) + x_1; - let x_2 = cache - .norm_2 - .forward_autoregressive(x_2, 1, |x| self.norm_2.forward(x)); - - let x_3 = cache - .pwff - .forward_autoregressive(x_2.clone(), 1, |x| self.pwff.forward(x)); - let mut x_3 = self.dropout.forward(x_3) + x_2; - - if !self.norm_first { - x_3 = cache - .norm_3 - .forward_autoregressive(x_3, 1, |x| self.norm_3.forward(x)); - } - - input.target = x_3; - input + let x_1 = self.self_attn.forward(self_attn_input); + let x_1 = self.dropout.forward(x_1.context) + x_0; + let x_1 = self.norm_1.forward(x_1); + + let mut cross_attn_input = + MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); + if let Some(mask_pad) = &input.memory_mask_pad { + cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone()); + } + if let Some(mask_attn) = &input.memory_mask_attn { + cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone()); } -} -impl TransformerDecoder { - /// Applies the forward pass. - pub fn forward(&self, mut input: TransformerDecoderInput) -> Tensor { - for layer in self.layers.iter() { - input = layer.forward(input); - } + let x_2 = self.cross_attn.forward(cross_attn_input); + let x_2 = self.dropout.forward(x_2.context) + x_1; + let x_2 = self.norm_2.forward(x_2); - input.target + let x_3 = self.pwff.forward(x_2.clone()); + let mut x_3 = self.dropout.forward(x_3) + x_2; + + if !self.norm_first { + x_3 = self.norm_3.forward(x_3) } - /// Applies the forward pass on the input using autoregressive cache. - pub fn forward_autoregressive_inference( - &self, - mut input: TransformerDecoderInput, - cache: &mut TransformerDecoderAutoregressiveCache, - ) -> Tensor { - for i in 0..self.layers.len() { - let layer = self.layers.get(i).unwrap(); - let cache = cache.layers.get_mut(i).unwrap(); + input.target = x_3; + input + } + + fn forward_autoregressive_inference( + &self, + mut input: TransformerDecoderInput, + cache: &mut TransformerDecoderLayerAutoregressiveCache, + ) -> TransformerDecoderInput { + let mut x_0 = input.target; + + if self.norm_first { + x_0 = cache + .norm_3 + .forward_autoregressive(x_0, 1, |x| self.norm_3.forward(x)); + } - input = layer.forward_autoregressive_inference(input, cache); - } + let mut self_attn_input = MhaInput::self_attn(x_0.clone()); + if let Some(mask_pad) = &input.target_mask_pad { + self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); + } + if let Some(mask_attn) = &input.target_mask_attn { + self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); + } - input.target + let x_1 = self + .self_attn + .forward_cache(self_attn_input, &mut cache.self_attn); + let x_1 = self.dropout.forward(x_1.context) + x_0; + let x_1 = cache + .norm_1 + .forward_autoregressive(x_1, 1, |x| self.norm_1.forward(x)); + + let mut mha_input = MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); + if let Some(mask_pad) = &input.memory_mask_pad { + mha_input = mha_input.mask_pad(mask_pad.clone()); } - /// Create an empty autoregressive cache. - pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache { - TransformerDecoderAutoregressiveCache::empty(self.layers.len()) + if let Some(mask_attn) = &input.memory_mask_attn { + mha_input = mha_input.mask_attn(mask_attn.clone()); } + + let x_2 = self + .cross_attn + .forward_cache(mha_input, &mut cache.cross_attn); + let x_2 = self.dropout.forward(x_2.context) + x_1; + let x_2 = cache + .norm_2 + .forward_autoregressive(x_2, 1, |x| self.norm_2.forward(x)); + + let x_3 = cache + .pwff + .forward_autoregressive(x_2.clone(), 1, |x| self.pwff.forward(x)); + let mut x_3 = self.dropout.forward(x_3) + x_2; + + if !self.norm_first { + x_3 = cache + .norm_3 + .forward_autoregressive(x_3, 1, |x| self.norm_3.forward(x)); + } + + input.target = x_3; + input + } } -#[cfg(test)] -mod tests { - use super::*; - use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; - use burn_tensor::Distribution; - - #[test] - fn test_autoregressive_norm_last() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - TestBackend::seed(0); - - test_autoregressive( - TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers) - .with_norm_first(false), - ) +impl TransformerDecoder { + /// Applies the forward pass. + pub fn forward(&self, mut input: TransformerDecoderInput) -> Tensor { + for layer in self.layers.iter() { + input = layer.forward(input); } - #[test] - fn test_autoregressive_norm_first() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - TestBackend::seed(0); + input.target + } + + /// Applies the forward pass on the input using autoregressive cache. + pub fn forward_autoregressive_inference( + &self, + mut input: TransformerDecoderInput, + cache: &mut TransformerDecoderAutoregressiveCache, + ) -> Tensor { + for i in 0..self.layers.len() { + let layer = self.layers.get(i).unwrap(); + let cache = cache.layers.get_mut(i).unwrap(); - test_autoregressive( - TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), - ) + input = layer.forward_autoregressive_inference(input, cache); } - fn test_autoregressive(config: TransformerDecoderConfig) { - let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; - let transformer = config.init(); - - let memory = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - ); - let target = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - ); - let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device()); - let input = TransformerDecoderInput::new(target.clone(), memory.clone()) - .target_mask_attn(mask_attn); - - // Normal forward using masking. - let output_1 = transformer.forward(input); - - // Forward using the autoregressive cache. - let mut output_2 = Vec::new(); - let mut cache = transformer.new_autoregressive_cache(); - - for i in 1..seq_length + 1 { - let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]); - - let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device()); - let input = TransformerDecoderInput::new(target.clone(), memory.clone()) - .target_mask_attn(mask_attn); - let next_tok = transformer // Greedy sampling - .forward_autoregressive_inference(input, &mut cache) - .slice([0..batch_size, i - 1..i, 0..d_model]); - output_2.push(next_tok); - } - - let output_2 = Tensor::cat(output_2, 1); - - // Should produce the same tokens. - output_1 - .into_data() - .assert_approx_eq(&output_2.into_data(), 3); + input.target + } + /// Create an empty autoregressive cache. + pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache { + TransformerDecoderAutoregressiveCache::empty(self.layers.len()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use burn_tensor::Distribution; + + #[test] + fn test_autoregressive_norm_last() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + TestBackend::seed(0); + + test_autoregressive( + TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(false), + ) + } + + #[test] + fn test_autoregressive_norm_first() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + TestBackend::seed(0); + + test_autoregressive( + TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), + ) + } + + fn test_autoregressive(config: TransformerDecoderConfig) { + let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; + let transformer = config.init(); + + let memory = + Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); + let target = + Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device()); + let input = + TransformerDecoderInput::new(target.clone(), memory.clone()).target_mask_attn(mask_attn); + + // Normal forward using masking. + let output_1 = transformer.forward(input); + + // Forward using the autoregressive cache. + let mut output_2 = Vec::new(); + let mut cache = transformer.new_autoregressive_cache(); + + for i in 1..seq_length + 1 { + let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]); + + let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device()); + let input = + TransformerDecoderInput::new(target.clone(), memory.clone()).target_mask_attn(mask_attn); + let next_tok = transformer // Greedy sampling + .forward_autoregressive_inference(input, &mut cache) + .slice([0..batch_size, i - 1..i, 0..d_model]); + output_2.push(next_tok); } + + let output_2 = Tensor::cat(output_2, 1); + + // Should produce the same tokens. + output_1 + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/transformer/encoder.rs b/burn-core/src/nn/transformer/encoder.rs index 3d5c17d601..82bef97fb5 100644 --- a/burn-core/src/nn/transformer/encoder.rs +++ b/burn-core/src/nn/transformer/encoder.rs @@ -2,43 +2,49 @@ use alloc::vec::Vec; use burn_tensor::Bool; use crate::{ - self as burn, - nn::{attention::MhaCache, cache::TensorCache, Initializer}, + self as burn, + nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ - config::Config, - module::Module, - nn::{ - attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, - Dropout, DropoutConfig, LayerNorm, LayerNormConfig, - }, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{ + attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, + Dropout, DropoutConfig, LayerNorm, LayerNormConfig, + }, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [Transformer Encoder](TransformerEncoder) layer. #[derive(Config)] pub struct TransformerEncoderConfig { - /// The size of the model. - pub d_model: usize, - /// The size of the position-wise feed-forward network. - pub d_ff: usize, - /// The number of attention heads. - pub n_heads: usize, - /// The number of layers. - pub n_layers: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - pub dropout: f64, - /// Layer norm will be applied first instead of after the other modules. - #[config(default = false)] - pub norm_first: bool, - /// The type of function used to initialize neural network parameters - #[config( - default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" - )] - pub initializer: Initializer, + /// The size of the model. + pub d_model: usize, + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + /// The number of attention heads. + pub n_heads: usize, + /// The number of layers. + pub n_layers: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + pub dropout: f64, + /// Layer norm will be applied first instead of after the other modules. + #[config(default = false)] + pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] + pub initializer: Initializer, } /// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -48,338 +54,334 @@ pub struct TransformerEncoderConfig { /// - layers: transformer encoder layers with `d_model` input and output features. #[derive(Module, Debug)] pub struct TransformerEncoder { - layers: Vec>, + layers: Vec>, } /// [Transformer Encoder](TransformerEncoder) forward pass input argument. #[derive(Debug)] pub struct TransformerEncoderInput { - tensor: Tensor, - mask_pad: Option>, - mask_attn: Option>, + tensor: Tensor, + mask_pad: Option>, + mask_attn: Option>, } impl TransformerEncoderInput { - /// Create a [transformer encoder](TransformerEncoder) input argument. - pub fn new(tensor: Tensor) -> Self { - Self { - tensor, - mask_pad: None, - mask_attn: None, - } - } - - /// Register the padding mask. - pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { - self.mask_pad = Some(mask_pad); - self - } - - /// Register the attention mask. - pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { - self.mask_attn = Some(mask_attn); - self + /// Create a [transformer encoder](TransformerEncoder) input argument. + pub fn new(tensor: Tensor) -> Self { + Self { + tensor, + mask_pad: None, + mask_attn: None, } + } + + /// Register the padding mask. + pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { + self.mask_pad = Some(mask_pad); + self + } + + /// Register the attention mask. + pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { + self.mask_attn = Some(mask_attn); + self + } } impl TransformerEncoderConfig { - /// Initialize a new [transformer encoder](TransformerEncoder) module. - pub fn init(&self) -> TransformerEncoder { - let layers = (0..self.n_layers) - .map(|_| TransformerEncoderLayer::new(self)) - .collect::>(); - - TransformerEncoder { layers } - } - /// Initialize a new [transformer encoder](TransformerEncoder) module with a - /// [record](TransformerEncoderRecord). - pub fn init_with( - &self, - record: TransformerEncoderRecord, - ) -> TransformerEncoder { - TransformerEncoder { - layers: record - .layers - .into_iter() - .map(|record| TransformerEncoderLayer::new_with(self, record)) - .collect(), - } + /// Initialize a new [transformer encoder](TransformerEncoder) module. + pub fn init(&self) -> TransformerEncoder { + let layers = (0..self.n_layers) + .map(|_| TransformerEncoderLayer::new(self)) + .collect::>(); + + TransformerEncoder { layers } + } + /// Initialize a new [transformer encoder](TransformerEncoder) module with a + /// [record](TransformerEncoderRecord). + pub fn init_with( + &self, + record: TransformerEncoderRecord, + ) -> TransformerEncoder { + TransformerEncoder { + layers: record + .layers + .into_iter() + .map(|record| TransformerEncoderLayer::new_with(self, record)) + .collect(), } + } } impl TransformerEncoder { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - tensor: `[batch_size, seq_length, d_model]` - /// - output: `[batch_size, seq_length, d_model]` - pub fn forward(&self, input: TransformerEncoderInput) -> Tensor { - let mut x = input.tensor; - - for layer in self.layers.iter() { - x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone()); - } - - x - } - /// Applies the forward pass on the input tensor using autoregressive cache. - /// - /// # Shapes - /// - /// - tensor: `[batch_size, seq_length, d_model]` - /// - output: `[batch_size, seq_length, d_model]` - pub fn forward_autoregressive_inference( - &self, - input: TransformerEncoderInput, - cache: &mut TransformerEncoderAutoregressiveCache, - ) -> Tensor { - let mut x = input.tensor; - - for i in 0..self.layers.len() { - let layer = self.layers.get(i).unwrap(); - let cache = cache.layers.get_mut(i).unwrap(); - - x = layer.forward_autoregressive_inference( - x, - input.mask_pad.clone(), - input.mask_attn.clone(), - cache, - ); - } - - x + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - tensor: `[batch_size, seq_length, d_model]` + /// - output: `[batch_size, seq_length, d_model]` + pub fn forward(&self, input: TransformerEncoderInput) -> Tensor { + let mut x = input.tensor; + + for layer in self.layers.iter() { + x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone()); } - /// Create an empty autoregressive cache. - pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache { - TransformerEncoderAutoregressiveCache::empty(self.layers.len()) + x + } + /// Applies the forward pass on the input tensor using autoregressive cache. + /// + /// # Shapes + /// + /// - tensor: `[batch_size, seq_length, d_model]` + /// - output: `[batch_size, seq_length, d_model]` + pub fn forward_autoregressive_inference( + &self, + input: TransformerEncoderInput, + cache: &mut TransformerEncoderAutoregressiveCache, + ) -> Tensor { + let mut x = input.tensor; + + for i in 0..self.layers.len() { + let layer = self.layers.get(i).unwrap(); + let cache = cache.layers.get_mut(i).unwrap(); + + x = layer.forward_autoregressive_inference( + x, + input.mask_pad.clone(), + input.mask_attn.clone(), + cache, + ); } + + x + } + + /// Create an empty autoregressive cache. + pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache { + TransformerEncoderAutoregressiveCache::empty(self.layers.len()) + } } /// Transformer encoder layer module. #[derive(Module, Debug)] pub struct TransformerEncoderLayer { - mha: MultiHeadAttention, - pwff: PositionWiseFeedForward, - norm_1: LayerNorm, - norm_2: LayerNorm, - dropout: Dropout, - norm_first: bool, + mha: MultiHeadAttention, + pwff: PositionWiseFeedForward, + norm_1: LayerNorm, + norm_2: LayerNorm, + dropout: Dropout, + norm_first: bool, } impl TransformerEncoderLayer { - fn new_with( - config: &TransformerEncoderConfig, - record: TransformerEncoderLayerRecord, - ) -> Self { - let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init_with(record.mha); - let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); - let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init_with(record.pwff); - - Self { - mha, - norm_1, - norm_2, - pwff, - dropout, - norm_first: config.norm_first, - } + fn new_with(config: &TransformerEncoderConfig, record: TransformerEncoderLayerRecord) -> Self { + let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init_with(record.mha); + let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); + let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .init_with(record.pwff); + + Self { + mha, + norm_1, + norm_2, + pwff, + dropout, + norm_first: config.norm_first, + } + } + fn new(config: &TransformerEncoderConfig) -> Self { + let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init(); + let norm_1 = LayerNormConfig::new(config.d_model).init(); + let norm_2 = LayerNormConfig::new(config.d_model).init(); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .init(); + + Self { + mha, + norm_1, + norm_2, + pwff, + dropout, + norm_first: config.norm_first, + } + } + + fn forward( + &self, + mut input: Tensor, + mask_pad: Option>, + mask_attn: Option>, + ) -> Tensor { + if self.norm_first { + input = self.norm_2.forward(input) } - fn new(config: &TransformerEncoderConfig) -> Self { - let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init(); - let norm_1 = LayerNormConfig::new(config.d_model).init(); - let norm_2 = LayerNormConfig::new(config.d_model).init(); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init(); - - Self { - mha, - norm_1, - norm_2, - pwff, - dropout, - norm_first: config.norm_first, - } + + let mut input_mhs = MhaInput::self_attn(input.clone()); + + if let Some(mask_pad) = mask_pad { + input_mhs = input_mhs.mask_pad(mask_pad); } - fn forward( - &self, - mut input: Tensor, - mask_pad: Option>, - mask_attn: Option>, - ) -> Tensor { - if self.norm_first { - input = self.norm_2.forward(input) - } + if let Some(mask_attn) = mask_attn { + input_mhs = input_mhs.mask_attn(mask_attn); + } + + let x_1 = self.mha.forward(input_mhs); + let x_1 = self.dropout.forward(x_1.context) + input; + let x_1 = self.norm_1.forward(x_1); - let mut input_mhs = MhaInput::self_attn(input.clone()); + let x_2 = self.pwff.forward(x_1.clone()); + let mut x_2 = self.dropout.forward(x_2) + x_1; - if let Some(mask_pad) = mask_pad { - input_mhs = input_mhs.mask_pad(mask_pad); - } + if !self.norm_first { + x_2 = self.norm_2.forward(x_2) + } - if let Some(mask_attn) = mask_attn { - input_mhs = input_mhs.mask_attn(mask_attn); - } + x_2 + } - let x_1 = self.mha.forward(input_mhs); - let x_1 = self.dropout.forward(x_1.context) + input; - let x_1 = self.norm_1.forward(x_1); + fn forward_autoregressive_inference( + &self, + mut input: Tensor, + mask_pad: Option>, + mask_attn: Option>, + cache: &mut TransformerEncoderLayerAutoregressiveCache, + ) -> Tensor { + if self.norm_first { + input = cache + .norm_2 + .forward_autoregressive(input, 1, |input| self.norm_2.forward(input)); + } - let x_2 = self.pwff.forward(x_1.clone()); - let mut x_2 = self.dropout.forward(x_2) + x_1; + let mut input_mhs = MhaInput::self_attn(input.clone()); - if !self.norm_first { - x_2 = self.norm_2.forward(x_2) - } + if let Some(mask_pad) = mask_pad { + input_mhs = input_mhs.mask_pad(mask_pad); + } - x_2 + if let Some(mask_attn) = mask_attn { + input_mhs = input_mhs.mask_attn(mask_attn); } - fn forward_autoregressive_inference( - &self, - mut input: Tensor, - mask_pad: Option>, - mask_attn: Option>, - cache: &mut TransformerEncoderLayerAutoregressiveCache, - ) -> Tensor { - if self.norm_first { - input = cache - .norm_2 - .forward_autoregressive(input, 1, |input| self.norm_2.forward(input)); - } - - let mut input_mhs = MhaInput::self_attn(input.clone()); - - if let Some(mask_pad) = mask_pad { - input_mhs = input_mhs.mask_pad(mask_pad); - } - - if let Some(mask_attn) = mask_attn { - input_mhs = input_mhs.mask_attn(mask_attn); - } - - let x_1 = self.mha.forward_cache(input_mhs, &mut cache.mha); - let x_1 = self.dropout.forward(x_1.context) + input; - let x_1 = cache - .norm_1 - .forward_autoregressive(x_1, 1, |x_1| self.norm_1.forward(x_1)); - - let x_2 = cache - .pwff - .forward_autoregressive(x_1.clone(), 1, |x_1| self.pwff.forward(x_1)); - let mut x_2 = self.dropout.forward(x_2) + x_1; - - if !self.norm_first { - x_2 = cache - .norm_2 - .forward_autoregressive(x_2, 1, |x_2| self.norm_2.forward(x_2)); - } - - x_2 + let x_1 = self.mha.forward_cache(input_mhs, &mut cache.mha); + let x_1 = self.dropout.forward(x_1.context) + input; + let x_1 = cache + .norm_1 + .forward_autoregressive(x_1, 1, |x_1| self.norm_1.forward(x_1)); + + let x_2 = cache + .pwff + .forward_autoregressive(x_1.clone(), 1, |x_1| self.pwff.forward(x_1)); + let mut x_2 = self.dropout.forward(x_2) + x_1; + + if !self.norm_first { + x_2 = cache + .norm_2 + .forward_autoregressive(x_2, 1, |x_2| self.norm_2.forward(x_2)); } + + x_2 + } } struct TransformerEncoderLayerAutoregressiveCache { - mha: MhaCache, - pwff: TensorCache, - norm_1: TensorCache, - norm_2: TensorCache, + mha: MhaCache, + pwff: TensorCache, + norm_1: TensorCache, + norm_2: TensorCache, } impl TransformerEncoderLayerAutoregressiveCache { - fn empty() -> Self { - Self { - mha: MhaCache::autoregressive(), - pwff: TensorCache::empty(), - norm_1: TensorCache::empty(), - norm_2: TensorCache::empty(), - } + fn empty() -> Self { + Self { + mha: MhaCache::autoregressive(), + pwff: TensorCache::empty(), + norm_1: TensorCache::empty(), + norm_2: TensorCache::empty(), } + } } /// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer. /// /// To be used during inference when decoding tokens. pub struct TransformerEncoderAutoregressiveCache { - layers: Vec>, + layers: Vec>, } impl TransformerEncoderAutoregressiveCache { - fn empty(num_layers: usize) -> Self { - Self { - layers: (0..num_layers) - .map(|_| TransformerEncoderLayerAutoregressiveCache::empty()) - .collect(), - } + fn empty(num_layers: usize) -> Self { + Self { + layers: (0..num_layers) + .map(|_| TransformerEncoderLayerAutoregressiveCache::empty()) + .collect(), } + } } #[cfg(test)] mod tests { - use super::*; - use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; - use burn_tensor::Distribution; - - #[test] - fn test_autoregressive_norm_last() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - test_autoregressive( - TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers) - .with_norm_first(false), - ) + use super::*; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use burn_tensor::Distribution; + + #[test] + fn test_autoregressive_norm_last() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + test_autoregressive( + TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(false), + ) + } + + #[test] + fn test_autoregressive_norm_first() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + test_autoregressive( + TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), + ) + } + + fn test_autoregressive(config: TransformerEncoderConfig) { + let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; + let transformer = config.init(); + + let tensor = + Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); + let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn); + + let output_1 = transformer.forward(input); + let mut output_2 = Vec::new(); + let mut cache = transformer.new_autoregressive_cache(); + + for i in 1..seq_length + 1 { + let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); + let input = TransformerEncoderInput::new(tensor.clone()); + let next_tok = transformer + .forward_autoregressive_inference(input, &mut cache) + .slice([0..batch_size, i - 1..i, 0..d_model]); + output_2.push(next_tok); } - #[test] - fn test_autoregressive_norm_first() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - test_autoregressive( - TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), - ) - } + let output_2 = Tensor::cat(output_2, 1); - fn test_autoregressive(config: TransformerEncoderConfig) { - let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; - let transformer = config.init(); - - let tensor = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - ); - let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); - let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn); - - let output_1 = transformer.forward(input); - let mut output_2 = Vec::new(); - let mut cache = transformer.new_autoregressive_cache(); - - for i in 1..seq_length + 1 { - let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); - let input = TransformerEncoderInput::new(tensor.clone()); - let next_tok = transformer - .forward_autoregressive_inference(input, &mut cache) - .slice([0..batch_size, i - 1..i, 0..d_model]); - output_2.push(next_tok); - } - - let output_2 = Tensor::cat(output_2, 1); - - output_1 - .into_data() - .assert_approx_eq(&output_2.into_data(), 3); - } + output_1 + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/transformer/pwff.rs b/burn-core/src/nn/transformer/pwff.rs index 1307ba8b80..207db3910f 100644 --- a/burn-core/src/nn/transformer/pwff.rs +++ b/burn-core/src/nn/transformer/pwff.rs @@ -2,27 +2,25 @@ use crate as burn; use crate::nn::Initializer; use crate::{ - config::Config, - module::Module, - nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU}, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU}, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer. #[derive(Config)] pub struct PositionWiseFeedForwardConfig { - /// The size of the input and output features. - pub d_model: usize, - /// The size of the hidden inner features. - pub d_ff: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - pub dropout: f64, - /// The type of function used to initialize neural network parameters - #[config( - default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" - )] - pub initializer: Initializer, + /// The size of the input and output features. + pub d_model: usize, + /// The size of the hidden inner features. + pub d_ff: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + pub dropout: f64, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] + pub initializer: Initializer, } /// Applies the position-wise feed-forward network to the input tensor. @@ -33,53 +31,53 @@ pub struct PositionWiseFeedForwardConfig { /// - linear outer: Linear layer with `d_ff` input features and `d_model` output features. #[derive(Module, Debug)] pub struct PositionWiseFeedForward { - linear_inner: Linear, - linear_outer: Linear, - dropout: Dropout, - gelu: GELU, + linear_inner: Linear, + linear_outer: Linear, + dropout: Dropout, + gelu: GELU, } impl PositionWiseFeedForwardConfig { - /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module. - pub fn init(&self) -> PositionWiseFeedForward { - PositionWiseFeedForward { - linear_inner: LinearConfig::new(self.d_model, self.d_ff) - .with_initializer(self.initializer.clone()) - .init(), - linear_outer: LinearConfig::new(self.d_ff, self.d_model) - .with_initializer(self.initializer.clone()) - .init(), - dropout: DropoutConfig::new(self.dropout).init(), - gelu: GELU::new(), - } + /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module. + pub fn init(&self) -> PositionWiseFeedForward { + PositionWiseFeedForward { + linear_inner: LinearConfig::new(self.d_model, self.d_ff) + .with_initializer(self.initializer.clone()) + .init(), + linear_outer: LinearConfig::new(self.d_ff, self.d_model) + .with_initializer(self.initializer.clone()) + .init(), + dropout: DropoutConfig::new(self.dropout).init(), + gelu: GELU::new(), } - /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module with a - /// [record](PositionWiseFeedForwardRecord). - pub fn init_with( - &self, - record: PositionWiseFeedForwardRecord, - ) -> PositionWiseFeedForward { - PositionWiseFeedForward { - linear_inner: LinearConfig::new(self.d_model, self.d_ff).init_with(record.linear_inner), - linear_outer: LinearConfig::new(self.d_ff, self.d_model).init_with(record.linear_outer), - dropout: DropoutConfig::new(self.dropout).init(), - gelu: GELU::new(), - } + } + /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module with a + /// [record](PositionWiseFeedForwardRecord). + pub fn init_with( + &self, + record: PositionWiseFeedForwardRecord, + ) -> PositionWiseFeedForward { + PositionWiseFeedForward { + linear_inner: LinearConfig::new(self.d_model, self.d_ff).init_with(record.linear_inner), + linear_outer: LinearConfig::new(self.d_ff, self.d_model).init_with(record.linear_outer), + dropout: DropoutConfig::new(self.dropout).init(), + gelu: GELU::new(), } + } } impl PositionWiseFeedForward { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - tensor: `[batch_size, seq_length, d_model]` - /// - output: `[batch_size, seq_length, d_model]` - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.linear_inner.forward(input); - let x = self.gelu.forward(x); - let x = self.dropout.forward(x); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - tensor: `[batch_size, seq_length, d_model]` + /// - output: `[batch_size, seq_length, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.linear_inner.forward(input); + let x = self.gelu.forward(x); + let x = self.dropout.forward(x); - self.linear_outer.forward(x) - } + self.linear_outer.forward(x) + } } diff --git a/burn-core/src/nn/unfold.rs b/burn-core/src/nn/unfold.rs index 26711622e3..3ad588b86a 100644 --- a/burn-core/src/nn/unfold.rs +++ b/burn-core/src/nn/unfold.rs @@ -10,50 +10,50 @@ use burn_tensor::Tensor; /// Configuration to create an [unfold 4D](Unfold4d) layer. #[derive(Config, Debug)] pub struct Unfold4dConfig { - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The stride of the convolution. - #[config(default = "[1, 1]")] - pub stride: [usize; 2], - /// Spacing between kernel elements. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], - /// The padding configuration. - #[config(default = "[0, 0]")] - pub padding: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + #[config(default = "[1, 1]")] + pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], + /// The padding configuration. + #[config(default = "[0, 0]")] + pub padding: [usize; 2], } /// Four-dimensional unfolding. #[derive(Module, Clone, Debug)] pub struct Unfold4d { - config: Unfold4dConfig, + config: Unfold4dConfig, } impl Unfold4dConfig { - /// Initialize a new [unfold 4k](Unfold4d) module. - pub fn init(&self) -> Unfold4d { - Unfold4d { - config: self.clone(), - } + /// Initialize a new [unfold 4k](Unfold4d) module. + pub fn init(&self) -> Unfold4d { + Unfold4d { + config: self.clone(), } + } } impl Unfold4d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// input: `[batch_size, channels_in, height, width]`, - /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, - pub fn forward(&self, input: Tensor) -> Tensor { - unfold4d( - input, - self.config.kernel_size, - UnfoldOptions::new( - self.config.stride, - self.config.padding, - self.config.dilation, - ), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// input: `[batch_size, channels_in, height, width]`, + /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, + pub fn forward(&self, input: Tensor) -> Tensor { + unfold4d( + input, + self.config.kernel_size, + UnfoldOptions::new( + self.config.stride, + self.config.padding, + self.config.dilation, + ), + ) + } } diff --git a/burn-core/src/optim/adagrad.rs b/burn-core/src/optim/adagrad.rs index bdb8dd4a0a..52ee2c0c1b 100644 --- a/burn-core/src/optim/adagrad.rs +++ b/burn-core/src/optim/adagrad.rs @@ -1,11 +1,11 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use super::{ - decay::{WeightDecay, WeightDecayConfig}, - Optimizer, SimpleOptimizer, + decay::{WeightDecay, WeightDecayConfig}, + Optimizer, SimpleOptimizer, }; use crate::config::Config; use crate::optim::adaptor::OptimizerAdaptor; @@ -15,263 +15,263 @@ use burn_tensor::backend::Backend; /// AdaGrad configuration. #[derive(Config)] pub struct AdaGradConfig { - #[config(default = 0.)] - lr_decay: f64, - #[config(default = 1e-5)] - epsilon: f32, - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + #[config(default = 0.)] + lr_decay: f64, + #[config(default = 1e-5)] + epsilon: f32, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } /// AdaGrad optimizer pub struct AdaGrad { - lr_decay: LRDecay, - weight_decay: Option>, + lr_decay: LRDecay, + weight_decay: Option>, } /// AdaGrad state. #[derive(Record, Clone, new)] pub struct AdaGradState { - lr_decay: LRDecayState, + lr_decay: LRDecayState, } impl SimpleOptimizer for AdaGrad { - type State = AdaGradState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - let mut state_lr_decay = None; + type State = AdaGradState; + + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + let mut state_lr_decay = None; + + if let Some(state) = state { + state_lr_decay = Some(state.lr_decay); + } - if let Some(state) = state { - state_lr_decay = Some(state.lr_decay); - } + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); + } - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); - } + let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay); - let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay); + let state = AdaGradState::new(state_lr_decay); - let state = AdaGradState::new(state_lr_decay); + (tensor - grad, Some(state)) + } - (tensor - grad, Some(state)) - } - - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.lr_decay = state.lr_decay.to_device(device); - state - } + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.lr_decay = state.lr_decay.to_device(device); + state + } } impl AdaGradConfig { - /// Initialize AdaGrad optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>(&self) -> impl Optimizer { - let optim = AdaGrad { - lr_decay: LRDecay { - lr_decay: self.lr_decay, - epsilon: self.epsilon, - }, - weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), - }; - - let mut optim = OptimizerAdaptor::from(optim); - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); - } - optim + /// Initialize AdaGrad optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>(&self) -> impl Optimizer { + let optim = AdaGrad { + lr_decay: LRDecay { + lr_decay: self.lr_decay, + epsilon: self.epsilon, + }, + weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), + }; + + let mut optim = OptimizerAdaptor::from(optim); + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); } + optim + } } /// Learning rate decay state (also includes sum state). #[derive(Record, new, Clone)] pub struct LRDecayState { - time: usize, - sum: Tensor, + time: usize, + sum: Tensor, } struct LRDecay { - lr_decay: f64, - epsilon: f32, + lr_decay: f64, + epsilon: f32, } impl LRDecay { - pub fn transform( - &self, - grad: Tensor, - lr: LearningRate, - lr_decay_state: Option>, - ) -> (Tensor, LRDecayState) { - let state = if let Some(mut state) = lr_decay_state { - state.sum = state.sum.add(grad.clone().powf(2.)); - state.time += 1; - state - } else { - LRDecayState::new(1, grad.clone().powf(2.)) - }; - - let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); - - let grad = grad - .div(state.sum.clone().sqrt().add_scalar(self.epsilon)) - .mul_scalar(new_lr); - - (grad, state) - } + pub fn transform( + &self, + grad: Tensor, + lr: LearningRate, + lr_decay_state: Option>, + ) -> (Tensor, LRDecayState) { + let state = if let Some(mut state) = lr_decay_state { + state.sum = state.sum.add(grad.clone().powf(2.)); + state.time += 1; + state + } else { + LRDecayState::new(1, grad.clone().powf(2.)) + }; + + let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); + + let grad = grad + .div(state.sum.clone().sqrt().add_scalar(self.epsilon)) + .mul_scalar(new_lr); + + (grad, state) + } } impl LRDecayState { - /// Move state to device. - /// - /// # Arguments - /// - /// * `device` - Device to move state to. - /// - /// # Returns - /// - /// Returns state moved to device. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.sum = self.sum.to_device(device); - self - } + /// Move state to device. + /// + /// # Arguments + /// + /// * `device` - Device to move state to. + /// + /// # Returns + /// + /// Returns state moved to device. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.sum = self.sum.to_device(device); + self + } } #[cfg(test)] mod tests { - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - - const LEARNING_RATE: LearningRate = 0.01; - - #[test] - fn test_adagrad_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_adagrad(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - BinFileRecorder::::default() - .record(optimizer.to_record(), "/tmp/test_optim".into()) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_adagrad(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - const ASSERT_PRECISION: usize = 6; - - #[test] - fn test_adagrad_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdaGradConfig::new() - .with_epsilon(1e-8) - .with_lr_decay(0.5) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711], - [ - 0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756, - ], - [ - -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538, - ], - [ - -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964, - ], - [ - 0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504, - ], - [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895], - ]); - let bias_expected = Data::from([ - -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - fn create_adagrad( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { - let config = AdaGradConfig::new(); - AdaGrad { - lr_decay: LRDecay { - lr_decay: config.lr_decay, - epsilon: config.epsilon, - }, - weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), - } - .into() + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + + const LEARNING_RATE: LearningRate = 0.01; + + #[test] + fn test_adagrad_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_adagrad(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + BinFileRecorder::::default() + .record(optimizer.to_record(), "/tmp/test_optim".into()) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_adagrad(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + const ASSERT_PRECISION: usize = 6; + + #[test] + fn test_adagrad_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdaGradConfig::new() + .with_epsilon(1e-8) + .with_lr_decay(0.5) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711], + [ + 0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756, + ], + [ + -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538, + ], + [ + -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964, + ], + [ + 0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504, + ], + [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895], + ]); + let bias_expected = Data::from([ + -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + fn create_adagrad( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + { + let config = AdaGradConfig::new(); + AdaGrad { + lr_decay: LRDecay { + lr_decay: config.lr_decay, + epsilon: config.epsilon, + }, + weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), } + .into() + } } diff --git a/burn-core/src/optim/adam.rs b/burn-core/src/optim/adam.rs index e43bd077fd..6a02ca0b56 100644 --- a/burn-core/src/optim/adam.rs +++ b/burn-core/src/optim/adam.rs @@ -1,11 +1,11 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use super::{ - decay::{WeightDecay, WeightDecayConfig}, - Optimizer, SimpleOptimizer, + decay::{WeightDecay, WeightDecayConfig}, + Optimizer, SimpleOptimizer, }; use crate::config::Config; use crate::optim::adaptor::OptimizerAdaptor; @@ -15,338 +15,337 @@ use burn_tensor::{backend::Backend, ElementConversion}; /// Adam configuration. #[derive(Config)] pub struct AdamConfig { - /// Parameter for Adam. - #[config(default = 0.9)] - beta_1: f32, - /// Parameter for Adam. - #[config(default = 0.999)] - beta_2: f32, - /// A value required for numerical stability. - #[config(default = 1e-5)] - epsilon: f32, - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + /// Parameter for Adam. + #[config(default = 0.9)] + beta_1: f32, + /// Parameter for Adam. + #[config(default = 0.999)] + beta_2: f32, + /// A value required for numerical stability. + #[config(default = 1e-5)] + epsilon: f32, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } /// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf). pub struct Adam { - momentum: AdaptiveMomentum, - weight_decay: Option>, + momentum: AdaptiveMomentum, + weight_decay: Option>, } /// Adam state. #[derive(Record, Clone, new)] pub struct AdamState { - momentum: AdaptiveMomentumState, + momentum: AdaptiveMomentumState, } impl SimpleOptimizer for Adam { - type State = AdamState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - let mut state_momentum = None; - - if let Some(state) = state { - state_momentum = Some(state.momentum); - } + type State = AdamState; + + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + let mut state_momentum = None; + + if let Some(state) = state { + state_momentum = Some(state.momentum); + } - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); - } + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); + } - let (grad, state_momentum) = self.momentum.transform(grad, state_momentum); + let (grad, state_momentum) = self.momentum.transform(grad, state_momentum); - let state = AdamState::new(state_momentum); - let delta = grad.mul_scalar(lr); + let state = AdamState::new(state_momentum); + let delta = grad.mul_scalar(lr); - (tensor - delta, Some(state)) - } + (tensor - delta, Some(state)) + } - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.momentum = state.momentum.to_device(device); - state - } + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.momentum = state.momentum.to_device(device); + state + } } impl AdamConfig { - /// Initialize Adam optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>(&self) -> impl Optimizer { - let optim = Adam { - momentum: AdaptiveMomentum { - beta_1: self.beta_1, - beta_2: self.beta_2, - epsilon: self.epsilon, - }, - weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), - }; - - let mut optim = OptimizerAdaptor::from(optim); - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); - } - optim + /// Initialize Adam optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>(&self) -> impl Optimizer { + let optim = Adam { + momentum: AdaptiveMomentum { + beta_1: self.beta_1, + beta_2: self.beta_2, + epsilon: self.epsilon, + }, + weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), + }; + + let mut optim = OptimizerAdaptor::from(optim); + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); } + optim + } } /// Adaptive momentum state. #[derive(Record, new, Clone)] pub struct AdaptiveMomentumState { - time: usize, - moment_1: Tensor, - moment_2: Tensor, + time: usize, + moment_1: Tensor, + moment_2: Tensor, } struct AdaptiveMomentum { - beta_1: f32, - beta_2: f32, - epsilon: f32, + beta_1: f32, + beta_2: f32, + epsilon: f32, } impl AdaptiveMomentum { - pub fn transform( - &self, - grad: Tensor, - momentum_state: Option>, - ) -> (Tensor, AdaptiveMomentumState) { - let state = if let Some(mut state) = momentum_state { - let factor = 1.0 - self.beta_1; - state.moment_1 = state - .moment_1 - .mul_scalar(self.beta_1) - .add(grad.clone().mul_scalar(factor)); - - let factor = 1.0 - self.beta_2; - state.moment_2 = state - .moment_2 - .mul_scalar(self.beta_2) - .add(grad.powf(2.0).mul_scalar(factor)); - - state.time += 1; - - state - } else { - let factor = 1.0 - self.beta_1; - let moment_1 = grad.clone().mul_scalar(factor); - - let factor = 1.0 - self.beta_2; - let moment_2 = grad.powf(2.0).mul_scalar(factor); - - AdaptiveMomentumState::new(1, moment_1, moment_2) - }; - - let time = (state.time as i32).elem(); - let moment_1_corrected = state - .moment_1 - .clone() - .div_scalar(1f32 - self.beta_1.powi(time)); - let moment_2_corrected = state - .moment_2 - .clone() - .div_scalar(1f32 - self.beta_2.powi(time)); - - let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); - - (grad, state) - } + pub fn transform( + &self, + grad: Tensor, + momentum_state: Option>, + ) -> (Tensor, AdaptiveMomentumState) { + let state = if let Some(mut state) = momentum_state { + let factor = 1.0 - self.beta_1; + state.moment_1 = state + .moment_1 + .mul_scalar(self.beta_1) + .add(grad.clone().mul_scalar(factor)); + + let factor = 1.0 - self.beta_2; + state.moment_2 = state + .moment_2 + .mul_scalar(self.beta_2) + .add(grad.powf(2.0).mul_scalar(factor)); + + state.time += 1; + + state + } else { + let factor = 1.0 - self.beta_1; + let moment_1 = grad.clone().mul_scalar(factor); + + let factor = 1.0 - self.beta_2; + let moment_2 = grad.powf(2.0).mul_scalar(factor); + + AdaptiveMomentumState::new(1, moment_1, moment_2) + }; + + let time = (state.time as i32).elem(); + let moment_1_corrected = state + .moment_1 + .clone() + .div_scalar(1f32 - self.beta_1.powi(time)); + let moment_2_corrected = state + .moment_2 + .clone() + .div_scalar(1f32 - self.beta_2.powi(time)); + + let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); + + (grad, state) + } } impl AdaptiveMomentumState { - /// Move state to device. - /// - /// # Arguments - /// - /// * `device` - Device to move state to. - /// - /// # Returns - /// - /// Returns state moved to device. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.moment_1 = self.moment_1.to_device(device); - self.moment_2 = self.moment_2.to_device(device); - self - } + /// Move state to device. + /// + /// # Arguments + /// + /// * `device` - Device to move state to. + /// + /// # Returns + /// + /// Returns state moved to device. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.moment_1 = self.moment_1.to_device(device); + self.moment_2 = self.moment_2.to_device(device); + self + } } #[cfg(test)] mod tests { - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - - const LEARNING_RATE: LearningRate = 0.01; - - #[test] - fn test_adam_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_adam(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - BinFileRecorder::::default() - .record(optimizer.to_record(), "/tmp/test_optim".into()) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_adam(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - const ASSERT_PRECISION: usize = 2; - - #[test] - fn test_adam_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(Some(WeightDecayConfig::new(0.5))) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154], - [ - 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133, - ], - [ - -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047, - ], - [ - -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651, - ], - [ - 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343, - ], - [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346], - ]); - let bias_expected = Data::from([ - -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - #[test] - fn test_adam_optimizer_no_nan() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - - let x = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(Some(WeightDecayConfig::new(0.5))) - .init(); - - let grads = linear.forward(x.clone()).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - assert!(!state_updated.weight.to_data().value[0].is_nan()); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - fn create_adam( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { - let config = AdamConfig::new(); - Adam { - momentum: AdaptiveMomentum { - beta_1: config.beta_1, - beta_2: config.beta_2, - epsilon: config.epsilon, - }, - weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), - } - .into() + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + + const LEARNING_RATE: LearningRate = 0.01; + + #[test] + fn test_adam_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_adam(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + BinFileRecorder::::default() + .record(optimizer.to_record(), "/tmp/test_optim".into()) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_adam(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + const ASSERT_PRECISION: usize = 2; + + #[test] + fn test_adam_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(Some(WeightDecayConfig::new(0.5))) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154], + [ + 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133, + ], + [ + -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047, + ], + [ + -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651, + ], + [ + 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343, + ], + [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346], + ]); + let bias_expected = Data::from([ + -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + #[test] + fn test_adam_optimizer_no_nan() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + + let x = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(Some(WeightDecayConfig::new(0.5))) + .init(); + + let grads = linear.forward(x.clone()).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + assert!(!state_updated.weight.to_data().value[0].is_nan()); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + fn create_adam( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> { + let config = AdamConfig::new(); + Adam { + momentum: AdaptiveMomentum { + beta_1: config.beta_1, + beta_2: config.beta_2, + epsilon: config.epsilon, + }, + weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), } + .into() + } } diff --git a/burn-core/src/optim/adamw.rs b/burn-core/src/optim/adamw.rs index befbeb88cb..8f8441b489 100644 --- a/burn-core/src/optim/adamw.rs +++ b/burn-core/src/optim/adamw.rs @@ -1,6 +1,6 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use std::marker::PhantomData; @@ -13,356 +13,355 @@ use burn_tensor::{backend::Backend, ElementConversion}; /// AdamW configuration. #[derive(Config)] pub struct AdamWConfig { - /// Parameter for AdamW. - #[config(default = 0.9)] - beta_1: f32, - /// Parameter for AdamW. - #[config(default = 0.999)] - beta_2: f32, - /// A value required for numerical stability. - #[config(default = 1e-5)] - epsilon: f32, - /// Weight decay config. - #[config(default = 1e-4)] - weight_decay: f32, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + /// Parameter for AdamW. + #[config(default = 0.9)] + beta_1: f32, + /// Parameter for AdamW. + #[config(default = 0.999)] + beta_2: f32, + /// A value required for numerical stability. + #[config(default = 1e-5)] + epsilon: f32, + /// Weight decay config. + #[config(default = 1e-4)] + weight_decay: f32, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } /// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101). pub struct AdamW { - momentum: AdaptiveMomentumW, - weight_decay: f32, - _phantom: PhantomData, + momentum: AdaptiveMomentumW, + weight_decay: f32, + _phantom: PhantomData, } /// AdamW state. #[derive(Record, Clone, new)] pub struct AdamWState { - momentum: AdaptiveMomentumWState, + momentum: AdaptiveMomentumWState, } impl SimpleOptimizer for AdamW { - type State = AdamWState; - - /// A single optimization step for any tensor that represents the parameters of a model. - fn step( - &self, - // Learning rate. - lr: LearningRate, - // Any tensor that represents the parameters of a model. - tensor: Tensor, - // Gradient of the loss w.r.t. the parameters. - grad: Tensor, - // State of the optimizer. - state: Option>, - ) -> (Tensor, Option>) { - let tensor_updated = tensor.clone() - tensor.mul_scalar(lr).mul_scalar(self.weight_decay); - - let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum)); - - let state = AdamWState { - momentum: momentum_state, - }; - - (tensor_updated - raw_delta.mul_scalar(lr), Some(state)) - } - - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.momentum = state.momentum.to_device(device); - state - } + type State = AdamWState; + + /// A single optimization step for any tensor that represents the parameters of a model. + fn step( + &self, + // Learning rate. + lr: LearningRate, + // Any tensor that represents the parameters of a model. + tensor: Tensor, + // Gradient of the loss w.r.t. the parameters. + grad: Tensor, + // State of the optimizer. + state: Option>, + ) -> (Tensor, Option>) { + let tensor_updated = tensor.clone() - tensor.mul_scalar(lr).mul_scalar(self.weight_decay); + + let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum)); + + let state = AdamWState { + momentum: momentum_state, + }; + + (tensor_updated - raw_delta.mul_scalar(lr), Some(state)) + } + + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.momentum = state.momentum.to_device(device); + state + } } impl AdamWConfig { - /// Initialize AdamW optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>(&self) -> impl Optimizer { - let optim = AdamW { - momentum: AdaptiveMomentumW { - beta_1: self.beta_1, - beta_2: self.beta_2, - epsilon: self.epsilon, - }, - weight_decay: self.weight_decay, - _phantom: Default::default(), - }; - - let mut optim = OptimizerAdaptor::from(optim); - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); - } - optim + /// Initialize AdamW optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>(&self) -> impl Optimizer { + let optim = AdamW { + momentum: AdaptiveMomentumW { + beta_1: self.beta_1, + beta_2: self.beta_2, + epsilon: self.epsilon, + }, + weight_decay: self.weight_decay, + _phantom: Default::default(), + }; + + let mut optim = OptimizerAdaptor::from(optim); + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); } + optim + } } /// Adaptive momentum state. #[derive(Record, new, Clone)] pub struct AdaptiveMomentumWState { - time: usize, - moment_1: Tensor, - moment_2: Tensor, + time: usize, + moment_1: Tensor, + moment_2: Tensor, } struct AdaptiveMomentumW { - beta_1: f32, - beta_2: f32, - epsilon: f32, + beta_1: f32, + beta_2: f32, + epsilon: f32, } impl AdaptiveMomentumW { - pub fn transform( - &self, - grad: Tensor, - state: Option>, - ) -> (Tensor, AdaptiveMomentumWState) { - let state = if let Some(mut state) = state { - // Update first moment estimate. - let factor = 1.0 - self.beta_1; - state.moment_1 = state - .moment_1 - .mul_scalar(self.beta_1) - .add(grad.clone().mul_scalar(factor)); - - // Update second moment estimate. - let factor = 1.0 - self.beta_2; - state.moment_2 = state - .moment_2 - .mul_scalar(self.beta_2) - .add(grad.powf(2.0).mul_scalar(factor)); - - // Update time. - state.time += 1; - - state - } else { - // Initialize first moment estimate. - let factor = 1.0 - self.beta_1; - let moment_1 = grad.clone().mul_scalar(factor); - - // Initialize second moment estimate. - let factor = 1.0 - self.beta_2; - let moment_2 = grad.powf(2.0).mul_scalar(factor); - - AdaptiveMomentumWState::new(1, moment_1, moment_2) - }; - - let time: i32 = (state.time as i32).elem(); - - // Compute bias-corrected first and second moment estimates. - let moment_1_corrected = state - .moment_1 - .clone() - .div_scalar(1f32 - self.beta_1.powi(time)); - - let moment_2_corrected = state - .moment_2 - .clone() - .div_scalar(1f32 - self.beta_2.powi(time)); - - // Compute update delta. This still needs to be scaled by the learning rate. - let update_delta = - moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); - - ( - update_delta, - AdaptiveMomentumWState::new(state.time, state.moment_1, state.moment_2), - ) - } + pub fn transform( + &self, + grad: Tensor, + state: Option>, + ) -> (Tensor, AdaptiveMomentumWState) { + let state = if let Some(mut state) = state { + // Update first moment estimate. + let factor = 1.0 - self.beta_1; + state.moment_1 = state + .moment_1 + .mul_scalar(self.beta_1) + .add(grad.clone().mul_scalar(factor)); + + // Update second moment estimate. + let factor = 1.0 - self.beta_2; + state.moment_2 = state + .moment_2 + .mul_scalar(self.beta_2) + .add(grad.powf(2.0).mul_scalar(factor)); + + // Update time. + state.time += 1; + + state + } else { + // Initialize first moment estimate. + let factor = 1.0 - self.beta_1; + let moment_1 = grad.clone().mul_scalar(factor); + + // Initialize second moment estimate. + let factor = 1.0 - self.beta_2; + let moment_2 = grad.powf(2.0).mul_scalar(factor); + + AdaptiveMomentumWState::new(1, moment_1, moment_2) + }; + + let time: i32 = (state.time as i32).elem(); + + // Compute bias-corrected first and second moment estimates. + let moment_1_corrected = state + .moment_1 + .clone() + .div_scalar(1f32 - self.beta_1.powi(time)); + + let moment_2_corrected = state + .moment_2 + .clone() + .div_scalar(1f32 - self.beta_2.powi(time)); + + // Compute update delta. This still needs to be scaled by the learning rate. + let update_delta = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); + + ( + update_delta, + AdaptiveMomentumWState::new(state.time, state.moment_1, state.moment_2), + ) + } } impl AdaptiveMomentumWState { - /// Move state to device. - /// - /// # Arguments - /// - /// * `device` - Device to move state to. - /// - /// # Returns - /// - /// Returns state moved to device. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.moment_1 = self.moment_1.to_device(device); - self.moment_2 = self.moment_2.to_device(device); - self - } + /// Move state to device. + /// + /// # Arguments + /// + /// * `device` - Device to move state to. + /// + /// # Returns + /// + /// Returns state moved to device. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.moment_1 = self.moment_1.to_device(device); + self.moment_2 = self.moment_2.to_device(device); + self + } } #[cfg(test)] mod tests { - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - use tempfile::TempDir; - - const LEARNING_RATE: LearningRate = 0.01; - - #[test] - fn test_adamw_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_adamw(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - let temp_dir = TempDir::new().unwrap(); - BinFileRecorder::::default() - .record(optimizer.to_record(), temp_dir.path().join("test_optim")) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_adamw(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - - const ASSERT_PRECISION: usize = 2; - - #[test] - fn test_adamw_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamWConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(0.5) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534], - [ - 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182, - ], - [ - -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981, - ], - [ - -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081, - ], - [ - 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993, - ], - [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580], - ]); - let bias_expected = Data::from([ - -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - #[test] - fn test_adam_optimizer_no_nan() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - - let x = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamWConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(0.5) - .init(); - - let grads = linear.forward(x.clone()).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - assert!(!state_updated.weight.to_data().value[0].is_nan()); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - fn create_adamw( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { - let config = AdamWConfig::new(); - AdamW { - momentum: AdaptiveMomentumW { - beta_1: config.beta_1, - beta_2: config.beta_2, - epsilon: config.epsilon, - }, - weight_decay: config.weight_decay, - _phantom: Default::default(), - } - .into() + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + use tempfile::TempDir; + + const LEARNING_RATE: LearningRate = 0.01; + + #[test] + fn test_adamw_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_adamw(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + let temp_dir = TempDir::new().unwrap(); + BinFileRecorder::::default() + .record(optimizer.to_record(), temp_dir.path().join("test_optim")) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_adamw(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + + const ASSERT_PRECISION: usize = 2; + + #[test] + fn test_adamw_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamWConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(0.5) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534], + [ + 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182, + ], + [ + -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981, + ], + [ + -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081, + ], + [ + 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993, + ], + [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580], + ]); + let bias_expected = Data::from([ + -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + #[test] + fn test_adam_optimizer_no_nan() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + + let x = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamWConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(0.5) + .init(); + + let grads = linear.forward(x.clone()).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + assert!(!state_updated.weight.to_data().value[0].is_nan()); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + fn create_adamw( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + { + let config = AdamWConfig::new(); + AdamW { + momentum: AdaptiveMomentumW { + beta_1: config.beta_1, + beta_2: config.beta_2, + epsilon: config.epsilon, + }, + weight_decay: config.weight_decay, + _phantom: Default::default(), } + .into() + } } diff --git a/burn-core/src/optim/base.rs b/burn-core/src/optim/base.rs index 5fc54fede3..3602efb57e 100644 --- a/burn-core/src/optim/base.rs +++ b/burn-core/src/optim/base.rs @@ -7,19 +7,19 @@ use crate::LearningRate; /// General trait to optimize [module](AutodiffModule). pub trait Optimizer: Send + Sync where - M: AutodiffModule, - B: AutodiffBackend, + M: AutodiffModule, + B: AutodiffBackend, { - /// Optimizer associative type to be used when saving and loading the state. - type Record: Record; + /// Optimizer associative type to be used when saving and loading the state. + type Record: Record; - /// Perform the optimizer step using the given learning rate and gradients. - /// The updated module is returned. - fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M; + /// Perform the optimizer step using the given learning rate and gradients. + /// The updated module is returned. + fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M; - /// Get the current state of the optimizer as a [record](Record). - fn to_record(&self) -> Self::Record; + /// Get the current state of the optimizer as a [record](Record). + fn to_record(&self) -> Self::Record; - /// Load the state of the optimizer as a [record](Record). - fn load_record(self, record: Self::Record) -> Self; + /// Load the state of the optimizer as a [record](Record). + fn load_record(self, record: Self::Record) -> Self; } diff --git a/burn-core/src/optim/decay.rs b/burn-core/src/optim/decay.rs index eb1653990d..a7ca4d8dcb 100644 --- a/burn-core/src/optim/decay.rs +++ b/burn-core/src/optim/decay.rs @@ -9,60 +9,60 @@ use crate::tensor::{ElementConversion, Tensor}; /// Configuration to create [weight decay](WeightDecay). #[derive(Config)] pub struct WeightDecayConfig { - /// L2 penalty. - pub penalty: f64, + /// L2 penalty. + pub penalty: f64, } /// State of [weight decay](WeightDecay). #[derive(Record, Clone, new)] pub struct WeightDecayState { - pub(crate) grad_last_step: Tensor, + pub(crate) grad_last_step: Tensor, } /// Weight decay implementation that transforms gradients. pub struct WeightDecay { - penalty: B::FloatElem, + penalty: B::FloatElem, } impl WeightDecay { - /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig). - pub fn new(config: &WeightDecayConfig) -> Self { - Self { - penalty: config.penalty.elem(), - } + /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig). + pub fn new(config: &WeightDecayConfig) -> Self { + Self { + penalty: config.penalty.elem(), } + } - /// Transforms a gradient. - /// - /// # Arguments - /// - /// * `grad` - Gradient to transform. - /// * `tensor` - Tensor param of the last iteration. - /// - /// # Returns - /// - /// * `grad` - Transformed gradient. - pub fn transform( - &self, - grad: Tensor, - tensor: Tensor, - ) -> Tensor { - tensor.mul_scalar(self.penalty).add(grad) - } + /// Transforms a gradient. + /// + /// # Arguments + /// + /// * `grad` - Gradient to transform. + /// * `tensor` - Tensor param of the last iteration. + /// + /// # Returns + /// + /// * `grad` - Transformed gradient. + pub fn transform( + &self, + grad: Tensor, + tensor: Tensor, + ) -> Tensor { + tensor.mul_scalar(self.penalty).add(grad) + } } impl WeightDecayState { - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.grad_last_step = self.grad_last_step.to_device(device); - self - } + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.grad_last_step = self.grad_last_step.to_device(device); + self + } } diff --git a/burn-core/src/optim/grad_accum.rs b/burn-core/src/optim/grad_accum.rs index e0e455bdf7..6655525b4f 100644 --- a/burn-core/src/optim/grad_accum.rs +++ b/burn-core/src/optim/grad_accum.rs @@ -8,115 +8,116 @@ use super::GradientsParams; /// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object. pub struct GradientsAccumulator { - grads: GradientsParams, - phantom: PhantomData, + grads: GradientsParams, + phantom: PhantomData, } impl Default for GradientsAccumulator { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl GradientsAccumulator { - /// Create a new gradients accumulator. - pub fn new() -> Self { - Self { - grads: GradientsParams::new(), - phantom: PhantomData, - } + /// Create a new gradients accumulator. + pub fn new() -> Self { + Self { + grads: GradientsParams::new(), + phantom: PhantomData, } + } } impl GradientsAccumulator { - /// Accumulate the given gradients for each parameter in the given module. - pub fn accumulate(&mut self, module: &M, grads: GradientsParams) - where - M: AutodiffModule, - { - let mut visitor = ModuleGradsAccumulator::::new(&mut self.grads, grads); - module.visit(&mut visitor); - } - - /// Return the accumulated gradients and reset the accumulator state. - pub fn grads(&mut self) -> GradientsParams { - let mut grads = GradientsParams::new(); - core::mem::swap(&mut self.grads, &mut grads); - - grads - } + /// Accumulate the given gradients for each parameter in the given module. + pub fn accumulate(&mut self, module: &M, grads: GradientsParams) + where + M: AutodiffModule, + { + let mut visitor = ModuleGradsAccumulator::::new(&mut self.grads, grads); + module.visit(&mut visitor); + } + + /// Return the accumulated gradients and reset the accumulator state. + pub fn grads(&mut self) -> GradientsParams { + let mut grads = GradientsParams::new(); + core::mem::swap(&mut self.grads, &mut grads); + + grads + } } #[derive(new)] struct ModuleGradsAccumulator<'a, M> { - grads: &'a mut GradientsParams, - grads_new: GradientsParams, - phantom: PhantomData, + grads: &'a mut GradientsParams, + grads_new: GradientsParams, + phantom: PhantomData, } impl<'a, B: AutodiffBackend, M: AutodiffModule> ModuleVisitor - for ModuleGradsAccumulator<'a, M> + for ModuleGradsAccumulator<'a, M> { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { - let grad_updated = match self.grads_new.remove::(id) { - Some(new) => match self.grads.remove::(id) { - Some(grad) => grad.add(new), - None => new, - }, - None => match self.grads.remove::(id) { - Some(grad) => grad, - None => return, - }, - }; - - self.grads - .register::(id.clone(), grad_updated); - } + fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + let grad_updated = match self.grads_new.remove::(id) { + Some(new) => match self.grads.remove::(id) { + Some(grad) => grad.add(new), + None => new, + }, + None => match self.grads.remove::(id) { + Some(grad) => grad, + None => return, + }, + }; + + self + .grads + .register::(id.clone(), grad_updated); + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - nn::{Linear, LinearConfig}, - TestAutodiffBackend, - }; - use burn_tensor::Distribution; - - #[test] - fn test_accumulate_gradients_one_step() { - let mut accumulator = GradientsAccumulator::new(); - let layer = layer(); - let loss = layer.forward(random_tensor()); - let grads = GradientsParams::from_grads(loss.backward(), &layer); - - accumulator.accumulate(&layer, grads); - - let grads = accumulator.grads(); - assert!(!grads.is_empty()) - } - - #[test] - fn test_accumulate_gradients_two_steps() { - let mut accumulator = GradientsAccumulator::new(); - let layer = layer(); - let loss_1 = layer.forward(random_tensor()); - let loss_2 = layer.forward(random_tensor()); - let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer); - let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer); - - accumulator.accumulate(&layer, grads_1); - accumulator.accumulate(&layer, grads_2); - - let grads = accumulator.grads(); - assert_eq!(grads.len(), 2) - } - - fn layer() -> Linear { - LinearConfig::new(20, 20).with_bias(true).init() - } - - fn random_tensor() -> Tensor { - Tensor::::random([2, 20], Distribution::Default) - } + use super::*; + use crate::{ + nn::{Linear, LinearConfig}, + TestAutodiffBackend, + }; + use burn_tensor::Distribution; + + #[test] + fn test_accumulate_gradients_one_step() { + let mut accumulator = GradientsAccumulator::new(); + let layer = layer(); + let loss = layer.forward(random_tensor()); + let grads = GradientsParams::from_grads(loss.backward(), &layer); + + accumulator.accumulate(&layer, grads); + + let grads = accumulator.grads(); + assert!(!grads.is_empty()) + } + + #[test] + fn test_accumulate_gradients_two_steps() { + let mut accumulator = GradientsAccumulator::new(); + let layer = layer(); + let loss_1 = layer.forward(random_tensor()); + let loss_2 = layer.forward(random_tensor()); + let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer); + let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer); + + accumulator.accumulate(&layer, grads_1); + accumulator.accumulate(&layer, grads_2); + + let grads = accumulator.grads(); + assert_eq!(grads.len(), 2) + } + + fn layer() -> Linear { + LinearConfig::new(20, 20).with_bias(true).init() + } + + fn random_tensor() -> Tensor { + Tensor::::random([2, 20], Distribution::Default) + } } diff --git a/burn-core/src/optim/grads.rs b/burn-core/src/optim/grads.rs index 81c3eb265c..79b05ae1d1 100644 --- a/burn-core/src/optim/grads.rs +++ b/burn-core/src/optim/grads.rs @@ -1,7 +1,7 @@ use burn_tensor::{ - backend::{AutodiffBackend, Backend}, - container::TensorContainer, - Tensor, + backend::{AutodiffBackend, Backend}, + container::TensorContainer, + Tensor, }; use crate::module::{AutodiffModule, ParamId}; @@ -11,115 +11,115 @@ use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter}; /// Data type that contains gradients for parameters. #[derive(Default)] pub struct GradientsParams { - container: TensorContainer, + container: TensorContainer, } impl GradientsParams { - /// Creates a new [GradientsParams](GradientsParams). - pub fn new() -> Self { - Self::default() - } - - /// Get the gradients for the given [parameter id](ParamId). - /// - /// # Notes - /// - /// You should use [remove](GradientsParams::remove) if you want to get the gradients - /// only one time. - pub fn get(&self, id: &ParamId) -> Option> - where - B: Backend, - { - self.container.get(id) - } - - /// Remove the gradients for the given [parameter id](ParamId). - pub fn remove(&mut self, id: &ParamId) -> Option> - where - B: Backend, - { - self.container.remove(id) - } - - /// Register a gradients tensor for the given [parameter id](ParamId). - /// - /// # Notes - /// - /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced. - pub fn register(&mut self, id: ParamId, value: Tensor) - where - B: Backend, - { - self.container.register(id, value) - } - - /// The number of gradients tensors registered. - pub fn len(&self) -> usize { - self.container.len() - } - - /// If any tensor is contained. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Change the device of each tensor gradients registered for the given [module](AutodiffModule). - pub fn to_device>( - mut self, - device: &B::Device, - module: &M, - ) -> Self { - let mut visitor = GradientsParamsChangeDevice::::new(device, &mut self); - module.visit(&mut visitor); - self - } - - /// Extract each tensor gradients for the given [module](AutodiffModule). - pub fn from_grads>( - grads: B::Gradients, - module: &M, - ) -> Self { - let mut grads_params = GradientsParams::new(); - let mut visitor = GradientsParamsConverter::::new(grads, &mut grads_params); - - module.visit(&mut visitor); - grads_params - } + /// Creates a new [GradientsParams](GradientsParams). + pub fn new() -> Self { + Self::default() + } + + /// Get the gradients for the given [parameter id](ParamId). + /// + /// # Notes + /// + /// You should use [remove](GradientsParams::remove) if you want to get the gradients + /// only one time. + pub fn get(&self, id: &ParamId) -> Option> + where + B: Backend, + { + self.container.get(id) + } + + /// Remove the gradients for the given [parameter id](ParamId). + pub fn remove(&mut self, id: &ParamId) -> Option> + where + B: Backend, + { + self.container.remove(id) + } + + /// Register a gradients tensor for the given [parameter id](ParamId). + /// + /// # Notes + /// + /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced. + pub fn register(&mut self, id: ParamId, value: Tensor) + where + B: Backend, + { + self.container.register(id, value) + } + + /// The number of gradients tensors registered. + pub fn len(&self) -> usize { + self.container.len() + } + + /// If any tensor is contained. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Change the device of each tensor gradients registered for the given [module](AutodiffModule). + pub fn to_device>( + mut self, + device: &B::Device, + module: &M, + ) -> Self { + let mut visitor = GradientsParamsChangeDevice::::new(device, &mut self); + module.visit(&mut visitor); + self + } + + /// Extract each tensor gradients for the given [module](AutodiffModule). + pub fn from_grads>( + grads: B::Gradients, + module: &M, + ) -> Self { + let mut grads_params = GradientsParams::new(); + let mut visitor = GradientsParamsConverter::::new(grads, &mut grads_params); + + module.visit(&mut visitor); + grads_params + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - module::{list_param_ids, Module}, - nn::{Linear, LinearConfig}, - TestAutodiffBackend, - }; - use burn_tensor::{backend::Backend, Distribution}; - - #[test] - fn test_convert_grads() { - let layer_1 = layer(); - let mut layer_2 = layer_1.clone(); - layer_2 = layer_2.fork(&::Device::default()); - let loss_1 = layer_1.forward(random_tensor()); - let loss_2 = layer_2.forward(random_tensor()); - let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1); - let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2); - - let param_ids_1 = list_param_ids(&layer_1); - let param_ids_2 = list_param_ids(&layer_2); - - assert_eq!(param_ids_1, param_ids_2); - assert_eq!(grads_1.len(), param_ids_1.len()); - assert_eq!(grads_2.len(), param_ids_2.len()); - } - - fn layer() -> Linear { - LinearConfig::new(20, 20).with_bias(true).init() - } - - fn random_tensor() -> Tensor { - Tensor::::random([2, 20], Distribution::Default) - } + use super::*; + use crate::{ + module::{list_param_ids, Module}, + nn::{Linear, LinearConfig}, + TestAutodiffBackend, + }; + use burn_tensor::{backend::Backend, Distribution}; + + #[test] + fn test_convert_grads() { + let layer_1 = layer(); + let mut layer_2 = layer_1.clone(); + layer_2 = layer_2.fork(&::Device::default()); + let loss_1 = layer_1.forward(random_tensor()); + let loss_2 = layer_2.forward(random_tensor()); + let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1); + let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2); + + let param_ids_1 = list_param_ids(&layer_1); + let param_ids_2 = list_param_ids(&layer_2); + + assert_eq!(param_ids_1, param_ids_2); + assert_eq!(grads_1.len(), param_ids_1.len()); + assert_eq!(grads_2.len(), param_ids_2.len()); + } + + fn layer() -> Linear { + LinearConfig::new(20, 20).with_bias(true).init() + } + + fn random_tensor() -> Tensor { + Tensor::::random([2, 20], Distribution::Default) + } } diff --git a/burn-core/src/optim/momentum.rs b/burn-core/src/optim/momentum.rs index ef2cb174f8..95f695ab77 100644 --- a/burn-core/src/optim/momentum.rs +++ b/burn-core/src/optim/momentum.rs @@ -8,86 +8,87 @@ use burn_tensor::backend::Backend; /// Configuration to create [momentum](Momentum). #[derive(Config)] pub struct MomentumConfig { - /// Momemtum factor - #[config(default = 0.9)] - pub momentum: f64, - /// Dampening factor. - #[config(default = 0.1)] - pub dampening: f64, - /// Enables Nesterov momentum, see [On the importance of initialization and - /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf). - #[config(default = false)] - pub nesterov: bool, + /// Momemtum factor + #[config(default = 0.9)] + pub momentum: f64, + /// Dampening factor. + #[config(default = 0.1)] + pub dampening: f64, + /// Enables Nesterov momentum, see [On the importance of initialization and + /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf). + #[config(default = false)] + pub nesterov: bool, } /// State of [momentum](Momentum). #[derive(Record, Clone, new)] pub struct MomentumState { - velocity: Tensor, + velocity: Tensor, } /// Momemtum implementation that transforms gradients. pub struct Momentum { - momentum: B::FloatElem, - dampening: f64, - nesterov: bool, + momentum: B::FloatElem, + dampening: f64, + nesterov: bool, } impl Momentum { - /// Creates a new [momentum](Momentum) from a [config](MomentumConfig). - pub fn new(config: &MomentumConfig) -> Self { - Self { - momentum: config.momentum.elem(), - dampening: config.dampening, - nesterov: config.nesterov, - } + /// Creates a new [momentum](Momentum) from a [config](MomentumConfig). + pub fn new(config: &MomentumConfig) -> Self { + Self { + momentum: config.momentum.elem(), + dampening: config.dampening, + nesterov: config.nesterov, } + } - /// Transforms a gradient. - /// - /// # Arguments - /// - /// * `grad` - Gradient to transform. - /// * `state` - State of the optimizer. - /// - /// # Returns - /// - /// * `grad` - Transformed gradient. - /// * `state` - State of the optimizer. - pub fn transform( - &self, - grad: Tensor, - state: Option>, - ) -> (Tensor, MomentumState) { - let velocity = if let Some(state) = state { - grad.clone() - .mul_scalar(1.0 - self.dampening) - .add(state.velocity.mul_scalar(self.momentum)) - } else { - grad.clone() - }; + /// Transforms a gradient. + /// + /// # Arguments + /// + /// * `grad` - Gradient to transform. + /// * `state` - State of the optimizer. + /// + /// # Returns + /// + /// * `grad` - Transformed gradient. + /// * `state` - State of the optimizer. + pub fn transform( + &self, + grad: Tensor, + state: Option>, + ) -> (Tensor, MomentumState) { + let velocity = if let Some(state) = state { + grad + .clone() + .mul_scalar(1.0 - self.dampening) + .add(state.velocity.mul_scalar(self.momentum)) + } else { + grad.clone() + }; - let grad = match self.nesterov { - true => velocity.clone().mul_scalar(self.momentum).add(grad), - false => velocity.clone(), - }; + let grad = match self.nesterov { + true => velocity.clone().mul_scalar(self.momentum).add(grad), + false => velocity.clone(), + }; - (grad, MomentumState::new(velocity)) - } + (grad, MomentumState::new(velocity)) + } } impl MomentumState { - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.velocity = self.velocity.to_device(device); - self - } + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.velocity = self.velocity.to_device(device); + self + } } diff --git a/burn-core/src/optim/rmsprop.rs b/burn-core/src/optim/rmsprop.rs index ffe683db34..72ea2b2e80 100644 --- a/burn-core/src/optim/rmsprop.rs +++ b/burn-core/src/optim/rmsprop.rs @@ -1,11 +1,11 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use super::{ - decay::{WeightDecay, WeightDecayConfig}, - SimpleOptimizer, + decay::{WeightDecay, WeightDecayConfig}, + SimpleOptimizer, }; use crate::config::Config; use crate::optim::adaptor::OptimizerAdaptor; @@ -15,510 +15,509 @@ use burn_tensor::backend::Backend; /// Configuration to create the [RMSProp](RMSProp) optimizer. #[derive(Config)] pub struct RMSPropConfig { - /// Smoothing constant. - #[config(default = 0.99)] - alpha: f32, - /// momentum for RMSProp. - #[config(default = 0.9)] - momentum: f32, - /// A value required for numerical stability. - #[config(default = 1e-5)] - epsilon: f32, - /// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance - #[config(default = false)] - centered: bool, - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + /// Smoothing constant. + #[config(default = 0.99)] + alpha: f32, + /// momentum for RMSProp. + #[config(default = 0.9)] + momentum: f32, + /// A value required for numerical stability. + #[config(default = 1e-5)] + epsilon: f32, + /// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance + #[config(default = false)] + centered: bool, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } impl RMSPropConfig { - /// Initialize RMSProp optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>( - &self, - ) -> OptimizerAdaptor, M, B> { - let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); - - let mut optim = OptimizerAdaptor::from(RMSProp { - alpha: self.alpha, - centered: self.centered, - weight_decay, - momentum: RMSPropMomentum { - momentum: self.momentum, - epsilon: self.epsilon, - }, - }); - - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); - } - - optim + /// Initialize RMSProp optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>( + &self, + ) -> OptimizerAdaptor, M, B> { + let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); + + let mut optim = OptimizerAdaptor::from(RMSProp { + alpha: self.alpha, + centered: self.centered, + weight_decay, + momentum: RMSPropMomentum { + momentum: self.momentum, + epsilon: self.epsilon, + }, + }); + + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); } + + optim + } } /// Optimizer that implements stochastic gradient descent with momentum. /// The optimizer can be configured with [RMSPropConfig](RMSPropConfig). pub struct RMSProp { - alpha: f32, - // epsilon: f32, - centered: bool, - // momentum: Option>, - momentum: RMSPropMomentum, - weight_decay: Option>, + alpha: f32, + // epsilon: f32, + centered: bool, + // momentum: Option>, + momentum: RMSPropMomentum, + weight_decay: Option>, } impl SimpleOptimizer for RMSProp { - type State = RMSPropState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - // fetch state for params - let mut state_square_avg = None; - let mut state_centered = None; - let mut state_momentum = None; - if let Some(state) = state { - state_square_avg = Some(state.square_avg); - state_centered = Some(state.centered); - state_momentum = state.momentum; - } - - // weight_decay transform - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); - } - - // square_avg transform - let (grad, state_square_avg) = - SquareAvgState::transform(self.alpha, grad, state_square_avg); - - // centered transform - let (grad, state_square_avg, state_centered) = CenteredState::transform( - self.alpha, - self.centered, - grad, - state_square_avg, - state_centered, - ); - - // momentum transform - let (grad, state_centered, state_momentum) = - self.momentum - .transform(grad, state_centered, state_momentum); - - // transition state - let state = RMSPropState::new(state_square_avg, state_centered, state_momentum); - - // tensor param transform - let delta = grad.mul_scalar(lr); - (tensor - delta, Some(state)) + type State = RMSPropState; + + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + // fetch state for params + let mut state_square_avg = None; + let mut state_centered = None; + let mut state_momentum = None; + if let Some(state) = state { + state_square_avg = Some(state.square_avg); + state_centered = Some(state.centered); + state_momentum = state.momentum; } - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.square_avg = state.square_avg.to_device(device); - state.centered = state.centered.to_device(device); - state.momentum = state.momentum.map(|momentum| momentum.to_device(device)); - state + // weight_decay transform + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); } + + // square_avg transform + let (grad, state_square_avg) = SquareAvgState::transform(self.alpha, grad, state_square_avg); + + // centered transform + let (grad, state_square_avg, state_centered) = CenteredState::transform( + self.alpha, + self.centered, + grad, + state_square_avg, + state_centered, + ); + + // momentum transform + let (grad, state_centered, state_momentum) = + self + .momentum + .transform(grad, state_centered, state_momentum); + + // transition state + let state = RMSPropState::new(state_square_avg, state_centered, state_momentum); + + // tensor param transform + let delta = grad.mul_scalar(lr); + (tensor - delta, Some(state)) + } + + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.square_avg = state.square_avg.to_device(device); + state.centered = state.centered.to_device(device); + state.momentum = state.momentum.map(|momentum| momentum.to_device(device)); + state + } } /// State of [RMSProp](RMSProp) #[derive(Record, Clone, new)] pub struct RMSPropState { - square_avg: SquareAvgState, - centered: CenteredState, - momentum: Option>, + square_avg: SquareAvgState, + centered: CenteredState, + momentum: Option>, } /// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct SquareAvgState { - square_avg: Tensor, + square_avg: Tensor, } impl SquareAvgState { - /// transform [SquareAvgState] to the next step - fn transform(alpha: f32, grad: Tensor, state: Option) -> (Tensor, Self) { - match state { - Some(state) => { - let square_avg = state - .square_avg - .mul_scalar(alpha) - .add(grad.clone().powf(2.).mul_scalar(1. - alpha)); - (grad, Self { square_avg }) - } - _ => { - let square_avg = grad.clone().powf(2.).mul_scalar(1. - alpha); - (grad, Self { square_avg }) - } - } - } - - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.square_avg = self.square_avg.to_device(device); - self + /// transform [SquareAvgState] to the next step + fn transform(alpha: f32, grad: Tensor, state: Option) -> (Tensor, Self) { + match state { + Some(state) => { + let square_avg = state + .square_avg + .mul_scalar(alpha) + .add(grad.clone().powf(2.).mul_scalar(1. - alpha)); + (grad, Self { square_avg }) + } + _ => { + let square_avg = grad.clone().powf(2.).mul_scalar(1. - alpha); + (grad, Self { square_avg }) + } } + } + + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.square_avg = self.square_avg.to_device(device); + self + } } /// [CenteredState](CenteredState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct CenteredState { - grad_avg: Option>, - avg: Tensor, + grad_avg: Option>, + avg: Tensor, } impl CenteredState { - /// transform [CenteredState] to the next step - fn transform( - alpha: f32, - centered: bool, - grad: Tensor, - square_avg_state: SquareAvgState, - centered_state: Option, - ) -> (Tensor, SquareAvgState, Self) { - if centered { - let grad_avg_constant = grad.clone().mul_scalar(1. - alpha); - let grad_avg = match centered_state { - Some(state) => state - .grad_avg - .map_or(grad_avg_constant.clone(), move |grad_avg| { - grad_avg.mul_scalar(alpha).add(grad_avg_constant) - }), - _ => grad_avg_constant, - }; - let avg = square_avg_state - .square_avg - .clone() - .sub(grad_avg.clone().powf(2.)); - - ( - grad, - square_avg_state, - Self { - grad_avg: Some(grad_avg), - avg, - }, - ) - } else { - ( - grad, - square_avg_state.clone(), - Self { - grad_avg: None, - avg: square_avg_state.square_avg, - }, - ) - } - } - - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device)); - self.avg = self.avg.to_device(device); - self + /// transform [CenteredState] to the next step + fn transform( + alpha: f32, + centered: bool, + grad: Tensor, + square_avg_state: SquareAvgState, + centered_state: Option, + ) -> (Tensor, SquareAvgState, Self) { + if centered { + let grad_avg_constant = grad.clone().mul_scalar(1. - alpha); + let grad_avg = match centered_state { + Some(state) => state + .grad_avg + .map_or(grad_avg_constant.clone(), move |grad_avg| { + grad_avg.mul_scalar(alpha).add(grad_avg_constant) + }), + _ => grad_avg_constant, + }; + let avg = square_avg_state + .square_avg + .clone() + .sub(grad_avg.clone().powf(2.)); + + ( + grad, + square_avg_state, + Self { + grad_avg: Some(grad_avg), + avg, + }, + ) + } else { + ( + grad, + square_avg_state.clone(), + Self { + grad_avg: None, + avg: square_avg_state.square_avg, + }, + ) } + } + + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device)); + self.avg = self.avg.to_device(device); + self + } } /// [RMSPropMomentum](RMSPropMomentum) is to store config status for optimizer. /// (, which is stored in [optimizer](RMSProp) itself and not passed in during `step()` calculation) pub struct RMSPropMomentum { - momentum: f32, - epsilon: f32, + momentum: f32, + epsilon: f32, } impl RMSPropMomentum { - /// transform [grad](Tensor) and [RMSPropMomentumState] to the next step - fn transform( - &self, - grad: Tensor, - centered_state: CenteredState, - momentum_state: Option>, - ) -> ( - Tensor, - CenteredState, - Option>, - ) { - let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); - - if self.momentum > 0. { - let buf = match momentum_state { - Some(state) => state.buf.mul_scalar(self.momentum).add(grad), - _ => grad, - }; - ( - buf.clone(), - centered_state, - Some(RMSPropMomentumState { buf }), - ) - } else { - (grad, centered_state, None) - } + /// transform [grad](Tensor) and [RMSPropMomentumState] to the next step + fn transform( + &self, + grad: Tensor, + centered_state: CenteredState, + momentum_state: Option>, + ) -> ( + Tensor, + CenteredState, + Option>, + ) { + let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); + + if self.momentum > 0. { + let buf = match momentum_state { + Some(state) => state.buf.mul_scalar(self.momentum).add(grad), + _ => grad, + }; + ( + buf.clone(), + centered_state, + Some(RMSPropMomentumState { buf }), + ) + } else { + (grad, centered_state, None) } + } } /// [RMSPropMomentumState](RMSPropMomentumState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct RMSPropMomentumState { - buf: Tensor, + buf: Tensor, } impl RMSPropMomentumState { - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.buf = self.buf.to_device(device); - self - } + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.buf = self.buf.to_device(device); + self + } } #[cfg(test)] mod tests { - use burn_tensor::Shape; - - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - use tempfile::TempDir; - - const LEARNING_RATE: LearningRate = 0.01; - const ASSERT_PRECISION: usize = 6; - - #[test] - fn test_rmsprop_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_rmsprop(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - let temp_dir = TempDir::new().unwrap(); - BinFileRecorder::::default() - .record(optimizer.to_record(), temp_dir.path().join("test_optim")) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_rmsprop(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - - /// used for test differences and debug - #[test] - fn test_rmsprop_optimizer_with_numbers_basic() { - let linear = given_linear_layer( - Data::from([ - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - ]), - Data::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = RMSPropConfig::new() - .with_alpha(0.99) - .with_epsilon(1e-8) - .with_weight_decay(WeightDecayConfig::new(0.05).into()) - .with_momentum(0.9) - .with_centered(false) - .init(); - - // println!("linear is {:?}", linear); - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - // println!("linear is {:?}", linear); - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - // println!("linear is {:?}", linear); - let state_updated = linear.into_record(); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - // println!("\nweight_updated\n{:?}", weight_updated); - // println!("\nbias_updated\n{:?}", bias_updated); - - let weights_expected = Data::from([ - [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937], - [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809], - [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881], - [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366], - [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005], - [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710], - ]); - let bias_expected = - Data::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - #[test] - fn test_rmsprop_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = RMSPropConfig::new() - .with_alpha(0.99) - .with_epsilon(1e-8) - .with_weight_decay(WeightDecayConfig::new(0.05).into()) - .with_momentum(0.9) - .with_centered(false) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [ - -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779, - ], - [ - -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207, - ], - [ - -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967, - ], - [ - -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997, - ], - [ - 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912, - ], - [ - -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126, - ], - ]); - let bias_expected = Data::from([ - -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - // println!("\nweight_updated\n{:?}", weight_updated); - // println!("\nbias_updated\n{:?}", bias_updated); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - #[allow(dead_code)] - fn create_random_tensor() -> Tensor { - Tensor::::random(Shape::new([2, 20]), Distribution::Default) - } - - fn create_rmsprop( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { - RMSPropConfig { - alpha: 0.99, - epsilon: 1e-9, - centered: false, - weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), - momentum: 0.9, - grad_clipping: None, - } - .init() + use burn_tensor::Shape; + + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + use tempfile::TempDir; + + const LEARNING_RATE: LearningRate = 0.01; + const ASSERT_PRECISION: usize = 6; + + #[test] + fn test_rmsprop_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_rmsprop(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + let temp_dir = TempDir::new().unwrap(); + BinFileRecorder::::default() + .record(optimizer.to_record(), temp_dir.path().join("test_optim")) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_rmsprop(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + + /// used for test differences and debug + #[test] + fn test_rmsprop_optimizer_with_numbers_basic() { + let linear = given_linear_layer( + Data::from([ + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + ]), + Data::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = RMSPropConfig::new() + .with_alpha(0.99) + .with_epsilon(1e-8) + .with_weight_decay(WeightDecayConfig::new(0.05).into()) + .with_momentum(0.9) + .with_centered(false) + .init(); + + // println!("linear is {:?}", linear); + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + // println!("linear is {:?}", linear); + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + // println!("linear is {:?}", linear); + let state_updated = linear.into_record(); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + // println!("\nweight_updated\n{:?}", weight_updated); + // println!("\nbias_updated\n{:?}", bias_updated); + + let weights_expected = Data::from([ + [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937], + [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809], + [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881], + [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366], + [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005], + [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710], + ]); + let bias_expected = Data::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + #[test] + fn test_rmsprop_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = RMSPropConfig::new() + .with_alpha(0.99) + .with_epsilon(1e-8) + .with_weight_decay(WeightDecayConfig::new(0.05).into()) + .with_momentum(0.9) + .with_centered(false) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [ + -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779, + ], + [ + -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207, + ], + [ + -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967, + ], + [ + -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997, + ], + [ + 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912, + ], + [ + -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126, + ], + ]); + let bias_expected = Data::from([ + -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + // println!("\nweight_updated\n{:?}", weight_updated); + // println!("\nbias_updated\n{:?}", bias_updated); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + #[allow(dead_code)] + fn create_random_tensor() -> Tensor { + Tensor::::random(Shape::new([2, 20]), Distribution::Default) + } + + fn create_rmsprop( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + { + RMSPropConfig { + alpha: 0.99, + epsilon: 1e-9, + centered: false, + weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), + momentum: 0.9, + grad_clipping: None, } + .init() + } } diff --git a/burn-core/src/optim/sgd.rs b/burn-core/src/optim/sgd.rs index b2ed4a4b75..b5d1526264 100644 --- a/burn-core/src/optim/sgd.rs +++ b/burn-core/src/optim/sgd.rs @@ -14,163 +14,163 @@ use burn_tensor::backend::{AutodiffBackend, Backend}; /// Configuration to create the [Sgd](Sgd) optimizer. #[derive(Config)] pub struct SgdConfig { - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Momentum](MomentumConfig) config. - momentum: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - gradient_clipping: Option, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Momentum](MomentumConfig) config. + momentum: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + gradient_clipping: Option, } /// Optimizer that implements stochastic gradient descent with momentum. /// /// The optimizer can be configured with [SgdConfig](SgdConfig). pub struct Sgd { - momentum: Option>, - weight_decay: Option>, + momentum: Option>, + weight_decay: Option>, } /// State of [Sgd](Sgd). #[derive(Record, Clone, new)] pub struct SgdState { - momentum: Option>, + momentum: Option>, } impl SgdConfig { - /// Creates a new [SgdConfig](SgdConfig) with default values. - pub fn init>( - &self, - ) -> OptimizerAdaptor, M, B> { - let momentum = self.momentum.as_ref().map(Momentum::new); - let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); - - let mut optim = OptimizerAdaptor::from(Sgd { - momentum, - weight_decay, - }); - if let Some(config) = &self.gradient_clipping { - optim = optim.with_grad_clipping(config.init()); - } - optim + /// Creates a new [SgdConfig](SgdConfig) with default values. + pub fn init>( + &self, + ) -> OptimizerAdaptor, M, B> { + let momentum = self.momentum.as_ref().map(Momentum::new); + let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); + + let mut optim = OptimizerAdaptor::from(Sgd { + momentum, + weight_decay, + }); + if let Some(config) = &self.gradient_clipping { + optim = optim.with_grad_clipping(config.init()); } + optim + } } impl SimpleOptimizer for Sgd { - type State = SgdState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - let mut state_momemtum = None; - - if let Some(state) = state { - state_momemtum = state.momentum; - } - - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); - } - - if let Some(momentum) = &self.momentum { - let (grad_out, state) = momentum.transform(grad, state_momemtum); - state_momemtum = Some(state); - grad = grad_out; - } - - let state = SgdState::new(state_momemtum); - let delta = grad.mul_scalar(lr); - - (tensor - delta, Some(state)) + type State = SgdState; + + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + let mut state_momemtum = None; + + if let Some(state) = state { + state_momemtum = state.momentum; } - fn to_device(mut state: Self::State, device: &B::Device) -> Self::State { - state.momentum = state.momentum.map(|state| state.to_device(device)); - state + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); } -} -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - grad_clipping::GradientClipping, - nn::{Linear, LinearConfig}, - optim::{GradientsParams, Optimizer}, - tensor::{Distribution, Shape}, - TestAutodiffBackend, TestBackend, - }; - - const LEARNING_RATE: LearningRate = 0.02; - - #[test] - fn with_updated_params_should_have_state() { - let layer = layer(); - let mut optim = sgd_with_all(); - let loss = layer.forward(random_tensor()); - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &layer); - let _layer = optim.step(LEARNING_RATE, layer, grads); - - let record = optim.to_record(); - - assert!(!record.is_empty()); + if let Some(momentum) = &self.momentum { + let (grad_out, state) = momentum.transform(grad, state_momemtum); + state_momemtum = Some(state); + grad = grad_out; } - #[test] - fn without_updated_params_should_not_have_state() { - let optim = sgd_with_all(); - let record = optim.to_record(); - assert!(record.is_empty()); - } + let state = SgdState::new(state_momemtum); + let delta = grad.mul_scalar(lr); - #[test] - fn can_attach_gradient_clipping() { - let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5)); - assert!(optim.has_gradient_clipping()); - } - - #[test] - fn should_load_state() { - let layer = layer(); - let mut optim = sgd_with_all(); - let loss = layer.forward(random_tensor()); - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &layer); - let _layer = optim.step(LEARNING_RATE, layer, grads); - - let record = optim.to_record(); - let optim_new = sgd_with_all(); - let record_new = optim_new.to_record(); - let optim_new = optim_new.load_record(record.clone()); - let state_restored = optim_new.to_record(); - - assert_ne!(record.len(), record_new.len()); - assert_eq!(record.len(), state_restored.len()); - } + (tensor - delta, Some(state)) + } - fn random_tensor() -> Tensor { - Tensor::::random(Shape::new([2, 20]), Distribution::Default) - } - - fn layer() -> Linear { - LinearConfig::new(20, 20).with_bias(true).init() - } + fn to_device(mut state: Self::State, device: &B::Device) -> Self::State { + state.momentum = state.momentum.map(|state| state.to_device(device)); + state + } +} - fn sgd_with_all( - ) -> OptimizerAdaptor, Linear, TestAutodiffBackend> { - SgdConfig { - weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), - momentum: Some(MomentumConfig { - momentum: 0.9, - dampening: 0.1, - nesterov: true, - }), - gradient_clipping: None, - } - .init() +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + grad_clipping::GradientClipping, + nn::{Linear, LinearConfig}, + optim::{GradientsParams, Optimizer}, + tensor::{Distribution, Shape}, + TestAutodiffBackend, TestBackend, + }; + + const LEARNING_RATE: LearningRate = 0.02; + + #[test] + fn with_updated_params_should_have_state() { + let layer = layer(); + let mut optim = sgd_with_all(); + let loss = layer.forward(random_tensor()); + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &layer); + let _layer = optim.step(LEARNING_RATE, layer, grads); + + let record = optim.to_record(); + + assert!(!record.is_empty()); + } + + #[test] + fn without_updated_params_should_not_have_state() { + let optim = sgd_with_all(); + let record = optim.to_record(); + assert!(record.is_empty()); + } + + #[test] + fn can_attach_gradient_clipping() { + let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5)); + assert!(optim.has_gradient_clipping()); + } + + #[test] + fn should_load_state() { + let layer = layer(); + let mut optim = sgd_with_all(); + let loss = layer.forward(random_tensor()); + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &layer); + let _layer = optim.step(LEARNING_RATE, layer, grads); + + let record = optim.to_record(); + let optim_new = sgd_with_all(); + let record_new = optim_new.to_record(); + let optim_new = optim_new.load_record(record.clone()); + let state_restored = optim_new.to_record(); + + assert_ne!(record.len(), record_new.len()); + assert_eq!(record.len(), state_restored.len()); + } + + fn random_tensor() -> Tensor { + Tensor::::random(Shape::new([2, 20]), Distribution::Default) + } + + fn layer() -> Linear { + LinearConfig::new(20, 20).with_bias(true).init() + } + + fn sgd_with_all( + ) -> OptimizerAdaptor, Linear, TestAutodiffBackend> { + SgdConfig { + weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), + momentum: Some(MomentumConfig { + momentum: 0.9, + dampening: 0.1, + nesterov: true, + }), + gradient_clipping: None, } + .init() + } } diff --git a/burn-core/src/optim/simple/adaptor.rs b/burn-core/src/optim/simple/adaptor.rs index 0b44c84183..996d9b6d27 100644 --- a/burn-core/src/optim/simple/adaptor.rs +++ b/burn-core/src/optim/simple/adaptor.rs @@ -1,9 +1,9 @@ use super::{record::AdaptorRecord, SimpleOptimizer}; use crate::{ - grad_clipping::GradientClipping, - module::{AutodiffModule, ModuleMapper, ParamId}, - optim::{GradientsParams, Optimizer}, - LearningRate, + grad_clipping::GradientClipping, + module::{AutodiffModule, ModuleMapper, ParamId}, + optim::{GradientsParams, Optimizer}, + LearningRate, }; use burn_tensor::{backend::AutodiffBackend, Tensor}; use core::marker::PhantomData; @@ -13,143 +13,143 @@ use hashbrown::HashMap; /// an [optimizer](Optimizer). pub struct OptimizerAdaptor where - O: SimpleOptimizer, - M: AutodiffModule, - B: AutodiffBackend, + O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, { - optim: O, - records: HashMap>, - module: PhantomData, - grad_clipping: Option, + optim: O, + records: HashMap>, + module: PhantomData, + grad_clipping: Option, } impl From for OptimizerAdaptor where - B: AutodiffBackend, - M: AutodiffModule, - O: SimpleOptimizer, + B: AutodiffBackend, + M: AutodiffModule, + O: SimpleOptimizer, { - fn from(optim: O) -> Self { - Self { - optim, - records: HashMap::new(), - module: PhantomData, - grad_clipping: None, - } + fn from(optim: O) -> Self { + Self { + optim, + records: HashMap::new(), + module: PhantomData, + grad_clipping: None, } + } } impl OptimizerAdaptor where - O: SimpleOptimizer, - M: AutodiffModule, - B: AutodiffBackend, + O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, { - /// Sets the gradient clipping. - /// - /// # Arguments - /// - /// * `gradient_clipping` - The gradient clipping. - /// - /// # Returns - /// - /// The optimizer. - pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self { - self.grad_clipping = Some(gradient_clipping); - self - } - - #[cfg(test)] - pub(crate) fn has_gradient_clipping(&self) -> bool { - self.grad_clipping.is_some() - } + /// Sets the gradient clipping. + /// + /// # Arguments + /// + /// * `gradient_clipping` - The gradient clipping. + /// + /// # Returns + /// + /// The optimizer. + pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self { + self.grad_clipping = Some(gradient_clipping); + self + } + + #[cfg(test)] + pub(crate) fn has_gradient_clipping(&self) -> bool { + self.grad_clipping.is_some() + } } impl Optimizer for OptimizerAdaptor where - B: AutodiffBackend, - M: AutodiffModule, - O: SimpleOptimizer, + B: AutodiffBackend, + M: AutodiffModule, + O: SimpleOptimizer, { - type Record = HashMap>; - - fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M { - let mut mapper = SimpleOptimizerMapper::::new( - &self.optim, - &mut self.records, - &mut grads, - lr, - self.grad_clipping.as_ref(), - ); - module.map(&mut mapper) - } - - fn to_record(&self) -> Self::Record { - self.records.clone() - } - - fn load_record(mut self, record: Self::Record) -> Self { - self.records = record; - self - } + type Record = HashMap>; + + fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M { + let mut mapper = SimpleOptimizerMapper::::new( + &self.optim, + &mut self.records, + &mut grads, + lr, + self.grad_clipping.as_ref(), + ); + module.map(&mut mapper) + } + + fn to_record(&self) -> Self::Record { + self.records.clone() + } + + fn load_record(mut self, record: Self::Record) -> Self { + self.records = record; + self + } } #[derive(new)] struct SimpleOptimizerMapper<'a, M, B, O> where - M: AutodiffModule, - B: AutodiffBackend, - O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, + O: SimpleOptimizer, { - optimizer: &'a O, - records: &'a mut HashMap>, - grads: &'a mut GradientsParams, - lr: LearningRate, - phantom: PhantomData, - grad_clipping: Option<&'a GradientClipping>, + optimizer: &'a O, + records: &'a mut HashMap>, + grads: &'a mut GradientsParams, + lr: LearningRate, + phantom: PhantomData, + grad_clipping: Option<&'a GradientClipping>, } impl<'a, M, B, O> ModuleMapper for SimpleOptimizerMapper<'a, M, B, O> where - M: AutodiffModule, - B: AutodiffBackend, - O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, + O: SimpleOptimizer, { - fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor { - let grad = self.grads.remove(id); - - if let Some(grad) = grad { - let device = grad.device(); - let is_require_grad = tensor.is_require_grad(); - let (key, record) = self.records.remove_entry(id).unzip(); - - let clipped_grad = if let Some(g_clipping) = self.grad_clipping { - g_clipping.clip_gradient(grad) - } else { - grad - }; - - let (tensor, state) = self.optimizer.step( - self.lr, - tensor.inner(), - clipped_grad, - record.map(|record| O::to_device(record.into_state(), &device)), - ); - - if let Some(state) = state { - self.records.insert( - key.unwrap_or_else(|| id.clone()), - AdaptorRecord::from_state(state), - ); - } - - let mut tensor = Tensor::from_inner(tensor); - if is_require_grad { - tensor = tensor.require_grad(); - } - return tensor; - } - - tensor + fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor { + let grad = self.grads.remove(id); + + if let Some(grad) = grad { + let device = grad.device(); + let is_require_grad = tensor.is_require_grad(); + let (key, record) = self.records.remove_entry(id).unzip(); + + let clipped_grad = if let Some(g_clipping) = self.grad_clipping { + g_clipping.clip_gradient(grad) + } else { + grad + }; + + let (tensor, state) = self.optimizer.step( + self.lr, + tensor.inner(), + clipped_grad, + record.map(|record| O::to_device(record.into_state(), &device)), + ); + + if let Some(state) = state { + self.records.insert( + key.unwrap_or_else(|| id.clone()), + AdaptorRecord::from_state(state), + ); + } + + let mut tensor = Tensor::from_inner(tensor); + if is_require_grad { + tensor = tensor.require_grad(); + } + return tensor; } + + tensor + } } diff --git a/burn-core/src/optim/simple/base.rs b/burn-core/src/optim/simple/base.rs index 5737960ad9..97d2fb961d 100644 --- a/burn-core/src/optim/simple/base.rs +++ b/burn-core/src/optim/simple/base.rs @@ -8,26 +8,26 @@ use burn_tensor::{backend::Backend, Tensor}; /// module parameter structure, handle tracked and untracked tensors, and the likes. pub trait SimpleOptimizer: Send + Sync where - B: Backend, + B: Backend, { - /// The state of the optimizer. It also implements [record](Record), so that it can be saved. - type State: Record + Clone + 'static; + /// The state of the optimizer. It also implements [record](Record), so that it can be saved. + type State: Record + Clone + 'static; - /// The optimizer step is performed for one tensor at a time with its gradient and state. - /// - /// Note that the state is passed as parameter, so implementations don't have to handle - /// the saving and loading of recorded states. - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - grad: Tensor, - state: Option>, - ) -> (Tensor, Option>); + /// The optimizer step is performed for one tensor at a time with its gradient and state. + /// + /// Note that the state is passed as parameter, so implementations don't have to handle + /// the saving and loading of recorded states. + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + state: Option>, + ) -> (Tensor, Option>); - /// Change the device of the state. - /// - /// This function will be called accordindly to have the state on the same device as the - /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called. - fn to_device(state: Self::State, device: &B::Device) -> Self::State; + /// Change the device of the state. + /// + /// This function will be called accordindly to have the state on the same device as the + /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called. + fn to_device(state: Self::State, device: &B::Device) -> Self::State; } diff --git a/burn-core/src/optim/simple/record/base.rs b/burn-core/src/optim/simple/record/base.rs index e0bc9199d5..75b2196fa9 100644 --- a/burn-core/src/optim/simple/record/base.rs +++ b/burn-core/src/optim/simple/record/base.rs @@ -1,7 +1,7 @@ use super::{AdaptorRecordItemV1, AdaptorRecordV1}; use crate::{ - optim::SimpleOptimizer, - record::{PrecisionSettings, Record}, + optim::SimpleOptimizer, + record::{PrecisionSettings, Record}, }; use burn_tensor::backend::Backend; use serde::{Deserialize, Serialize}; @@ -10,76 +10,76 @@ use serde::{Deserialize, Serialize}; /// /// Records are versioned for backward compatibility, so old records can be loaded. pub enum AdaptorRecord, B: Backend> { - /// Version 1. - V1(AdaptorRecordV1), + /// Version 1. + V1(AdaptorRecordV1), } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub enum AdaptorRecordItem, B: Backend, S: PrecisionSettings> { - /// Version 1. - V1(AdaptorRecordItemV1), + /// Version 1. + V1(AdaptorRecordItemV1), } impl Record for AdaptorRecord where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - type Item = AdaptorRecordItem; + type Item = AdaptorRecordItem; - fn into_item(self) -> Self::Item { - match self { - AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()), - } + fn into_item(self) -> Self::Item { + match self { + AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()), } + } - fn from_item(item: Self::Item) -> Self { - match item { - AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item)), - } + fn from_item(item: Self::Item) -> Self { + match item { + AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item)), } + } } impl Clone for AdaptorRecord where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - fn clone(&self) -> Self { - match self { - AdaptorRecord::V1(record) => Self::V1(record.clone()), - } + fn clone(&self) -> Self { + match self { + AdaptorRecord::V1(record) => Self::V1(record.clone()), } + } } impl AdaptorRecord where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - /// Converts the record into the optimizer state. - /// - /// # Returns - /// - /// The optimizer state. - pub fn into_state(self) -> O::State { - match self { - AdaptorRecord::V1(record) => record.into_state(), - } + /// Converts the record into the optimizer state. + /// + /// # Returns + /// + /// The optimizer state. + pub fn into_state(self) -> O::State { + match self { + AdaptorRecord::V1(record) => record.into_state(), } + } - /// Converts the optimizer state into the record. - /// - /// # Arguments - /// - /// * `state`: The optimizer state. - /// - /// # Returns - /// - /// The record. - pub fn from_state(state: O::State) -> Self { - Self::V1(AdaptorRecordV1::from_state(state)) - } + /// Converts the optimizer state into the record. + /// + /// # Arguments + /// + /// * `state`: The optimizer state. + /// + /// # Returns + /// + /// The record. + pub fn from_state(state: O::State) -> Self { + Self::V1(AdaptorRecordV1::from_state(state)) + } } diff --git a/burn-core/src/optim/simple/record/v1.rs b/burn-core/src/optim/simple/record/v1.rs index 9c47403473..7c721f6347 100644 --- a/burn-core/src/optim/simple/record/v1.rs +++ b/burn-core/src/optim/simple/record/v1.rs @@ -1,6 +1,6 @@ use crate::{ - optim::SimpleOptimizer, - record::{PrecisionSettings, Record}, + optim::SimpleOptimizer, + record::{PrecisionSettings, Record}, }; use burn_tensor::backend::Backend; use core::any::Any; @@ -8,178 +8,178 @@ use serde::{Deserialize, Serialize}; /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. pub enum AdaptorRecordV1, B: Backend> { - /// Rank 1. - Rank1(O::State<1>), + /// Rank 1. + Rank1(O::State<1>), - /// Rank 2. - Rank2(O::State<2>), + /// Rank 2. + Rank2(O::State<2>), - /// Rank 3. - Rank3(O::State<3>), + /// Rank 3. + Rank3(O::State<3>), - /// Rank 4. - Rank4(O::State<4>), + /// Rank 4. + Rank4(O::State<4>), - /// Rank 5. - Rank5(O::State<5>), + /// Rank 5. + Rank5(O::State<5>), - /// Rank 6. - Rank6(O::State<6>), + /// Rank 6. + Rank6(O::State<6>), - /// Rank 7. - Rank7(O::State<7>), + /// Rank 7. + Rank7(O::State<7>), - /// Rank 8. - Rank8(O::State<8>), + /// Rank 8. + Rank8(O::State<8>), } impl, B: Backend> Clone for AdaptorRecordV1 { - fn clone(&self) -> Self { - match self { - AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()), - AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()), - AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()), - AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()), - AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()), - AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()), - AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()), - AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()), - } + fn clone(&self) -> Self { + match self { + AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()), + AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()), + AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()), + AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()), + AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()), + AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()), + AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()), + AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()), } + } } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub enum AdaptorRecordItemV1, B: Backend, S: PrecisionSettings> { - /// Rank 1. - Rank1( as Record>::Item), + /// Rank 1. + Rank1( as Record>::Item), - /// Rank 2. - Rank2( as Record>::Item), + /// Rank 2. + Rank2( as Record>::Item), - /// Rank 3. - Rank3( as Record>::Item), + /// Rank 3. + Rank3( as Record>::Item), - /// Rank 4. - Rank4( as Record>::Item), + /// Rank 4. + Rank4( as Record>::Item), - /// Rank 5. - Rank5( as Record>::Item), + /// Rank 5. + Rank5( as Record>::Item), - /// Rank 6. - Rank6( as Record>::Item), + /// Rank 6. + Rank6( as Record>::Item), - /// Rank 7. - Rank7( as Record>::Item), + /// Rank 7. + Rank7( as Record>::Item), - /// Rank 8. - Rank8( as Record>::Item), + /// Rank 8. + Rank8( as Record>::Item), } impl AdaptorRecordV1 where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - /// Convert the record into the state. - /// - /// # Returns - /// - /// The state. - /// - /// # Panics - /// - /// Panics if the state dimension is not supported. - pub fn into_state(self) -> O::State { - let boxed_state: Box = match self { - AdaptorRecordV1::Rank1(s) => Box::new(s), - AdaptorRecordV1::Rank2(s) => Box::new(s), - AdaptorRecordV1::Rank3(s) => Box::new(s), - AdaptorRecordV1::Rank4(s) => Box::new(s), - AdaptorRecordV1::Rank5(s) => Box::new(s), - AdaptorRecordV1::Rank6(s) => Box::new(s), - AdaptorRecordV1::Rank7(s) => Box::new(s), - AdaptorRecordV1::Rank8(s) => Box::new(s), - }; - let state = boxed_state - .downcast::>() - .expect("Unsupported state dimension, dimension up to 8 are supported."); - *state - } - - /// Convert the state into the record. - /// - /// # Arguments - /// - /// * `state`: The state. - /// - /// # Returns - /// - /// The record. - pub fn from_state(state: O::State) -> Self { - let state: Box = Box::new(state); - - match D { - 1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()), - 2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()), - 3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()), - 4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()), - 5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()), - 6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()), - 7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()), - 8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()), - _ => panic!("Unsupported state dimension, dimension up to 8 are supported."), - } + /// Convert the record into the state. + /// + /// # Returns + /// + /// The state. + /// + /// # Panics + /// + /// Panics if the state dimension is not supported. + pub fn into_state(self) -> O::State { + let boxed_state: Box = match self { + AdaptorRecordV1::Rank1(s) => Box::new(s), + AdaptorRecordV1::Rank2(s) => Box::new(s), + AdaptorRecordV1::Rank3(s) => Box::new(s), + AdaptorRecordV1::Rank4(s) => Box::new(s), + AdaptorRecordV1::Rank5(s) => Box::new(s), + AdaptorRecordV1::Rank6(s) => Box::new(s), + AdaptorRecordV1::Rank7(s) => Box::new(s), + AdaptorRecordV1::Rank8(s) => Box::new(s), + }; + let state = boxed_state + .downcast::>() + .expect("Unsupported state dimension, dimension up to 8 are supported."); + *state + } + + /// Convert the state into the record. + /// + /// # Arguments + /// + /// * `state`: The state. + /// + /// # Returns + /// + /// The record. + pub fn from_state(state: O::State) -> Self { + let state: Box = Box::new(state); + + match D { + 1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()), + 2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()), + 3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()), + 4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()), + 5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()), + 6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()), + 7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()), + 8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()), + _ => panic!("Unsupported state dimension, dimension up to 8 are supported."), } + } } impl Record for AdaptorRecordV1 where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - type Item = AdaptorRecordItemV1; - - fn into_item(self) -> Self::Item { - match self { - AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()), - AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()), - AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()), - AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()), - AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()), - AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()), - AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()), - AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()), - } + type Item = AdaptorRecordItemV1; + + fn into_item(self) -> Self::Item { + match self { + AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()), + AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()), + AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()), + AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()), + AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()), + AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()), + AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()), + AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()), } - - fn from_item(item: Self::Item) -> Self { - match item { - AdaptorRecordItemV1::Rank1(item) => { - AdaptorRecordV1::Rank1( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank2(item) => { - AdaptorRecordV1::Rank2( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank3(item) => { - AdaptorRecordV1::Rank3( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank4(item) => { - AdaptorRecordV1::Rank4( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank5(item) => { - AdaptorRecordV1::Rank5( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank6(item) => { - AdaptorRecordV1::Rank6( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank7(item) => { - AdaptorRecordV1::Rank7( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank8(item) => { - AdaptorRecordV1::Rank8( as Record>::from_item(item)) - } - } + } + + fn from_item(item: Self::Item) -> Self { + match item { + AdaptorRecordItemV1::Rank1(item) => { + AdaptorRecordV1::Rank1( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank2(item) => { + AdaptorRecordV1::Rank2( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank3(item) => { + AdaptorRecordV1::Rank3( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank4(item) => { + AdaptorRecordV1::Rank4( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank5(item) => { + AdaptorRecordV1::Rank5( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank6(item) => { + AdaptorRecordV1::Rank6( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank7(item) => { + AdaptorRecordV1::Rank7( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank8(item) => { + AdaptorRecordV1::Rank8( as Record>::from_item(item)) + } } + } } diff --git a/burn-core/src/optim/visitor.rs b/burn-core/src/optim/visitor.rs index 1631fcf74e..54c864fcbb 100644 --- a/burn-core/src/optim/visitor.rs +++ b/burn-core/src/optim/visitor.rs @@ -5,40 +5,42 @@ use core::marker::PhantomData; #[derive(new)] pub struct GradientsParamsConverter<'a, M: AutodiffModule, B: AutodiffBackend> { - grads: B::Gradients, - grads_params: &'a mut GradientsParams, - phatom: PhantomData, + grads: B::Gradients, + grads_params: &'a mut GradientsParams, + phatom: PhantomData, } #[derive(new)] pub struct GradientsParamsChangeDevice<'a, M: AutodiffModule, B: AutodiffBackend> { - device: &'a B::Device, - grads: &'a mut GradientsParams, - phatom: PhantomData, + device: &'a B::Device, + grads: &'a mut GradientsParams, + phatom: PhantomData, } impl<'a, B, M> ModuleVisitor for GradientsParamsConverter<'a, M, B> where - B: AutodiffBackend, - M: AutodiffModule, + B: AutodiffBackend, + M: AutodiffModule, { - fn visit(&mut self, id: &ParamId, tensor: &Tensor) { - if let Some(grad) = tensor.grad_remove(&mut self.grads) { - self.grads_params - .register::(id.clone(), grad); - } + fn visit(&mut self, id: &ParamId, tensor: &Tensor) { + if let Some(grad) = tensor.grad_remove(&mut self.grads) { + self + .grads_params + .register::(id.clone(), grad); } + } } impl<'a, B, M> ModuleVisitor for GradientsParamsChangeDevice<'a, M, B> where - B: AutodiffBackend, - M: AutodiffModule, + B: AutodiffBackend, + M: AutodiffModule, { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { - if let Some(grad) = self.grads.remove::(id) { - self.grads - .register::(id.clone(), grad.to_device(self.device)); - } + fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + if let Some(grad) = self.grads.remove::(id) { + self + .grads + .register::(id.clone(), grad.to_device(self.device)); } + } } diff --git a/burn-core/src/record/base.rs b/burn-core/src/record/base.rs index 0c1633e8ec..522075d44f 100644 --- a/burn-core/src/record/base.rs +++ b/burn-core/src/record/base.rs @@ -5,12 +5,12 @@ use serde::{de::DeserializeOwned, Serialize}; /// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings). pub trait Record: Send + Sync { - /// Type of the item that can be serialized and deserialized. - type Item: Serialize + DeserializeOwned; + /// Type of the item that can be serialized and deserialized. + type Item: Serialize + DeserializeOwned; - /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings). - fn into_item(self) -> Self::Item; + /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings). + fn into_item(self) -> Self::Item; - /// Convert the given item into a record. - fn from_item(item: Self::Item) -> Self; + /// Convert the given item into a record. + fn from_item(item: Self::Item) -> Self; } diff --git a/burn-core/src/record/file.rs b/burn-core/src/record/file.rs index 00c9b0192e..9cbbb16dcc 100644 --- a/burn-core/src/record/file.rs +++ b/burn-core/src/record/file.rs @@ -7,10 +7,10 @@ use std::{fs::File, path::PathBuf}; /// Recorder trait specialized to save and load data to and from files. pub trait FileRecorder: - Recorder + Recorder { - /// File extension of the format used by the recorder. - fn file_extension() -> &'static str; + /// File extension of the format used by the recorder. + fn file_extension() -> &'static str; } /// Default [file recorder](FileRecorder). @@ -19,361 +19,360 @@ pub type DefaultFileRecorder = NamedMpkGzFileRecorder; /// File recorder using the [bincode format](bincode). #[derive(new, Debug, Default, Clone)] pub struct BinFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [bincode format](bincode) compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct BinGzFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [json format](serde_json) compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct JsonGzFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using [pretty json format](serde_json) for easy readability. #[derive(new, Debug, Default, Clone)] pub struct PrettyJsonFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [named msgpack](rmp_serde) format compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct NamedMpkGzFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [named msgpack](rmp_serde) format. #[derive(new, Debug, Default, Clone)] pub struct NamedMpkFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } impl FileRecorder for BinGzFileRecorder { - fn file_extension() -> &'static str { - "bin.gz" - } + fn file_extension() -> &'static str { + "bin.gz" + } } impl FileRecorder for BinFileRecorder { - fn file_extension() -> &'static str { - "bin" - } + fn file_extension() -> &'static str { + "bin" + } } impl FileRecorder for JsonGzFileRecorder { - fn file_extension() -> &'static str { - "json.gz" - } + fn file_extension() -> &'static str { + "json.gz" + } } impl FileRecorder for PrettyJsonFileRecorder { - fn file_extension() -> &'static str { - "json" - } + fn file_extension() -> &'static str { + "json" + } } impl FileRecorder for NamedMpkGzFileRecorder { - fn file_extension() -> &'static str { - "mpk.gz" - } + fn file_extension() -> &'static str { + "mpk.gz" + } } impl FileRecorder for NamedMpkFileRecorder { - fn file_extension() -> &'static str { - "mpk" - } + fn file_extension() -> &'static str { + "mpk" + } } macro_rules! str2reader { - ( + ( $file:expr ) => {{ - $file.set_extension(::file_extension()); - let path = $file.as_path(); - - File::open(path) - .map_err(|err| match err.kind() { - std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), - _ => RecorderError::Unknown(err.to_string()), - }) - .map(|file| BufReader::new(file)) - }}; + $file.set_extension(::file_extension()); + let path = $file.as_path(); + + File::open(path) + .map_err(|err| match err.kind() { + std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), + _ => RecorderError::Unknown(err.to_string()), + }) + .map(|file| BufReader::new(file)) + }}; } macro_rules! str2writer { - ( + ( $file:expr ) => {{ - $file.set_extension(::file_extension()); - let path = $file.as_path(); - - if path.exists() { - log::info!("File exists, replacing"); - std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?; - } - - File::create(path) - .map_err(|err| match err.kind() { - std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), - _ => RecorderError::Unknown(err.to_string()), - }) - .map(|file| BufWriter::new(file)) - }}; -} + $file.set_extension(::file_extension()); + let path = $file.as_path(); -impl Recorder for BinGzFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let config = bin_config(); - let writer = str2writer!(file)?; - let mut writer = GzEncoder::new(writer, Compression::default()); - - bincode::serde::encode_into_std_write(&item, &mut writer, config) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) + if path.exists() { + log::info!("File exists, replacing"); + std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?; } - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let mut reader = GzDecoder::new(reader); - let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; + File::create(path) + .map_err(|err| match err.kind() { + std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), + _ => RecorderError::Unknown(err.to_string()), + }) + .map(|file| BufWriter::new(file)) + }}; +} - Ok(state) - } +impl Recorder for BinGzFileRecorder { + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let config = bin_config(); + let writer = str2writer!(file)?; + let mut writer = GzEncoder::new(writer, Compression::default()); + + bincode::serde::encode_into_std_write(&item, &mut writer, config) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let mut reader = GzDecoder::new(reader); + let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } impl Recorder for BinFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let config = bin_config(); - let mut writer = str2writer!(file)?; - bincode::serde::encode_into_std_write(&item, &mut writer, config) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let mut reader = str2reader!(file)?; - let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let config = bin_config(); + let mut writer = str2writer!(file)?; + bincode::serde::encode_into_std_write(&item, &mut writer, config) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let mut reader = str2reader!(file)?; + let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(state) + } } impl Recorder for JsonGzFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let writer = str2writer!(file)?; - let writer = GzEncoder::new(writer, Compression::default()); - serde_json::to_writer(writer, &item) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let reader = GzDecoder::new(reader); - let state = serde_json::from_reader(reader) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let writer = str2writer!(file)?; + let writer = GzEncoder::new(writer, Compression::default()); + serde_json::to_writer(writer, &item).map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let reader = GzDecoder::new(reader); + let state = + serde_json::from_reader(reader).map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } impl Recorder for PrettyJsonFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let writer = str2writer!(file)?; - serde_json::to_writer_pretty(writer, &item) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let state = serde_json::from_reader(reader) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let writer = str2writer!(file)?; + serde_json::to_writer_pretty(writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let state = + serde_json::from_reader(reader).map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } impl Recorder for NamedMpkGzFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let writer = str2writer!(file)?; - let mut writer = GzEncoder::new(writer, Compression::default()); - rmp_serde::encode::write_named(&mut writer, &item) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let reader = GzDecoder::new(reader); - let state = rmp_serde::decode::from_read(reader) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let writer = str2writer!(file)?; + let mut writer = GzEncoder::new(writer, Compression::default()); + rmp_serde::encode::write_named(&mut writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let reader = GzDecoder::new(reader); + let state = rmp_serde::decode::from_read(reader) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } impl Recorder for NamedMpkFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let mut writer = str2writer!(file)?; - - rmp_serde::encode::write_named(&mut writer, &item) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let state = rmp_serde::decode::from_read(reader) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let mut writer = str2writer!(file)?; + + rmp_serde::encode::write_named(&mut writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let state = rmp_serde::decode::from_read(reader) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } #[cfg(test)] mod tests { - use burn_tensor::backend::Backend; - - use super::*; - use crate::{ - module::Module, - nn::{ - conv::{Conv2d, Conv2dConfig}, - Linear, LinearConfig, - }, - record::{BinBytesRecorder, FullPrecisionSettings}, - TestBackend, - }; - - use crate as burn; - - static FILE_PATH: &str = "/tmp/burn_test_file_recorder"; - - #[test] - fn test_can_save_and_load_jsongz_format() { - test_can_save_and_load(JsonGzFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_bin_format() { - test_can_save_and_load(BinFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_bingz_format() { - test_can_save_and_load(BinGzFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_pretty_json_format() { - test_can_save_and_load(PrettyJsonFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_mpkgz_format() { - test_can_save_and_load(NamedMpkGzFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_mpk_format() { - test_can_save_and_load(NamedMpkFileRecorder::::default()) - } - - fn test_can_save_and_load(recorder: Recorder) { - let model_before = create_model(); - recorder - .record(model_before.clone().into_record(), FILE_PATH.into()) - .unwrap(); - - let model_after = create_model().load_record(recorder.load(FILE_PATH.into()).unwrap()); - - let byte_recorder = BinBytesRecorder::::default(); - let model_bytes_before = byte_recorder - .record(model_before.into_record(), ()) - .unwrap(); - let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap(); - - assert_eq!(model_bytes_after, model_bytes_before); - } - - #[derive(Module, Debug)] - pub struct Model { - conv2d1: Conv2d, - linear1: Linear, - phantom: core::marker::PhantomData, - } - - pub fn create_model() -> Model { - let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(); - - let linear1 = LinearConfig::new(32, 32).with_bias(true).init(); - - Model { - conv2d1, - linear1, - phantom: core::marker::PhantomData, - } + use burn_tensor::backend::Backend; + + use super::*; + use crate::{ + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + Linear, LinearConfig, + }, + record::{BinBytesRecorder, FullPrecisionSettings}, + TestBackend, + }; + + use crate as burn; + + static FILE_PATH: &str = "/tmp/burn_test_file_recorder"; + + #[test] + fn test_can_save_and_load_jsongz_format() { + test_can_save_and_load(JsonGzFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_bin_format() { + test_can_save_and_load(BinFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_bingz_format() { + test_can_save_and_load(BinGzFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_pretty_json_format() { + test_can_save_and_load(PrettyJsonFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_mpkgz_format() { + test_can_save_and_load(NamedMpkGzFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_mpk_format() { + test_can_save_and_load(NamedMpkFileRecorder::::default()) + } + + fn test_can_save_and_load(recorder: Recorder) { + let model_before = create_model(); + recorder + .record(model_before.clone().into_record(), FILE_PATH.into()) + .unwrap(); + + let model_after = create_model().load_record(recorder.load(FILE_PATH.into()).unwrap()); + + let byte_recorder = BinBytesRecorder::::default(); + let model_bytes_before = byte_recorder + .record(model_before.into_record(), ()) + .unwrap(); + let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap(); + + assert_eq!(model_bytes_after, model_bytes_before); + } + + #[derive(Module, Debug)] + pub struct Model { + conv2d1: Conv2d, + linear1: Linear, + phantom: core::marker::PhantomData, + } + + pub fn create_model() -> Model { + let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(); + + let linear1 = LinearConfig::new(32, 32).with_bias(true).init(); + + Model { + conv2d1, + linear1, + phantom: core::marker::PhantomData, } + } } diff --git a/burn-core/src/record/memory.rs b/burn-core/src/record/memory.rs index 545ad87e64..f96f1f8b17 100644 --- a/burn-core/src/record/memory.rs +++ b/burn-core/src/record/memory.rs @@ -9,42 +9,42 @@ use serde::{de::DeserializeOwned, Serialize}; /// This is especially useful in no_std environment where weights are stored directly in /// compiled binaries. pub trait BytesRecorder: - Recorder, LoadArgs = Vec> + Recorder, LoadArgs = Vec> { } /// In memory recorder using the [bincode format](bincode). #[derive(new, Debug, Default, Clone)] pub struct BinBytesRecorder { - _settings: core::marker::PhantomData, + _settings: core::marker::PhantomData, } impl BytesRecorder for BinBytesRecorder {} impl Recorder for BinBytesRecorder { - type Settings = S; - type RecordArgs = (); - type RecordOutput = Vec; - type LoadArgs = Vec; - - fn save_item( - &self, - item: I, - _args: Self::RecordArgs, - ) -> Result { - Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap()) - } - fn load_item(&self, args: Self::LoadArgs) -> Result { - let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap(); - Ok(state) - } + type Settings = S; + type RecordArgs = (); + type RecordOutput = Vec; + type LoadArgs = Vec; + + fn save_item( + &self, + item: I, + _args: Self::RecordArgs, + ) -> Result { + Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap()) + } + fn load_item(&self, args: Self::LoadArgs) -> Result { + let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap(); + Ok(state) + } } #[cfg(feature = "std")] /// In memory recorder using the [Named MessagePack](rmp_serde). #[derive(new, Debug, Default, Clone)] pub struct NamedMpkBytesRecorder { - _settings: core::marker::PhantomData, + _settings: core::marker::PhantomData, } #[cfg(feature = "std")] @@ -52,53 +52,53 @@ impl BytesRecorder for NamedMpkBytesRecorder {} #[cfg(feature = "std")] impl Recorder for NamedMpkBytesRecorder { - type Settings = S; - type RecordArgs = (); - type RecordOutput = Vec; - type LoadArgs = Vec; - - fn save_item( - &self, - item: I, - _args: Self::RecordArgs, - ) -> Result { - rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string())) - } - fn load_item(&self, args: Self::LoadArgs) -> Result { - rmp_serde::decode::from_slice(&args).map_err(|e| RecorderError::Unknown(e.to_string())) - } + type Settings = S; + type RecordArgs = (); + type RecordOutput = Vec; + type LoadArgs = Vec; + + fn save_item( + &self, + item: I, + _args: Self::RecordArgs, + ) -> Result { + rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string())) + } + fn load_item(&self, args: Self::LoadArgs) -> Result { + rmp_serde::decode::from_slice(&args).map_err(|e| RecorderError::Unknown(e.to_string())) + } } #[cfg(test)] mod tests { - use super::*; - use crate::{module::Module, nn, record::FullPrecisionSettings, TestBackend}; - - #[test] - fn test_can_save_and_load_bin_format() { - test_can_save_and_load(BinBytesRecorder::::default()) - } - - #[cfg(feature = "std")] - #[test] - fn test_can_save_and_load_named_mpk_format() { - test_can_save_and_load(NamedMpkBytesRecorder::::default()) - } - - fn test_can_save_and_load(recorder: Recorder) { - let model1 = create_model(); - let model2 = create_model(); - let bytes1 = recorder.record(model1.into_record(), ()).unwrap(); - let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap(); - - let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap()); - let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap(); - - assert_ne!(bytes1, bytes2); - assert_eq!(bytes1, bytes2_after); - } - - pub fn create_model() -> nn::Linear { - nn::LinearConfig::new(32, 32).with_bias(true).init() - } + use super::*; + use crate::{module::Module, nn, record::FullPrecisionSettings, TestBackend}; + + #[test] + fn test_can_save_and_load_bin_format() { + test_can_save_and_load(BinBytesRecorder::::default()) + } + + #[cfg(feature = "std")] + #[test] + fn test_can_save_and_load_named_mpk_format() { + test_can_save_and_load(NamedMpkBytesRecorder::::default()) + } + + fn test_can_save_and_load(recorder: Recorder) { + let model1 = create_model(); + let model2 = create_model(); + let bytes1 = recorder.record(model1.into_record(), ()).unwrap(); + let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap(); + + let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap()); + let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap(); + + assert_ne!(bytes1, bytes2); + assert_eq!(bytes1, bytes2_after); + } + + pub fn create_model() -> nn::Linear { + nn::LinearConfig::new(32, 32).with_bias(true).init() + } } diff --git a/burn-core/src/record/primitive.rs b/burn-core/src/record/primitive.rs index 507635b205..c24cd1046b 100644 --- a/burn-core/src/record/primitive.rs +++ b/burn-core/src/record/primitive.rs @@ -13,123 +13,124 @@ use burn_tensor::{DataSerialize, Element}; use hashbrown::HashMap; impl Record for () { - type Item = (); + type Item = (); - fn into_item(self) -> Self::Item {} + fn into_item(self) -> Self::Item {} - fn from_item(_item: Self::Item) -> Self {} + fn from_item(_item: Self::Item) -> Self {} } impl Record for Vec { - type Item = Vec>; + type Item = Vec>; - fn into_item(self) -> Self::Item { - self.into_iter().map(Record::into_item).collect() - } + fn into_item(self) -> Self::Item { + self.into_iter().map(Record::into_item).collect() + } - fn from_item(item: Self::Item) -> Self { - item.into_iter().map(Record::from_item).collect() - } + fn from_item(item: Self::Item) -> Self { + item.into_iter().map(Record::from_item).collect() + } } impl Record for Option { - type Item = Option>; + type Item = Option>; - fn into_item(self) -> Self::Item { - self.map(Record::into_item) - } + fn into_item(self) -> Self::Item { + self.map(Record::into_item) + } - fn from_item(item: Self::Item) -> Self { - item.map(Record::from_item) - } + fn from_item(item: Self::Item) -> Self { + item.map(Record::from_item) + } } impl Record for [T; N] { - type Item = Vec>; - - fn into_item(self) -> Self::Item { - self.map(Record::into_item).into_iter().collect() - } - - fn from_item(item: Self::Item) -> Self { - item.into_iter() - .map(Record::from_item) - .collect::>() - .try_into() - .unwrap_or_else(|_| panic!("An arrar of size {N}")) - } + type Item = Vec>; + + fn into_item(self) -> Self::Item { + self.map(Record::into_item).into_iter().collect() + } + + fn from_item(item: Self::Item) -> Self { + item + .into_iter() + .map(Record::from_item) + .collect::>() + .try_into() + .unwrap_or_else(|_| panic!("An arrar of size {N}")) + } } impl Record for HashMap { - type Item = HashMap>; - - fn into_item(self) -> Self::Item { - let mut items = HashMap::with_capacity(self.len()); - self.into_iter().for_each(|(id, record)| { - items.insert(id.to_string(), record.into_item()); - }); - items - } - - fn from_item(item: Self::Item) -> Self { - let mut record = HashMap::with_capacity(item.len()); - item.into_iter().for_each(|(id, item)| { - record.insert(ParamId::from(id), T::from_item(item)); - }); - record - } + type Item = HashMap>; + + fn into_item(self) -> Self::Item { + let mut items = HashMap::with_capacity(self.len()); + self.into_iter().for_each(|(id, record)| { + items.insert(id.to_string(), record.into_item()); + }); + items + } + + fn from_item(item: Self::Item) -> Self { + let mut record = HashMap::with_capacity(item.len()); + item.into_iter().for_each(|(id, item)| { + record.insert(ParamId::from(id), T::from_item(item)); + }); + record + } } impl Record for DataSerialize { - type Item = DataSerialize; + type Item = DataSerialize; - fn into_item(self) -> Self::Item { - self.convert() - } + fn into_item(self) -> Self::Item { + self.convert() + } - fn from_item(item: Self::Item) -> Self { - item.convert() - } + fn from_item(item: Self::Item) -> Self { + item.convert() + } } /// (De)serialize parameters into a clean format. #[derive(new, Debug, Clone, Serialize, Deserialize)] pub struct ParamSerde { - id: String, - param: T, + id: String, + param: T, } impl Record for Param> { - type Item = ParamSerde>; - - fn into_item(self) -> Self::Item { - ParamSerde::new(self.id.into_string(), self.value.into_item()) - } - - fn from_item(item: Self::Item) -> Self { - Param::new( - ParamId::from(item.id), - Tensor::from_item(item.param).require_grad(), // Same behavior as when we create a new - // Param from a tensor. - ) - } + type Item = ParamSerde>; + + fn into_item(self) -> Self::Item { + ParamSerde::new(self.id.into_string(), self.value.into_item()) + } + + fn from_item(item: Self::Item) -> Self { + Param::new( + ParamId::from(item.id), + Tensor::from_item(item.param).require_grad(), // Same behavior as when we create a new + // Param from a tensor. + ) + } } // Type that can be serialized as is without any conversion. macro_rules! primitive { - ($type:ty) => { - impl Record for $type { - type Item = $type; - - fn into_item(self) -> Self::Item { - self - } - - fn from_item(item: Self::Item) -> Self { - item - } - } - }; + ($type:ty) => { + impl Record for $type { + type Item = $type; + + fn into_item(self) -> Self::Item { + self + } + + fn from_item(item: Self::Item) -> Self { + item + } + } + }; } // General Types diff --git a/burn-core/src/record/recorder.rs b/burn-core/src/record/recorder.rs index 8c76199ff3..d278881e85 100644 --- a/burn-core/src/record/recorder.rs +++ b/burn-core/src/record/recorder.rs @@ -8,148 +8,148 @@ use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record}; #[cfg(feature = "std")] use super::{ - BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings, - PrettyJsonFileRecorder, + BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings, + PrettyJsonFileRecorder, }; /// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned). pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone { - /// Type of the settings used by the recorder. - type Settings: PrecisionSettings; - - /// Arguments used to record objects. - type RecordArgs: Clone; - - /// Record output type. - type RecordOutput; - - /// Arguments used to load recorded objects. - type LoadArgs: Clone; - - /// Records an item. - /// - /// # Arguments - /// - /// * `record` - The item to record. - /// * `args` - Arguments used to record the item. - /// - /// # Returns - /// - /// The output of the recording. - fn record( - &self, - record: R, - args: Self::RecordArgs, - ) -> Result { - let item = record.into_item::(); - let item = BurnRecord::new::(item); - - self.save_item(item, args) - } - - /// Load an item from the given arguments. - fn load(&self, args: Self::LoadArgs) -> Result { - let item: BurnRecord> = - self.load_item(args.clone()).map_err(|err| { - if let Ok(record) = self.load_item::(args.clone()) { - let mut message = "Unable to load record.".to_string(); - let metadata = recorder_metadata::(); - if metadata.float != record.metadata.float { - message += format!( - "\nMetadata has a different float type: Actual {:?}, Expected {:?}", - record.metadata.float, metadata.float - ) - .as_str(); - } - if metadata.int != record.metadata.int { - message += format!( - "\nMetadata has a different int type: Actual {:?}, Expected {:?}", - record.metadata.int, metadata.int - ) - .as_str(); - } - if metadata.format != record.metadata.format { - message += format!( - "\nMetadata has a different format: Actual {:?}, Expected {:?}", - record.metadata.format, metadata.format - ) - .as_str(); - } - if metadata.version != record.metadata.version { - message += format!( - "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}", - record.metadata.version, metadata.version - ) - .as_str(); - } - - message += format!("\nError: {:?}", err).as_str(); - - return RecorderError::Unknown(message); - } - - err - })?; - - Ok(R::from_item(item.item)) - } + /// Type of the settings used by the recorder. + type Settings: PrecisionSettings; + + /// Arguments used to record objects. + type RecordArgs: Clone; + + /// Record output type. + type RecordOutput; + + /// Arguments used to load recorded objects. + type LoadArgs: Clone; + + /// Records an item. + /// + /// # Arguments + /// + /// * `record` - The item to record. + /// * `args` - Arguments used to record the item. + /// + /// # Returns + /// + /// The output of the recording. + fn record( + &self, + record: R, + args: Self::RecordArgs, + ) -> Result { + let item = record.into_item::(); + let item = BurnRecord::new::(item); + + self.save_item(item, args) + } + + /// Load an item from the given arguments. + fn load(&self, args: Self::LoadArgs) -> Result { + let item: BurnRecord> = + self.load_item(args.clone()).map_err(|err| { + if let Ok(record) = self.load_item::(args.clone()) { + let mut message = "Unable to load record.".to_string(); + let metadata = recorder_metadata::(); + if metadata.float != record.metadata.float { + message += format!( + "\nMetadata has a different float type: Actual {:?}, Expected {:?}", + record.metadata.float, metadata.float + ) + .as_str(); + } + if metadata.int != record.metadata.int { + message += format!( + "\nMetadata has a different int type: Actual {:?}, Expected {:?}", + record.metadata.int, metadata.int + ) + .as_str(); + } + if metadata.format != record.metadata.format { + message += format!( + "\nMetadata has a different format: Actual {:?}, Expected {:?}", + record.metadata.format, metadata.format + ) + .as_str(); + } + if metadata.version != record.metadata.version { + message += format!( + "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}", + record.metadata.version, metadata.version + ) + .as_str(); + } + + message += format!("\nError: {:?}", err).as_str(); + + return RecorderError::Unknown(message); + } - /// Saves an item. - /// - /// This method is used by [record](Recorder::record) to save the item. - /// - /// # Arguments - /// - /// * `item` - Item to save. - /// * `args` - Arguments to use to save the item. - /// - /// # Returns - /// - /// The output of the save operation. - fn save_item( - &self, - item: I, - args: Self::RecordArgs, - ) -> Result; - - /// Loads an item. - /// - /// This method is used by [load](Recorder::load) to load the item. - /// - /// # Arguments - /// - /// * `args` - Arguments to use to load the item. - /// - /// # Returns - /// - /// The loaded item. - fn load_item(&self, args: Self::LoadArgs) -> Result; + err + })?; + + Ok(R::from_item(item.item)) + } + + /// Saves an item. + /// + /// This method is used by [record](Recorder::record) to save the item. + /// + /// # Arguments + /// + /// * `item` - Item to save. + /// * `args` - Arguments to use to save the item. + /// + /// # Returns + /// + /// The output of the save operation. + fn save_item( + &self, + item: I, + args: Self::RecordArgs, + ) -> Result; + + /// Loads an item. + /// + /// This method is used by [load](Recorder::load) to load the item. + /// + /// # Arguments + /// + /// * `args` - Arguments to use to load the item. + /// + /// # Returns + /// + /// The loaded item. + fn load_item(&self, args: Self::LoadArgs) -> Result; } fn recorder_metadata() -> BurnMetadata { - BurnMetadata::new( - type_name::<::FloatElem>().to_string(), - type_name::<::IntElem>().to_string(), - type_name::().to_string(), - env!("CARGO_PKG_VERSION").to_string(), - format!("{:?}", R::Settings::default()), - ) + BurnMetadata::new( + type_name::<::FloatElem>().to_string(), + type_name::<::IntElem>().to_string(), + type_name::().to_string(), + env!("CARGO_PKG_VERSION").to_string(), + format!("{:?}", R::Settings::default()), + ) } /// Error that can occur when using a [Recorder](Recorder). #[derive(Debug)] pub enum RecorderError { - /// File not found. - FileNotFound(String), + /// File not found. + FileNotFound(String), - /// Other error. - Unknown(String), + /// Other error. + Unknown(String), } impl core::fmt::Display for RecorderError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("{self:?}").as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(format!("{self:?}").as_str()) + } } // TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765) @@ -157,60 +157,60 @@ impl core::fmt::Display for RecorderError { impl std::error::Error for RecorderError {} pub(crate) fn bin_config() -> bincode::config::Configuration { - bincode::config::standard() + bincode::config::standard() } /// Metadata of a record. #[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct BurnMetadata { - /// Float type used to record the item. - pub float: String, + /// Float type used to record the item. + pub float: String, - /// Int type used to record the item. - pub int: String, + /// Int type used to record the item. + pub int: String, - /// Format used to record the item. - pub format: String, + /// Format used to record the item. + pub format: String, - /// Burn record version used to record the item. - pub version: String, + /// Burn record version used to record the item. + pub version: String, - /// Settings used to record the item. - pub settings: String, + /// Settings used to record the item. + pub settings: String, } /// Record that can be saved by a [Recorder](Recorder). #[derive(Serialize, Deserialize, Debug)] pub struct BurnRecord { - /// Metadata of the record. - pub metadata: BurnMetadata, + /// Metadata of the record. + pub metadata: BurnMetadata, - /// Item to record. - pub item: I, + /// Item to record. + pub item: I, } impl BurnRecord { - /// Creates a new record. - /// - /// # Arguments - /// - /// * `item` - Item to record. - /// - /// # Returns - /// - /// The new record. - pub fn new(item: I) -> Self { - let metadata = recorder_metadata::(); - - Self { metadata, item } - } + /// Creates a new record. + /// + /// # Arguments + /// + /// * `item` - Item to record. + /// + /// # Returns + /// + /// The new record. + pub fn new(item: I) -> Self { + let metadata = recorder_metadata::(); + + Self { metadata, item } + } } /// Record that can be saved by a [Recorder](Recorder) without the item. #[derive(new, Debug, Serialize, Deserialize)] pub struct BurnRecordNoItem { - /// Metadata of the record. - pub metadata: BurnMetadata, + /// Metadata of the record. + pub metadata: BurnMetadata, } /// Default recorder. @@ -252,45 +252,45 @@ pub type DebugRecordSettings = PrettyJsonFileRecorder; #[cfg(all(test, feature = "std"))] mod tests { - static FILE_PATH: &str = "/tmp/burn_test_record"; + static FILE_PATH: &str = "/tmp/burn_test_record"; - use super::*; - use burn_tensor::ElementConversion; + use super::*; + use burn_tensor::ElementConversion; - #[test] - #[should_panic] - fn err_when_invalid_item() { - #[derive(new, Serialize, Deserialize)] - struct Item { - value: S::FloatElem, - } + #[test] + #[should_panic] + fn err_when_invalid_item() { + #[derive(new, Serialize, Deserialize)] + struct Item { + value: S::FloatElem, + } - impl Record for Item { - type Item = Item; + impl Record for Item { + type Item = Item; - fn into_item(self) -> Self::Item { - Item { - value: self.value.elem(), - } - } + fn into_item(self) -> Self::Item { + Item { + value: self.value.elem(), + } + } - fn from_item(item: Self::Item) -> Self { - Item { - value: item.value.elem(), - } - } + fn from_item(item: Self::Item) -> Self { + Item { + value: item.value.elem(), } + } + } - let item = Item::::new(16.elem()); + let item = Item::::new(16.elem()); - // Serialize in f32. - let recorder = DefaultFileRecorder::::new(); - recorder.record(item, FILE_PATH.into()).unwrap(); + // Serialize in f32. + let recorder = DefaultFileRecorder::::new(); + recorder.record(item, FILE_PATH.into()).unwrap(); - // Can't deserialize f32 into f16. - let recorder = DefaultFileRecorder::::new(); - recorder - .load::>(FILE_PATH.into()) - .unwrap(); - } + // Can't deserialize f32 into f16. + let recorder = DefaultFileRecorder::::new(); + recorder + .load::>(FILE_PATH.into()) + .unwrap(); + } } diff --git a/burn-core/src/record/settings.rs b/burn-core/src/record/settings.rs index a59c6ec331..202a5fb183 100644 --- a/burn-core/src/record/settings.rs +++ b/burn-core/src/record/settings.rs @@ -3,13 +3,13 @@ use serde::{de::DeserializeOwned, Serialize}; /// Settings allowing to control the precision when (de)serializing items. pub trait PrecisionSettings: - Send + Sync + core::fmt::Debug + core::default::Default + Clone + Send + Sync + core::fmt::Debug + core::default::Default + Clone { - /// Float element type. - type FloatElem: Element + Serialize + DeserializeOwned; + /// Float element type. + type FloatElem: Element + Serialize + DeserializeOwned; - /// Integer element type. - type IntElem: Element + Serialize + DeserializeOwned; + /// Integer element type. + type IntElem: Element + Serialize + DeserializeOwned; } /// Default precision settings. @@ -25,16 +25,16 @@ pub struct HalfPrecisionSettings; pub struct DoublePrecisionSettings; impl PrecisionSettings for FullPrecisionSettings { - type FloatElem = f32; - type IntElem = f32; + type FloatElem = f32; + type IntElem = f32; } impl PrecisionSettings for DoublePrecisionSettings { - type FloatElem = f64; - type IntElem = i64; + type FloatElem = f64; + type IntElem = i64; } impl PrecisionSettings for HalfPrecisionSettings { - type FloatElem = half::f16; - type IntElem = i16; + type FloatElem = half::f16; + type IntElem = i16; } diff --git a/burn-core/src/record/tensor.rs b/burn-core/src/record/tensor.rs index 70badf2169..e60897fa68 100644 --- a/burn-core/src/record/tensor.rs +++ b/burn-core/src/record/tensor.rs @@ -6,129 +6,129 @@ use serde::{Deserialize, Serialize}; /// using the given [record settings](RecordSettings). #[derive(new, Clone, Debug)] pub struct FloatTensorSerde { - data: DataSerialize, + data: DataSerialize, } /// This struct implements serde to lazily serialize and deserialize an int tensor /// using the given [record settings](RecordSettings). #[derive(new, Clone, Debug)] pub struct IntTensorSerde { - data: DataSerialize, + data: DataSerialize, } /// This struct implements serde to lazily serialize and deserialize an bool tensor. #[derive(new, Clone, Debug)] pub struct BoolTensorSerde { - data: DataSerialize, + data: DataSerialize, } // --- SERDE IMPLEMENTATIONS --- // impl Serialize for FloatTensorSerde { - fn serialize(&self, serializer: Se) -> Result - where - Se: serde::Serializer, - { - self.data.serialize(serializer) - } + fn serialize(&self, serializer: Se) -> Result + where + Se: serde::Serializer, + { + self.data.serialize(serializer) + } } impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde { - fn deserialize(deserializer: De) -> Result - where - De: serde::Deserializer<'de>, - { - let data = DataSerialize::::deserialize(deserializer)?; - - Ok(Self::new(data)) - } + fn deserialize(deserializer: De) -> Result + where + De: serde::Deserializer<'de>, + { + let data = DataSerialize::::deserialize(deserializer)?; + + Ok(Self::new(data)) + } } impl Serialize for IntTensorSerde { - fn serialize(&self, serializer: Se) -> Result - where - Se: serde::Serializer, - { - self.data.serialize(serializer) - } + fn serialize(&self, serializer: Se) -> Result + where + Se: serde::Serializer, + { + self.data.serialize(serializer) + } } impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde { - fn deserialize(deserializer: De) -> Result - where - De: serde::Deserializer<'de>, - { - let data = DataSerialize::::deserialize(deserializer)?; - Ok(Self::new(data)) - } + fn deserialize(deserializer: De) -> Result + where + De: serde::Deserializer<'de>, + { + let data = DataSerialize::::deserialize(deserializer)?; + Ok(Self::new(data)) + } } impl Serialize for BoolTensorSerde { - fn serialize(&self, serializer: Se) -> Result - where - Se: serde::Serializer, - { - self.data.serialize(serializer) - } + fn serialize(&self, serializer: Se) -> Result + where + Se: serde::Serializer, + { + self.data.serialize(serializer) + } } impl<'de> Deserialize<'de> for BoolTensorSerde { - fn deserialize(deserializer: De) -> Result - where - De: serde::Deserializer<'de>, - { - let data = DataSerialize::::deserialize(deserializer)?; - - Ok(Self::new(data)) - } + fn deserialize(deserializer: De) -> Result + where + De: serde::Deserializer<'de>, + { + let data = DataSerialize::::deserialize(deserializer)?; + + Ok(Self::new(data)) + } } // --- RECORD IMPLEMENTATIONS --- // impl Record for Tensor { - type Item = FloatTensorSerde; + type Item = FloatTensorSerde; - fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording float tensors isn't yet supported on wasm."); + fn into_item(self) -> Self::Item { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + todo!("Recording float tensors isn't yet supported on wasm."); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - FloatTensorSerde::new(self.into_data().convert().serialize()) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + FloatTensorSerde::new(self.into_data().convert().serialize()) + } - fn from_item(item: Self::Item) -> Self { - Tensor::from_data(item.data.convert::()) - } + fn from_item(item: Self::Item) -> Self { + Tensor::from_data(item.data.convert::()) + } } impl Record for Tensor { - type Item = IntTensorSerde; + type Item = IntTensorSerde; - fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording int tensors isn't yet supported on wasm."); + fn into_item(self) -> Self::Item { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + todo!("Recording int tensors isn't yet supported on wasm."); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - IntTensorSerde::new(self.into_data().convert().serialize()) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + IntTensorSerde::new(self.into_data().convert().serialize()) + } - fn from_item(item: Self::Item) -> Self { - Tensor::from_data(item.data.convert()) - } + fn from_item(item: Self::Item) -> Self { + Tensor::from_data(item.data.convert()) + } } impl Record for Tensor { - type Item = BoolTensorSerde; + type Item = BoolTensorSerde; - fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording bool tensors isn't yet supported on wasm."); + fn into_item(self) -> Self::Item { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + todo!("Recording bool tensors isn't yet supported on wasm."); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - BoolTensorSerde::new(self.into_data().serialize()) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + BoolTensorSerde::new(self.into_data().serialize()) + } - fn from_item(item: Self::Item) -> Self { - Tensor::from_data(item.data) - } + fn from_item(item: Self::Item) -> Self { + Tensor::from_data(item.data) + } } diff --git a/burn-core/tests/derive_config.rs b/burn-core/tests/derive_config.rs index dec636ee91..227d85f336 100644 --- a/burn-core/tests/derive_config.rs +++ b/burn-core/tests/derive_config.rs @@ -6,102 +6,102 @@ pub struct TestEmptyStructConfig {} #[derive(Config, Debug, PartialEq)] pub struct TestStructConfig { - int: i32, - #[config(default = 2)] - int_default: i32, - float: f32, - #[config(default = 2.0)] - float_default: f32, - string: String, - other_config: TestEmptyStructConfig, + int: i32, + #[config(default = 2)] + int_default: i32, + float: f32, + #[config(default = 2.0)] + float_default: f32, + string: String, + other_config: TestEmptyStructConfig, } #[derive(Config, Debug, PartialEq)] pub enum TestEnumConfig { - None, - Single(f32), - Multiple(f32, String), - Named { first: f32, second: String }, + None, + Single(f32), + Multiple(f32, String), + Named { first: f32, second: String }, } #[cfg(feature = "std")] #[test] fn struct_config_should_impl_serde() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - let file_path = "/tmp/test_struct_config.json"; + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + let file_path = "/tmp/test_struct_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestStructConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestStructConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[test] fn struct_config_should_impl_clone() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - assert_eq!(config, config.clone()); + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + assert_eq!(config, config.clone()); } #[test] fn struct_config_should_impl_display() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - assert_eq!(burn::config::config_to_json(&config), config.to_string()); + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + assert_eq!(burn::config::config_to_json(&config), config.to_string()); } #[cfg(feature = "std")] #[test] fn enum_config_no_value_should_impl_serde() { - let config = TestEnumConfig::None; - let file_path = "/tmp/test_enum_no_value_config.json"; + let config = TestEnumConfig::None; + let file_path = "/tmp/test_enum_no_value_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestEnumConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestEnumConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[cfg(feature = "std")] #[test] fn enum_config_one_value_should_impl_serde() { - let config = TestEnumConfig::Single(42.0); - let file_path = "/tmp/test_enum_one_value_config.json"; + let config = TestEnumConfig::Single(42.0); + let file_path = "/tmp/test_enum_one_value_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestEnumConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestEnumConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[cfg(feature = "std")] #[test] fn enum_config_multiple_values_should_impl_serde() { - let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); - let file_path = "/tmp/test_enum_multiple_values_config.json"; + let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); + let file_path = "/tmp/test_enum_multiple_values_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestEnumConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestEnumConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[test] fn enum_config_should_impl_clone() { - let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); - assert_eq!(config, config.clone()); + let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); + assert_eq!(config, config.clone()); } #[test] fn enum_config_should_impl_display() { - let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); - assert_eq!(burn::config::config_to_json(&config), config.to_string()); + let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); + assert_eq!(burn::config::config_to_json(&config), config.to_string()); } #[test] fn struct_config_can_load_binary() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - let binary = config_to_json(&config).as_bytes().to_vec(); + let binary = config_to_json(&config).as_bytes().to_vec(); - let config_loaded = TestStructConfig::load_binary(&binary).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestStructConfig::load_binary(&binary).unwrap(); + assert_eq!(config, config_loaded); } diff --git a/burn-core/tests/derive_module.rs b/burn-core/tests/derive_module.rs index 87beafc422..427bd668b3 100644 --- a/burn-core/tests/derive_module.rs +++ b/burn-core/tests/derive_module.rs @@ -9,139 +9,139 @@ pub type TestAutodiffBackend = burn_autodiff::Autodiff; #[derive(Module, Debug)] pub struct ModuleBasic { - weight_basic: Param>, + weight_basic: Param>, } impl ModuleBasic { - fn new() -> Self { - let weight_basic = Tensor::random(Shape::new([20, 20]), Distribution::Default); - Self { - weight_basic: Param::from(weight_basic), - } + fn new() -> Self { + let weight_basic = Tensor::random(Shape::new([20, 20]), Distribution::Default); + Self { + weight_basic: Param::from(weight_basic), } + } } #[derive(Module, Debug)] pub struct ModuleComposed { - weight: Param>, - basic: ModuleBasic, + weight: Param>, + basic: ModuleBasic, } impl ModuleComposed { - fn new() -> Self { - let weight = Tensor::random(Shape::new([20, 20]), Distribution::Default); - Self { - weight: Param::from(weight), - basic: ModuleBasic::new(), - } + fn new() -> Self { + let weight = Tensor::random(Shape::new([20, 20]), Distribution::Default); + Self { + weight: Param::from(weight), + basic: ModuleBasic::new(), } + } } mod state { - use super::*; - - #[test] - fn should_load_from_record_basic() { - let module_1 = ModuleBasic::::new(); - let mut module_2 = ModuleBasic::::new(); - let state_1 = module_1.clone().into_record(); - - assert_ne!( - module_1.weight_basic.to_data(), - module_2.weight_basic.to_data() - ); - - module_2 = module_2.load_record(state_1); - - assert_eq!( - module_1.weight_basic.to_data(), - module_2.weight_basic.to_data() - ); - } - - #[test] - fn should_load_from_record_compose() { - let module_1 = ModuleComposed::::new(); - let mut module_2 = ModuleComposed::::new(); - assert_ne!(module_1.weight.to_data(), module_2.weight.to_data()); - assert_ne!( - module_1.basic.weight_basic.to_data(), - module_2.basic.weight_basic.to_data() - ); - - let state_1 = module_1.clone().into_record(); - module_2 = module_2.load_record(state_1); - - assert_eq!(module_1.weight.to_data(), module_2.weight.to_data()); - assert_eq!( - module_1.basic.weight_basic.to_data(), - module_2.basic.weight_basic.to_data() - ); - } + use super::*; + + #[test] + fn should_load_from_record_basic() { + let module_1 = ModuleBasic::::new(); + let mut module_2 = ModuleBasic::::new(); + let state_1 = module_1.clone().into_record(); + + assert_ne!( + module_1.weight_basic.to_data(), + module_2.weight_basic.to_data() + ); + + module_2 = module_2.load_record(state_1); + + assert_eq!( + module_1.weight_basic.to_data(), + module_2.weight_basic.to_data() + ); + } + + #[test] + fn should_load_from_record_compose() { + let module_1 = ModuleComposed::::new(); + let mut module_2 = ModuleComposed::::new(); + assert_ne!(module_1.weight.to_data(), module_2.weight.to_data()); + assert_ne!( + module_1.basic.weight_basic.to_data(), + module_2.basic.weight_basic.to_data() + ); + + let state_1 = module_1.clone().into_record(); + module_2 = module_2.load_record(state_1); + + assert_eq!(module_1.weight.to_data(), module_2.weight.to_data()); + assert_eq!( + module_1.basic.weight_basic.to_data(), + module_2.basic.weight_basic.to_data() + ); + } } mod num_params { - use super::*; - - #[test] - fn should_calculate_num_params_basic() { - let module = ModuleBasic::::new(); - assert_eq!(20 * 20, module.num_params()); - } - - #[test] - fn should_output_state_composed() { - let module = ModuleComposed::::new(); - assert_eq!(2 * 20 * 20, module.num_params()); - } + use super::*; + + #[test] + fn should_calculate_num_params_basic() { + let module = ModuleBasic::::new(); + assert_eq!(20 * 20, module.num_params()); + } + + #[test] + fn should_output_state_composed() { + let module = ModuleComposed::::new(); + assert_eq!(2 * 20 * 20, module.num_params()); + } } #[cfg(feature = "std")] mod require_grad { - use burn_tensor::backend::AutodiffBackend; + use burn_tensor::backend::AutodiffBackend; - use super::*; + use super::*; - #[test] - fn should_have_grad_by_default() { - let module = ModuleBasic::::new(); - let mut grads = calculate_grads(&module); + #[test] + fn should_have_grad_by_default() { + let module = ModuleBasic::::new(); + let mut grads = calculate_grads(&module); - let grad_x = module.weight_basic.grad_remove(&mut grads); + let grad_x = module.weight_basic.grad_remove(&mut grads); - assert!(grad_x.is_some()); - } + assert!(grad_x.is_some()); + } - #[test] - fn should_have_no_grad_after_no_grad() { - let module = ModuleBasic::::new().no_grad(); - let mut grads = calculate_grads(&module); + #[test] + fn should_have_no_grad_after_no_grad() { + let module = ModuleBasic::::new().no_grad(); + let mut grads = calculate_grads(&module); - let grad_x = module.weight_basic.grad_remove(&mut grads); + let grad_x = module.weight_basic.grad_remove(&mut grads); - assert!(grad_x.is_none()); - } + assert!(grad_x.is_none()); + } - #[test] - fn should_have_grad_when_from_record() { - let module = ModuleBasic::::new(); - let record = ModuleBasicRecord { - weight_basic: module.weight_basic.clone(), // Even when param is no_grad, - }; - let module = module.load_record(record); - let mut grads = calculate_grads(&module); + #[test] + fn should_have_grad_when_from_record() { + let module = ModuleBasic::::new(); + let record = ModuleBasicRecord { + weight_basic: module.weight_basic.clone(), // Even when param is no_grad, + }; + let module = module.load_record(record); + let mut grads = calculate_grads(&module); - let grad_x = module.weight_basic.grad_remove(&mut grads); + let grad_x = module.weight_basic.grad_remove(&mut grads); - assert!(grad_x.is_some()); - } + assert!(grad_x.is_some()); + } - fn calculate_grads( - module: &ModuleBasic, - ) -> ::Gradients { - let x = Tensor::ones([20, 20]).require_grad(); - let y = module.weight_basic.val().matmul(x); + fn calculate_grads( + module: &ModuleBasic, + ) -> ::Gradients { + let x = Tensor::ones([20, 20]).require_grad(); + let y = module.weight_basic.val().matmul(x); - y.backward() - } + y.backward() + } } diff --git a/burn-core/tests/derive_record.rs b/burn-core/tests/derive_record.rs index c0a6653731..d11bd58181 100644 --- a/burn-core/tests/derive_record.rs +++ b/burn-core/tests/derive_record.rs @@ -7,11 +7,11 @@ use burn_tensor::Tensor; // It compiles #[derive(Record)] pub struct TestWithBackendRecord { - tensor: Tensor, + tensor: Tensor, } // It compiles #[derive(Record)] pub struct TestWithoutBackendRecord { - tensor: usize, + tensor: usize, } diff --git a/burn-core/tests/record_resilience.rs b/burn-core/tests/record_resilience.rs index 9bf4d651cb..b021dc5a72 100644 --- a/burn-core/tests/record_resilience.rs +++ b/burn-core/tests/record_resilience.rs @@ -1,298 +1,290 @@ #[cfg(feature = "std")] mod tests { - use burn::{ - module::Module, - nn, - record::{ - BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings, - PrettyJsonFileRecorder, RecorderError, - }, + use burn::{ + module::Module, + nn, + record::{ + BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings, + PrettyJsonFileRecorder, RecorderError, + }, + }; + use burn_core as burn; + use burn_tensor::backend::Backend; + use std::path::PathBuf; + + type TestBackend = burn_ndarray::NdArray; + + #[derive(Module, Debug)] + pub struct Model { + single_const: f32, + linear1: nn::Linear, + array_const: [usize; 2], + linear2: nn::Linear, + } + + #[derive(Module, Debug)] + pub struct ModelNewOptionalField { + single_const: f32, + linear1: nn::Linear, + array_const: [usize; 2], + linear2: nn::Linear, + new_field: Option, + } + + #[derive(Module, Debug)] + pub struct ModelNewConstantField { + single_const: f32, + linear1: nn::Linear, + array_const: [usize; 2], + linear2: nn::Linear, + new_field: usize, + } + + #[derive(Module, Debug)] + pub struct ModelNewFieldOrders { + array_const: [usize; 2], + linear2: nn::Linear, + single_const: f32, + linear1: nn::Linear, + } + + #[test] + fn deserialize_with_new_optional_field_works_with_default_file_recorder() { + deserialize_with_new_optional_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_optional_field_works_with_default_file_recorder() { + deserialize_with_removed_optional_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_constant_field_works_with_default_file_recorder() { + deserialize_with_new_constant_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_constant_field_works_with_default_file_recorder() { + deserialize_with_removed_constant_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_field_order_works_with_default_file_recorder() { + deserialize_with_new_field_order( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + #[test] + fn deserialize_with_new_optional_field_works_with_pretty_json() { + deserialize_with_new_optional_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_optional_field_works_with_pretty_json() { + deserialize_with_removed_optional_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_constant_field_works_with_pretty_json() { + deserialize_with_new_constant_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_constant_field_works_with_pretty_json() { + deserialize_with_removed_constant_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_field_order_works_with_pretty_json() { + deserialize_with_new_field_order( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + #[should_panic] + fn deserialize_with_new_optional_field_doesnt_works_with_bin_file_recorder() { + deserialize_with_new_optional_field("bin", BinFileRecorder::::new()) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() { + deserialize_with_removed_optional_field("bin", BinFileRecorder::::new()) + .unwrap(); + } + + #[test] + fn deserialize_with_new_constant_field_works_with_bin_file_recorder() { + deserialize_with_new_constant_field("bin", BinFileRecorder::::new()) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() { + deserialize_with_removed_constant_field("bin", BinFileRecorder::::new()) + .unwrap(); + } + + #[test] + #[should_panic] + fn deserialize_with_new_field_order_works_with_bin_file_recorder() { + deserialize_with_new_field_order("bin", BinFileRecorder::::new()) + .unwrap(); + } + + fn deserialize_with_new_optional_field(name: &str, recorder: R) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_new_optional_field-{name}").into(); + let model = Model { + single_const: 32.0, + linear1: nn::LinearConfig::new(20, 20).init::(), + array_const: [2, 2], + linear2: nn::LinearConfig::new(20, 20).init::(), }; - use burn_core as burn; - use burn_tensor::backend::Backend; - use std::path::PathBuf; - - type TestBackend = burn_ndarray::NdArray; - - #[derive(Module, Debug)] - pub struct Model { - single_const: f32, - linear1: nn::Linear, - array_const: [usize; 2], - linear2: nn::Linear, - } - - #[derive(Module, Debug)] - pub struct ModelNewOptionalField { - single_const: f32, - linear1: nn::Linear, - array_const: [usize; 2], - linear2: nn::Linear, - new_field: Option, - } - - #[derive(Module, Debug)] - pub struct ModelNewConstantField { - single_const: f32, - linear1: nn::Linear, - array_const: [usize; 2], - linear2: nn::Linear, - new_field: usize, - } - - #[derive(Module, Debug)] - pub struct ModelNewFieldOrders { - array_const: [usize; 2], - linear2: nn::Linear, - single_const: f32, - linear1: nn::Linear, - } - - #[test] - fn deserialize_with_new_optional_field_works_with_default_file_recorder() { - deserialize_with_new_optional_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_optional_field_works_with_default_file_recorder() { - deserialize_with_removed_optional_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_constant_field_works_with_default_file_recorder() { - deserialize_with_new_constant_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_constant_field_works_with_default_file_recorder() { - deserialize_with_removed_constant_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_field_order_works_with_default_file_recorder() { - deserialize_with_new_field_order( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - #[test] - fn deserialize_with_new_optional_field_works_with_pretty_json() { - deserialize_with_new_optional_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_optional_field_works_with_pretty_json() { - deserialize_with_removed_optional_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_constant_field_works_with_pretty_json() { - deserialize_with_new_constant_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_constant_field_works_with_pretty_json() { - deserialize_with_removed_constant_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_field_order_works_with_pretty_json() { - deserialize_with_new_field_order( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - #[should_panic] - fn deserialize_with_new_optional_field_doesnt_works_with_bin_file_recorder() { - deserialize_with_new_optional_field("bin", BinFileRecorder::::new()) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() { - deserialize_with_removed_optional_field( - "bin", - BinFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_constant_field_works_with_bin_file_recorder() { - deserialize_with_new_constant_field("bin", BinFileRecorder::::new()) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() { - deserialize_with_removed_constant_field( - "bin", - BinFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - #[should_panic] - fn deserialize_with_new_field_order_works_with_bin_file_recorder() { - deserialize_with_new_field_order("bin", BinFileRecorder::::new()) - .unwrap(); - } - - fn deserialize_with_new_optional_field(name: &str, recorder: R) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_new_optional_field-{name}").into(); - let model = Model { - single_const: 32.0, - linear1: nn::LinearConfig::new(20, 20).init::(), - array_const: [2, 2], - linear2: nn::LinearConfig::new(20, 20).init::(), - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_removed_optional_field( - name: &str, - recorder: R, - ) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = - format!("/tmp/deserialize_with_removed_optional_field-{name}").into(); - let model = ModelNewOptionalField { - single_const: 32.0, - linear1: nn::LinearConfig::new(20, 20).init::(), - array_const: [2, 2], - linear2: nn::LinearConfig::new(20, 20).init::(), - new_field: None, - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_new_constant_field(name: &str, recorder: R) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_new_constant_field-{name}").into(); - let model = Model { - single_const: 32.0, - array_const: [2, 2], - linear1: nn::LinearConfig::new(20, 20).init::(), - linear2: nn::LinearConfig::new(20, 20).init::(), - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_removed_constant_field( - name: &str, - recorder: R, - ) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = - format!("/tmp/deserialize_with_removed_constant_field-{name}").into(); - let model = ModelNewConstantField { - single_const: 32.0, - array_const: [2, 2], - linear1: nn::LinearConfig::new(20, 20).init::(), - linear2: nn::LinearConfig::new(20, 20).init::(), - new_field: 0, - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_new_field_order(name: &str, recorder: R) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_new_field_order-{name}").into(); - let model = Model { - array_const: [2, 2], - single_const: 32.0, - linear1: nn::LinearConfig::new(20, 20).init::(), - linear2: nn::LinearConfig::new(20, 20).init::(), - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_removed_optional_field( + name: &str, + recorder: R, + ) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_removed_optional_field-{name}").into(); + let model = ModelNewOptionalField { + single_const: 32.0, + linear1: nn::LinearConfig::new(20, 20).init::(), + array_const: [2, 2], + linear2: nn::LinearConfig::new(20, 20).init::(), + new_field: None, + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_new_constant_field(name: &str, recorder: R) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_new_constant_field-{name}").into(); + let model = Model { + single_const: 32.0, + array_const: [2, 2], + linear1: nn::LinearConfig::new(20, 20).init::(), + linear2: nn::LinearConfig::new(20, 20).init::(), + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_removed_constant_field( + name: &str, + recorder: R, + ) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_removed_constant_field-{name}").into(); + let model = ModelNewConstantField { + single_const: 32.0, + array_const: [2, 2], + linear1: nn::LinearConfig::new(20, 20).init::(), + linear2: nn::LinearConfig::new(20, 20).init::(), + new_field: 0, + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_new_field_order(name: &str, recorder: R) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_new_field_order-{name}").into(); + let model = Model { + array_const: [2, 2], + single_const: 32.0, + linear1: nn::LinearConfig::new(20, 20).init::(), + linear2: nn::LinearConfig::new(20, 20).init::(), + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } } diff --git a/burn-dataset/examples/speech_commands.rs b/burn-dataset/examples/speech_commands.rs index cce7f131e1..5b4ff7791d 100644 --- a/burn-dataset/examples/speech_commands.rs +++ b/burn-dataset/examples/speech_commands.rs @@ -3,21 +3,21 @@ use burn_dataset::{audio::SpeechCommandsDataset, Dataset}; #[cfg(feature = "audio")] fn speech_command() { - let index: usize = 4835; - let test = SpeechCommandsDataset::test(); - let item = test.get(index).unwrap(); + let index: usize = 4835; + let test = SpeechCommandsDataset::test(); + let item = test.get(index).unwrap(); - println!("Item: {:?}", item); - println!("Item Length: {:?}", item.audio_samples.len()); - println!("Label: {}", item.label.to_string()); + println!("Item: {:?}", item); + println!("Item Length: {:?}", item.audio_samples.len()); + println!("Label: {}", item.label.to_string()); - assert_eq!(test.len(), 4890); - assert_eq!(item.label.to_string(), "Yes"); - assert_eq!(item.sample_rate, 16000); - assert_eq!(item.audio_samples.len(), 16000); + assert_eq!(test.len(), 4890); + assert_eq!(item.label.to_string(), "Yes"); + assert_eq!(item.sample_rate, 16000); + assert_eq!(item.audio_samples.len(), 16000); } fn main() { - #[cfg(feature = "audio")] - speech_command() + #[cfg(feature = "audio")] + speech_command() } diff --git a/burn-dataset/src/audio/speech_commands.rs b/burn-dataset/src/audio/speech_commands.rs index 28c2d34f20..401f8965ac 100644 --- a/burn-dataset/src/audio/speech_commands.rs +++ b/burn-dataset/src/audio/speech_commands.rs @@ -1,6 +1,6 @@ use crate::{ - transform::{Mapper, MapperDataset}, - Dataset, HuggingfaceDatasetLoader, SqliteDataset, + transform::{Mapper, MapperDataset}, + Dataset, HuggingfaceDatasetLoader, SqliteDataset, }; use hound::WavReader; @@ -17,65 +17,65 @@ type MappedDataset = MapperDataset, ConvertSamples, #[allow(missing_docs)] #[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)] pub enum SpeechCommandClass { - // Target command words - Yes = 0, - No = 1, - Up = 2, - Down = 3, - Left = 4, - Right = 5, - On = 6, - Off = 7, - Stop = 8, - Go = 9, - Zero = 10, - One = 11, - Two = 12, - Three = 13, - Four = 14, - Five = 15, - Six = 16, - Seven = 17, - Eight = 18, - Nine = 19, - - // Non-target words that can be grouped into "Other" - Bed = 20, - Bird = 21, - Cat = 22, - Dog = 23, - Happy = 24, - House = 25, - Marvin = 26, - Sheila = 27, - Tree = 28, - Wow = 29, - - // Commands from v2 dataset, that can be grouped into "Other" - Backward = 30, - Forward = 31, - Follow = 32, - Learn = 33, - Visual = 34, - - // Background noise - Silence = 35, - - // Other miscellaneous words - Other = 36, + // Target command words + Yes = 0, + No = 1, + Up = 2, + Down = 3, + Left = 4, + Right = 5, + On = 6, + Off = 7, + Stop = 8, + Go = 9, + Zero = 10, + One = 11, + Two = 12, + Three = 13, + Four = 14, + Five = 15, + Six = 16, + Seven = 17, + Eight = 18, + Nine = 19, + + // Non-target words that can be grouped into "Other" + Bed = 20, + Bird = 21, + Cat = 22, + Dog = 23, + Happy = 24, + House = 25, + Marvin = 26, + Sheila = 27, + Tree = 28, + Wow = 29, + + // Commands from v2 dataset, that can be grouped into "Other" + Backward = 30, + Forward = 31, + Follow = 32, + Learn = 33, + Visual = 34, + + // Background noise + Silence = 35, + + // Other miscellaneous words + Other = 36, } /// Struct containing raw speech data returned from a database. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SpeechItemRaw { - /// Audio file bytes. - pub audio_bytes: Vec, + /// Audio file bytes. + pub audio_bytes: Vec, - /// Label index. - pub label: usize, + /// Label index. + pub label: usize, - /// Indicates if the label is unknown. - pub is_unknown: bool, + /// Indicates if the label is unknown. + pub is_unknown: bool, } /// Speech item with audio samples and label. @@ -88,14 +88,14 @@ pub struct SpeechItemRaw { /// The original label is also stored in the `label_original` field for debugging and remapping if needed. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SpeechItem { - /// Audio samples in the range [-1.0, 1.0]. - pub audio_samples: Vec, + /// Audio samples in the range [-1.0, 1.0]. + pub audio_samples: Vec, - /// The sample rate of the audio. - pub sample_rate: usize, + /// The sample rate of the audio. + pub sample_rate: usize, - /// The label of the audio. - pub label: SpeechCommandClass, + /// The label of the audio. + pub label: SpeechCommandClass, } /// Speech Commands dataset from Huggingface v0.02. @@ -114,96 +114,95 @@ pub struct SpeechItem { /// - test: 4,890 audio files /// - validation: 9,982 audio files pub struct SpeechCommandsDataset { - dataset: MappedDataset, + dataset: MappedDataset, } impl SpeechCommandsDataset { - /// Create a new dataset with the given split. - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = - HuggingfaceDatasetLoader::new("speech_commands") - .with_subset("v0.02") - .dataset(split) - .unwrap(); - let dataset = MapperDataset::new(dataset, ConvertSamples); - Self { dataset } - } - - /// Create a new dataset with the train split. - pub fn train() -> Self { - Self::new("train") - } - - /// Create a new dataset with the test split. - pub fn test() -> Self { - Self::new("test") - } - - /// Create a new dataset with the validation split. - pub fn validation() -> Self { - Self::new("validation") - } - - /// Returns the number of classes in the dataset - pub fn num_classes() -> usize { - SpeechCommandClass::COUNT - } + /// Create a new dataset with the given split. + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("speech_commands") + .with_subset("v0.02") + .dataset(split) + .unwrap(); + let dataset = MapperDataset::new(dataset, ConvertSamples); + Self { dataset } + } + + /// Create a new dataset with the train split. + pub fn train() -> Self { + Self::new("train") + } + + /// Create a new dataset with the test split. + pub fn test() -> Self { + Self::new("test") + } + + /// Create a new dataset with the validation split. + pub fn validation() -> Self { + Self::new("validation") + } + + /// Returns the number of classes in the dataset + pub fn num_classes() -> usize { + SpeechCommandClass::COUNT + } } impl Dataset for SpeechCommandsDataset { - fn get(&self, index: usize) -> Option { - self.dataset.get(index) - } + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } /// Mapper converting audio bytes into audio samples and the label to enum class. struct ConvertSamples; impl ConvertSamples { - /// Convert label to enum class. - fn to_speechcommandclass(label: usize) -> SpeechCommandClass { - SpeechCommandClass::from_repr(label).unwrap() - } - - /// Convert audio bytes into samples of floats [-1.0, 1.0]. - fn to_audiosamples(bytes: &Vec) -> (Vec, usize) { - let reader = WavReader::new(bytes.as_slice()).unwrap(); - let spec = reader.spec(); - - // Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample). - let max_value = (1 << (spec.bits_per_sample - 1)) as f32; - - // The sample rate of the audio. - let sample_rate = spec.sample_rate as usize; - - // Convert the audio samples to floats [-1.0, 1.0]. - let audio_samples: Vec = reader - .into_samples::() - .filter_map(Result::ok) - .map(|sample| sample as f32 / max_value) - .collect(); - - (audio_samples, sample_rate) - } + /// Convert label to enum class. + fn to_speechcommandclass(label: usize) -> SpeechCommandClass { + SpeechCommandClass::from_repr(label).unwrap() + } + + /// Convert audio bytes into samples of floats [-1.0, 1.0]. + fn to_audiosamples(bytes: &Vec) -> (Vec, usize) { + let reader = WavReader::new(bytes.as_slice()).unwrap(); + let spec = reader.spec(); + + // Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample). + let max_value = (1 << (spec.bits_per_sample - 1)) as f32; + + // The sample rate of the audio. + let sample_rate = spec.sample_rate as usize; + + // Convert the audio samples to floats [-1.0, 1.0]. + let audio_samples: Vec = reader + .into_samples::() + .filter_map(Result::ok) + .map(|sample| sample as f32 / max_value) + .collect(); + + (audio_samples, sample_rate) + } } impl Mapper for ConvertSamples { - /// Convert audio bytes into samples of floats [-1.0, 1.0] - /// and the label to enum class with the target word, other and silence classes. - fn map(&self, item: &SpeechItemRaw) -> SpeechItem { - let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes); - - // Convert the label to enum class, with the target words, other and silence classes. - let label = Self::to_speechcommandclass(item.label); - - SpeechItem { - audio_samples, - sample_rate, - label, - } + /// Convert audio bytes into samples of floats [-1.0, 1.0] + /// and the label to enum class with the target word, other and silence classes. + fn map(&self, item: &SpeechItemRaw) -> SpeechItem { + let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes); + + // Convert the label to enum class, with the target words, other and silence classes. + let label = Self::to_speechcommandclass(item.label); + + SpeechItem { + audio_samples, + sample_rate, + label, } + } } diff --git a/burn-dataset/src/dataset/base.rs b/burn-dataset/src/dataset/base.rs index eb53980c94..6f4caead7d 100644 --- a/burn-dataset/src/dataset/base.rs +++ b/burn-dataset/src/dataset/base.rs @@ -4,68 +4,68 @@ use crate::DatasetIterator; /// The dataset trait defines a basic collection of items with a predefined size. pub trait Dataset: Send + Sync { - /// Gets the item at the given index. - fn get(&self, index: usize) -> Option; + /// Gets the item at the given index. + fn get(&self, index: usize) -> Option; - /// Gets the number of items in the dataset. - fn len(&self) -> usize; + /// Gets the number of items in the dataset. + fn len(&self) -> usize; - /// Checks if the dataset is empty. - fn is_empty(&self) -> bool { - self.len() == 0 - } + /// Checks if the dataset is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } - /// Returns an iterator over the dataset. - fn iter(&self) -> DatasetIterator<'_, I> - where - Self: Sized, - { - DatasetIterator::new(self) - } + /// Returns an iterator over the dataset. + fn iter(&self) -> DatasetIterator<'_, I> + where + Self: Sized, + { + DatasetIterator::new(self) + } } impl Dataset for Arc where - D: Dataset, + D: Dataset, { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } impl Dataset for Arc> { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } impl Dataset for Box where - D: Dataset, + D: Dataset, { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } impl Dataset for Box> { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } diff --git a/burn-dataset/src/dataset/fake.rs b/burn-dataset/src/dataset/fake.rs index c27f8cf0d3..af8762d5e0 100644 --- a/burn-dataset/src/dataset/fake.rs +++ b/burn-dataset/src/dataset/fake.rs @@ -3,36 +3,36 @@ use fake::{Dummy, Fake, Faker}; /// Dataset filled with fake items generated from the [fake](fake) crate. pub struct FakeDataset { - dataset: InMemDataset, + dataset: InMemDataset, } impl> FakeDataset { - /// Create a new fake dataset with the given size. - pub fn new(size: usize) -> Self { - let mut items = Vec::with_capacity(size); - for _ in 0..size { - items.push(Faker.fake()); - } - let dataset = InMemDataset::new(items); - - Self { dataset } + /// Create a new fake dataset with the given size. + pub fn new(size: usize) -> Self { + let mut items = Vec::with_capacity(size); + for _ in 0..size { + items.push(Faker.fake()); } + let dataset = InMemDataset::new(items); + + Self { dataset } + } } impl Dataset for FakeDataset { - fn iter(&self) -> DatasetIterator<'_, I> { - DatasetIterator::new(self) - } + fn iter(&self) -> DatasetIterator<'_, I> { + DatasetIterator::new(self) + } - fn get(&self, index: usize) -> Option { - self.dataset.get(index) - } + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } - fn is_empty(&self) -> bool { - self.dataset.is_empty() - } + fn is_empty(&self) -> bool { + self.dataset.is_empty() + } } diff --git a/burn-dataset/src/dataset/in_memory.rs b/burn-dataset/src/dataset/in_memory.rs index a3b167f0c7..1091e8f08c 100644 --- a/burn-dataset/src/dataset/in_memory.rs +++ b/burn-dataset/src/dataset/in_memory.rs @@ -1,7 +1,7 @@ use std::{ - fs::File, - io::{BufRead, BufReader}, - path::Path, + fs::File, + io::{BufRead, BufReader}, + path::Path, }; use serde::de::DeserializeOwned; @@ -10,162 +10,162 @@ use crate::Dataset; /// Dataset where all items are stored in ram. pub struct InMemDataset { - items: Vec, + items: Vec, } impl InMemDataset { - /// Creates a new in memory dataset from the given items. - pub fn new(items: Vec) -> Self { - InMemDataset { items } - } + /// Creates a new in memory dataset from the given items. + pub fn new(items: Vec) -> Self { + InMemDataset { items } + } } impl Dataset for InMemDataset where - I: Clone + Send + Sync, + I: Clone + Send + Sync, { - fn get(&self, index: usize) -> Option { - self.items.get(index).cloned() - } - fn len(&self) -> usize { - self.items.len() - } + fn get(&self, index: usize) -> Option { + self.items.get(index).cloned() + } + fn len(&self) -> usize { + self.items.len() + } } impl InMemDataset where - I: Clone + DeserializeOwned, + I: Clone + DeserializeOwned, { - /// Create from a dataset. All items are loaded in memory. - pub fn from_dataset(dataset: &impl Dataset) -> Self { - let items: Vec = dataset.iter().collect(); - Self::new(items) + /// Create from a dataset. All items are loaded in memory. + pub fn from_dataset(dataset: &impl Dataset) -> Self { + let items: Vec = dataset.iter().collect(); + Self::new(items) + } + + /// Create from a json rows file (one json per line). + /// + /// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html) + pub fn from_json_rows>(path: P) -> Result { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut items = Vec::new(); + + for line in reader.lines() { + let item = serde_json::from_str(line.unwrap().as_str()).unwrap(); + items.push(item); } - /// Create from a json rows file (one json per line). - /// - /// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html) - pub fn from_json_rows>(path: P) -> Result { - let file = File::open(path)?; - let reader = BufReader::new(file); - let mut items = Vec::new(); + let dataset = Self::new(items); - for line in reader.lines() { - let item = serde_json::from_str(line.unwrap().as_str()).unwrap(); - items.push(item); - } + Ok(dataset) + } - let dataset = Self::new(items); - - Ok(dataset) - } + /// Create from a csv file. + /// + /// The first line of the csv file must be the header. The header must contain the name of the fields in the struct. + /// + /// The supported field types are: String, integer, float, and bool. + /// + /// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde) + pub fn from_csv>(path: P) -> Result { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut rdr = csv::Reader::from_reader(reader); - /// Create from a csv file. - /// - /// The first line of the csv file must be the header. The header must contain the name of the fields in the struct. - /// - /// The supported field types are: String, integer, float, and bool. - /// - /// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde) - pub fn from_csv>(path: P) -> Result { - let file = File::open(path)?; - let reader = BufReader::new(file); - let mut rdr = csv::Reader::from_reader(reader); + let mut items = Vec::new(); - let mut items = Vec::new(); - - for result in rdr.deserialize() { - let item: I = result?; - items.push(item); - } + for result in rdr.deserialize() { + let item: I = result?; + items.push(item); + } - let dataset = Self::new(items); + let dataset = Self::new(items); - Ok(dataset) - } + Ok(dataset) + } } #[cfg(test)] mod tests { - use super::*; - use crate::{test_data, SqliteDataset}; - - use rstest::{fixture, rstest}; - use serde::{Deserialize, Serialize}; - - const DB_FILE: &str = "tests/data/sqlite-dataset.db"; - const JSON_FILE: &str = "tests/data/dataset.json"; - const CSV_FILE: &str = "tests/data/dataset.csv"; - - type SqlDs = SqliteDataset; - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct Sample { - column_str: String, - column_bytes: Vec, - column_int: i64, - column_bool: bool, - column_float: f64, - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct SampleCvs { - column_str: String, - column_int: i64, - column_bool: bool, - column_float: f64, - } - - #[fixture] - fn train_dataset() -> SqlDs { - SqliteDataset::from_db_file(DB_FILE, "train").unwrap() - } - - #[rstest] - pub fn from_dataset(train_dataset: SqlDs) { - let dataset = InMemDataset::from_dataset(&train_dataset); - - let non_existing_record_index: usize = 10; - let record_index: usize = 0; - - assert_eq!(train_dataset.get(non_existing_record_index), None); - assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1"); - } - - #[test] - pub fn from_json_rows() { - let dataset = InMemDataset::::from_json_rows(JSON_FILE).unwrap(); - - let non_existing_record_index: usize = 10; - let record_index: usize = 1; - - assert_eq!(dataset.get(non_existing_record_index), None); - assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); - assert!(!dataset.get(record_index).unwrap().column_bool); - } - - #[test] - pub fn from_csv_rows() { - let dataset = InMemDataset::::from_csv(CSV_FILE).unwrap(); - - let non_existing_record_index: usize = 10; - let record_index: usize = 1; - - assert_eq!(dataset.get(non_existing_record_index), None); - assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); - assert_eq!(dataset.get(record_index).unwrap().column_int, 1); - assert!(!dataset.get(record_index).unwrap().column_bool); - assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0); - } - - #[test] - pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() { - let items_original = test_data::string_items(); - let dataset = InMemDataset::new(items_original.clone()); - - let items: Vec = dataset.iter().collect(); - - assert_eq!(items_original, items); - } + use super::*; + use crate::{test_data, SqliteDataset}; + + use rstest::{fixture, rstest}; + use serde::{Deserialize, Serialize}; + + const DB_FILE: &str = "tests/data/sqlite-dataset.db"; + const JSON_FILE: &str = "tests/data/dataset.json"; + const CSV_FILE: &str = "tests/data/dataset.csv"; + + type SqlDs = SqliteDataset; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct Sample { + column_str: String, + column_bytes: Vec, + column_int: i64, + column_bool: bool, + column_float: f64, + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct SampleCvs { + column_str: String, + column_int: i64, + column_bool: bool, + column_float: f64, + } + + #[fixture] + fn train_dataset() -> SqlDs { + SqliteDataset::from_db_file(DB_FILE, "train").unwrap() + } + + #[rstest] + pub fn from_dataset(train_dataset: SqlDs) { + let dataset = InMemDataset::from_dataset(&train_dataset); + + let non_existing_record_index: usize = 10; + let record_index: usize = 0; + + assert_eq!(train_dataset.get(non_existing_record_index), None); + assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1"); + } + + #[test] + pub fn from_json_rows() { + let dataset = InMemDataset::::from_json_rows(JSON_FILE).unwrap(); + + let non_existing_record_index: usize = 10; + let record_index: usize = 1; + + assert_eq!(dataset.get(non_existing_record_index), None); + assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); + assert!(!dataset.get(record_index).unwrap().column_bool); + } + + #[test] + pub fn from_csv_rows() { + let dataset = InMemDataset::::from_csv(CSV_FILE).unwrap(); + + let non_existing_record_index: usize = 10; + let record_index: usize = 1; + + assert_eq!(dataset.get(non_existing_record_index), None); + assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); + assert_eq!(dataset.get(record_index).unwrap().column_int, 1); + assert!(!dataset.get(record_index).unwrap().column_bool); + assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0); + } + + #[test] + pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() { + let items_original = test_data::string_items(); + let dataset = InMemDataset::new(items_original.clone()); + + let items: Vec = dataset.iter().collect(); + + assert_eq!(items_original, items); + } } diff --git a/burn-dataset/src/dataset/iterator.rs b/burn-dataset/src/dataset/iterator.rs index e513e1a08c..4d4045d16a 100644 --- a/burn-dataset/src/dataset/iterator.rs +++ b/burn-dataset/src/dataset/iterator.rs @@ -3,29 +3,29 @@ use std::iter::Iterator; /// Dataset iterator. pub struct DatasetIterator<'a, I> { - current: usize, - dataset: &'a dyn Dataset, + current: usize, + dataset: &'a dyn Dataset, } impl<'a, I> DatasetIterator<'a, I> { - /// Creates a new dataset iterator. - pub fn new(dataset: &'a D) -> Self - where - D: Dataset, - { - DatasetIterator { - current: 0, - dataset, - } + /// Creates a new dataset iterator. + pub fn new(dataset: &'a D) -> Self + where + D: Dataset, + { + DatasetIterator { + current: 0, + dataset, } + } } impl<'a, I> Iterator for DatasetIterator<'a, I> { - type Item = I; + type Item = I; - fn next(&mut self) -> Option { - let item = self.dataset.get(self.current); - self.current += 1; - item - } + fn next(&mut self) -> Option { + let item = self.dataset.get(self.current); + self.current += 1; + item + } } diff --git a/burn-dataset/src/dataset/sqlite.rs b/burn-dataset/src/dataset/sqlite.rs index 3428642982..1f6aa55877 100644 --- a/burn-dataset/src/dataset/sqlite.rs +++ b/burn-dataset/src/dataset/sqlite.rs @@ -1,21 +1,21 @@ use std::{ - collections::HashSet, - fs, io, - marker::PhantomData, - path::{Path, PathBuf}, - sync::{Arc, RwLock}, + collections::HashSet, + fs, io, + marker::PhantomData, + path::{Path, PathBuf}, + sync::{Arc, RwLock}, }; use crate::Dataset; use gix_tempfile::{ - handle::{persist, Writable}, - AutoRemove, ContainingDirectory, Handle, + handle::{persist, Writable}, + AutoRemove, ContainingDirectory, Handle, }; use r2d2::{Pool, PooledConnection}; use r2d2_sqlite::{ - rusqlite::{OpenFlags, OptionalExtension}, - SqliteConnectionManager, + rusqlite::{OpenFlags, OptionalExtension}, + SqliteConnectionManager, }; use sanitize_filename::sanitize; use serde::{de::DeserializeOwned, Serialize}; @@ -27,39 +27,39 @@ pub type Result = core::result::Result; /// Sqlite dataset error. #[derive(thiserror::Error, Debug)] pub enum SqliteDatasetError { - /// IO related error. - #[error("IO error: {0}")] - Io(#[from] io::Error), + /// IO related error. + #[error("IO error: {0}")] + Io(#[from] io::Error), - /// Sql related error. - #[error("Sql error: {0}")] - Sql(#[from] serde_rusqlite::rusqlite::Error), + /// Sql related error. + #[error("Sql error: {0}")] + Sql(#[from] serde_rusqlite::rusqlite::Error), - /// Serde related error. - #[error("Serde error: {0}")] - Serde(#[from] rmp_serde::encode::Error), + /// Serde related error. + #[error("Serde error: {0}")] + Serde(#[from] rmp_serde::encode::Error), - /// The database file already exists error. - #[error("Overwrite flag is set to false and the database file already exists: {0}")] - FileExists(PathBuf), + /// The database file already exists error. + #[error("Overwrite flag is set to false and the database file already exists: {0}")] + FileExists(PathBuf), - /// Error when creating the connection pool. - #[error("Failed to create connection pool: {0}")] - ConnectionPool(#[from] r2d2::Error), + /// Error when creating the connection pool. + #[error("Failed to create connection pool: {0}")] + ConnectionPool(#[from] r2d2::Error), - /// Error when persisting the temporary database file. - #[error("Could not persist the temporary database file: {0}")] - PersistDbFile(#[from] persist::Error), + /// Error when persisting the temporary database file. + #[error("Could not persist the temporary database file: {0}")] + PersistDbFile(#[from] persist::Error), - /// Any other error. - #[error("{0}")] - Other(&'static str), + /// Any other error. + #[error("{0}")] + Other(&'static str), } impl From<&'static str> for SqliteDatasetError { - fn from(s: &'static str) -> Self { - SqliteDatasetError::Other(s) - } + fn from(s: &'static str) -> Self { + SqliteDatasetError::Other(s) + } } /// This struct represents a dataset where all items are stored in an SQLite database. @@ -89,323 +89,322 @@ impl From<&'static str> for SqliteDatasetError { /// method to read the data from the table. #[derive(Debug)] pub struct SqliteDataset { - db_file: PathBuf, - split: String, - conn_pool: Pool, - columns: Vec, - len: usize, - select_statement: String, - row_serialized: bool, - phantom: PhantomData, + db_file: PathBuf, + split: String, + conn_pool: Pool, + columns: Vec, + len: usize, + select_statement: String, + row_serialized: bool, + phantom: PhantomData, } impl SqliteDataset { - /// Initializes a `SqliteDataset` from a SQLite database file and a split name. - pub fn from_db_file>(db_file: P, split: &str) -> Result { - // Create a connection pool - let conn_pool = create_conn_pool(&db_file, false)?; - - // Determine how the table is stored - let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?; - - // Create a select statement and save it - let select_statement = if row_serialized { - format!("select item from {split} where row_id = ?") - } else { - format!("select * from {split} where row_id = ?") - }; - - // Save the column names and the number of rows - let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?; - - Ok(SqliteDataset { - db_file: db_file.as_ref().to_path_buf(), - split: split.to_string(), - conn_pool, - columns, - len, - select_statement, - row_serialized, - phantom: PhantomData, - }) - } + /// Initializes a `SqliteDataset` from a SQLite database file and a split name. + pub fn from_db_file>(db_file: P, split: &str) -> Result { + // Create a connection pool + let conn_pool = create_conn_pool(&db_file, false)?; - /// Returns true if table has two columns: row_id (integer) and item (blob). - /// - /// This is used to determine if the table is row serialized or not. - fn check_if_row_serialized( - conn_pool: &Pool, - split: &str, - ) -> Result { - // This struct is used to store the column name and type - struct Column { - name: String, - ty: String, - } - - const COLUMN_NAME: usize = 1; - const COLUMN_TYPE: usize = 2; - - let sql_statement = format!("PRAGMA table_info({split})"); - - let conn = conn_pool.get()?; - - let mut stmt = conn.prepare(sql_statement.as_str())?; - let column_iter = stmt.query_map([], |row| { - Ok(Column { - name: row - .get::(COLUMN_NAME) - .unwrap() - .to_lowercase(), - ty: row - .get::(COLUMN_TYPE) - .unwrap() - .to_lowercase(), - }) - })?; - - let mut columns: Vec = vec![]; - - for column in column_iter { - columns.push(column?); - } - - if columns.len() != 2 { - Ok(false) - } else { - // Check if the column names and types match the expected values - Ok(columns[0].name == "row_id" - && columns[0].ty == "integer" - && columns[1].name == "item" - && columns[1].ty == "blob") - } - } + // Determine how the table is stored + let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?; - /// Get the database file name. - pub fn db_file(&self) -> PathBuf { - self.db_file.clone() - } + // Create a select statement and save it + let select_statement = if row_serialized { + format!("select item from {split} where row_id = ?") + } else { + format!("select * from {split} where row_id = ?") + }; + + // Save the column names and the number of rows + let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?; + + Ok(SqliteDataset { + db_file: db_file.as_ref().to_path_buf(), + split: split.to_string(), + conn_pool, + columns, + len, + select_statement, + row_serialized, + phantom: PhantomData, + }) + } + + /// Returns true if table has two columns: row_id (integer) and item (blob). + /// + /// This is used to determine if the table is row serialized or not. + fn check_if_row_serialized( + conn_pool: &Pool, + split: &str, + ) -> Result { + // This struct is used to store the column name and type + struct Column { + name: String, + ty: String, + } + + const COLUMN_NAME: usize = 1; + const COLUMN_TYPE: usize = 2; + + let sql_statement = format!("PRAGMA table_info({split})"); + + let conn = conn_pool.get()?; + + let mut stmt = conn.prepare(sql_statement.as_str())?; + let column_iter = stmt.query_map([], |row| { + Ok(Column { + name: row + .get::(COLUMN_NAME) + .unwrap() + .to_lowercase(), + ty: row + .get::(COLUMN_TYPE) + .unwrap() + .to_lowercase(), + }) + })?; - /// Get the split name. - pub fn split(&self) -> &str { - self.split.as_str() + let mut columns: Vec = vec![]; + + for column in column_iter { + columns.push(column?); } + + if columns.len() != 2 { + Ok(false) + } else { + // Check if the column names and types match the expected values + Ok( + columns[0].name == "row_id" + && columns[0].ty == "integer" + && columns[1].name == "item" + && columns[1].ty == "blob", + ) + } + } + + /// Get the database file name. + pub fn db_file(&self) -> PathBuf { + self.db_file.clone() + } + + /// Get the split name. + pub fn split(&self) -> &str { + self.split.as_str() + } } impl Dataset for SqliteDataset where - I: Clone + Send + Sync + DeserializeOwned, + I: Clone + Send + Sync + DeserializeOwned, { - /// Get an item from the dataset. - fn get(&self, index: usize) -> Option { - // Row ids start with 1 (one) and index starts with 0 (zero) - let row_id = index + 1; - - // Get a connection from the pool - let connection = self.conn_pool.get().unwrap(); - let mut statement = connection.prepare(self.select_statement.as_str()).unwrap(); - - if self.row_serialized { - // Fetch with a single column `item` and deserialize it with MessagePack - statement - .query_row([row_id], |row| { - // Deserialize item (blob) with MessagePack (rmp-serde) - Ok( - rmp_serde::from_slice::(row.get_ref(0).unwrap().as_blob().unwrap()) - .unwrap(), - ) - }) - .optional() //Converts Error (not found) to None - .unwrap() - } else { - // Fetch a row with multiple columns and deserialize it serde_rusqlite - statement - .query_row([row_id], |row| { - // Deserialize the row with serde_rusqlite - Ok(from_row_with_columns::(row, &self.columns).unwrap()) - }) - .optional() //Converts Error (not found) to None - .unwrap() - } + /// Get an item from the dataset. + fn get(&self, index: usize) -> Option { + // Row ids start with 1 (one) and index starts with 0 (zero) + let row_id = index + 1; + + // Get a connection from the pool + let connection = self.conn_pool.get().unwrap(); + let mut statement = connection.prepare(self.select_statement.as_str()).unwrap(); + + if self.row_serialized { + // Fetch with a single column `item` and deserialize it with MessagePack + statement + .query_row([row_id], |row| { + // Deserialize item (blob) with MessagePack (rmp-serde) + Ok(rmp_serde::from_slice::(row.get_ref(0).unwrap().as_blob().unwrap()).unwrap()) + }) + .optional() //Converts Error (not found) to None + .unwrap() + } else { + // Fetch a row with multiple columns and deserialize it serde_rusqlite + statement + .query_row([row_id], |row| { + // Deserialize the row with serde_rusqlite + Ok(from_row_with_columns::(row, &self.columns).unwrap()) + }) + .optional() //Converts Error (not found) to None + .unwrap() } + } - /// Return the number of rows in the dataset. - fn len(&self) -> usize { - self.len - } + /// Return the number of rows in the dataset. + fn len(&self) -> usize { + self.len + } } /// Fetch the column names and the number of rows from the database. fn fetch_columns_and_len( - conn_pool: &Pool, - select_statement: &str, - split: &str, + conn_pool: &Pool, + select_statement: &str, + split: &str, ) -> Result<(Vec, usize)> { - // Save the column names - let connection = conn_pool.get()?; - let statement = connection.prepare(select_statement)?; - let columns = columns_from_statement(&statement); - - // Count the number of rows and save it as len - // - // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables. - // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id, - // which corresponds to the number of rows in the table. - // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps. - // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index. - let mut statement = - connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?; - - let len = statement.query_row([], |row| { - let len: usize = row.get(0)?; - Ok(len) - })?; - Ok((columns, len)) + // Save the column names + let connection = conn_pool.get()?; + let statement = connection.prepare(select_statement)?; + let columns = columns_from_statement(&statement); + + // Count the number of rows and save it as len + // + // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables. + // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id, + // which corresponds to the number of rows in the table. + // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps. + // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index. + let mut statement = + connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?; + + let len = statement.query_row([], |row| { + let len: usize = row.get(0)?; + Ok(len) + })?; + Ok((columns, len)) } /// Helper function to create a connection pool fn create_conn_pool>( - db_file: P, - write: bool, + db_file: P, + write: bool, ) -> Result> { - let sqlite_flags = if write { - OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE - } else { - OpenFlags::SQLITE_OPEN_READ_ONLY - }; + let sqlite_flags = if write { + OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE + } else { + OpenFlags::SQLITE_OPEN_READ_ONLY + }; - // Create a connection pool and make sure the connections are read only - let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags); + // Create a connection pool and make sure the connections are read only + let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags); - Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool) + Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool) } /// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets. /// It consists of an optional name, a database file path, and a base directory for storage. #[derive(Clone, Debug)] pub struct SqliteDatasetStorage { - name: Option, - db_file: Option, - base_dir: Option, + name: Option, + db_file: Option, + base_dir: Option, } impl SqliteDatasetStorage { - /// Creates a new instance of `SqliteDatasetStorage` using a dataset name. - /// - /// # Arguments - /// - /// * `name` - A string slice that holds the name of the dataset. - pub fn from_name(name: &str) -> Self { - SqliteDatasetStorage { - name: Some(name.to_string()), - db_file: None, - base_dir: None, - } - } - - /// Creates a new instance of `SqliteDatasetStorage` using a database file path. - /// - /// # Arguments - /// - /// * `db_file` - A reference to the Path that represents the database file path. - pub fn from_file>(db_file: P) -> Self { - SqliteDatasetStorage { - name: None, - db_file: Some(db_file.as_ref().to_path_buf()), - base_dir: None, - } - } - - /// Sets the base directory for storing the dataset. - /// - /// # Arguments - /// - /// * `base_dir` - A string slice that represents the base directory. - pub fn with_base_dir>(mut self, base_dir: P) -> Self { - self.base_dir = Some(base_dir.as_ref().to_path_buf()); - self - } - - /// Checks if the database file exists in the given path. - /// - /// # Returns - /// - /// * A boolean value indicating whether the file exists or not. - pub fn exists(&self) -> bool { - self.db_file().exists() - } - - /// Fetches the database file path. - /// - /// # Returns - /// - /// * A `PathBuf` instance representing the file path. - pub fn db_file(&self) -> PathBuf { - let db_file = match &self.db_file { - Some(db_file) => db_file.clone(), - None => { - let name = sanitize(self.name.as_ref().expect("Name is not set")); - Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db")) - } - }; - db_file - } - - /// Determines the base directory for storing the dataset. - /// - /// # Arguments - /// - /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory. - /// - /// # Returns - /// - /// * A `PathBuf` instance representing the base directory. - pub fn base_dir(base_dir: Option) -> PathBuf { - match base_dir { - Some(base_dir) => base_dir, - None => { - let home_dir = dirs::home_dir().expect("Could not get home directory"); - - home_dir.join(".cache").join("burn-dataset") - } - } - } - - /// Provides a writer instance for the SQLite dataset. - /// - /// # Arguments - /// - /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. - pub fn writer(&self, overwrite: bool) -> Result> - where - I: Clone + Send + Sync + Serialize + DeserializeOwned, - { - SqliteDatasetWriter::new(self.db_file(), overwrite) + /// Creates a new instance of `SqliteDatasetStorage` using a dataset name. + /// + /// # Arguments + /// + /// * `name` - A string slice that holds the name of the dataset. + pub fn from_name(name: &str) -> Self { + SqliteDatasetStorage { + name: Some(name.to_string()), + db_file: None, + base_dir: None, + } + } + + /// Creates a new instance of `SqliteDatasetStorage` using a database file path. + /// + /// # Arguments + /// + /// * `db_file` - A reference to the Path that represents the database file path. + pub fn from_file>(db_file: P) -> Self { + SqliteDatasetStorage { + name: None, + db_file: Some(db_file.as_ref().to_path_buf()), + base_dir: None, + } + } + + /// Sets the base directory for storing the dataset. + /// + /// # Arguments + /// + /// * `base_dir` - A string slice that represents the base directory. + pub fn with_base_dir>(mut self, base_dir: P) -> Self { + self.base_dir = Some(base_dir.as_ref().to_path_buf()); + self + } + + /// Checks if the database file exists in the given path. + /// + /// # Returns + /// + /// * A boolean value indicating whether the file exists or not. + pub fn exists(&self) -> bool { + self.db_file().exists() + } + + /// Fetches the database file path. + /// + /// # Returns + /// + /// * A `PathBuf` instance representing the file path. + pub fn db_file(&self) -> PathBuf { + let db_file = match &self.db_file { + Some(db_file) => db_file.clone(), + None => { + let name = sanitize(self.name.as_ref().expect("Name is not set")); + Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db")) + } + }; + db_file + } + + /// Determines the base directory for storing the dataset. + /// + /// # Arguments + /// + /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory. + /// + /// # Returns + /// + /// * A `PathBuf` instance representing the base directory. + pub fn base_dir(base_dir: Option) -> PathBuf { + match base_dir { + Some(base_dir) => base_dir, + None => { + let home_dir = dirs::home_dir().expect("Could not get home directory"); + + home_dir.join(".cache").join("burn-dataset") + } + } + } + + /// Provides a writer instance for the SQLite dataset. + /// + /// # Arguments + /// + /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. + pub fn writer(&self, overwrite: bool) -> Result> + where + I: Clone + Send + Sync + Serialize + DeserializeOwned, + { + SqliteDatasetWriter::new(self.db_file(), overwrite) + } + + /// Provides a reader instance for the SQLite dataset. + /// + /// # Arguments + /// + /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test"). + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise. + pub fn reader(&self, split: &str) -> Result> + where + I: Clone + Send + Sync + Serialize + DeserializeOwned, + { + if !self.exists() { + panic!("The database file does not exist"); } - /// Provides a reader instance for the SQLite dataset. - /// - /// # Arguments - /// - /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test"). - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise. - pub fn reader(&self, split: &str) -> Result> - where - I: Clone + Send + Sync + Serialize + DeserializeOwned, - { - if !self.exists() { - panic!("The database file does not exist"); - } - - SqliteDataset::from_db_file(self.db_file(), split) - } + SqliteDataset::from_db_file(self.db_file(), split) + } } /// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets. @@ -420,190 +419,190 @@ impl SqliteDatasetStorage { /// - Enlargement of a dataset's item count post preprocessing #[derive(Debug)] pub struct SqliteDatasetWriter { - db_file: PathBuf, - db_file_tmp: Option>, - splits: Arc>>, - overwrite: bool, - conn_pool: Option>, - is_completed: Arc>, - phantom: PhantomData, + db_file: PathBuf, + db_file_tmp: Option>, + splits: Arc>>, + overwrite: bool, + conn_pool: Option>, + is_completed: Arc>, + phantom: PhantomData, } impl SqliteDatasetWriter where - I: Clone + Send + Sync + Serialize + DeserializeOwned, + I: Clone + Send + Sync + Serialize + DeserializeOwned, { - /// Creates a new instance of `SqliteDatasetWriter`. - /// - /// # Arguments - /// - /// * `db_file` - A reference to the Path that represents the database file path. - /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. - pub fn new>(db_file: P, overwrite: bool) -> Result { - let writer = Self { - db_file: db_file.as_ref().to_path_buf(), - db_file_tmp: None, - splits: Arc::new(RwLock::new(HashSet::new())), - overwrite, - conn_pool: None, - is_completed: Arc::new(RwLock::new(false)), - phantom: PhantomData, - }; - - writer.init() - } - - /// Initializes the dataset writer by creating the database file, tables, and connection pool. - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise. - fn init(mut self) -> Result { - // Remove the db file if it already exists - if self.db_file.exists() { - if self.overwrite { - fs::remove_file(&self.db_file)?; - } else { - return Err(SqliteDatasetError::FileExists(self.db_file)); - } - } - - // Create the database file directory if it does not exist - let db_file_dir = self - .db_file - .parent() - .ok_or("Unable to get parent directory")?; - - if !db_file_dir.exists() { - fs::create_dir_all(db_file_dir)?; - } - - // Create a temp database file name as {base_dir}/{name}.db.tmp - let mut db_file_tmp = self.db_file.clone(); - db_file_tmp.set_extension("db.tmp"); - if db_file_tmp.exists() { - fs::remove_file(&db_file_tmp)?; - } - - // Create the temp database file and wrap it with a gix_tempfile::Handle - // This will ensure that the temp file is deleted when the writer is dropped - // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this) - gix_tempfile::signal::setup(Default::default()); - self.db_file_tmp = Some(gix_tempfile::writable_at( - &db_file_tmp, - ContainingDirectory::Exists, - AutoRemove::Tempfile, - )?); - - let conn_pool = create_conn_pool(db_file_tmp, true)?; - self.conn_pool = Some(conn_pool); - - Ok(self) - } - - /// Serializes and writes an item to the database. The item is written to the table for the - /// specified split. If the table does not exist, it is created. If the table exists, the item - /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/) - /// - /// # Arguments - /// - /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test"). - /// * `item` - A reference to the item to be written to the database. - /// - /// # Returns - /// - /// * A `Result` containing the index of the inserted row if successful, an error otherwise. - pub fn write(&self, split: &str, item: &I) -> Result { - // Acquire the read lock (wont't block other reads) - let is_completed = self.is_completed.read().unwrap(); - - // If the writer is completed, return an error - if *is_completed { - return Err(SqliteDatasetError::Other( - "Cannot save to a completed dataset writer", - )); - } - - // create the table for the split if it does not exist - if !self.splits.read().unwrap().contains(split) { - self.create_table(split)?; - } - - // Get a connection from the pool - let conn_pool = self.conn_pool.as_ref().unwrap(); - let conn = conn_pool.get()?; - - // Serialize the item using MessagePack - let serialized_item = rmp_serde::to_vec(item)?; - - // Turn off the synchronous and journal mode for speed up - // We are sacrificing durability for speed but it's okay because - // we always recreate the dataset if it is not completed. - pragma_update_with_error_handling(&conn, "synchronous", "OFF")?; - pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?; - - // Insert the serialized item into the database - let insert_statement = format!("insert into {split} (item) values (?)", split = split); - conn.execute(insert_statement.as_str(), [serialized_item])?; - - // Get the primary key of the last inserted row and convert to index (row_id-1) - let index = (conn.last_insert_rowid() - 1) as usize; - - Ok(index) - } - - /// Marks the dataset as completed and persists the temporary database file. - pub fn set_completed(&mut self) -> Result<()> { - let mut is_completed = self.is_completed.write().unwrap(); - - // Rename the database file from tmp to db - let _file_result = self - .db_file_tmp - .take() // take ownership of the temporary file and set to None - .unwrap() // unwrap the temporary file - .persist(&self.db_file)? - .ok_or("Unable to persist the database file")?; - - *is_completed = true; - Ok(()) - } + /// Creates a new instance of `SqliteDatasetWriter`. + /// + /// # Arguments + /// + /// * `db_file` - A reference to the Path that represents the database file path. + /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. + pub fn new>(db_file: P, overwrite: bool) -> Result { + let writer = Self { + db_file: db_file.as_ref().to_path_buf(), + db_file_tmp: None, + splits: Arc::new(RwLock::new(HashSet::new())), + overwrite, + conn_pool: None, + is_completed: Arc::new(RwLock::new(false)), + phantom: PhantomData, + }; - /// Creates table for the data split. - /// - /// Note: call is idempotent and thread-safe. - /// - /// # Arguments - /// - /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test"). - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise. - /// - /// TODO (@antimora): add support creating a table with columns corresponding to the item fields - fn create_table(&self, split: &str) -> Result<()> { - // Check if the split already exists - if self.splits.read().unwrap().contains(split) { - return Ok(()); - } - - let conn_pool = self.conn_pool.as_ref().unwrap(); - let connection = conn_pool.get()?; - let create_table_statement = format!( + writer.init() + } + + /// Initializes the dataset writer by creating the database file, tables, and connection pool. + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise. + fn init(mut self) -> Result { + // Remove the db file if it already exists + if self.db_file.exists() { + if self.overwrite { + fs::remove_file(&self.db_file)?; + } else { + return Err(SqliteDatasetError::FileExists(self.db_file)); + } + } + + // Create the database file directory if it does not exist + let db_file_dir = self + .db_file + .parent() + .ok_or("Unable to get parent directory")?; + + if !db_file_dir.exists() { + fs::create_dir_all(db_file_dir)?; + } + + // Create a temp database file name as {base_dir}/{name}.db.tmp + let mut db_file_tmp = self.db_file.clone(); + db_file_tmp.set_extension("db.tmp"); + if db_file_tmp.exists() { + fs::remove_file(&db_file_tmp)?; + } + + // Create the temp database file and wrap it with a gix_tempfile::Handle + // This will ensure that the temp file is deleted when the writer is dropped + // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this) + gix_tempfile::signal::setup(Default::default()); + self.db_file_tmp = Some(gix_tempfile::writable_at( + &db_file_tmp, + ContainingDirectory::Exists, + AutoRemove::Tempfile, + )?); + + let conn_pool = create_conn_pool(db_file_tmp, true)?; + self.conn_pool = Some(conn_pool); + + Ok(self) + } + + /// Serializes and writes an item to the database. The item is written to the table for the + /// specified split. If the table does not exist, it is created. If the table exists, the item + /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/) + /// + /// # Arguments + /// + /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test"). + /// * `item` - A reference to the item to be written to the database. + /// + /// # Returns + /// + /// * A `Result` containing the index of the inserted row if successful, an error otherwise. + pub fn write(&self, split: &str, item: &I) -> Result { + // Acquire the read lock (wont't block other reads) + let is_completed = self.is_completed.read().unwrap(); + + // If the writer is completed, return an error + if *is_completed { + return Err(SqliteDatasetError::Other( + "Cannot save to a completed dataset writer", + )); + } + + // create the table for the split if it does not exist + if !self.splits.read().unwrap().contains(split) { + self.create_table(split)?; + } + + // Get a connection from the pool + let conn_pool = self.conn_pool.as_ref().unwrap(); + let conn = conn_pool.get()?; + + // Serialize the item using MessagePack + let serialized_item = rmp_serde::to_vec(item)?; + + // Turn off the synchronous and journal mode for speed up + // We are sacrificing durability for speed but it's okay because + // we always recreate the dataset if it is not completed. + pragma_update_with_error_handling(&conn, "synchronous", "OFF")?; + pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?; + + // Insert the serialized item into the database + let insert_statement = format!("insert into {split} (item) values (?)", split = split); + conn.execute(insert_statement.as_str(), [serialized_item])?; + + // Get the primary key of the last inserted row and convert to index (row_id-1) + let index = (conn.last_insert_rowid() - 1) as usize; + + Ok(index) + } + + /// Marks the dataset as completed and persists the temporary database file. + pub fn set_completed(&mut self) -> Result<()> { + let mut is_completed = self.is_completed.write().unwrap(); + + // Rename the database file from tmp to db + let _file_result = self + .db_file_tmp + .take() // take ownership of the temporary file and set to None + .unwrap() // unwrap the temporary file + .persist(&self.db_file)? + .ok_or("Unable to persist the database file")?; + + *is_completed = true; + Ok(()) + } + + /// Creates table for the data split. + /// + /// Note: call is idempotent and thread-safe. + /// + /// # Arguments + /// + /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test"). + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise. + /// + /// TODO (@antimora): add support creating a table with columns corresponding to the item fields + fn create_table(&self, split: &str) -> Result<()> { + // Check if the split already exists + if self.splits.read().unwrap().contains(split) { + return Ok(()); + } + + let conn_pool = self.conn_pool.as_ref().unwrap(); + let connection = conn_pool.get()?; + let create_table_statement = format!( "create table if not exists {split} (row_id integer primary key autoincrement not null, item blob not null)" ); - connection.execute(create_table_statement.as_str(), [])?; + connection.execute(create_table_statement.as_str(), [])?; - // Add the split to the splits - self.splits.write().unwrap().insert(split.to_string()); + // Add the split to the splits + self.splits.write().unwrap().insert(split.to_string()); - Ok(()) - } + Ok(()) + } } /// Runs a pragma update and ignores the `ExecuteReturnedResults` error. @@ -612,237 +611,235 @@ where /// and can be ignored. This function runs the pragma update and ignores the error if it is /// `ExecuteReturnedResults`. fn pragma_update_with_error_handling( - conn: &PooledConnection, - setting: &str, - value: &str, + conn: &PooledConnection, + setting: &str, + value: &str, ) -> Result<()> { - let result = conn.pragma_update(None, setting, value); - if let Err(error) = result { - if error != rusqlite::Error::ExecuteReturnedResults { - return Err(SqliteDatasetError::Sql(error)); - } + let result = conn.pragma_update(None, setting, value); + if let Err(error) = result { + if error != rusqlite::Error::ExecuteReturnedResults { + return Err(SqliteDatasetError::Sql(error)); } - Ok(()) + } + Ok(()) } #[cfg(test)] mod tests { - use rayon::prelude::*; - use rstest::{fixture, rstest}; - use serde::{Deserialize, Serialize}; - use tempfile::{tempdir, NamedTempFile, TempDir}; - - use super::*; - - type SqlDs = SqliteDataset; - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct Sample { - column_str: String, - column_bytes: Vec, - column_int: i64, - column_bool: bool, - column_float: f64, - } - - #[fixture] - fn train_dataset() -> SqlDs { - SqliteDataset::::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap() - } - - #[rstest] - pub fn len(train_dataset: SqlDs) { - assert_eq!(train_dataset.len(), 2); - } - - #[rstest] - pub fn get_some(train_dataset: SqlDs) { - let item = train_dataset.get(0).unwrap(); - assert_eq!(item.column_str, "HI1"); - assert_eq!(item.column_bytes, vec![55, 231, 159]); - assert_eq!(item.column_int, 1); - assert!(item.column_bool); - assert_eq!(item.column_float, 1.0); - } - - #[rstest] - pub fn get_none(train_dataset: SqlDs) { - assert_eq!(train_dataset.get(10), None); - } - - #[rstest] - pub fn multi_thread(train_dataset: SqlDs) { - let indices: Vec = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1]; - let results: Vec> = - indices.par_iter().map(|&i| train_dataset.get(i)).collect(); - - let mut match_count = 0; - for (_index, result) in indices.iter().zip(results.iter()) { - match result { - Some(_val) => match_count += 1, - None => (), - } - } - - assert_eq!(match_count, 5); - } - - #[test] - fn sqlite_dataset_storage() { - // Test with non-existing file - let storage = SqliteDatasetStorage::from_file("non-existing.db"); - assert!(!storage.exists()); - - // Test with non-existing name - let storage = SqliteDatasetStorage::from_name("non-existing.db"); - assert!(!storage.exists()); - - // Test with existing file - let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db"); - assert!(storage.exists()); - let result = storage.reader::("train"); - assert!(result.is_ok()); - let train = result.unwrap(); - assert_eq!(train.len(), 2); - - // Test get writer - let temp_file = NamedTempFile::new().unwrap(); - let storage = SqliteDatasetStorage::from_file(temp_file.path()); - assert!(storage.exists()); - let result = storage.writer::(true); - assert!(result.is_ok()); - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct Complex { - column_str: String, - column_bytes: Vec, - column_int: i64, - column_bool: bool, - column_float: f64, - column_complex: Vec>>, - } - - /// Create a temporary directory. - #[fixture] - fn tmp_dir() -> TempDir { - // Create a TempDir. This object will be automatically - // deleted when it goes out of scope. - tempdir().unwrap() - } - type Writer = SqliteDatasetWriter; - - /// Create a SqliteDatasetWriter with a temporary directory. - /// Make sure to return the temporary directory so that it is not deleted. - #[fixture] - fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) { - let temp_dir_str = tmp_dir.path(); - let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str); - let overwrite = true; - let result = storage.writer::(overwrite); - assert!(result.is_ok()); - let writer = result.unwrap(); - (writer, tmp_dir) - } - - #[test] - fn test_new() { - // Test that the constructor works with overwrite = true - let test_path = NamedTempFile::new().unwrap(); - let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); - assert!(!test_path.path().exists()); - - // Test that the constructor works with overwrite = false - let test_path = NamedTempFile::new().unwrap(); - let result = SqliteDatasetWriter::::new(&test_path, false); - assert!(result.is_err()); - - // Test that the constructor works with no existing file - let temp = NamedTempFile::new().unwrap(); - let test_path = temp.path().to_path_buf(); - assert!(temp.close().is_ok()); - assert!(!test_path.exists()); - let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); - assert!(!test_path.exists()); - } - - #[rstest] - pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) { - // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) - let (writer, _tmp_dir) = writer_fixture; - - assert!(writer.overwrite); - assert!(!writer.db_file.exists()); + use rayon::prelude::*; + use rstest::{fixture, rstest}; + use serde::{Deserialize, Serialize}; + use tempfile::{tempdir, NamedTempFile, TempDir}; + + use super::*; + + type SqlDs = SqliteDataset; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct Sample { + column_str: String, + column_bytes: Vec, + column_int: i64, + column_bool: bool, + column_float: f64, + } + + #[fixture] + fn train_dataset() -> SqlDs { + SqliteDataset::::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap() + } + + #[rstest] + pub fn len(train_dataset: SqlDs) { + assert_eq!(train_dataset.len(), 2); + } + + #[rstest] + pub fn get_some(train_dataset: SqlDs) { + let item = train_dataset.get(0).unwrap(); + assert_eq!(item.column_str, "HI1"); + assert_eq!(item.column_bytes, vec![55, 231, 159]); + assert_eq!(item.column_int, 1); + assert!(item.column_bool); + assert_eq!(item.column_float, 1.0); + } + + #[rstest] + pub fn get_none(train_dataset: SqlDs) { + assert_eq!(train_dataset.get(10), None); + } + + #[rstest] + pub fn multi_thread(train_dataset: SqlDs) { + let indices: Vec = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1]; + let results: Vec> = indices.par_iter().map(|&i| train_dataset.get(i)).collect(); + + let mut match_count = 0; + for (_index, result) in indices.iter().zip(results.iter()) { + match result { + Some(_val) => match_count += 1, + None => (), + } + } + + assert_eq!(match_count, 5); + } + + #[test] + fn sqlite_dataset_storage() { + // Test with non-existing file + let storage = SqliteDatasetStorage::from_file("non-existing.db"); + assert!(!storage.exists()); + + // Test with non-existing name + let storage = SqliteDatasetStorage::from_name("non-existing.db"); + assert!(!storage.exists()); + + // Test with existing file + let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db"); + assert!(storage.exists()); + let result = storage.reader::("train"); + assert!(result.is_ok()); + let train = result.unwrap(); + assert_eq!(train.len(), 2); + + // Test get writer + let temp_file = NamedTempFile::new().unwrap(); + let storage = SqliteDatasetStorage::from_file(temp_file.path()); + assert!(storage.exists()); + let result = storage.writer::(true); + assert!(result.is_ok()); + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct Complex { + column_str: String, + column_bytes: Vec, + column_int: i64, + column_bool: bool, + column_float: f64, + column_complex: Vec>>, + } + + /// Create a temporary directory. + #[fixture] + fn tmp_dir() -> TempDir { + // Create a TempDir. This object will be automatically + // deleted when it goes out of scope. + tempdir().unwrap() + } + type Writer = SqliteDatasetWriter; + + /// Create a SqliteDatasetWriter with a temporary directory. + /// Make sure to return the temporary directory so that it is not deleted. + #[fixture] + fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) { + let temp_dir_str = tmp_dir.path(); + let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str); + let overwrite = true; + let result = storage.writer::(overwrite); + assert!(result.is_ok()); + let writer = result.unwrap(); + (writer, tmp_dir) + } + + #[test] + fn test_new() { + // Test that the constructor works with overwrite = true + let test_path = NamedTempFile::new().unwrap(); + let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); + assert!(!test_path.path().exists()); + + // Test that the constructor works with overwrite = false + let test_path = NamedTempFile::new().unwrap(); + let result = SqliteDatasetWriter::::new(&test_path, false); + assert!(result.is_err()); + + // Test that the constructor works with no existing file + let temp = NamedTempFile::new().unwrap(); + let test_path = temp.path().to_path_buf(); + assert!(temp.close().is_ok()); + assert!(!test_path.exists()); + let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); + assert!(!test_path.exists()); + } + + #[rstest] + pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) { + // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) + let (writer, _tmp_dir) = writer_fixture; + + assert!(writer.overwrite); + assert!(!writer.db_file.exists()); + + let new_item = Complex { + column_str: "HI1".to_string(), + column_bytes: vec![1_u8, 2, 3], + column_int: 0, + column_bool: true, + column_float: 1.0, + column_complex: vec![vec![vec![[1, 23_u8, 3]]]], + }; - let new_item = Complex { - column_str: "HI1".to_string(), - column_bytes: vec![1_u8, 2, 3], - column_int: 0, - column_bool: true, - column_float: 1.0, - column_complex: vec![vec![vec![[1, 23_u8, 3]]]], - }; + let index = writer.write("train", &new_item).unwrap(); + assert_eq!(index, 0); - let index = writer.write("train", &new_item).unwrap(); - assert_eq!(index, 0); + let mut writer = writer; - let mut writer = writer; + writer.set_completed().expect("Failed to set completed"); - writer.set_completed().expect("Failed to set completed"); + assert!(writer.db_file.exists()); + assert!(writer.db_file_tmp.is_none()); - assert!(writer.db_file.exists()); - assert!(writer.db_file_tmp.is_none()); + let result = writer.write("train", &new_item); - let result = writer.write("train", &new_item); + // Should fail because the writer is completed + assert!(result.is_err()); - // Should fail because the writer is completed - assert!(result.is_err()); + let dataset = SqliteDataset::::from_db_file(writer.db_file, "train").unwrap(); - let dataset = SqliteDataset::::from_db_file(writer.db_file, "train").unwrap(); + let fetched_item = dataset.get(0).unwrap(); + assert_eq!(fetched_item, new_item); + assert_eq!(dataset.len(), 1); + } - let fetched_item = dataset.get(0).unwrap(); - assert_eq!(fetched_item, new_item); - assert_eq!(dataset.len(), 1); - } + #[rstest] + pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) { + // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) + let (writer, _tmp_dir) = writer_fixture; - #[rstest] - pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) { - // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) - let (writer, _tmp_dir) = writer_fixture; + let writer = Arc::new(writer); + let record_count = 20; - let writer = Arc::new(writer); - let record_count = 20; + let splits = ["train", "test"]; - let splits = ["train", "test"]; + (0..record_count).into_par_iter().for_each(|index: i64| { + let thread_id: std::thread::ThreadId = std::thread::current().id(); + let sample = Complex { + column_str: format!("test_{:?}_{}", thread_id, index), + column_bytes: vec![index as u8, 2, 3], + column_int: index, + column_bool: true, + column_float: 1.0, + column_complex: vec![vec![vec![[1, index as u8, 3]]]], + }; - (0..record_count).into_par_iter().for_each(|index: i64| { - let thread_id: std::thread::ThreadId = std::thread::current().id(); - let sample = Complex { - column_str: format!("test_{:?}_{}", thread_id, index), - column_bytes: vec![index as u8, 2, 3], - column_int: index, - column_bool: true, - column_float: 1.0, - column_complex: vec![vec![vec![[1, index as u8, 3]]]], - }; + // half for train and half for test + let split = splits[index as usize % 2]; - // half for train and half for test - let split = splits[index as usize % 2]; + let _index = writer.write(split, &sample).unwrap(); + }); - let _index = writer.write(split, &sample).unwrap(); - }); + let mut writer = Arc::try_unwrap(writer).unwrap(); - let mut writer = Arc::try_unwrap(writer).unwrap(); + writer + .set_completed() + .expect("Should set completed successfully"); - writer - .set_completed() - .expect("Should set completed successfully"); + let train = SqliteDataset::::from_db_file(writer.db_file.clone(), "train").unwrap(); + let test = SqliteDataset::::from_db_file(writer.db_file, "test").unwrap(); - let train = - SqliteDataset::::from_db_file(writer.db_file.clone(), "train").unwrap(); - let test = SqliteDataset::::from_db_file(writer.db_file, "test").unwrap(); - - assert_eq!(train.len(), record_count as usize / 2); - assert_eq!(test.len(), record_count as usize / 2); - } + assert_eq!(train.len(), record_count as usize / 2); + assert_eq!(test.len(), record_count as usize / 2); + } } diff --git a/burn-dataset/src/lib.rs b/burn-dataset/src/lib.rs index 1d8045878a..1719a87de2 100644 --- a/burn-dataset/src/lib.rs +++ b/burn-dataset/src/lib.rs @@ -26,12 +26,12 @@ pub use source::huggingface::downloader::*; #[cfg(test)] mod test_data { - pub fn string_items() -> Vec { - vec![ - "1 Item".to_string(), - "2 Items".to_string(), - "3 Items".to_string(), - "4 Items".to_string(), - ] - } + pub fn string_items() -> Vec { + vec![ + "1 Item".to_string(), + "2 Items".to_string(), + "3 Items".to_string(), + "4 Items".to_string(), + ] + } } diff --git a/burn-dataset/src/source/huggingface/downloader.rs b/burn-dataset/src/source/huggingface/downloader.rs index 9d3ef277dc..b6a8121b20 100644 --- a/burn-dataset/src/source/huggingface/downloader.rs +++ b/burn-dataset/src/source/huggingface/downloader.rs @@ -17,25 +17,25 @@ const VENV_BIN_PYTHON: &str = "Scripts\\python"; /// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader). #[derive(Error, Debug)] pub enum ImporterError { - /// Unknown error. - #[error("unknown: `{0}`")] - Unknown(String), + /// Unknown error. + #[error("unknown: `{0}`")] + Unknown(String), - /// Fail to download python dependencies. - #[error("fail to download python dependencies: `{0}`")] - FailToDownloadPythonDependencies(String), + /// Fail to download python dependencies. + #[error("fail to download python dependencies: `{0}`")] + FailToDownloadPythonDependencies(String), - /// Fail to create sqlite dataset. - #[error("sqlite dataset: `{0}`")] - SqliteDataset(#[from] SqliteDatasetError), + /// Fail to create sqlite dataset. + #[error("sqlite dataset: `{0}`")] + SqliteDataset(#[from] SqliteDatasetError), - /// python3 is not installed. - #[error("python3 is not installed")] - PythonNotInstalled, + /// python3 is not installed. + #[error("python3 is not installed")] + PythonNotInstalled, - /// venv environment is not initialized. - #[error("venv environment is not initialized")] - VenvNotInitialized, + /// venv environment is not initialized. + #[error("venv environment is not initialized")] + VenvNotInitialized, } /// Load a dataset from [huggingface datasets](https://huggingface.co/datasets). @@ -58,237 +58,237 @@ pub enum ImporterError { /// .dataset("train") /// .unwrap(); pub struct HuggingfaceDatasetLoader { - name: String, - subset: Option, - base_dir: Option, - huggingface_token: Option, - huggingface_cache_dir: Option, + name: String, + subset: Option, + base_dir: Option, + huggingface_token: Option, + huggingface_cache_dir: Option, } impl HuggingfaceDatasetLoader { - /// Create a huggingface dataset loader. - pub fn new(name: &str) -> Self { - Self { - name: name.to_string(), - subset: None, - base_dir: None, - huggingface_token: None, - huggingface_cache_dir: None, - } - } - - /// Create a huggingface dataset loader for a subset of the dataset. - /// - /// The subset name must be one of the subsets listed in the dataset page. - /// - /// If no subset names are listed, then do not use this method. - pub fn with_subset(mut self, subset: &str) -> Self { - self.subset = Some(subset.to_string()); - self - } - - /// Specify a base directory to store the dataset. - /// - /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`. - pub fn with_base_dir(mut self, base_dir: &str) -> Self { - self.base_dir = Some(base_dir.into()); - self - } - - /// Specify a huggingface token to download datasets behind authentication. - /// - /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens) - pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self { - self.huggingface_token = Some(huggingface_token.to_string()); - self + /// Create a huggingface dataset loader. + pub fn new(name: &str) -> Self { + Self { + name: name.to_string(), + subset: None, + base_dir: None, + huggingface_token: None, + huggingface_cache_dir: None, } - - /// Specify a huggingface cache directory to store the downloaded datasets. - /// - /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`. - pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self { - self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string()); - self + } + + /// Create a huggingface dataset loader for a subset of the dataset. + /// + /// The subset name must be one of the subsets listed in the dataset page. + /// + /// If no subset names are listed, then do not use this method. + pub fn with_subset(mut self, subset: &str) -> Self { + self.subset = Some(subset.to_string()); + self + } + + /// Specify a base directory to store the dataset. + /// + /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`. + pub fn with_base_dir(mut self, base_dir: &str) -> Self { + self.base_dir = Some(base_dir.into()); + self + } + + /// Specify a huggingface token to download datasets behind authentication. + /// + /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens) + pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self { + self.huggingface_token = Some(huggingface_token.to_string()); + self + } + + /// Specify a huggingface cache directory to store the downloaded datasets. + /// + /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`. + pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self { + self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string()); + self + } + + /// Load the dataset. + pub fn dataset( + self, + split: &str, + ) -> Result, ImporterError> { + let db_file = self.db_file()?; + let dataset = SqliteDataset::from_db_file(db_file, split)?; + Ok(dataset) + } + + /// Get the path to the sqlite database file. + /// + /// If the database file does not exist, it will be downloaded and imported. + pub fn db_file(self) -> Result { + // determine (and create if needed) the base directory + let base_dir = SqliteDatasetStorage::base_dir(self.base_dir); + + if !base_dir.exists() { + create_dir_all(&base_dir).expect("Failed to create base directory"); } - /// Load the dataset. - pub fn dataset( - self, - split: &str, - ) -> Result, ImporterError> { - let db_file = self.db_file()?; - let dataset = SqliteDataset::from_db_file(db_file, split)?; - Ok(dataset) + //sanitize the name and subset + let name = sanitize(self.name.as_str()); + + // create the db file path + let db_file_name = if let Some(subset) = self.subset.clone() { + format!("{}-{}.db", name, sanitize(subset.as_str())) + } else { + format!("{}.db", name) + }; + + let db_file = base_dir.join(db_file_name); + + // import the dataset if needed + if !Path::new(&db_file).exists() { + import( + self.name, + self.subset, + db_file.clone(), + base_dir, + self.huggingface_token, + self.huggingface_cache_dir, + )?; } - /// Get the path to the sqlite database file. - /// - /// If the database file does not exist, it will be downloaded and imported. - pub fn db_file(self) -> Result { - // determine (and create if needed) the base directory - let base_dir = SqliteDatasetStorage::base_dir(self.base_dir); - - if !base_dir.exists() { - create_dir_all(&base_dir).expect("Failed to create base directory"); - } - - //sanitize the name and subset - let name = sanitize(self.name.as_str()); - - // create the db file path - let db_file_name = if let Some(subset) = self.subset.clone() { - format!("{}-{}.db", name, sanitize(subset.as_str())) - } else { - format!("{}.db", name) - }; - - let db_file = base_dir.join(db_file_name); - - // import the dataset if needed - if !Path::new(&db_file).exists() { - import( - self.name, - self.subset, - db_file.clone(), - base_dir, - self.huggingface_token, - self.huggingface_cache_dir, - )?; - } - - Ok(db_file) - } + Ok(db_file) + } } /// Import a dataset from huggingface. The transformed dataset is stored as sqlite database. fn import( - name: String, - subset: Option, - base_file: PathBuf, - base_dir: PathBuf, - huggingface_token: Option, - huggingface_cache_dir: Option, + name: String, + subset: Option, + base_file: PathBuf, + base_dir: PathBuf, + huggingface_token: Option, + huggingface_cache_dir: Option, ) -> Result<(), ImporterError> { - let venv_python_path = install_python_deps(&base_dir)?; + let venv_python_path = install_python_deps(&base_dir)?; - let mut command = Command::new(venv_python_path); + let mut command = Command::new(venv_python_path); - command.arg(importer_script_path(&base_dir)); + command.arg(importer_script_path(&base_dir)); - command.arg("--name"); - command.arg(name); + command.arg("--name"); + command.arg(name); - command.arg("--file"); - command.arg(base_file); + command.arg("--file"); + command.arg(base_file); - if let Some(subset) = subset { - command.arg("--subset"); - command.arg(subset); - } + if let Some(subset) = subset { + command.arg("--subset"); + command.arg(subset); + } - if let Some(huggingface_token) = huggingface_token { - command.arg("--token"); - command.arg(huggingface_token); - } + if let Some(huggingface_token) = huggingface_token { + command.arg("--token"); + command.arg(huggingface_token); + } - if let Some(huggingface_cache_dir) = huggingface_cache_dir { - command.arg("--cache_dir"); - command.arg(huggingface_cache_dir); - } + if let Some(huggingface_cache_dir) = huggingface_cache_dir { + command.arg("--cache_dir"); + command.arg(huggingface_cache_dir); + } - let mut handle = command.spawn().unwrap(); - handle - .wait() - .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?; + let mut handle = command.spawn().unwrap(); + handle + .wait() + .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?; - Ok(()) + Ok(()) } /// check python --version output is `Python 3.x.x` fn check_python_version_is_3(python: &str) -> bool { - let output = Command::new(python).arg("--version").output(); - match output { - Ok(output) => { - if output.status.success() { - let version_string = String::from_utf8_lossy(&output.stdout); - if let Some(index) = version_string.find(' ') { - let version = &version_string[index + 1..]; - version.starts_with("3.") - } else { - false - } - } else { - false - } + let output = Command::new(python).arg("--version").output(); + match output { + Ok(output) => { + if output.status.success() { + let version_string = String::from_utf8_lossy(&output.stdout); + if let Some(index) = version_string.find(' ') { + let version = &version_string[index + 1..]; + version.starts_with("3.") + } else { + false } - Err(_error) => false, + } else { + false + } } + Err(_error) => false, + } } /// get python3 name `python` `python3` or `py` fn get_python_name() -> Result<&'static str, ImporterError> { - let python_name_list = ["python3", "python", "py"]; - for python_name in python_name_list.iter() { - if check_python_version_is_3(python_name) { - return Ok(python_name); - } + let python_name_list = ["python3", "python", "py"]; + for python_name in python_name_list.iter() { + if check_python_version_is_3(python_name) { + return Ok(python_name); } - Err(ImporterError::PythonNotInstalled) + } + Err(ImporterError::PythonNotInstalled) } fn importer_script_path(base_dir: &Path) -> PathBuf { - let path_file = base_dir.join("importer.py"); + let path_file = base_dir.join("importer.py"); - fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader"); - path_file + fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader"); + path_file } fn install_python_deps(base_dir: &Path) -> Result { - let venv_dir = base_dir.join("venv"); - let venv_python_path = venv_dir.join(VENV_BIN_PYTHON); - // If the venv environment is already initialized, skip the initialization. - if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { - let python_name = get_python_name()?; - let mut command = Command::new(python_name); - command.args([ - "-m", - "venv", - venv_dir - .as_os_str() - .to_str() - .expect("Path utf8 conversion should not fail"), - ]); - - // Spawn the venv creation process and wait for it to complete. - let mut handle = command.spawn().unwrap(); - - handle.wait().map_err(|err| { - ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)) - })?; - // Check if the venv environment can be used successfully." - if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { - return Err(ImporterError::VenvNotInitialized); - } - } - - let mut command = Command::new(&venv_python_path); + let venv_dir = base_dir.join("venv"); + let venv_python_path = venv_dir.join(VENV_BIN_PYTHON); + // If the venv environment is already initialized, skip the initialization. + if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { + let python_name = get_python_name()?; + let mut command = Command::new(python_name); command.args([ - "-m", - "pip", - "--quiet", - "install", - "pyarrow", - "sqlalchemy", - "Pillow", - "soundfile", - "datasets", + "-m", + "venv", + venv_dir + .as_os_str() + .to_str() + .expect("Path utf8 conversion should not fail"), ]); - // Spawn the pip install process and wait for it to complete. + // Spawn the venv creation process and wait for it to complete. let mut handle = command.spawn().unwrap(); - handle.wait().map_err(|err| { - ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)) - })?; - Ok(venv_python_path) + handle + .wait() + .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)))?; + // Check if the venv environment can be used successfully." + if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { + return Err(ImporterError::VenvNotInitialized); + } + } + + let mut command = Command::new(&venv_python_path); + command.args([ + "-m", + "pip", + "--quiet", + "install", + "pyarrow", + "sqlalchemy", + "Pillow", + "soundfile", + "datasets", + ]); + + // Spawn the pip install process and wait for it to complete. + let mut handle = command.spawn().unwrap(); + handle + .wait() + .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)))?; + + Ok(venv_python_path) } diff --git a/burn-dataset/src/source/huggingface/mnist.rs b/burn-dataset/src/source/huggingface/mnist.rs index 88b37180ac..6126141dbc 100644 --- a/burn-dataset/src/source/huggingface/mnist.rs +++ b/burn-dataset/src/source/huggingface/mnist.rs @@ -11,43 +11,43 @@ const HEIGHT: usize = 28; /// MNIST item. #[derive(Deserialize, Serialize, Debug, Clone)] pub struct MNISTItem { - /// Image as a 2D array of floats. - pub image: [[f32; WIDTH]; HEIGHT], + /// Image as a 2D array of floats. + pub image: [[f32; WIDTH]; HEIGHT], - /// Label of the image. - pub label: usize, + /// Label of the image. + pub label: usize, } #[derive(Deserialize, Debug, Clone)] struct MNISTItemRaw { - pub image_bytes: Vec, - pub label: usize, + pub image_bytes: Vec, + pub label: usize, } struct BytesToImage; impl Mapper for BytesToImage { - /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). - fn map(&self, item: &MNISTItemRaw) -> MNISTItem { - let image = image::load_from_memory(&item.image_bytes).unwrap(); - let image = image.as_luma8().unwrap(); - - // Ensure the image dimensions are correct. - debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32)); - - // Convert the image to a 2D array of floats. - let mut image_array = [[0f32; WIDTH]; HEIGHT]; - for (i, pixel) in image.as_raw().iter().enumerate() { - let x = i % WIDTH; - let y = i / HEIGHT; - image_array[y][x] = *pixel as f32; - } - - MNISTItem { - image: image_array, - label: item.label, - } + /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). + fn map(&self, item: &MNISTItemRaw) -> MNISTItem { + let image = image::load_from_memory(&item.image_bytes).unwrap(); + let image = image.as_luma8().unwrap(); + + // Ensure the image dimensions are correct. + debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32)); + + // Convert the image to a 2D array of floats. + let mut image_array = [[0f32; WIDTH]; HEIGHT]; + for (i, pixel) in image.as_raw().iter().enumerate() { + let x = i % WIDTH; + let y = i / HEIGHT; + image_array[y][x] = *pixel as f32; } + + MNISTItem { + image: image_array, + label: item.label, + } + } } type MappedDataset = MapperDataset, BytesToImage, MNISTItemRaw>; @@ -56,37 +56,37 @@ type MappedDataset = MapperDataset, BytesToImage, MN /// /// The data is downloaded from Huggingface and stored in a SQLite database. pub struct MNISTDataset { - dataset: MappedDataset, + dataset: MappedDataset, } impl Dataset for MNISTDataset { - fn get(&self, index: usize) -> Option { - self.dataset.get(index) - } + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } impl MNISTDataset { - /// Creates a new train dataset. - pub fn train() -> Self { - Self::new("train") - } + /// Creates a new train dataset. + pub fn train() -> Self { + Self::new("train") + } - /// Creates a new test dataset. - pub fn test() -> Self { - Self::new("test") - } + /// Creates a new test dataset. + pub fn test() -> Self { + Self::new("test") + } - fn new(split: &str) -> Self { - let dataset = HuggingfaceDatasetLoader::new("mnist") - .dataset(split) - .unwrap(); + fn new(split: &str) -> Self { + let dataset = HuggingfaceDatasetLoader::new("mnist") + .dataset(split) + .unwrap(); - let dataset = MapperDataset::new(dataset, BytesToImage); + let dataset = MapperDataset::new(dataset, BytesToImage); - Self { dataset } - } + Self { dataset } + } } diff --git a/burn-dataset/src/transform/composed.rs b/burn-dataset/src/transform/composed.rs index 8f26bd5976..0903e7059e 100644 --- a/burn-dataset/src/transform/composed.rs +++ b/burn-dataset/src/transform/composed.rs @@ -3,29 +3,29 @@ use crate::Dataset; /// Compose multiple datasets together to create a bigger one. #[derive(new)] pub struct ComposedDataset { - datasets: Vec, + datasets: Vec, } impl Dataset for ComposedDataset where - D: Dataset, - I: Clone, + D: Dataset, + I: Clone, { - fn get(&self, index: usize) -> Option { - let mut current_index = 0; - for dataset in self.datasets.iter() { - if index < dataset.len() + current_index { - return dataset.get(index - current_index); - } - current_index += dataset.len(); - } - None + fn get(&self, index: usize) -> Option { + let mut current_index = 0; + for dataset in self.datasets.iter() { + if index < dataset.len() + current_index { + return dataset.get(index - current_index); + } + current_index += dataset.len(); } - fn len(&self) -> usize { - let mut total = 0; - for dataset in self.datasets.iter() { - total += dataset.len(); - } - total + None + } + fn len(&self) -> usize { + let mut total = 0; + for dataset in self.datasets.iter() { + total += dataset.len(); } + total + } } diff --git a/burn-dataset/src/transform/mapper.rs b/burn-dataset/src/transform/mapper.rs index b089a375ed..1cf39ea2aa 100644 --- a/burn-dataset/src/transform/mapper.rs +++ b/burn-dataset/src/transform/mapper.rs @@ -3,58 +3,58 @@ use std::marker::PhantomData; /// Basic mapper trait to be used with the [mapper dataset](MapperDataset). pub trait Mapper: Send + Sync { - /// Maps an item of type I to an item of type O. - fn map(&self, item: &I) -> O; + /// Maps an item of type I to an item of type O. + fn map(&self, item: &I) -> O; } /// Dataset mapping each element in an inner dataset to another element type lazily. #[derive(new)] pub struct MapperDataset { - dataset: D, - mapper: M, - input: PhantomData, + dataset: D, + mapper: M, + input: PhantomData, } impl Dataset for MapperDataset where - D: Dataset, - M: Mapper + Send + Sync, - I: Send + Sync, - O: Send + Sync, + D: Dataset, + M: Mapper + Send + Sync, + I: Send + Sync, + O: Send + Sync, { - fn get(&self, index: usize) -> Option { - let item = self.dataset.get(index); - item.map(|item| self.mapper.map(&item)) - } - - fn len(&self) -> usize { - self.dataset.len() - } + fn get(&self, index: usize) -> Option { + let item = self.dataset.get(index); + item.map(|item| self.mapper.map(&item)) + } + + fn len(&self) -> usize { + self.dataset.len() + } } #[cfg(test)] mod tests { - use super::*; - use crate::{test_data, InMemDataset}; - - #[test] - pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() { - struct StringToFirstChar; - - impl Mapper for StringToFirstChar { - fn map(&self, item: &String) -> String { - let mut item = item.clone(); - item.truncate(1); - item - } - } + use super::*; + use crate::{test_data, InMemDataset}; + + #[test] + pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() { + struct StringToFirstChar; + + impl Mapper for StringToFirstChar { + fn map(&self, item: &String) -> String { + let mut item = item.clone(); + item.truncate(1); + item + } + } - let items_original = test_data::string_items(); - let dataset = InMemDataset::new(items_original); - let dataset = MapperDataset::new(dataset, StringToFirstChar); + let items_original = test_data::string_items(); + let dataset = InMemDataset::new(items_original); + let dataset = MapperDataset::new(dataset, StringToFirstChar); - let items: Vec = dataset.iter().collect(); + let items: Vec = dataset.iter().collect(); - assert_eq!(vec!["1", "2", "3", "4"], items); - } + assert_eq!(vec!["1", "2", "3", "4"], items); + } } diff --git a/burn-dataset/src/transform/partial.rs b/burn-dataset/src/transform/partial.rs index c8bd53f08b..2cc018b3ee 100644 --- a/burn-dataset/src/transform/partial.rs +++ b/burn-dataset/src/transform/partial.rs @@ -4,136 +4,136 @@ use std::{marker::PhantomData, sync::Arc}; /// Only use a fraction of an existing dataset lazily. #[derive(new)] pub struct PartialDataset { - dataset: D, - start_index: usize, - end_index: usize, - input: PhantomData, + dataset: D, + start_index: usize, + end_index: usize, + input: PhantomData, } impl PartialDataset where - D: Dataset, + D: Dataset, { - /// Splits a dataset into multiple partial datasets. - pub fn split(dataset: D, num: usize) -> Vec, I>> { - let dataset = Arc::new(dataset); // cheap cloning. + /// Splits a dataset into multiple partial datasets. + pub fn split(dataset: D, num: usize) -> Vec, I>> { + let dataset = Arc::new(dataset); // cheap cloning. - let mut current = 0; - let mut datasets = Vec::with_capacity(num); + let mut current = 0; + let mut datasets = Vec::with_capacity(num); - let batch_size = dataset.len() / num; + let batch_size = dataset.len() / num; - for i in 0..num { - let start = current; - let mut end = current + batch_size; + for i in 0..num { + let start = current; + let mut end = current + batch_size; - if i == (num - 1) { - end = dataset.len(); - } + if i == (num - 1) { + end = dataset.len(); + } - let dataset = PartialDataset::new(dataset.clone(), start, end); + let dataset = PartialDataset::new(dataset.clone(), start, end); - current += batch_size; - datasets.push(dataset); - } - - datasets + current += batch_size; + datasets.push(dataset); } + + datasets + } } impl Dataset for PartialDataset where - D: Dataset, - I: Clone + Send + Sync, + D: Dataset, + I: Clone + Send + Sync, { - fn get(&self, index: usize) -> Option { - let index = index + self.start_index; - if index < self.start_index || index >= self.end_index { - return None; - } - self.dataset.get(index) + fn get(&self, index: usize) -> Option { + let index = index + self.start_index; + if index < self.start_index || index >= self.end_index { + return None; } + self.dataset.get(index) + } - fn len(&self) -> usize { - usize::min(self.end_index - self.start_index, self.dataset.len()) - } + fn len(&self) -> usize { + usize::min(self.end_index - self.start_index, self.dataset.len()) + } } #[cfg(test)] mod tests { - use super::*; - use crate::FakeDataset; - use std::collections::HashSet; - - #[test] - fn test_start_from_beginning() { - let dataset_original = FakeDataset::::new(27); - let mut items_original_1 = HashSet::new(); - let mut items_original_2 = HashSet::new(); - let mut items_partial = HashSet::new(); - dataset_original.iter().enumerate().for_each(|(i, item)| { - match i >= 10 { - true => items_original_2.insert(item), - false => items_original_1.insert(item), - }; - }); - - let dataset_partial = PartialDataset::new(dataset_original, 0, 10); - - for item in dataset_partial.iter() { - items_partial.insert(item); - } - - assert_eq!(dataset_partial.len(), 10); - assert_eq!(items_original_1, items_partial); - for item in items_original_2 { - assert!(!items_partial.contains(&item)); - } + use super::*; + use crate::FakeDataset; + use std::collections::HashSet; + + #[test] + fn test_start_from_beginning() { + let dataset_original = FakeDataset::::new(27); + let mut items_original_1 = HashSet::new(); + let mut items_original_2 = HashSet::new(); + let mut items_partial = HashSet::new(); + dataset_original.iter().enumerate().for_each(|(i, item)| { + match i >= 10 { + true => items_original_2.insert(item), + false => items_original_1.insert(item), + }; + }); + + let dataset_partial = PartialDataset::new(dataset_original, 0, 10); + + for item in dataset_partial.iter() { + items_partial.insert(item); } - #[test] - fn test_start_inside() { - let dataset_original = FakeDataset::::new(27); - let mut items_original_1 = HashSet::new(); - let mut items_original_2 = HashSet::new(); - let mut items_partial = HashSet::new(); - - dataset_original.iter().enumerate().for_each(|(i, item)| { - match !(10..20).contains(&i) { - true => items_original_2.insert(item), - false => items_original_1.insert(item), - }; - }); - - let dataset_partial = PartialDataset::new(dataset_original, 10, 20); - for item in dataset_partial.iter() { - items_partial.insert(item); - } - - assert_eq!(dataset_partial.len(), 10); - assert_eq!(items_original_1, items_partial); - for item in items_original_2 { - assert!(!items_partial.contains(&item)); - } + assert_eq!(dataset_partial.len(), 10); + assert_eq!(items_original_1, items_partial); + for item in items_original_2 { + assert!(!items_partial.contains(&item)); + } + } + + #[test] + fn test_start_inside() { + let dataset_original = FakeDataset::::new(27); + let mut items_original_1 = HashSet::new(); + let mut items_original_2 = HashSet::new(); + let mut items_partial = HashSet::new(); + + dataset_original.iter().enumerate().for_each(|(i, item)| { + match !(10..20).contains(&i) { + true => items_original_2.insert(item), + false => items_original_1.insert(item), + }; + }); + + let dataset_partial = PartialDataset::new(dataset_original, 10, 20); + for item in dataset_partial.iter() { + items_partial.insert(item); } - #[test] - fn test_split_contains_all_items_without_duplicates() { - let dataset_original = FakeDataset::::new(27); - let mut items_original = Vec::new(); - let mut items_partial = Vec::new(); - for item in dataset_original.iter() { - items_original.push(item); - } - - let dataset_partials = PartialDataset::split(dataset_original, 4); + assert_eq!(dataset_partial.len(), 10); + assert_eq!(items_original_1, items_partial); + for item in items_original_2 { + assert!(!items_partial.contains(&item)); + } + } + + #[test] + fn test_split_contains_all_items_without_duplicates() { + let dataset_original = FakeDataset::::new(27); + let mut items_original = Vec::new(); + let mut items_partial = Vec::new(); + for item in dataset_original.iter() { + items_original.push(item); + } - for dataset in dataset_partials { - for item in dataset.iter() { - items_partial.push(item); - } - } + let dataset_partials = PartialDataset::split(dataset_original, 4); - assert_eq!(items_original, items_partial); + for dataset in dataset_partials { + for item in dataset.iter() { + items_partial.push(item); + } } + + assert_eq!(items_original, items_partial); + } } diff --git a/burn-dataset/src/transform/random.rs b/burn-dataset/src/transform/random.rs index 5b9de9d8d3..0dd79754bc 100644 --- a/burn-dataset/src/transform/random.rs +++ b/burn-dataset/src/transform/random.rs @@ -5,51 +5,51 @@ use std::marker::PhantomData; /// Shuffled a dataset, consider using [sampler dataset](crate::transform::SamplerDataset) is you /// want a probability distribution that is computed lazily. pub struct ShuffledDataset { - dataset: D, - indices: Vec, - input: PhantomData, + dataset: D, + indices: Vec, + input: PhantomData, } impl ShuffledDataset where - D: Dataset, + D: Dataset, { - /// Creates a new shuffled dataset. - pub fn new(dataset: D, rng: &mut StdRng) -> Self { - let mut indices = Vec::with_capacity(dataset.len()); - for i in 0..dataset.len() { - indices.push(i); - } - indices.shuffle(rng); - - Self { - dataset, - indices, - input: PhantomData, - } + /// Creates a new shuffled dataset. + pub fn new(dataset: D, rng: &mut StdRng) -> Self { + let mut indices = Vec::with_capacity(dataset.len()); + for i in 0..dataset.len() { + indices.push(i); } + indices.shuffle(rng); - /// Creates a new shuffled dataset with a fixed seed. - pub fn with_seed(dataset: D, seed: u64) -> Self { - let mut rng = StdRng::seed_from_u64(seed); - Self::new(dataset, &mut rng) + Self { + dataset, + indices, + input: PhantomData, } + } + + /// Creates a new shuffled dataset with a fixed seed. + pub fn with_seed(dataset: D, seed: u64) -> Self { + let mut rng = StdRng::seed_from_u64(seed); + Self::new(dataset, &mut rng) + } } impl Dataset for ShuffledDataset where - D: Dataset, - I: Clone + Send + Sync, + D: Dataset, + I: Clone + Send + Sync, { - fn get(&self, index: usize) -> Option { - let index = match self.indices.get(index) { - Some(index) => index, - None => return None, - }; - self.dataset.get(*index) - } + fn get(&self, index: usize) -> Option { + let index = match self.indices.get(index) { + Some(index) => index, + None => return None, + }; + self.dataset.get(*index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } diff --git a/burn-dataset/src/transform/sampler.rs b/burn-dataset/src/transform/sampler.rs index b3c0077ab8..69b6c1639e 100644 --- a/burn-dataset/src/transform/sampler.rs +++ b/burn-dataset/src/transform/sampler.rs @@ -15,132 +15,132 @@ use std::{marker::PhantomData, ops::DerefMut, sync::Mutex}; /// set the dataset to an arbitrary size. Once every item has been used, a new cycle is /// created with a new random suffle. pub struct SamplerDataset { - dataset: D, - size: usize, - state: Mutex, - input: PhantomData, + dataset: D, + size: usize, + state: Mutex, + input: PhantomData, } enum SamplerState { - WithReplacement(StdRng), - WithoutReplacement(StdRng, Vec), + WithReplacement(StdRng), + WithoutReplacement(StdRng, Vec), } impl SamplerDataset where - D: Dataset, - I: Send + Sync, + D: Dataset, + I: Send + Sync, { - /// Creates a new sampler dataset with replacement. - pub fn new(dataset: D, size: usize) -> Self { - Self { - dataset, - size, - state: Mutex::new(SamplerState::WithReplacement(StdRng::from_entropy())), - input: PhantomData, - } - } - - /// Creates a new sampler dataset with replacement. - pub fn with_replacement(dataset: D, size: usize) -> Self { - Self::new(dataset, size) + /// Creates a new sampler dataset with replacement. + pub fn new(dataset: D, size: usize) -> Self { + Self { + dataset, + size, + state: Mutex::new(SamplerState::WithReplacement(StdRng::from_entropy())), + input: PhantomData, } - - /// Creates a new sampler dataset without replacement. - pub fn without_replacement(dataset: D, size: usize) -> Self { - Self { - dataset, - size, - state: Mutex::new(SamplerState::WithoutReplacement( - StdRng::from_entropy(), - Vec::new(), - )), - input: PhantomData, - } + } + + /// Creates a new sampler dataset with replacement. + pub fn with_replacement(dataset: D, size: usize) -> Self { + Self::new(dataset, size) + } + + /// Creates a new sampler dataset without replacement. + pub fn without_replacement(dataset: D, size: usize) -> Self { + Self { + dataset, + size, + state: Mutex::new(SamplerState::WithoutReplacement( + StdRng::from_entropy(), + Vec::new(), + )), + input: PhantomData, } + } - fn index(&self) -> usize { - let mut state = self.state.lock().unwrap(); - - match state.deref_mut() { - SamplerState::WithReplacement(rng) => rng.sample(Uniform::new(0, self.dataset.len())), - SamplerState::WithoutReplacement(rng, indices) => { - if indices.is_empty() { - // Refill the state. - *indices = (0..self.dataset.len()).choose_multiple(rng, self.dataset.len()); - } + fn index(&self) -> usize { + let mut state = self.state.lock().unwrap(); - indices.pop().expect("Indices are refilled when empty.") - } + match state.deref_mut() { + SamplerState::WithReplacement(rng) => rng.sample(Uniform::new(0, self.dataset.len())), + SamplerState::WithoutReplacement(rng, indices) => { + if indices.is_empty() { + // Refill the state. + *indices = (0..self.dataset.len()).choose_multiple(rng, self.dataset.len()); } + + indices.pop().expect("Indices are refilled when empty.") + } } + } } impl Dataset for SamplerDataset where - D: Dataset, - I: Send + Sync, + D: Dataset, + I: Send + Sync, { - fn get(&self, index: usize) -> Option { - if index >= self.size { - return None; - } - - self.dataset.get(self.index()) + fn get(&self, index: usize) -> Option { + if index >= self.size { + return None; } - fn len(&self) -> usize { - self.size - } + self.dataset.get(self.index()) + } + + fn len(&self) -> usize { + self.size + } } #[cfg(test)] mod tests { - use super::*; - use crate::FakeDataset; - use std::collections::HashMap; - - #[test] - fn sampler_dataset_with_replacement_iter() { - let factor = 3; - let len_original = 10; - let dataset_sampler = SamplerDataset::with_replacement( - FakeDataset::::new(len_original), - len_original * factor, - ); - let mut total = 0; - - for _item in dataset_sampler.iter() { - total += 1; - } - - assert_eq!(total, factor * len_original); + use super::*; + use crate::FakeDataset; + use std::collections::HashMap; + + #[test] + fn sampler_dataset_with_replacement_iter() { + let factor = 3; + let len_original = 10; + let dataset_sampler = SamplerDataset::with_replacement( + FakeDataset::::new(len_original), + len_original * factor, + ); + let mut total = 0; + + for _item in dataset_sampler.iter() { + total += 1; } - #[test] - fn sampler_dataset_without_replacement_bucket_test() { - let factor = 3; - let len_original = 10; - let dataset_sampler = SamplerDataset::without_replacement( - FakeDataset::::new(len_original), - len_original * factor, - ); - let mut buckets = HashMap::new(); - - for item in dataset_sampler.iter() { - let count = match buckets.get(&item) { - Some(count) => count + 1, - None => 1, - }; - - buckets.insert(item, count); - } + assert_eq!(total, factor * len_original); + } + + #[test] + fn sampler_dataset_without_replacement_bucket_test() { + let factor = 3; + let len_original = 10; + let dataset_sampler = SamplerDataset::without_replacement( + FakeDataset::::new(len_original), + len_original * factor, + ); + let mut buckets = HashMap::new(); + + for item in dataset_sampler.iter() { + let count = match buckets.get(&item) { + Some(count) => count + 1, + None => 1, + }; + + buckets.insert(item, count); + } - let mut total = 0; - for count in buckets.into_values() { - assert_eq!(count, factor); - total += count; - } - assert_eq!(total, factor * len_original); + let mut total = 0; + for count in buckets.into_values() { + assert_eq!(count, factor); + total += count; } + assert_eq!(total, factor * len_original); + } } diff --git a/burn-derive/src/config/analyzer.rs b/burn-derive/src/config/analyzer.rs index e5e628585c..af55e06b1e 100644 --- a/burn-derive/src/config/analyzer.rs +++ b/burn-derive/src/config/analyzer.rs @@ -8,80 +8,80 @@ use syn::{Field, Ident}; pub struct ConfigAnalyzerFactory {} pub trait ConfigAnalyzer { - fn gen_new_fn(&self) -> TokenStream { - quote! {} - } - fn gen_builder_fns(&self) -> TokenStream { - quote! {} - } - fn gen_serde_impl(&self) -> TokenStream; - fn gen_clone_impl(&self) -> TokenStream; - fn gen_display_impl(&self) -> TokenStream; - fn gen_config_impl(&self) -> TokenStream; + fn gen_new_fn(&self) -> TokenStream { + quote! {} + } + fn gen_builder_fns(&self) -> TokenStream { + quote! {} + } + fn gen_serde_impl(&self) -> TokenStream; + fn gen_clone_impl(&self) -> TokenStream; + fn gen_display_impl(&self) -> TokenStream; + fn gen_config_impl(&self) -> TokenStream; } impl ConfigAnalyzerFactory { - pub fn new() -> Self { - Self {} - } + pub fn new() -> Self { + Self {} + } - pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box { - let name = item.ident.clone(); - let config_type = parse_asm(item); + pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box { + let name = item.ident.clone(); + let config_type = parse_asm(item); - match config_type { - ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)), - ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)), - } + match config_type { + ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)), + ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)), } + } - fn create_struct_analyzer(&self, name: Ident, fields: Vec) -> ConfigStructAnalyzer { - let fields = fields.into_iter().map(FieldTypeAnalyzer::new); + fn create_struct_analyzer(&self, name: Ident, fields: Vec) -> ConfigStructAnalyzer { + let fields = fields.into_iter().map(FieldTypeAnalyzer::new); - let mut fields_required = Vec::new(); - let mut fields_option = Vec::new(); - let mut fields_default = Vec::new(); + let mut fields_required = Vec::new(); + let mut fields_option = Vec::new(); + let mut fields_default = Vec::new(); - for field in fields { - let attributes: Vec = field - .attributes() - .filter(|attr| attr.has_name("config")) - .map(|attr| attr.item()) - .collect(); + for field in fields { + let attributes: Vec = field + .attributes() + .filter(|attr| attr.has_name("config")) + .map(|attr| attr.item()) + .collect(); - if !attributes.is_empty() { - let item = attributes.first().unwrap().clone(); - fields_default.push((field.clone(), item)); - continue; - } + if !attributes.is_empty() { + let item = attributes.first().unwrap().clone(); + fields_default.push((field.clone(), item)); + continue; + } - if field.is_of_type(&["Option"]) { - fields_option.push(field.clone()); - continue; - } + if field.is_of_type(&["Option"]) { + fields_option.push(field.clone()); + continue; + } - fields_required.push(field.clone()); - } - - ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default) + fields_required.push(field.clone()); } - fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer { - ConfigEnumAnalyzer::new(name, data) - } + ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default) + } + + fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer { + ConfigEnumAnalyzer::new(name, data) + } } enum ConfigType { - Struct(Vec), - Enum(syn::DataEnum), + Struct(Vec), + Enum(syn::DataEnum), } fn parse_asm(ast: &syn::DeriveInput) -> ConfigType { - match &ast.data { - syn::Data::Struct(struct_data) => { - ConfigType::Struct(struct_data.fields.clone().into_iter().collect()) - } - syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()), - syn::Data::Union(_) => panic!("Only struct and enum can be derived"), + match &ast.data { + syn::Data::Struct(struct_data) => { + ConfigType::Struct(struct_data.fields.clone().into_iter().collect()) } + syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()), + syn::Data::Union(_) => panic!("Only struct and enum can be derived"), + } } diff --git a/burn-derive/src/config/analyzer_enum.rs b/burn-derive/src/config/analyzer_enum.rs index e926f4c502..9c817bc2a4 100644 --- a/burn-derive/src/config/analyzer_enum.rs +++ b/burn-derive/src/config/analyzer_enum.rs @@ -4,174 +4,174 @@ use quote::quote; use syn::{FieldsNamed, Variant}; pub struct ConfigEnumAnalyzer { - name: Ident, - data: syn::DataEnum, + name: Ident, + data: syn::DataEnum, } impl ConfigEnumAnalyzer { - pub fn new(name: Ident, data: syn::DataEnum) -> Self { - Self { name, data } - } + pub fn new(name: Ident, data: syn::DataEnum) -> Self { + Self { name, data } + } + + fn serde_enum_ident(&self) -> Ident { + Ident::new(&format!("{}Serde", self.name), self.name.span()) + } + + fn gen_serde_enum(&self) -> TokenStream { + let enum_name = self.serde_enum_ident(); + let data = &self.data.variants; + + quote! { + #[derive(serde::Serialize, serde::Deserialize)] + enum #enum_name { + #data + } - fn serde_enum_ident(&self) -> Ident { - Ident::new(&format!("{}Serde", self.name), self.name.span()) } + } - fn gen_serde_enum(&self) -> TokenStream { - let enum_name = self.serde_enum_ident(); - let data = &self.data.variants; - - quote! { - #[derive(serde::Serialize, serde::Deserialize)] - enum #enum_name { - #data - } + fn gen_variant_field(&self, variant: &Variant) -> (TokenStream, TokenStream) { + let gen_fields_unnamed = |num: usize| { + let mut input = Vec::new(); + let mut output = Vec::new(); - } - } + for i in 0..num { + let arg_name = Ident::new(&format!("arg_{i}"), self.name.span()); - fn gen_variant_field(&self, variant: &Variant) -> (TokenStream, TokenStream) { - let gen_fields_unnamed = |num: usize| { - let mut input = Vec::new(); - let mut output = Vec::new(); + input.push(quote! { #arg_name }); + output.push(quote! { #arg_name.clone() }); + } - for i in 0..num { - let arg_name = Ident::new(&format!("arg_{i}"), self.name.span()); + (quote! (( #(#input),* )), quote! (( #(#output),* ))) + }; + let gen_fields_named = |fields: &FieldsNamed| { + let mut input = Vec::new(); + let mut output = Vec::new(); - input.push(quote! { #arg_name }); - output.push(quote! { #arg_name.clone() }); - } + fields.named.iter().for_each(|field| { + let ident = &field.ident; - (quote! (( #(#input),* )), quote! (( #(#output),* ))) - }; - let gen_fields_named = |fields: &FieldsNamed| { - let mut input = Vec::new(); - let mut output = Vec::new(); - - fields.named.iter().for_each(|field| { - let ident = &field.ident; - - input.push(quote! { - #ident - }); - output.push(quote! { - #ident: #ident.clone() - }); - }); - - (quote! {{ #(#input),* }}, quote! {{ #(#output),* }}) - }; - - match &variant.fields { - syn::Fields::Named(fields) => gen_fields_named(fields), - syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()), - syn::Fields::Unit => (quote! {}, quote! {}), - } - } + input.push(quote! { + #ident + }); + output.push(quote! { + #ident: #ident.clone() + }); + }); - fn gen_serialize_fn(&self) -> TokenStream { - let enum_name = self.serde_enum_ident(); - let variants = self.data.variants.iter().map(|variant| { - let variant_name = &variant.ident; - let (variant_input, variant_output) = self.gen_variant_field(variant); + (quote! {{ #(#input),* }}, quote! {{ #(#output),* }}) + }; - quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output } - }); - let name = &self.name; - - quote! { - impl serde::Serialize for #name { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer { - let serde_state = match self { - #(#variants),* - }; - serde_state.serialize(serializer) - } + match &variant.fields { + syn::Fields::Named(fields) => gen_fields_named(fields), + syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()), + syn::Fields::Unit => (quote! {}, quote! {}), + } + } + + fn gen_serialize_fn(&self) -> TokenStream { + let enum_name = self.serde_enum_ident(); + let variants = self.data.variants.iter().map(|variant| { + let variant_name = &variant.ident; + let (variant_input, variant_output) = self.gen_variant_field(variant); + + quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output } + }); + let name = &self.name; + + quote! { + impl serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer { + let serde_state = match self { + #(#variants),* + }; + serde_state.serialize(serializer) } - } - } - - fn gen_deserialize_fn(&self) -> TokenStream { - let enum_name = self.serde_enum_ident(); - let variants = self.data.variants.iter().map(|variant| { - let variant_name = &variant.ident; - let (variant_input, variant_output) = self.gen_variant_field(variant); - quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output } - }); - let name = &self.name; - - quote! { - impl<'de> serde::Deserialize<'de> for #name { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de> { - let serde_state = #enum_name::deserialize(deserializer)?; - Ok(match serde_state { - #(#variants),* - }) - } + } + } + + fn gen_deserialize_fn(&self) -> TokenStream { + let enum_name = self.serde_enum_ident(); + let variants = self.data.variants.iter().map(|variant| { + let variant_name = &variant.ident; + let (variant_input, variant_output) = self.gen_variant_field(variant); + + quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output } + }); + let name = &self.name; + + quote! { + impl<'de> serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + let serde_state = #enum_name::deserialize(deserializer)?; + Ok(match serde_state { + #(#variants),* + }) } - } + } + } } impl ConfigAnalyzer for ConfigEnumAnalyzer { - fn gen_serde_impl(&self) -> TokenStream { - let struct_gen = self.gen_serde_enum(); - let serialize_gen = self.gen_serialize_fn(); - let deserialize_gen = self.gen_deserialize_fn(); - - quote! { - #struct_gen - #serialize_gen - #deserialize_gen - } + fn gen_serde_impl(&self) -> TokenStream { + let struct_gen = self.gen_serde_enum(); + let serialize_gen = self.gen_serialize_fn(); + let deserialize_gen = self.gen_deserialize_fn(); + + quote! { + #struct_gen + #serialize_gen + #deserialize_gen } - - fn gen_clone_impl(&self) -> TokenStream { - let variants = self.data.variants.iter().map(|variant| { - let variant_name = &variant.ident; - let (variant_input, variant_output) = self.gen_variant_field(variant); - - quote! { Self::#variant_name #variant_input => Self::#variant_name #variant_output } - }); - let name = &self.name; - - quote! { - impl Clone for #name { - fn clone(&self) -> Self { - match self { - #(#variants),* - } + } + + fn gen_clone_impl(&self) -> TokenStream { + let variants = self.data.variants.iter().map(|variant| { + let variant_name = &variant.ident; + let (variant_input, variant_output) = self.gen_variant_field(variant); + + quote! { Self::#variant_name #variant_input => Self::#variant_name #variant_output } + }); + let name = &self.name; + + quote! { + impl Clone for #name { + fn clone(&self) -> Self { + match self { + #(#variants),* } } - } + } + } - fn gen_display_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_display_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl core::fmt::Display for #name { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(&burn::config::config_to_json(self)) - } + quote! { + impl core::fmt::Display for #name { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(&burn::config::config_to_json(self)) } } } + } - fn gen_config_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_config_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl burn::config::Config for #name { - } + quote! { + impl burn::config::Config for #name { } } + } } diff --git a/burn-derive/src/config/analyzer_struct.rs b/burn-derive/src/config/analyzer_struct.rs index 18ec62c169..699bc49661 100644 --- a/burn-derive/src/config/analyzer_struct.rs +++ b/burn-derive/src/config/analyzer_struct.rs @@ -4,294 +4,294 @@ use proc_macro2::{Ident, TokenStream}; use quote::quote; pub struct ConfigStructAnalyzer { + name: Ident, + fields_required: Vec, + fields_option: Vec, + fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, +} + +impl ConfigStructAnalyzer { + pub fn new( name: Ident, fields_required: Vec, fields_option: Vec, fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, -} - -impl ConfigStructAnalyzer { - pub fn new( - name: Ident, - fields_required: Vec, - fields_option: Vec, - fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, - ) -> Self { - Self { - name, - fields_required, - fields_option, - fields_default, - } + ) -> Self { + Self { + name, + fields_required, + fields_option, + fields_default, } + } - fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream { - let name = &self.name; + fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream { + let name = &self.name; - quote! { - impl #name { - #tokens - } + quote! { + impl #name { + #tokens } } + } - fn names(&self) -> Vec { - let mut names = Vec::new(); - - for field in self.fields_required.iter() { - names.push(field.clone()); - } + fn names(&self) -> Vec { + let mut names = Vec::new(); - for field in self.fields_option.iter() { - names.push(field.clone()); - } - - for (field, _) in self.fields_default.iter() { - names.push(field.clone()); - } + for field in self.fields_required.iter() { + names.push(field.clone()); + } - names + for field in self.fields_option.iter() { + names.push(field.clone()); } - fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec { - let mut name_types = Vec::new(); + for (field, _) in self.fields_default.iter() { + names.push(field.clone()); + } - for field in names.iter() { - let name = field.ident(); - let ty = &field.field.ty; + names + } - name_types.push(quote! { - #name: #ty - }); - } + fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec { + let mut name_types = Vec::new(); - name_types - } + for field in names.iter() { + let name = field.ident(); + let ty = &field.field.ty; - fn serde_struct_ident(&self) -> Ident { - Ident::new(&format!("{}Serde", self.name), self.name.span()) + name_types.push(quote! { + #name: #ty + }); } - fn gen_serialize_fn( - &self, - struct_name: &Ident, - struct_gen: &TokenStream, - names: &[FieldTypeAnalyzer], - ) -> TokenStream { - let name = &self.name; - let names = names.iter().map(|name| { - let name = name.ident(); - quote! { #name: self.#name.clone() } - }); - - quote! { - impl serde::Serialize for #name { - - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer { - #[derive(serde::Serialize)] - #struct_gen - - let serde_state = #struct_name { - #(#names),* - }; - serde_state.serialize(serializer) - } + name_types + } + + fn serde_struct_ident(&self) -> Ident { + Ident::new(&format!("{}Serde", self.name), self.name.span()) + } + + fn gen_serialize_fn( + &self, + struct_name: &Ident, + struct_gen: &TokenStream, + names: &[FieldTypeAnalyzer], + ) -> TokenStream { + let name = &self.name; + let names = names.iter().map(|name| { + let name = name.ident(); + quote! { #name: self.#name.clone() } + }); + + quote! { + impl serde::Serialize for #name { + + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer { + #[derive(serde::Serialize)] + #struct_gen + + let serde_state = #struct_name { + #(#names),* + }; + serde_state.serialize(serializer) } - } - } - fn gen_deserialize_fn( - &self, - struct_name: &Ident, - struct_gen: &TokenStream, - names: &[FieldTypeAnalyzer], - ) -> TokenStream { - let name = &self.name; - let names = names.iter().map(|name| { - let name = name.ident(); - quote! { #name: serde_state.#name } - }); - - quote! { - impl<'de> serde::Deserialize<'de> for #name { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de> { - #[derive(serde::Deserialize)] - #struct_gen - - let serde_state = #struct_name::deserialize(deserializer)?; - Ok(#name { - #(#names),* - }) - } + } + } + + fn gen_deserialize_fn( + &self, + struct_name: &Ident, + struct_gen: &TokenStream, + names: &[FieldTypeAnalyzer], + ) -> TokenStream { + let name = &self.name; + let names = names.iter().map(|name| { + let name = name.ident(); + quote! { #name: serde_state.#name } + }); + + quote! { + impl<'de> serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + #[derive(serde::Deserialize)] + #struct_gen + + let serde_state = #struct_name::deserialize(deserializer)?; + Ok(#name { + #(#names),* + }) } - } - } - fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream { - let struct_name = self.serde_struct_ident(); + } + } - quote! { - struct #struct_name { - #(#names),* - } + fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream { + let struct_name = self.serde_struct_ident(); + quote! { + struct #struct_name { + #(#names),* } + } + } } impl ConfigAnalyzer for ConfigStructAnalyzer { - fn gen_new_fn(&self) -> TokenStream { - let mut body = quote! {}; - let mut names = Vec::new(); - - for field in self.fields_required.iter() { - let name = field.ident(); - let ty = &field.field.ty; - - body.extend(quote! { - #name: #name, - }); - names.push(quote! { - #name: #ty - }); - } + fn gen_new_fn(&self) -> TokenStream { + let mut body = quote! {}; + let mut names = Vec::new(); + + for field in self.fields_required.iter() { + let name = field.ident(); + let ty = &field.field.ty; + + body.extend(quote! { + #name: #name, + }); + names.push(quote! { + #name: #ty + }); + } - for field in self.fields_option.iter() { - let name = field.ident(); + for field in self.fields_option.iter() { + let name = field.ident(); - body.extend(quote! { - #name: None, - }); - } + body.extend(quote! { + #name: None, + }); + } - for (field, attribute) in self.fields_default.iter() { - let name = field.ident(); - let value = &attribute.value; - match value { - syn::Lit::Str(value) => { - let stream: proc_macro2::TokenStream = value.value().parse().unwrap(); + for (field, attribute) in self.fields_default.iter() { + let name = field.ident(); + let value = &attribute.value; + match value { + syn::Lit::Str(value) => { + let stream: proc_macro2::TokenStream = value.value().parse().unwrap(); - body.extend(quote! { - #name: #stream, - }); - } - _ => { - body.extend(quote! { - #name: #value, - }); - } - }; + body.extend(quote! { + #name: #stream, + }); } - - let body = quote! { - /// Create a new instance of the config. - pub fn new( - #(#names),* - ) -> Self { - Self { #body } - } - }; - self.wrap_impl_block(body) + _ => { + body.extend(quote! { + #name: #value, + }); + } + }; } - fn gen_builder_fns(&self) -> TokenStream { - let mut body = quote! {}; - - for (field, _) in self.fields_default.iter() { - let name = field.ident(); - let doc = field.doc().unwrap_or_else(|| { - quote! { - /// Set the default value for the field. - } - }); - let ty = &field.field.ty; - let fn_name = Ident::new(&format!("with_{name}"), name.span()); - - body.extend(quote! { - #doc - pub fn #fn_name(mut self, #name: #ty) -> Self { - self.#name = #name; - self - } - }); + let body = quote! { + /// Create a new instance of the config. + pub fn new( + #(#names),* + ) -> Self { + Self { #body } } + }; + self.wrap_impl_block(body) + } - for field in self.fields_option.iter() { - let name = field.ident(); - let ty = &field.field.ty; - let fn_name = Ident::new(&format!("with_{name}"), name.span()); + fn gen_builder_fns(&self) -> TokenStream { + let mut body = quote! {}; - body.extend(quote! { + for (field, _) in self.fields_default.iter() { + let name = field.ident(); + let doc = field.doc().unwrap_or_else(|| { + quote! { /// Set the default value for the field. - pub fn #fn_name(mut self, #name: #ty) -> Self { - self.#name = #name; - self - } - }); } - - self.wrap_impl_block(body) + }); + let ty = &field.field.ty; + let fn_name = Ident::new(&format!("with_{name}"), name.span()); + + body.extend(quote! { + #doc + pub fn #fn_name(mut self, #name: #ty) -> Self { + self.#name = #name; + self + } + }); } - fn gen_serde_impl(&self) -> TokenStream { - let names = self.names(); + for field in self.fields_option.iter() { + let name = field.ident(); + let ty = &field.field.ty; + let fn_name = Ident::new(&format!("with_{name}"), name.span()); + + body.extend(quote! { + /// Set the default value for the field. + pub fn #fn_name(mut self, #name: #ty) -> Self { + self.#name = #name; + self + } + }); + } - let struct_name = self.serde_struct_ident(); - let name_types = self.name_types(&names); - let struct_gen = self.gen_serde_struct(&name_types); + self.wrap_impl_block(body) + } - let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names); - let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names); + fn gen_serde_impl(&self) -> TokenStream { + let names = self.names(); - quote! { - #serialize_gen - #deserialize_gen - } - } + let struct_name = self.serde_struct_ident(); + let name_types = self.name_types(&names); + let struct_gen = self.gen_serde_struct(&name_types); - fn gen_clone_impl(&self) -> TokenStream { - let name = &self.name; - let names = self.names().into_iter().map(|name| { - let name = name.ident(); - quote! { #name: self.#name.clone() } - }); + let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names); + let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names); - quote! { - impl Clone for #name { - fn clone(&self) -> Self { - Self { - #(#names),* - } + quote! { + #serialize_gen + #deserialize_gen + } + } + + fn gen_clone_impl(&self) -> TokenStream { + let name = &self.name; + let names = self.names().into_iter().map(|name| { + let name = name.ident(); + quote! { #name: self.#name.clone() } + }); + + quote! { + impl Clone for #name { + fn clone(&self) -> Self { + Self { + #(#names),* } } - } + } + } - fn gen_display_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_display_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl core::fmt::Display for #name { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(&burn::config::config_to_json(self)) - } + quote! { + impl core::fmt::Display for #name { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(&burn::config::config_to_json(self)) } } } + } - fn gen_config_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_config_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl burn::config::Config for #name { - } + quote! { + impl burn::config::Config for #name { } } + } } diff --git a/burn-derive/src/config/base.rs b/burn-derive/src/config/base.rs index cca3f0430b..c1a45e28f4 100644 --- a/burn-derive/src/config/base.rs +++ b/burn-derive/src/config/base.rs @@ -2,23 +2,23 @@ use super::ConfigAnalyzerFactory; use quote::quote; pub(crate) fn derive_impl(item: &syn::DeriveInput) -> proc_macro::TokenStream { - let factory = ConfigAnalyzerFactory::new(); - let analyzer = factory.create_analyzer(item); + let factory = ConfigAnalyzerFactory::new(); + let analyzer = factory.create_analyzer(item); - let constructor = analyzer.gen_new_fn(); - let builders = analyzer.gen_builder_fns(); - let serde = analyzer.gen_serde_impl(); - let clone = analyzer.gen_clone_impl(); - let display = analyzer.gen_display_impl(); - let config_impl = analyzer.gen_config_impl(); + let constructor = analyzer.gen_new_fn(); + let builders = analyzer.gen_builder_fns(); + let serde = analyzer.gen_serde_impl(); + let clone = analyzer.gen_clone_impl(); + let display = analyzer.gen_display_impl(); + let config_impl = analyzer.gen_config_impl(); - quote! { - #config_impl - #constructor - #builders - #serde - #clone - #display - } - .into() + quote! { + #config_impl + #constructor + #builders + #serde + #clone + #display + } + .into() } diff --git a/burn-derive/src/lib.rs b/burn-derive/src/lib.rs index 91f254b88f..35a5e22f71 100644 --- a/burn-derive/src/lib.rs +++ b/burn-derive/src/lib.rs @@ -15,20 +15,20 @@ pub(crate) mod shared; /// Derive macro for the module. #[proc_macro_derive(Module)] pub fn module_derive(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - module::derive_impl(&input) + let input = syn::parse(input).unwrap(); + module::derive_impl(&input) } /// Derive macro for the record. #[proc_macro_derive(Record)] pub fn record_derive(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - record::derive_impl(&input) + let input = syn::parse(input).unwrap(); + record::derive_impl(&input) } /// Derive macro for the config. #[proc_macro_derive(Config, attributes(config))] pub fn config_derive(input: TokenStream) -> TokenStream { - let item = syn::parse(input).unwrap(); - config::derive_impl(&item) + let item = syn::parse(input).unwrap(); + config::derive_impl(&item) } diff --git a/burn-derive/src/module/base.rs b/burn-derive/src/module/base.rs index 1cf41bb5ed..9536a7adc7 100644 --- a/burn-derive/src/module/base.rs +++ b/burn-derive/src/module/base.rs @@ -1,6 +1,6 @@ use super::{ - codegen::ModuleCodegen, codegen_struct::StructModuleCodegen, record::ModuleRecordCodegen, - record_struct::StructModuleRecordCodegen, + codegen::ModuleCodegen, codegen_struct::StructModuleCodegen, record::ModuleRecordCodegen, + record_struct::StructModuleRecordCodegen, }; use crate::module::display; use proc_macro::TokenStream; @@ -8,149 +8,149 @@ use quote::quote; use syn::{parse_quote, Ident}; pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; - let has_backend = ast - .generics - .type_params() - .map(|param| param.ident == "B") - .reduce(|accum, is_backend| is_backend || accum) - .unwrap_or(false); - - if !has_backend { - return constant_impl(ast); - } - - let (generics, generics_ty, generics_where) = ast.generics.split_for_impl(); - let backend_trait = fetch_backend_trait(&ast.generics); - - let display_fn = display::display_fn(name); - - let generator = StructModuleCodegen::from_ast(ast); - let num_params_fn = generator.gen_num_params(); - let visit = generator.gen_visit(); - let map_mut = generator.gen_map(); - let valid_fn = generator.gen_valid(); - let into_record_fn = generator.gen_into_record(); - let load_record_fn = generator.gen_load_record(); - let clone_fn = generator.gen_clone(); - let generics_names_except_backend = generics_names_except_backend(&ast.generics); - - let record_name = Ident::new(format!("{}Record", name).as_str(), name.span()); - let record_gen = StructModuleRecordCodegen::new(generator.fields); - let record_struct = record_gen.gen_record_type(&record_name, &ast.generics); - - let gen = quote! { - impl #generics burn::module::Module for #name #generics_ty #generics_where { - type Record = #record_name #generics_ty; - - #load_record_fn - #into_record_fn - - #num_params_fn - - #visit - #map_mut - } - - impl #generics burn::module::AutodiffModule for #name #generics_ty - where - B: burn::tensor::backend::AutodiffBackend, - ::InnerBackend: #backend_trait, - { - type InnerModule=#name; - - #valid_fn - } - - impl #generics core::fmt::Display for #name #generics_ty #generics_where { - #display_fn - } - - impl #generics Clone for #name #generics_ty #generics_where { - #clone_fn - } - - #record_struct - }; - - gen.into() + let name = &ast.ident; + let has_backend = ast + .generics + .type_params() + .map(|param| param.ident == "B") + .reduce(|accum, is_backend| is_backend || accum) + .unwrap_or(false); + + if !has_backend { + return constant_impl(ast); + } + + let (generics, generics_ty, generics_where) = ast.generics.split_for_impl(); + let backend_trait = fetch_backend_trait(&ast.generics); + + let display_fn = display::display_fn(name); + + let generator = StructModuleCodegen::from_ast(ast); + let num_params_fn = generator.gen_num_params(); + let visit = generator.gen_visit(); + let map_mut = generator.gen_map(); + let valid_fn = generator.gen_valid(); + let into_record_fn = generator.gen_into_record(); + let load_record_fn = generator.gen_load_record(); + let clone_fn = generator.gen_clone(); + let generics_names_except_backend = generics_names_except_backend(&ast.generics); + + let record_name = Ident::new(format!("{}Record", name).as_str(), name.span()); + let record_gen = StructModuleRecordCodegen::new(generator.fields); + let record_struct = record_gen.gen_record_type(&record_name, &ast.generics); + + let gen = quote! { + impl #generics burn::module::Module for #name #generics_ty #generics_where { + type Record = #record_name #generics_ty; + + #load_record_fn + #into_record_fn + + #num_params_fn + + #visit + #map_mut + } + + impl #generics burn::module::AutodiffModule for #name #generics_ty + where + B: burn::tensor::backend::AutodiffBackend, + ::InnerBackend: #backend_trait, + { + type InnerModule=#name; + + #valid_fn + } + + impl #generics core::fmt::Display for #name #generics_ty #generics_where { + #display_fn + } + + impl #generics Clone for #name #generics_ty #generics_where { + #clone_fn + } + + #record_struct + }; + + gen.into() } // When there is no backend in the generic parameter, the struct is considered as a constant. fn constant_impl(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; - let (_, generics_ty, generics_where) = ast.generics.split_for_impl(); - - let backend: syn::Generics = parse_quote! { }; - let backend_ad: syn::Generics = parse_quote! { }; - - let mut generics_module = ast.generics.clone(); - let mut generics_module_ad = ast.generics.clone(); - - for param in backend.params.into_iter() { - generics_module.params.push(param); - } - for param in backend_ad.params.into_iter() { - generics_module_ad.params.push(param); - } - let (generics_module, _, _) = generics_module.split_for_impl(); - let (generics_module_ad, _, _) = generics_module_ad.split_for_impl(); - - let gen = quote! { - impl #generics_module burn::module::Module for #name #generics_ty #generics_where { - burn::constant!(module); - } - - impl #generics_module_ad burn::module::AutodiffModule for #name #generics_ty #generics_where { - burn::constant!(ad_module, #name #generics_ty); - } - }; - - gen.into() + let name = &ast.ident; + let (_, generics_ty, generics_where) = ast.generics.split_for_impl(); + + let backend: syn::Generics = parse_quote! { }; + let backend_ad: syn::Generics = parse_quote! { }; + + let mut generics_module = ast.generics.clone(); + let mut generics_module_ad = ast.generics.clone(); + + for param in backend.params.into_iter() { + generics_module.params.push(param); + } + for param in backend_ad.params.into_iter() { + generics_module_ad.params.push(param); + } + let (generics_module, _, _) = generics_module.split_for_impl(); + let (generics_module_ad, _, _) = generics_module_ad.split_for_impl(); + + let gen = quote! { + impl #generics_module burn::module::Module for #name #generics_ty #generics_where { + burn::constant!(module); + } + + impl #generics_module_ad burn::module::AutodiffModule for #name #generics_ty #generics_where { + burn::constant!(ad_module, #name #generics_ty); + } + }; + + gen.into() } fn fetch_backend_trait(generics: &syn::Generics) -> proc_macro2::TokenStream { - static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str = "Modules should be generic over a backend. + static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str = "Modules should be generic over a backend. - The generic argument named `B` should have its first trait bound being a backend trait. - The default backend trait is `burn::tensor::backend::Backend`. - Any backend trait is supported."; - for param in generics.params.iter() { - if let syn::GenericParam::Type(ty) = ¶m { - if ty.ident == "B" { - let bound = ty - .bounds - .first() - .expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG); - - return quote! { - #bound - }; - } - } + for param in generics.params.iter() { + if let syn::GenericParam::Type(ty) = ¶m { + if ty.ident == "B" { + let bound = ty + .bounds + .first() + .expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG); + + return quote! { + #bound + }; + } } + } - panic!("{BACKEND_TRAIT_COMPILATION_ERROR_MSG}"); + panic!("{BACKEND_TRAIT_COMPILATION_ERROR_MSG}"); } fn generics_names_except_backend(generics: &syn::Generics) -> proc_macro2::TokenStream { - let mut named = quote! {}; - - generics.params.iter().for_each(|param| { - match param { - syn::GenericParam::Type(ty) => { - if ty.ident != "B" { - let ident = &ty.ident; - named.extend(quote! { #ident, }); - } - } - syn::GenericParam::Lifetime(_) => panic!("Lifetime not supported in module"), - syn::GenericParam::Const(c) => { - let ident = &c.ident; - named.extend(quote! { #ident, }); - } - }; - }); + let mut named = quote! {}; + + generics.params.iter().for_each(|param| { + match param { + syn::GenericParam::Type(ty) => { + if ty.ident != "B" { + let ident = &ty.ident; + named.extend(quote! { #ident, }); + } + } + syn::GenericParam::Lifetime(_) => panic!("Lifetime not supported in module"), + syn::GenericParam::Const(c) => { + let ident = &c.ident; + named.extend(quote! { #ident, }); + } + }; + }); - named + named } diff --git a/burn-derive/src/module/codegen.rs b/burn-derive/src/module/codegen.rs index 852138018c..2a7d52cedf 100644 --- a/burn-derive/src/module/codegen.rs +++ b/burn-derive/src/module/codegen.rs @@ -2,11 +2,11 @@ use proc_macro2::TokenStream; /// Basic trait to be implemented for Module generation. pub(crate) trait ModuleCodegen { - fn gen_num_params(&self) -> TokenStream; - fn gen_visit(&self) -> TokenStream; - fn gen_map(&self) -> TokenStream; - fn gen_valid(&self) -> TokenStream; - fn gen_into_record(&self) -> TokenStream; - fn gen_load_record(&self) -> TokenStream; - fn gen_clone(&self) -> TokenStream; + fn gen_num_params(&self) -> TokenStream; + fn gen_visit(&self) -> TokenStream; + fn gen_map(&self) -> TokenStream; + fn gen_valid(&self) -> TokenStream; + fn gen_into_record(&self) -> TokenStream; + fn gen_load_record(&self) -> TokenStream; + fn gen_clone(&self) -> TokenStream; } diff --git a/burn-derive/src/module/codegen_struct.rs b/burn-derive/src/module/codegen_struct.rs index a6988b2bd4..1f163b413a 100644 --- a/burn-derive/src/module/codegen_struct.rs +++ b/burn-derive/src/module/codegen_struct.rs @@ -5,164 +5,164 @@ use quote::quote; use super::codegen::ModuleCodegen; pub(crate) struct StructModuleCodegen { - pub fields: Vec, + pub fields: Vec, } impl ModuleCodegen for StructModuleCodegen { - fn gen_num_params(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - num_params += burn::module::Module::::num_params(&self.#name); - } - }); - - quote! { - fn num_params(&self) -> usize { - let mut num_params = 0; - #body - num_params - } + fn gen_num_params(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + num_params += burn::module::Module::::num_params(&self.#name); + } + }); + + quote! { + fn num_params(&self) -> usize { + let mut num_params = 0; + #body + num_params } } - - fn gen_visit(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - burn::module::Module::visit(&self.#name, visitor); - } - }); - - quote! { - fn visit>(&self, visitor: &mut V) { - #body - } + } + + fn gen_visit(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + burn::module::Module::visit(&self.#name, visitor); + } + }); + + quote! { + fn visit>(&self, visitor: &mut V) { + #body } } + } - fn gen_map(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = burn::module::Module::map(self.#name, mapper); - } - }); + fn gen_map(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = burn::module::Module::map(self.#name, mapper); + } + }); - quote! { - fn map>(self, mapper: &mut M) -> Self { - #body + quote! { + fn map>(self, mapper: &mut M) -> Self { + #body - Self { - #(#names),* - } + Self { + #(#names),* } } } + } - fn gen_valid(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = burn::module::AutodiffModule::::valid(&self.#name); - } - }); + fn gen_valid(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = burn::module::AutodiffModule::::valid(&self.#name); + } + }); - quote! { - fn valid(&self) -> Self::InnerModule { - #body + quote! { + fn valid(&self) -> Self::InnerModule { + #body - Self::InnerModule { - #(#names),* - } + Self::InnerModule { + #(#names),* } } } - - fn gen_into_record(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - #name: burn::module::Module::::into_record(self.#name), - } - }); - - quote! { - fn into_record(self) -> Self::Record { - Self::Record { - #body - } + } + + fn gen_into_record(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + #name: burn::module::Module::::into_record(self.#name), + } + }); + + quote! { + fn into_record(self) -> Self::Record { + Self::Record { + #body } } } - - fn gen_load_record(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - #name: burn::module::Module::::load_record(self.#name, record.#name), - } - }); - - quote! { - fn load_record(self, record: Self::Record) -> Self { - Self { - #body - } + } + + fn gen_load_record(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + #name: burn::module::Module::::load_record(self.#name, record.#name), + } + }); + + quote! { + fn load_record(self, record: Self::Record) -> Self { + Self { + #body } } } + } - fn gen_clone(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = self.#name.clone(); - } - }); + fn gen_clone(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = self.#name.clone(); + } + }); - quote! { - fn clone(&self) -> Self { - #body + quote! { + fn clone(&self) -> Self { + #body - Self { - #(#names),* - } + Self { + #(#names),* } } } + } } impl StructModuleCodegen { - pub fn from_ast(ast: &syn::DeriveInput) -> Self { - Self { - fields: parse_fields(ast) - .into_iter() - .map(FieldTypeAnalyzer::new) - .collect(), - } + pub fn from_ast(ast: &syn::DeriveInput) -> Self { + Self { + fields: parse_fields(ast) + .into_iter() + .map(FieldTypeAnalyzer::new) + .collect(), } + } - fn gen_fields_fn_names(&self, func: F) -> (Vec, TokenStream) - where - F: Fn(Ident) -> TokenStream, - { - let mut body = quote! {}; - let mut names = Vec::new(); + fn gen_fields_fn_names(&self, func: F) -> (Vec, TokenStream) + where + F: Fn(Ident) -> TokenStream, + { + let mut body = quote! {}; + let mut names = Vec::new(); - for field in self.fields.iter() { - let name = field.ident(); + for field in self.fields.iter() { + let name = field.ident(); - names.push(name.clone()); - body.extend(func(field.ident())); - } - - (names, body) + names.push(name.clone()); + body.extend(func(field.ident())); } - fn gen_fields_fn(&self, func: F) -> TokenStream - where - F: Fn(Ident) -> TokenStream, - { - let mut body = quote! {}; + (names, body) + } - for field in self.fields.iter() { - body.extend(func(field.ident())); - } + fn gen_fields_fn(&self, func: F) -> TokenStream + where + F: Fn(Ident) -> TokenStream, + { + let mut body = quote! {}; - body + for field in self.fields.iter() { + body.extend(func(field.ident())); } + + body + } } diff --git a/burn-derive/src/module/display.rs b/burn-derive/src/module/display.rs index f9c024ff49..3e15331ebe 100644 --- a/burn-derive/src/module/display.rs +++ b/burn-derive/src/module/display.rs @@ -2,10 +2,10 @@ use proc_macro2::Ident; use quote::quote; pub fn display_fn(name: &Ident) -> proc_macro2::TokenStream { - quote! { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}[num_params={}]", stringify!(#name), self.num_params()) + quote! { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}[num_params={}]", stringify!(#name), self.num_params()) - } - } + } + } } diff --git a/burn-derive/src/module/record.rs b/burn-derive/src/module/record.rs index 72e3cd2872..be05e34f82 100644 --- a/burn-derive/src/module/record.rs +++ b/burn-derive/src/module/record.rs @@ -3,6 +3,6 @@ use syn::Generics; /// Basic trait to generate a record type based on the Module struct. pub(crate) trait ModuleRecordCodegen { - /// Generate the record type (i.e a struct) - fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream; + /// Generate the record type (i.e a struct) + fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream; } diff --git a/burn-derive/src/module/record_struct.rs b/burn-derive/src/module/record_struct.rs index 47c330f96e..2521e52b03 100644 --- a/burn-derive/src/module/record_struct.rs +++ b/burn-derive/src/module/record_struct.rs @@ -7,30 +7,30 @@ use super::record::ModuleRecordCodegen; #[derive(new)] pub(crate) struct StructModuleRecordCodegen { - fields: Vec, + fields: Vec, } impl ModuleRecordCodegen for StructModuleRecordCodegen { - fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { - let mut fields = quote! {}; + fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { + let mut fields = quote! {}; - for field in self.fields.iter() { - let ty = &field.field.ty; - let name = &field.field.ident; + for field in self.fields.iter() { + let ty = &field.field.ty; + let name = &field.field.ident; - fields.extend(quote! { - /// The module record associative type. - pub #name: <#ty as burn::module::Module>::Record, - }); - } + fields.extend(quote! { + /// The module record associative type. + pub #name: <#ty as burn::module::Module>::Record, + }); + } - quote! { + quote! { - /// The record type for the module. - #[derive(burn::record::Record, Debug, Clone)] - pub struct #record_name #generics { - #fields - } + /// The record type for the module. + #[derive(burn::record::Record, Debug, Clone)] + pub struct #record_name #generics { + #fields } } + } } diff --git a/burn-derive/src/record/base.rs b/burn-derive/src/record/base.rs index e446e961fc..deeaf9551b 100644 --- a/burn-derive/src/record/base.rs +++ b/burn-derive/src/record/base.rs @@ -6,83 +6,83 @@ use super::{codegen::RecordItemCodegen, codegen_struct::StructRecordItemCodegen} use crate::shared::field::{parse_fields, FieldTypeAnalyzer}; pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream { - let record_gen = RecordDeriveCodegen::from_ast(ast); - let item_struct = record_gen.gen_record_type(); - let record_impl = record_gen.gen_impl_record(); - - quote! { - #item_struct - #record_impl - } - .into() + let record_gen = RecordDeriveCodegen::from_ast(ast); + let item_struct = record_gen.gen_record_type(); + let record_impl = record_gen.gen_impl_record(); + + quote! { + #item_struct + #record_impl + } + .into() } struct RecordDeriveCodegen { - name_record: Ident, - name_item: Ident, - gen: StructRecordItemCodegen, - generics: Generics, + name_record: Ident, + name_item: Ident, + gen: StructRecordItemCodegen, + generics: Generics, } impl RecordDeriveCodegen { - pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self { - let name_record = ast.ident.clone(); - let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span()); - - Self { - name_record, - name_item, - gen: StructRecordItemCodegen::new( - parse_fields(ast) - .into_iter() - .map(FieldTypeAnalyzer::new) - .collect(), - ), - generics: ast.generics.clone(), - } + pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self { + let name_record = ast.ident.clone(); + let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span()); + + Self { + name_record, + name_item, + gen: StructRecordItemCodegen::new( + parse_fields(ast) + .into_iter() + .map(FieldTypeAnalyzer::new) + .collect(), + ), + generics: ast.generics.clone(), } + } - /// Generate the record type with the correct generics. - pub(crate) fn gen_record_type(&self) -> TokenStream { - let param: syn::Generics = parse_quote! { }; - let mut generics = self.generics.clone(); - - for param in param.params.into_iter() { - generics.params.push(param); - } + /// Generate the record type with the correct generics. + pub(crate) fn gen_record_type(&self) -> TokenStream { + let param: syn::Generics = parse_quote! { }; + let mut generics = self.generics.clone(); - self.gen.gen_item_type(&self.name_item, &generics) + for param in param.params.into_iter() { + generics.params.push(param); } - /// Generate the implementation for the Record trait. - pub(crate) fn gen_impl_record(&self) -> TokenStream { - let name = &self.name_record; - let item_generics = self.record_item_generics(); - let (_, ty_generics_item, _) = item_generics.split_for_impl(); - let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + self.gen.gen_item_type(&self.name_item, &generics) + } - let name_item = &self.name_item; - let into_item_fn = self.gen.gen_into_item(name_item); - let from_item_fn = self.gen.gen_from_item(); + /// Generate the implementation for the Record trait. + pub(crate) fn gen_impl_record(&self) -> TokenStream { + let name = &self.name_record; + let item_generics = self.record_item_generics(); + let (_, ty_generics_item, _) = item_generics.split_for_impl(); + let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); - quote! { - impl #impl_generics burn::record::Record for #name #ty_generics #where_clause { - type Item = #name_item #ty_generics_item; + let name_item = &self.name_item; + let into_item_fn = self.gen.gen_into_item(name_item); + let from_item_fn = self.gen.gen_from_item(); - #into_item_fn - #from_item_fn + quote! { + impl #impl_generics burn::record::Record for #name #ty_generics #where_clause { + type Item = #name_item #ty_generics_item; - } - } - } + #into_item_fn + #from_item_fn - fn record_item_generics(&self) -> Generics { - let param: syn::Generics = parse_quote! { }; - let mut generics = self.generics.clone(); - for param in param.params.into_iter() { - generics.params.push(param); } + } + } - generics + fn record_item_generics(&self) -> Generics { + let param: syn::Generics = parse_quote! { }; + let mut generics = self.generics.clone(); + for param in param.params.into_iter() { + generics.params.push(param); } + + generics + } } diff --git a/burn-derive/src/record/codegen.rs b/burn-derive/src/record/codegen.rs index aafcead97a..5af75f4678 100644 --- a/burn-derive/src/record/codegen.rs +++ b/burn-derive/src/record/codegen.rs @@ -3,10 +3,10 @@ use syn::Generics; /// Basic trait to be implemented for record generation. pub(crate) trait RecordItemCodegen { - /// Generate the record item type (i.e a struct) - fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream; - /// Generate the into_item function. - fn gen_into_item(&self, item_name: &Ident) -> TokenStream; - /// Generate the from item function. - fn gen_from_item(&self) -> TokenStream; + /// Generate the record item type (i.e a struct) + fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream; + /// Generate the into_item function. + fn gen_into_item(&self, item_name: &Ident) -> TokenStream; + /// Generate the from item function. + fn gen_from_item(&self) -> TokenStream; } diff --git a/burn-derive/src/record/codegen_struct.rs b/burn-derive/src/record/codegen_struct.rs index fcefaaad48..9fa1e7e1c5 100644 --- a/burn-derive/src/record/codegen_struct.rs +++ b/burn-derive/src/record/codegen_struct.rs @@ -7,76 +7,76 @@ use super::codegen::RecordItemCodegen; #[derive(new)] pub(crate) struct StructRecordItemCodegen { - fields: Vec, + fields: Vec, } impl RecordItemCodegen for StructRecordItemCodegen { - fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream { - let mut fields = quote! {}; - let mut bounds = quote! {}; + fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream { + let mut fields = quote! {}; + let mut bounds = quote! {}; - for field in self.fields.iter() { - let ty = &field.field.ty; - let name = &field.field.ident; + for field in self.fields.iter() { + let ty = &field.field.ty; + let name = &field.field.ident; - fields.extend(quote! { - /// Field to be serialized. - pub #name: <#ty as burn::record::Record>::Item, - }); - bounds.extend(quote!{ - <#ty as burn::record::Record>::Item: serde::Serialize + serde::de::DeserializeOwned, - }); - } - let bound = bounds.to_string(); + fields.extend(quote! { + /// Field to be serialized. + pub #name: <#ty as burn::record::Record>::Item, + }); + bounds.extend(quote! { + <#ty as burn::record::Record>::Item: serde::Serialize + serde::de::DeserializeOwned, + }); + } + let bound = bounds.to_string(); - quote! { + quote! { - /// The record item type for the module. - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] - #[serde(bound = #bound)] - pub struct #item_name #generics { - #fields - } + /// The record item type for the module. + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + #[serde(bound = #bound)] + pub struct #item_name #generics { + #fields } } + } - fn gen_into_item(&self, item_name: &Ident) -> TokenStream { - let mut body_into_item = quote! {}; + fn gen_into_item(&self, item_name: &Ident) -> TokenStream { + let mut body_into_item = quote! {}; - for field in self.fields.iter() { - let name = &field.field.ident; + for field in self.fields.iter() { + let name = &field.field.ident; - body_into_item.extend(quote! { - #name: burn::record::Record::into_item::(self.#name), - }); - } + body_into_item.extend(quote! { + #name: burn::record::Record::into_item::(self.#name), + }); + } - quote! { - fn into_item(self) -> Self::Item { - #item_name { - #body_into_item - } + quote! { + fn into_item(self) -> Self::Item { + #item_name { + #body_into_item } } } + } - fn gen_from_item(&self) -> TokenStream { - let mut body_from_item = quote! {}; + fn gen_from_item(&self) -> TokenStream { + let mut body_from_item = quote! {}; - for field in self.fields.iter() { - let name = &field.field.ident; + for field in self.fields.iter() { + let name = &field.field.ident; - body_from_item.extend(quote! { - #name: burn::record::Record::from_item::(item.#name), - }); - } + body_from_item.extend(quote! { + #name: burn::record::Record::from_item::(item.#name), + }); + } - quote! { - fn from_item(item: Self::Item) -> Self { - Self { - #body_from_item - } + quote! { + fn from_item(item: Self::Item) -> Self { + Self { + #body_from_item } } } + } } diff --git a/burn-derive/src/shared/attribute.rs b/burn-derive/src/shared/attribute.rs index b678b3a7ec..e1cd0d6c33 100644 --- a/burn-derive/src/shared/attribute.rs +++ b/burn-derive/src/shared/attribute.rs @@ -1,53 +1,53 @@ use syn::{Attribute, Ident, Meta}; pub struct AttributeAnalyzer { - attr: Attribute, + attr: Attribute, } #[derive(Clone)] pub struct AttributeItem { - pub ident: Ident, - pub value: syn::Lit, + pub ident: Ident, + pub value: syn::Lit, } impl AttributeAnalyzer { - pub fn new(attr: Attribute) -> Self { - Self { attr } + pub fn new(attr: Attribute) -> Self { + Self { attr } + } + + pub fn item(&self) -> AttributeItem { + let value = match &self.attr.meta { + Meta::List(val) => val.parse_args::().unwrap(), + Meta::NameValue(meta) => meta.clone(), + Meta::Path(_) => panic!("Path meta unsupported"), + }; + + let lit = match value.value { + syn::Expr::Lit(lit) => lit.lit, + _ => panic!("Only literal is supported"), + }; + + AttributeItem { + ident: value.path.get_ident().unwrap().clone(), + value: lit, } - - pub fn item(&self) -> AttributeItem { - let value = match &self.attr.meta { - Meta::List(val) => val.parse_args::().unwrap(), - Meta::NameValue(meta) => meta.clone(), - Meta::Path(_) => panic!("Path meta unsupported"), - }; - - let lit = match value.value { - syn::Expr::Lit(lit) => lit.lit, - _ => panic!("Only literal is supported"), - }; - - AttributeItem { - ident: value.path.get_ident().unwrap().clone(), - value: lit, - } - } - - pub fn has_name(&self, name: &str) -> bool { - Self::path_syn_name(self.attr.path()) == name - } - - fn path_syn_name(path: &syn::Path) -> String { - let length = path.segments.len(); - let mut name = String::new(); - for (i, segment) in path.segments.iter().enumerate() { - if i == length - 1 { - name += segment.ident.to_string().as_str(); - } else { - let tmp = segment.ident.to_string() + "::"; - name += tmp.as_str(); - } - } - name + } + + pub fn has_name(&self, name: &str) -> bool { + Self::path_syn_name(self.attr.path()) == name + } + + fn path_syn_name(path: &syn::Path) -> String { + let length = path.segments.len(); + let mut name = String::new(); + for (i, segment) in path.segments.iter().enumerate() { + if i == length - 1 { + name += segment.ident.to_string().as_str(); + } else { + let tmp = segment.ident.to_string() + "::"; + name += tmp.as_str(); + } } + name + } } diff --git a/burn-derive/src/shared/field.rs b/burn-derive/src/shared/field.rs index 274cabcc15..b48a3c3e2e 100644 --- a/burn-derive/src/shared/field.rs +++ b/burn-derive/src/shared/field.rs @@ -5,101 +5,103 @@ use syn::{Field, Type, TypePath}; #[derive(Clone)] pub struct FieldTypeAnalyzer { - pub field: Field, + pub field: Field, } impl FieldTypeAnalyzer { - pub fn new(field: Field) -> Self { - FieldTypeAnalyzer { field } - } + pub fn new(field: Field) -> Self { + FieldTypeAnalyzer { field } + } - pub fn ident(&self) -> Ident { - self.field.ident.clone().unwrap() - } + pub fn ident(&self) -> Ident { + self.field.ident.clone().unwrap() + } - pub fn is_of_type(&self, paths: &[&str]) -> bool { - match &self.field.ty { - syn::Type::Path(path) => { - let name = Self::path_name(path); - paths.contains(&name.as_str()) - } - _ => false, - } + pub fn is_of_type(&self, paths: &[&str]) -> bool { + match &self.field.ty { + syn::Type::Path(path) => { + let name = Self::path_name(path); + paths.contains(&name.as_str()) + } + _ => false, } + } - #[allow(dead_code)] - pub fn first_generic_field(&self) -> TypePath { - let err = || panic!("Field {} as no generic", self.field.ident.clone().unwrap()); - match &self.field.ty { - syn::Type::Path(path) => Self::path_generic_argument(path), - _ => err(), - } + #[allow(dead_code)] + pub fn first_generic_field(&self) -> TypePath { + let err = || panic!("Field {} as no generic", self.field.ident.clone().unwrap()); + match &self.field.ty { + syn::Type::Path(path) => Self::path_generic_argument(path), + _ => err(), } - pub fn path_generic_argument(path: &TypePath) -> TypePath { - let segment = path.path.segments.last().unwrap(); - let err = || panic!("Path segment {} has no generic", segment.ident.clone(),); - match &segment.arguments { - syn::PathArguments::None => err(), - syn::PathArguments::AngleBracketed(param) => { - let first_param = param.args.first().unwrap(); + } + pub fn path_generic_argument(path: &TypePath) -> TypePath { + let segment = path.path.segments.last().unwrap(); + let err = || panic!("Path segment {} has no generic", segment.ident.clone(),); + match &segment.arguments { + syn::PathArguments::None => err(), + syn::PathArguments::AngleBracketed(param) => { + let first_param = param.args.first().unwrap(); - if let syn::GenericArgument::Type(Type::Path(path)) = first_param { - path.clone() - } else { - err() - } - } - syn::PathArguments::Parenthesized(_) => err(), + if let syn::GenericArgument::Type(Type::Path(path)) = first_param { + path.clone() + } else { + err() } + } + syn::PathArguments::Parenthesized(_) => err(), } + } - fn path_name(path: &TypePath) -> String { - let length = path.path.segments.len(); - let mut name = String::new(); - for (i, segment) in path.path.segments.iter().enumerate() { - if i == length - 1 { - name += segment.ident.to_string().as_str(); - } else { - let tmp = segment.ident.to_string() + "::"; - name += tmp.as_str(); - } - } - name + fn path_name(path: &TypePath) -> String { + let length = path.path.segments.len(); + let mut name = String::new(); + for (i, segment) in path.path.segments.iter().enumerate() { + if i == length - 1 { + name += segment.ident.to_string().as_str(); + } else { + let tmp = segment.ident.to_string() + "::"; + name += tmp.as_str(); + } } + name + } - /// Returns the doc of the field if present. - pub fn doc(&self) -> Option { - self.field - .attrs - .iter() - .find(|attr| attr.path().is_ident("doc")) - .map(|doc| { - quote! { - #doc - } - }) - } + /// Returns the doc of the field if present. + pub fn doc(&self) -> Option { + self + .field + .attrs + .iter() + .find(|attr| attr.path().is_ident("doc")) + .map(|doc| { + quote! { + #doc + } + }) + } - pub fn attributes(&self) -> impl Iterator { - self.field - .attrs - .clone() - .into_iter() - .map(AttributeAnalyzer::new) - } + pub fn attributes(&self) -> impl Iterator { + self + .field + .attrs + .clone() + .into_iter() + .map(AttributeAnalyzer::new) + } } pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec { - let mut fields = Vec::new(); + let mut fields = Vec::new(); - match &ast.data { - syn::Data::Struct(struct_data) => { - for field in struct_data.fields.iter() { - fields.push(field.clone()); - } - } - syn::Data::Enum(_) => panic!("Only struct can be derived"), - syn::Data::Union(_) => panic!("Only struct can be derived"), - }; - fields + match &ast.data { + syn::Data::Struct(struct_data) => { + for field in struct_data.fields.iter() { + fields.push(field.clone()); + } + } + syn::Data::Enum(_) => panic!("Only struct can be derived"), + syn::Data::Union(_) => panic!("Only struct can be derived"), + }; + fields } diff --git a/burn-fusion/src/backend.rs b/burn-fusion/src/backend.rs index 1a9a30604a..2db49004d3 100644 --- a/burn-fusion/src/backend.rs +++ b/burn-fusion/src/backend.rs @@ -1,6 +1,6 @@ use crate::{ - client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor, - HandleContainer, + client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor, + HandleContainer, }; use burn_tensor::{backend::Backend, Device, Shape}; use core::marker::PhantomData; @@ -9,62 +9,62 @@ use std::sync::Arc; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); pub(crate) fn get_client(device: &B::FusionDevice) -> B::FusionClient { - CLIENTS.client(device) + CLIENTS.client(device) } /// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend). #[derive(Clone, Debug, Default)] pub struct Fusion { - _backend: PhantomData, + _backend: PhantomData, } impl Backend for Fusion { - type Device = B::Device; + type Device = B::Device; - // TODO: Find a better way to handle full precision. - type FullPrecisionBackend = Self; - type FullPrecisionElem = B::FloatElem; + // TODO: Find a better way to handle full precision. + type FullPrecisionBackend = Self; + type FullPrecisionElem = B::FloatElem; - type TensorPrimitive = FusionTensor; + type TensorPrimitive = FusionTensor; - type FloatElem = B::FloatElem; + type FloatElem = B::FloatElem; - type IntTensorPrimitive = FusionTensor; + type IntTensorPrimitive = FusionTensor; - type IntElem = B::IntElem; + type IntElem = B::IntElem; - type BoolTensorPrimitive = FusionTensor; + type BoolTensorPrimitive = FusionTensor; - fn name() -> String { - format!("fusion<{}>", B::name()) - } + fn name() -> String { + format!("fusion<{}>", B::name()) + } - fn seed(seed: u64) { - B::seed(seed); - } + fn seed(seed: u64) { + B::seed(seed); + } - fn sync(device: &Self::Device) { - let client = CLIENTS.client::(&device.clone().into()); - client.drain_graph(); - B::sync(device) - } + fn sync(device: &Self::Device) { + let client = CLIENTS.client::(&device.clone().into()); + client.drain_graph(); + B::sync(device) + } } /// The status of a [fusion ops](FusionOps). pub enum FusionStatus { - /// No more operations can be fused. - Closed(FusionProperties), - /// More operations can be fused. - Open(FusionProperties), + /// No more operations can be fused. + Closed(FusionProperties), + /// More operations can be fused. + Open(FusionProperties), } /// The properties of a [fusion ops](FusionOps). #[derive(Debug, Clone, Copy, Default)] pub struct FusionProperties { - /// The score of the optimization, higher is better. - pub score: u64, - /// If the operation is ready to be executed. - pub ready: bool, + /// The score of the optimization, higher is better. + pub score: u64, + /// If the operation is ready to be executed. + pub ready: bool, } /// The fusion operation abstraction allows implementations to fuse many @@ -80,77 +80,77 @@ pub struct FusionProperties { /// Also, it is important to return (FusionStatus::Closed) when no more registered operation can /// improve the performance. pub trait FusionOps: Send { - /// Register a new [tensor operation](TensorOpsDescription). - /// - /// The return value should be either [closed](FusionStatus::Closed) or - /// [open](FusionStatus::Open). - /// - /// When [closed](FusionStatus::Closed), it's assumed that no more operation can be added - /// to the current fusion operation. No [tensor operation](TensorOpsDescription) can be - /// ignored, they are either accepted or rejected, and the [status](FusionStatus) describes it. - fn register(&mut self, ops: Arc>) -> FusionStatus; - /// Execute the operation. - fn execute(&mut self, handles: &mut HandleContainer); - /// Reset the state. - fn reset(&mut self); - /// The size of operations fused. - fn len(&self) -> usize; - /// If the current operation is empty. - fn is_empty(&self) -> bool { - self.len() == 0 - } + /// Register a new [tensor operation](TensorOpsDescription). + /// + /// The return value should be either [closed](FusionStatus::Closed) or + /// [open](FusionStatus::Open). + /// + /// When [closed](FusionStatus::Closed), it's assumed that no more operation can be added + /// to the current fusion operation. No [tensor operation](TensorOpsDescription) can be + /// ignored, they are either accepted or rejected, and the [status](FusionStatus) describes it. + fn register(&mut self, ops: Arc>) -> FusionStatus; + /// Execute the operation. + fn execute(&mut self, handles: &mut HandleContainer); + /// Reset the state. + fn reset(&mut self); + /// The size of operations fused. + fn len(&self) -> usize; + /// If the current operation is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } } /// The device id. #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)] pub struct DeviceId { - /// The type id identifies the type of the device. - pub type_id: u16, - /// The index id identifies the device number. - pub index_id: u32, + /// The type id identifies the type of the device. + pub type_id: u16, + /// The index id identifies the device number. + pub index_id: u32, } /// The handle device trait allows to get an id for a backend device. pub trait FusionDevice: Clone + Send + Sync + PartialEq { - /// Return the [device id](DeviceId). - fn id(&self) -> DeviceId; + /// Return the [device id](DeviceId). + fn id(&self) -> DeviceId; } /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [fusion operation](crate::FusionOps). pub trait FusionBackend: Backend { - /// The device type that can return an ID. - /// - /// It can be the same as (Backend::Device), but must implement (FusionDevice). - type FusionDevice: FusionDevice + From + Into + core::fmt::Debug; - /// The type that can be used to point to a tensor of any kind. - type Handle: Sync + Send + Clone; - /// What kind of client should be used. - type FusionClient: FusionClient; - - /// The list of operations that will be used to optimize the computational graph. - fn operations(device: &Device) -> Vec>>; - - /// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive). - fn float_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::TensorPrimitive; - /// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive). - fn int_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::IntTensorPrimitive; - /// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). - fn bool_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::BoolTensorPrimitive; - - /// Convert a [float tensor](Backend::TensorPrimitive) to a [handle](FusionBackend::Handle). - fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle; - /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle). - fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle; - /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle). - fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle; + /// The device type that can return an ID. + /// + /// It can be the same as (Backend::Device), but must implement (FusionDevice). + type FusionDevice: FusionDevice + From + Into + core::fmt::Debug; + /// The type that can be used to point to a tensor of any kind. + type Handle: Sync + Send + Clone; + /// What kind of client should be used. + type FusionClient: FusionClient; + + /// The list of operations that will be used to optimize the computational graph. + fn operations(device: &Device) -> Vec>>; + + /// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive). + fn float_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::TensorPrimitive; + /// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive). + fn int_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::IntTensorPrimitive; + /// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). + fn bool_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::BoolTensorPrimitive; + + /// Convert a [float tensor](Backend::TensorPrimitive) to a [handle](FusionBackend::Handle). + fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle; + /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle). + fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle; + /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle). + fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle; } diff --git a/burn-fusion/src/client/base.rs b/burn-fusion/src/client/base.rs index 778c71030f..e8329ff5cf 100644 --- a/burn-fusion/src/client/base.rs +++ b/burn-fusion/src/client/base.rs @@ -1,65 +1,65 @@ use crate::{ - graph::{GraphExecution, TensorOpsDescription}, - FusionBackend, FusionTensor, Handle, TensorDescription, TensorId, + graph::{GraphExecution, TensorOpsDescription}, + FusionBackend, FusionTensor, Handle, TensorDescription, TensorId, }; use burn_tensor::{ - ops::{FloatElem, IntElem}, - Data, Reader, + ops::{FloatElem, IntElem}, + Data, Reader, }; /// Define how to interact with the fusion server. pub trait FusionClient: Send + Sync + Clone { - /// The [fusion backend](FusionBackend) associated type. - type FusionBackend: FusionBackend; - /// The [graph execution](GraphExecution) associated type. - type GraphExecution: GraphExecution; + /// The [fusion backend](FusionBackend) associated type. + type FusionBackend: FusionBackend; + /// The [graph execution](GraphExecution) associated type. + type GraphExecution: GraphExecution; - /// Create a new client for the given [fusion device](FusionBackend::FusionDevice). - fn new(device: ::FusionDevice) -> Self; - /// Register a new [tensor operation description](TensorOpsDescription). - fn register(&self, ops: TensorOpsDescription); - /// Register all lazy computation. - fn drain_graph(&self); - /// Get the current device used by all operations handled by this client. - fn device(&self) -> &::FusionDevice; - /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; - /// Create a tensor with the given handle and shape. - fn register_tensor( - &self, - handle: Handle, - shape: Vec, - ) -> FusionTensor; - /// Read the values contained by a float tensor. - fn read_tensor_float( - &self, - tensor: TensorDescription, - ) -> Reader, D>>; - /// Read the values contained by an int tensor. - fn read_tensor_int( - &self, - tensor: TensorDescription, - ) -> Reader, D>>; - /// Read the values contained by a bool tensor. - fn read_tensor_bool(&self, tensor: TensorDescription) -> Reader>; - /// Change the client of the given float tensor. - fn change_client_float( - &self, - tensor: TensorDescription, - client: Self, - ) -> FusionTensor; - /// Change the client of the given int tensor. - fn change_client_int( - &self, - tensor: TensorDescription, - client: Self, - ) -> FusionTensor; - /// Change the client of the given bool tensor. - fn change_client_bool( - &self, - tensor: TensorDescription, - client: Self, - ) -> FusionTensor; - /// Drop the tensor with the given [tensor id](TensorId). - fn register_orphan(&self, id: &TensorId); + /// Create a new client for the given [fusion device](FusionBackend::FusionDevice). + fn new(device: ::FusionDevice) -> Self; + /// Register a new [tensor operation description](TensorOpsDescription). + fn register(&self, ops: TensorOpsDescription); + /// Register all lazy computation. + fn drain_graph(&self); + /// Get the current device used by all operations handled by this client. + fn device(&self) -> &::FusionDevice; + /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. + fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; + /// Create a tensor with the given handle and shape. + fn register_tensor( + &self, + handle: Handle, + shape: Vec, + ) -> FusionTensor; + /// Read the values contained by a float tensor. + fn read_tensor_float( + &self, + tensor: TensorDescription, + ) -> Reader, D>>; + /// Read the values contained by an int tensor. + fn read_tensor_int( + &self, + tensor: TensorDescription, + ) -> Reader, D>>; + /// Read the values contained by a bool tensor. + fn read_tensor_bool(&self, tensor: TensorDescription) -> Reader>; + /// Change the client of the given float tensor. + fn change_client_float( + &self, + tensor: TensorDescription, + client: Self, + ) -> FusionTensor; + /// Change the client of the given int tensor. + fn change_client_int( + &self, + tensor: TensorDescription, + client: Self, + ) -> FusionTensor; + /// Change the client of the given bool tensor. + fn change_client_bool( + &self, + tensor: TensorDescription, + client: Self, + ) -> FusionTensor; + /// Drop the tensor with the given [tensor id](TensorId). + fn register_orphan(&self, id: &TensorId); } diff --git a/burn-fusion/src/client/mutex.rs b/burn-fusion/src/client/mutex.rs index db4bceb55a..aa7408181a 100644 --- a/burn-fusion/src/client/mutex.rs +++ b/burn-fusion/src/client/mutex.rs @@ -1,7 +1,7 @@ use super::FusionClient; use crate::{ - graph::{GraphExecution, TensorOpsDescription}, - FusionBackend, FusionServer, FusionTensor, Handle, + graph::{GraphExecution, TensorOpsDescription}, + FusionBackend, FusionServer, FusionTensor, Handle, }; use burn_tensor::ops::FloatElem; use spin::Mutex; @@ -10,150 +10,149 @@ use std::sync::Arc; /// Use a mutex to communicate with the fusion server. pub struct MutexFusionClient where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - server: Arc>>, - device: B::FusionDevice, + server: Arc>>, + device: B::FusionDevice, } impl Clone for MutexFusionClient where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - device: self.device.clone(), - } + fn clone(&self) -> Self { + Self { + server: self.server.clone(), + device: self.device.clone(), } + } } impl FusionClient for MutexFusionClient where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - type FusionBackend = B; - type GraphExecution = G; - - fn new(device: B::FusionDevice) -> Self { - Self { - device: device.clone(), - server: Arc::new(Mutex::new(FusionServer::new(device))), - } - } - - fn register(&self, ops: TensorOpsDescription) { - self.server.lock().register(ops); - } - - fn drain_graph(&self) { - self.server.lock().drain_graph(); - } - - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor { - let id = self.server.lock().create_empty_handle(); - - FusionTensor::new(id, shape, self.clone()) - } - - fn device(&self) -> &::FusionDevice { - &self.device - } - fn register_tensor( - &self, - handle: Handle, - shape: Vec, - ) -> FusionTensor { - let mut server = self.server.lock(); - let id = server.create_empty_handle(); - server.handles.register_handle(id.as_ref().clone(), handle); - core::mem::drop(server); - - FusionTensor::new(id, shape, self.clone()) - } - - fn read_tensor_float( - &self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> { - self.server.lock().read_float(tensor) - } - - fn read_tensor_int( - &self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> - { - self.server.lock().read_int(tensor) - } - - fn read_tensor_bool( - &self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader> { - self.server.lock().read_bool(tensor) - } - - fn change_client_float( - &self, - tensor: crate::TensorDescription, - client: Self, - ) -> FusionTensor { - let device = client.device.clone().into(); - - let mut other_server = client.server.lock(); - - let id = self - .server - .lock() - .change_server_float::(&tensor, &device, &mut other_server); - - core::mem::drop(other_server); - - FusionTensor::new(id, tensor.shape, client) - } - fn change_client_int( - &self, - tensor: crate::TensorDescription, - client: Self, - ) -> FusionTensor { - let device = client.device.clone().into(); - - let mut other_server = client.server.lock(); - - let id = self - .server - .lock() - .change_server_int::(&tensor, &device, &mut other_server); - - core::mem::drop(other_server); - - FusionTensor::new(id, tensor.shape, client) - } - - fn change_client_bool( - &self, - tensor: crate::TensorDescription, - client: Self, - ) -> FusionTensor { - let device = client.device.clone().into(); - - let mut other_server = client.server.lock(); - - let id = self - .server - .lock() - .change_server_bool::(&tensor, &device, &mut other_server); - - core::mem::drop(other_server); - - FusionTensor::new(id, tensor.shape, client) - } + type FusionBackend = B; + type GraphExecution = G; - fn register_orphan(&self, id: &crate::TensorId) { - self.server.lock().drop_tensor_handle(id.clone()); + fn new(device: B::FusionDevice) -> Self { + Self { + device: device.clone(), + server: Arc::new(Mutex::new(FusionServer::new(device))), } + } + + fn register(&self, ops: TensorOpsDescription) { + self.server.lock().register(ops); + } + + fn drain_graph(&self) { + self.server.lock().drain_graph(); + } + + fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor { + let id = self.server.lock().create_empty_handle(); + + FusionTensor::new(id, shape, self.clone()) + } + + fn device(&self) -> &::FusionDevice { + &self.device + } + fn register_tensor( + &self, + handle: Handle, + shape: Vec, + ) -> FusionTensor { + let mut server = self.server.lock(); + let id = server.create_empty_handle(); + server.handles.register_handle(id.as_ref().clone(), handle); + core::mem::drop(server); + + FusionTensor::new(id, shape, self.clone()) + } + + fn read_tensor_float( + &self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> { + self.server.lock().read_float(tensor) + } + + fn read_tensor_int( + &self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> { + self.server.lock().read_int(tensor) + } + + fn read_tensor_bool( + &self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader> { + self.server.lock().read_bool(tensor) + } + + fn change_client_float( + &self, + tensor: crate::TensorDescription, + client: Self, + ) -> FusionTensor { + let device = client.device.clone().into(); + + let mut other_server = client.server.lock(); + + let id = self + .server + .lock() + .change_server_float::(&tensor, &device, &mut other_server); + + core::mem::drop(other_server); + + FusionTensor::new(id, tensor.shape, client) + } + fn change_client_int( + &self, + tensor: crate::TensorDescription, + client: Self, + ) -> FusionTensor { + let device = client.device.clone().into(); + + let mut other_server = client.server.lock(); + + let id = self + .server + .lock() + .change_server_int::(&tensor, &device, &mut other_server); + + core::mem::drop(other_server); + + FusionTensor::new(id, tensor.shape, client) + } + + fn change_client_bool( + &self, + tensor: crate::TensorDescription, + client: Self, + ) -> FusionTensor { + let device = client.device.clone().into(); + + let mut other_server = client.server.lock(); + + let id = self + .server + .lock() + .change_server_bool::(&tensor, &device, &mut other_server); + + core::mem::drop(other_server); + + FusionTensor::new(id, tensor.shape, client) + } + + fn register_orphan(&self, id: &crate::TensorId) { + self.server.lock().drop_tensor_handle(id.clone()); + } } diff --git a/burn-fusion/src/fusion.rs b/burn-fusion/src/fusion.rs index f26630a561..bc5ea0937e 100644 --- a/burn-fusion/src/fusion.rs +++ b/burn-fusion/src/fusion.rs @@ -6,65 +6,65 @@ pub type Handle = ::Handle; type Key = (core::any::TypeId, DeviceId); pub(crate) struct FusionClientLocator { - clients: spin::Mutex>>>, + clients: spin::Mutex>>>, } impl FusionClientLocator { - /// Create a new client locator. - pub const fn new() -> Self { - Self { - clients: spin::Mutex::new(None), - } + /// Create a new client locator. + pub const fn new() -> Self { + Self { + clients: spin::Mutex::new(None), } + } - /// Get the fusion client for the given device. - /// - /// Provide the init function to create a new client if it isn't already initialized. - pub fn client( - &self, - device: &::FusionDevice, - ) -> C { - let device_id = device.id(); - let client_id = (core::any::TypeId::of::(), device_id); - let mut clients = self.clients.lock(); + /// Get the fusion client for the given device. + /// + /// Provide the init function to create a new client if it isn't already initialized. + pub fn client( + &self, + device: &::FusionDevice, + ) -> C { + let device_id = device.id(); + let client_id = (core::any::TypeId::of::(), device_id); + let mut clients = self.clients.lock(); - if clients.is_none() { - let client = C::new(device.clone()); - Self::register_inner::(client_id, client, &mut clients); - } + if clients.is_none() { + let client = C::new(device.clone()); + Self::register_inner::(client_id, client, &mut clients); + } - match clients.deref_mut() { - Some(clients) => match clients.get(&client_id) { - Some(client) => { - let client: &C = client.downcast_ref().unwrap(); - client.clone() - } - None => { - let client = C::new(device.clone()); - let any = Box::new(client.clone()); - clients.insert(client_id, any); - client - } - }, - _ => unreachable!(), + match clients.deref_mut() { + Some(clients) => match clients.get(&client_id) { + Some(client) => { + let client: &C = client.downcast_ref().unwrap(); + client.clone() + } + None => { + let client = C::new(device.clone()); + let any = Box::new(client.clone()); + clients.insert(client_id, any); + client } + }, + _ => unreachable!(), } + } - fn register_inner( - key: Key, - client: C, - clients: &mut Option>>, - ) { - if clients.is_none() { - *clients = Some(HashMap::new()); - } + fn register_inner( + key: Key, + client: C, + clients: &mut Option>>, + ) { + if clients.is_none() { + *clients = Some(HashMap::new()); + } - if let Some(clients) = clients { - if clients.contains_key(&key) { - panic!("Client already created for device {:?}", key); - } + if let Some(clients) = clients { + if clients.contains_key(&key) { + panic!("Client already created for device {:?}", key); + } - clients.insert(key, Box::new(client)); - } + clients.insert(key, Box::new(client)); } + } } diff --git a/burn-fusion/src/graph/base.rs b/burn-fusion/src/graph/base.rs index 28bcade60f..43fe5677cb 100644 --- a/burn-fusion/src/graph/base.rs +++ b/burn-fusion/src/graph/base.rs @@ -4,95 +4,95 @@ use std::{ops::RangeBounds, sync::Arc, vec::Drain}; /// The computational graph containing a list of [tensor operation descriptions](TensorOpsDescription). pub struct Graph { - operations: Vec>>, + operations: Vec>>, } impl Graph { - pub(crate) fn new() -> Self { - Self { - operations: Vec::new(), - } + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), } - pub(crate) fn add(&mut self, ops: Arc>) { - self.operations.push(ops); + } + pub(crate) fn add(&mut self, ops: Arc>) { + self.operations.push(ops); + } + + /// The size of the graph. + pub fn len(&self) -> usize { + self.operations.len() + } + + /// If the graph is empty. + pub fn is_empty(&self) -> bool { + self.operations.len() == 0 + } + + fn drain(&mut self, range: R) -> Drain<'_, Arc>> + where + R: RangeBounds, + { + self.operations.drain(range) + } + + fn remove>(&mut self, range: R, handles: &mut HandleContainer) { + for ops in self.operations.drain(range) { + ops.cleanup_tensor(handles) } - - /// The size of the graph. - pub fn len(&self) -> usize { - self.operations.len() - } - - /// If the graph is empty. - pub fn is_empty(&self) -> bool { - self.operations.len() == 0 - } - - fn drain(&mut self, range: R) -> Drain<'_, Arc>> - where - R: RangeBounds, - { - self.operations.drain(range) - } - - fn remove>(&mut self, range: R, handles: &mut HandleContainer) { - for ops in self.operations.drain(range) { - ops.cleanup_tensor(handles) - } + } + + fn nodes(&self) -> &[Arc>] { + &self.operations + } + + pub(crate) fn execute_optimization( + &mut self, + handles: &mut HandleContainer, + index: usize, + optimizations: &mut [Optimization], + ) { + let optimization = optimizations.get_mut(index).unwrap(); + let num_keep = optimization.ops.len(); + optimization.ops.execute(handles); + + self.remove(0..num_keep, handles); + + for optimization in optimizations.iter_mut() { + optimization.reset(); + + for node in self.nodes() { + optimization.register(node); + } } + } - fn nodes(&self) -> &[Arc>] { - &self.operations - } - - pub(crate) fn execute_optimization( - &mut self, - handles: &mut HandleContainer, - index: usize, - optimizations: &mut [Optimization], - ) { - let optimization = optimizations.get_mut(index).unwrap(); - let num_keep = optimization.ops.len(); - optimization.ops.execute(handles); - - self.remove(0..num_keep, handles); - - for optimization in optimizations.iter_mut() { - optimization.reset(); - - for node in self.nodes() { - optimization.register(node); - } - } - } - - pub(crate) fn execute(&mut self, handles: &mut HandleContainer) { - for ops in self.drain(..) { - ops.execute(handles); - ops.cleanup_tensor(handles); - } + pub(crate) fn execute(&mut self, handles: &mut HandleContainer) { + for ops in self.drain(..) { + ops.execute(handles); + ops.cleanup_tensor(handles); } + } } /// An optimization that can be executed. #[derive(new)] pub struct Optimization { - /// The [fusion operation](FusionOps) to potentially be executed. - pub ops: Box>, - /// The current status of the optimization. - pub status: FusionStatus, + /// The [fusion operation](FusionOps) to potentially be executed. + pub ops: Box>, + /// The current status of the optimization. + pub status: FusionStatus, } impl Optimization { - pub(crate) fn register(&mut self, ops: &Arc>) { - if let FusionStatus::Closed(_) = self.status { - return; - } - - self.status = self.ops.register(ops.clone()); + pub(crate) fn register(&mut self, ops: &Arc>) { + if let FusionStatus::Closed(_) = self.status { + return; } - pub(crate) fn reset(&mut self) { - self.ops.reset(); - self.status = FusionStatus::Open(FusionProperties::default()); - } + self.status = self.ops.register(ops.clone()); + } + + pub(crate) fn reset(&mut self) { + self.ops.reset(); + self.status = FusionStatus::Open(FusionProperties::default()); + } } diff --git a/burn-fusion/src/graph/execution.rs b/burn-fusion/src/graph/execution.rs index 36cbf1a6d3..85c5159a3c 100644 --- a/burn-fusion/src/graph/execution.rs +++ b/burn-fusion/src/graph/execution.rs @@ -3,15 +3,15 @@ use crate::{FusionBackend, FusionStatus, HandleContainer}; /// The graph execution trait abstracts the way the graph is executing optimizations. pub trait GraphExecution: Default + Send { - /// Execute the given graph using the list of potential [optimizations](Optimization). - /// May do nothing if empty or not ready - fn maybe_execute( - &mut self, - graph: &mut Graph, - handles: &mut HandleContainer, - optimizations: &mut [Optimization], - force: bool, - ); + /// Execute the given graph using the list of potential [optimizations](Optimization). + /// May do nothing if empty or not ready + fn maybe_execute( + &mut self, + graph: &mut Graph, + handles: &mut HandleContainer, + optimizations: &mut [Optimization], + force: bool, + ); } /// Execute an optimization following a greedy algorithm. @@ -19,65 +19,65 @@ pub trait GraphExecution: Default + Send { pub struct GreedyGraphExecution; impl GraphExecution for GreedyGraphExecution { - fn maybe_execute( - &mut self, - graph: &mut Graph, - handles: &mut HandleContainer, - optimizations: &mut [Optimization], - force: bool, - ) { - loop { - if !force && still_optimizing(optimizations) { - break; - } + fn maybe_execute( + &mut self, + graph: &mut Graph, + handles: &mut HandleContainer, + optimizations: &mut [Optimization], + force: bool, + ) { + loop { + if !force && still_optimizing(optimizations) { + break; + } - match find_best_optimization_index(optimizations) { - Some(index) => { - graph.execute_optimization(handles, index, optimizations); - } - None => { - graph.execute(handles); - optimizations.iter_mut().for_each(|ops| ops.reset()); - } - } - - if graph.is_empty() { - // No more ops to fuse. - break; - } + match find_best_optimization_index(optimizations) { + Some(index) => { + graph.execute_optimization(handles, index, optimizations); + } + None => { + graph.execute(handles); + optimizations.iter_mut().for_each(|ops| ops.reset()); } + } + + if graph.is_empty() { + // No more ops to fuse. + break; + } } + } } fn still_optimizing(optimizations: &[Optimization]) -> bool { - let mut num_stopped = 0; + let mut num_stopped = 0; - for optimization in optimizations.iter() { - if let FusionStatus::Closed(_) = optimization.status { - num_stopped += 1 - } + for optimization in optimizations.iter() { + if let FusionStatus::Closed(_) = optimization.status { + num_stopped += 1 } + } - num_stopped < optimizations.len() + num_stopped < optimizations.len() } fn find_best_optimization_index( - optimizations: &[Optimization], + optimizations: &[Optimization], ) -> Option { - let mut best_index = None; - let mut best_score = 0; + let mut best_index = None; + let mut best_score = 0; - for (i, optimization) in optimizations.iter().enumerate() { - let properties = match optimization.status { - FusionStatus::Closed(properties) => properties, - FusionStatus::Open(properties) => properties, - }; + for (i, optimization) in optimizations.iter().enumerate() { + let properties = match optimization.status { + FusionStatus::Closed(properties) => properties, + FusionStatus::Open(properties) => properties, + }; - if properties.ready && properties.score >= best_score { - best_index = Some(i); - best_score = properties.score; - } + if properties.ready && properties.score >= best_score { + best_index = Some(i); + best_score = properties.score; } + } - best_index + best_index } diff --git a/burn-fusion/src/graph/ops.rs b/burn-fusion/src/graph/ops.rs index 3f437bd1f5..ba1dd224bd 100644 --- a/burn-fusion/src/graph/ops.rs +++ b/burn-fusion/src/graph/ops.rs @@ -2,1487 +2,1479 @@ use crate::FusionBackend; use crate::{HandleContainer, TensorDescription}; use burn_tensor::ops::FloatElem; use burn_tensor::{ - ops::{ConvOptions, ConvTransposeOptions}, - Distribution, Element, + ops::{ConvOptions, ConvTransposeOptions}, + Distribution, Element, }; use core::hash::Hash; use std::ops::Range; /// General trait to abstract how a single operation is executed. pub trait Ops: Send + Sync { - /// The argument necessary for the execution to happen. - type Args: Send + Sync; + /// The argument necessary for the execution to happen. + type Args: Send + Sync; - /// Execute the operation. - fn execute(&self, args: &Self::Args, handles: &mut HandleContainer); + /// Execute the operation. + fn execute(&self, args: &Self::Args, handles: &mut HandleContainer); } /// Describe all tensor operations possible. pub enum TensorOpsDescription { - /// Basic operation on a float tensor. - BaseOpsFloat(BaseOpsDescription), - /// Basic operation on an int tensor. - BaseOpsInt(BaseOpsDescription), - /// Basic operation on a bool tensor. - BaseOpsBool(BaseOpsDescription), - /// Numeric operation on a float tensor. - NumericOpsFloat(NumericOpsDescription), - /// Numeric operation on an int tensor. - NumericOpsInt(NumericOpsDescription), - /// Operation specific to a bool tensor. - BoolOps(BoolOpsDescription), - /// Operation specific to an int tensor. - IntOps(IntOpsDescription), - /// Operation specific to a float tensor. - FloatOps(FloatOpsDescription), - /// Module operation. - ModuleOps(ModuleOpsDescription), + /// Basic operation on a float tensor. + BaseOpsFloat(BaseOpsDescription), + /// Basic operation on an int tensor. + BaseOpsInt(BaseOpsDescription), + /// Basic operation on a bool tensor. + BaseOpsBool(BaseOpsDescription), + /// Numeric operation on a float tensor. + NumericOpsFloat(NumericOpsDescription), + /// Numeric operation on an int tensor. + NumericOpsInt(NumericOpsDescription), + /// Operation specific to a bool tensor. + BoolOps(BoolOpsDescription), + /// Operation specific to an int tensor. + IntOps(IntOpsDescription), + /// Operation specific to a float tensor. + FloatOps(FloatOpsDescription), + /// Module operation. + ModuleOps(ModuleOpsDescription), } /// Operation description specific to a float tensor. pub enum FloatOpsDescription { - /// Operation corresponding to [exp](burn_tensor::ops::TensorOps::exp). - Exp( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [log](burn_tensor::ops::TensorOps::log). - Log( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [log1p](burn_tensor::ops::TensorOps::log1p). - Log1p( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [erf](burn_tensor::ops::TensorOps::erf). - Erf( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [powf](burn_tensor::ops::TensorOps::powf). - Powf( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to [sqrt](burn_tensor::ops::TensorOps::sqrt). - Sqrt( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [cos](burn_tensor::ops::TensorOps::cos). - Cos( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [sin](burn_tensor::ops::TensorOps::sin). - Sin( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [tanh](burn_tensor::ops::TensorOps::tanh). - Tanh( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [into_int](burn_tensor::ops::TensorOps::into_int). - IntoInt( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [matmul](burn_tensor::ops::TensorOps::matmul). - Matmul( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to [random](burn_tensor::ops::TensorOps::random). - Random( - (TensorDescription, Distribution>), - Box>)>>, - ), - /// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip). - Recip( - UnaryOpsDescription, - Box>, - ), + /// Operation corresponding to [exp](burn_tensor::ops::TensorOps::exp). + Exp( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [log](burn_tensor::ops::TensorOps::log). + Log( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [log1p](burn_tensor::ops::TensorOps::log1p). + Log1p( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [erf](burn_tensor::ops::TensorOps::erf). + Erf( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [powf](burn_tensor::ops::TensorOps::powf). + Powf( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to [sqrt](burn_tensor::ops::TensorOps::sqrt). + Sqrt( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [cos](burn_tensor::ops::TensorOps::cos). + Cos( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [sin](burn_tensor::ops::TensorOps::sin). + Sin( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [tanh](burn_tensor::ops::TensorOps::tanh). + Tanh( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [into_int](burn_tensor::ops::TensorOps::into_int). + IntoInt( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [matmul](burn_tensor::ops::TensorOps::matmul). + Matmul( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to [random](burn_tensor::ops::TensorOps::random). + Random( + (TensorDescription, Distribution>), + Box>)>>, + ), + /// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip). + Recip( + UnaryOpsDescription, + Box>, + ), } /// Operation description specific to module. pub enum ModuleOpsDescription { - /// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding). - Embedding( - EmbeddingDescription, - Box>, - ), - /// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward). - EmbeddingBackward( - EmbeddingBackwardDescription, - Box>, - ), - /// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d). - Conv1d(Conv1dDescription, Box>), - /// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d). - Conv2d(Conv2dDescription, Box>), - /// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d). - ConvTranspose1d( - ConvTranspose1dDescription, - Box>, - ), - /// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d). - ConvTranspose2d( - ConvTranspose2dDescription, - Box>, - ), - /// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d). - AvgPool1d( - AvgPool1dDescription, - Box>, - ), - /// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d). - AvgPool2d( - AvgPool2dDescription, - Box>, - ), - /// Operation corresponding to - /// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward). - AvgPool1dBackward( - AvgPool1dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward). - AvgPool2dBackward( - AvgPool2dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d). - AdaptiveAvgPool1d( - AdaptiveAvgPool1dDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d). - AdaptiveAvgPool2d( - AdaptiveAvgPool2dDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward). - AdaptiveAvgPool1dBackward( - AdaptiveAvgPool1dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward). - AdaptiveAvgPool2dBackward( - AdaptiveAvgPool2dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d). - MaxPool1d( - MaxPool1dDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices). - MaxPool1dWithIndices( - MaxPool1dWithIndicesDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward). - MaxPool1dWithIndicesBackward( - MaxPool1dWithIndicesBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d). - MaxPool2d( - MaxPool2dDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices). - MaxPool2dWithIndices( - MaxPool2dWithIndicesDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward). - MaxPool2dWithIndicesBackward( - MaxPool2dWithIndicesBackwardDescription, - Box>, - ), + /// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding). + Embedding( + EmbeddingDescription, + Box>, + ), + /// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward). + EmbeddingBackward( + EmbeddingBackwardDescription, + Box>, + ), + /// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d). + Conv1d(Conv1dDescription, Box>), + /// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d). + Conv2d(Conv2dDescription, Box>), + /// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d). + ConvTranspose1d( + ConvTranspose1dDescription, + Box>, + ), + /// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d). + ConvTranspose2d( + ConvTranspose2dDescription, + Box>, + ), + /// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d). + AvgPool1d( + AvgPool1dDescription, + Box>, + ), + /// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d). + AvgPool2d( + AvgPool2dDescription, + Box>, + ), + /// Operation corresponding to + /// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward). + AvgPool1dBackward( + AvgPool1dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward). + AvgPool2dBackward( + AvgPool2dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d). + AdaptiveAvgPool1d( + AdaptiveAvgPool1dDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d). + AdaptiveAvgPool2d( + AdaptiveAvgPool2dDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward). + AdaptiveAvgPool1dBackward( + AdaptiveAvgPool1dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward). + AdaptiveAvgPool2dBackward( + AdaptiveAvgPool2dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d). + MaxPool1d( + MaxPool1dDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices). + MaxPool1dWithIndices( + MaxPool1dWithIndicesDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward). + MaxPool1dWithIndicesBackward( + MaxPool1dWithIndicesBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d). + MaxPool2d( + MaxPool2dDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices). + MaxPool2dWithIndices( + MaxPool2dWithIndicesDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward). + MaxPool2dWithIndicesBackward( + MaxPool2dWithIndicesBackwardDescription, + Box>, + ), } /// Basic operations that can be done on any tensor type. pub enum BaseOpsDescription { - /// Operation corresponding to: - /// - /// Float => [to device](burn_tensor::ops::TensorOps::to_device). - /// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device). - /// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device). - ToDevice( - (TensorDescription, B::Device), - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [reshape](burn_tensor::ops::TensorOps::reshape). - /// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape). - /// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape). - Reshape( - ReshapeDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [swap_dims](burn_tensor::ops::TensorOps::swap_dims). - /// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims). - /// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims). - SwapDims( - SwapDimsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [slice](burn_tensor::ops::TensorOps::slice). - /// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice). - /// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice). - Slice( - SliceOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [slice assign](burn_tensor::ops::TensorOps::slice_assign). - /// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign). - /// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign). - SliceAssign( - SliceAssignOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [equal](burn_tensor::ops::TensorOps::equal). - /// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal). - /// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal). - Equal( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [repeat](burn_tensor::ops::TensorOps::repeat). - /// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat). - /// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat). - Repeat( - RepeatOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [cat](burn_tensor::ops::TensorOps::cat). - /// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat). - /// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat). - Cat(CatOpsDescription, Box>), + /// Operation corresponding to: + /// + /// Float => [to device](burn_tensor::ops::TensorOps::to_device). + /// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device). + /// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device). + ToDevice( + (TensorDescription, B::Device), + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [reshape](burn_tensor::ops::TensorOps::reshape). + /// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape). + /// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape). + Reshape( + ReshapeDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [swap_dims](burn_tensor::ops::TensorOps::swap_dims). + /// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims). + /// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims). + SwapDims( + SwapDimsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [slice](burn_tensor::ops::TensorOps::slice). + /// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice). + /// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice). + Slice( + SliceOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [slice assign](burn_tensor::ops::TensorOps::slice_assign). + /// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign). + /// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign). + SliceAssign( + SliceAssignOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [equal](burn_tensor::ops::TensorOps::equal). + /// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal). + /// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal). + Equal( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [repeat](burn_tensor::ops::TensorOps::repeat). + /// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat). + /// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat). + Repeat( + RepeatOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [cat](burn_tensor::ops::TensorOps::cat). + /// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat). + /// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat). + Cat(CatOpsDescription, Box>), } /// Numeric operations on int and float tensors. pub enum NumericOpsDescription { - /// Operation corresponding to: - /// - /// Float => [add](burn_tensor::ops::TensorOps::add). - /// Int => [add](burn_tensor::ops::IntTensorOps::int_add). - Add( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [add scalar](burn_tensor::ops::TensorOps::add_scalar). - /// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar). - AddScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [sub](burn_tensor::ops::TensorOps::sub). - /// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub). - Sub( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [sub scalar](burn_tensor::ops::TensorOps::sub_scalar). - /// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar). - SubScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [div](burn_tensor::ops::TensorOps::div). - /// Int => [div](burn_tensor::ops::IntTensorOps::int_div). - Div( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [div scalar](burn_tensor::ops::TensorOps::div_scalar). - /// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar). - DivScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [mul](burn_tensor::ops::TensorOps::mul). - /// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul). - Mul( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [mul scalar](burn_tensor::ops::TensorOps::mul_scalar). - /// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar). - MulScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [abs](burn_tensor::ops::TensorOps::abs). - /// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs). - Abs( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [ones](burn_tensor::ops::TensorOps::ones). - /// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones). - Ones(TensorDescription, Box>), - /// Operation corresponding to: - /// - /// Float => [zeros](burn_tensor::ops::TensorOps::zeros). - /// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros). - Zeros(TensorDescription, Box>), - /// Operation corresponding to: - /// - /// Float => [full](burn_tensor::ops::TensorOps::full). - /// Int => [full](burn_tensor::ops::IntTensorOps::int_full). - Full( - (TensorDescription, E), - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [gather](burn_tensor::ops::TensorOps::gather). - /// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather). - Gather( - GatherOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [scatter](burn_tensor::ops::TensorOps::scatter). - /// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter). - Scatter( - ScatterOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [select](burn_tensor::ops::TensorOps::select). - /// Int => [select](burn_tensor::ops::IntTensorOps::int_select). - Select( - SelectOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [select assign](burn_tensor::ops::TensorOps::select_assign). - /// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign). - SelectAssign( - SelectAssignOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [mask where](burn_tensor::ops::TensorOps::mask_where). - /// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where). - MaskWhere( - MaskWhereOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [mask fill](burn_tensor::ops::TensorOps::mask_fill). - /// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill). - MaskFill( - MaskFillOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [mean dim](burn_tensor::ops::TensorOps::mean_dim). - /// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim). - MeanDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [mean](burn_tensor::ops::TensorOps::mean). - /// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean). - Mean( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [sum](burn_tensor::ops::TensorOps::sum). - /// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum). - Sum( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [sum dim](burn_tensor::ops::TensorOps::sum_dim). - /// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim). - SumDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [equal elem](burn_tensor::ops::TensorOps::equal_elem). - /// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem). - EqualElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [greater](burn_tensor::ops::TensorOps::greater). - /// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater). - Greater( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [greater elem](burn_tensor::ops::TensorOps::greater_elem). - /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). - GreaterElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [greater equal](burn_tensor::ops::TensorOps::greater_elem). - /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). - GreaterEqual( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [greater equal elem](burn_tensor::ops::TensorOps::greater_equal_elem). - /// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem). - GreaterEqualElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [lower](burn_tensor::ops::TensorOps::lower). - /// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower). - Lower( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [lower elem](burn_tensor::ops::TensorOps::lower_elem). - /// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem). - LowerElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [lower equal](burn_tensor::ops::TensorOps::lower_equal). - /// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal). - LowerEqual( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [lower equal elem](burn_tensor::ops::TensorOps::lower_equal_elem). - /// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem). - LowerEqualElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [argmax](burn_tensor::ops::TensorOps::argmax). - /// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax). - ArgMax( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [argmin](burn_tensor::ops::TensorOps::argmin). - /// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin). - ArgMin( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [max](burn_tensor::ops::TensorOps::max). - /// Int => [max](burn_tensor::ops::IntTensorOps::int_max). - Max( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [max dim with indices](burn_tensor::ops::TensorOps::max_dim_with_indices). - /// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices). - MaxDimWithIndices( - ReduceDimWithIndicesDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [min dim with indices](burn_tensor::ops::TensorOps::min_dim_with_indices). - /// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices). - MinDimWithIndices( - ReduceDimWithIndicesDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [min](burn_tensor::ops::TensorOps::min). - /// Int => [min](burn_tensor::ops::IntTensorOps::int_min). - Min( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [max dim](burn_tensor::ops::TensorOps::max_dim). - /// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim). - MaxDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [min dim](burn_tensor::ops::TensorOps::min_dim). - /// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim). - MinDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [clamp](burn_tensor::ops::TensorOps::clamp). - /// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp). - Clamp( - ClampOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max). - /// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max). - ClampMax( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min). - /// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min). - ClampMin( - ScalarOpsDescription, - Box>>, - ), + /// Operation corresponding to: + /// + /// Float => [add](burn_tensor::ops::TensorOps::add). + /// Int => [add](burn_tensor::ops::IntTensorOps::int_add). + Add( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [add scalar](burn_tensor::ops::TensorOps::add_scalar). + /// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar). + AddScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [sub](burn_tensor::ops::TensorOps::sub). + /// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub). + Sub( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [sub scalar](burn_tensor::ops::TensorOps::sub_scalar). + /// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar). + SubScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [div](burn_tensor::ops::TensorOps::div). + /// Int => [div](burn_tensor::ops::IntTensorOps::int_div). + Div( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [div scalar](burn_tensor::ops::TensorOps::div_scalar). + /// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar). + DivScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [mul](burn_tensor::ops::TensorOps::mul). + /// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul). + Mul( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [mul scalar](burn_tensor::ops::TensorOps::mul_scalar). + /// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar). + MulScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [abs](burn_tensor::ops::TensorOps::abs). + /// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs). + Abs( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [ones](burn_tensor::ops::TensorOps::ones). + /// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones). + Ones(TensorDescription, Box>), + /// Operation corresponding to: + /// + /// Float => [zeros](burn_tensor::ops::TensorOps::zeros). + /// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros). + Zeros(TensorDescription, Box>), + /// Operation corresponding to: + /// + /// Float => [full](burn_tensor::ops::TensorOps::full). + /// Int => [full](burn_tensor::ops::IntTensorOps::int_full). + Full( + (TensorDescription, E), + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [gather](burn_tensor::ops::TensorOps::gather). + /// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather). + Gather( + GatherOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [scatter](burn_tensor::ops::TensorOps::scatter). + /// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter). + Scatter( + ScatterOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [select](burn_tensor::ops::TensorOps::select). + /// Int => [select](burn_tensor::ops::IntTensorOps::int_select). + Select( + SelectOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [select assign](burn_tensor::ops::TensorOps::select_assign). + /// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign). + SelectAssign( + SelectAssignOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [mask where](burn_tensor::ops::TensorOps::mask_where). + /// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where). + MaskWhere( + MaskWhereOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [mask fill](burn_tensor::ops::TensorOps::mask_fill). + /// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill). + MaskFill( + MaskFillOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [mean dim](burn_tensor::ops::TensorOps::mean_dim). + /// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim). + MeanDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [mean](burn_tensor::ops::TensorOps::mean). + /// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean). + Mean( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [sum](burn_tensor::ops::TensorOps::sum). + /// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum). + Sum( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [sum dim](burn_tensor::ops::TensorOps::sum_dim). + /// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim). + SumDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [equal elem](burn_tensor::ops::TensorOps::equal_elem). + /// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem). + EqualElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [greater](burn_tensor::ops::TensorOps::greater). + /// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater). + Greater( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [greater elem](burn_tensor::ops::TensorOps::greater_elem). + /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). + GreaterElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [greater equal](burn_tensor::ops::TensorOps::greater_elem). + /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). + GreaterEqual( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [greater equal elem](burn_tensor::ops::TensorOps::greater_equal_elem). + /// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem). + GreaterEqualElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [lower](burn_tensor::ops::TensorOps::lower). + /// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower). + Lower( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [lower elem](burn_tensor::ops::TensorOps::lower_elem). + /// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem). + LowerElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [lower equal](burn_tensor::ops::TensorOps::lower_equal). + /// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal). + LowerEqual( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [lower equal elem](burn_tensor::ops::TensorOps::lower_equal_elem). + /// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem). + LowerEqualElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [argmax](burn_tensor::ops::TensorOps::argmax). + /// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax). + ArgMax( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [argmin](burn_tensor::ops::TensorOps::argmin). + /// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin). + ArgMin( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [max](burn_tensor::ops::TensorOps::max). + /// Int => [max](burn_tensor::ops::IntTensorOps::int_max). + Max( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [max dim with indices](burn_tensor::ops::TensorOps::max_dim_with_indices). + /// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices). + MaxDimWithIndices( + ReduceDimWithIndicesDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [min dim with indices](burn_tensor::ops::TensorOps::min_dim_with_indices). + /// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices). + MinDimWithIndices( + ReduceDimWithIndicesDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [min](burn_tensor::ops::TensorOps::min). + /// Int => [min](burn_tensor::ops::IntTensorOps::int_min). + Min( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [max dim](burn_tensor::ops::TensorOps::max_dim). + /// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim). + MaxDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [min dim](burn_tensor::ops::TensorOps::min_dim). + /// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim). + MinDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [clamp](burn_tensor::ops::TensorOps::clamp). + /// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp). + Clamp( + ClampOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max). + /// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max). + ClampMax( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min). + /// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min). + ClampMin( + ScalarOpsDescription, + Box>>, + ), } /// Operation description specific to an int tensor. pub enum IntOpsDescription { - /// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float). - IntoFloat( - UnaryOpsDescription, - Box>, - ), + /// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float). + IntoFloat( + UnaryOpsDescription, + Box>, + ), } /// Operation description specific to a bool tensor. pub enum BoolOpsDescription { - /// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float). - IntoFloat( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int). - IntoInt( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not). - Not( - UnaryOpsDescription, - Box>, - ), + /// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float). + IntoFloat( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int). + IntoInt( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not). + Not( + UnaryOpsDescription, + Box>, + ), } #[derive(Hash)] /// Swap dim operation description. pub struct SwapDimsDescription { - /// Input tensor description. - pub input: TensorDescription, - /// output tensor description. - pub out: TensorDescription, - /// The first dim to swap. - pub dim1: usize, - /// The second dim to swap. - pub dim2: usize, + /// Input tensor description. + pub input: TensorDescription, + /// output tensor description. + pub out: TensorDescription, + /// The first dim to swap. + pub dim1: usize, + /// The second dim to swap. + pub dim2: usize, } #[derive(Hash)] #[allow(missing_docs)] pub struct ReshapeDescription { - pub input: TensorDescription, - pub out: TensorDescription, - pub shape: Vec, + pub input: TensorDescription, + pub out: TensorDescription, + pub shape: Vec, } #[derive(Hash)] #[allow(missing_docs)] pub struct BinaryOpsDescription { - pub lhs: TensorDescription, - pub rhs: TensorDescription, - pub out: TensorDescription, + pub lhs: TensorDescription, + pub rhs: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct UnaryOpsDescription { - pub input: TensorDescription, - pub out: TensorDescription, + pub input: TensorDescription, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct ScalarOpsDescription { - pub lhs: TensorDescription, - pub rhs: E, - pub out: TensorDescription, + pub lhs: TensorDescription, + pub rhs: E, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct GatherOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ScatterOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub value: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SelectOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SelectAssignOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub value: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SliceOpsDescription { - pub tensor: TensorDescription, - pub ranges: Vec>, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub ranges: Vec>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SliceAssignOpsDescription { - pub tensor: TensorDescription, - pub ranges: Vec>, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub ranges: Vec>, + pub value: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaskWhereOpsDescription { - pub tensor: TensorDescription, - pub mask: TensorDescription, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub mask: TensorDescription, + pub value: TensorDescription, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct MaskFillOpsDescription { - pub tensor: TensorDescription, - pub mask: TensorDescription, - pub value: E, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub mask: TensorDescription, + pub value: E, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct ClampOpsDescription { - pub tensor: TensorDescription, - pub min: E, - pub max: E, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub min: E, + pub max: E, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct RepeatOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub times: usize, - pub shape: Vec, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub times: usize, + pub shape: Vec, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct CatOpsDescription { - pub tensors: Vec, - pub dim: usize, - pub out: TensorDescription, + pub tensors: Vec, + pub dim: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ReduceDimWithIndicesDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub out: TensorDescription, - pub out_indices: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub out: TensorDescription, + pub out_indices: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct EmbeddingDescription { - pub weights: TensorDescription, - pub indices: TensorDescription, - pub out: TensorDescription, + pub weights: TensorDescription, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct EmbeddingBackwardDescription { - pub weights: TensorDescription, - pub out_grad: TensorDescription, - pub indices: TensorDescription, - pub out: TensorDescription, + pub weights: TensorDescription, + pub out_grad: TensorDescription, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct Conv1dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvOptions<1>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvOptions<1>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct Conv2dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvOptions<2>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvOptions<2>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ConvTranspose1dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvTransposeOptions<1>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvTransposeOptions<1>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ConvTranspose2dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvTransposeOptions<2>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvTransposeOptions<2>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool1dDescription { - pub x: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool2dDescription { - pub x: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool1dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool2dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dDescription { - pub x: TensorDescription, - pub output_size: usize, - pub out: TensorDescription, + pub x: TensorDescription, + pub output_size: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dDescription { - pub x: TensorDescription, - pub output_size: [usize; 2], - pub out: TensorDescription, + pub x: TensorDescription, + pub output_size: [usize; 2], + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dDescription { - pub x: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub dilation: usize, - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub dilation: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesDescription { - pub x: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub dilation: usize, - pub out: TensorDescription, - pub out_indices: TensorDescription, + pub x: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub dilation: usize, + pub out: TensorDescription, + pub out_indices: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub indices: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub dilation: usize, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub indices: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub dilation: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dDescription { - pub x: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub dilation: [usize; 2], - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dWithIndicesDescription { - pub x: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub dilation: [usize; 2], - pub out: TensorDescription, - pub out_indices: TensorDescription, + pub x: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub out: TensorDescription, + pub out_indices: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dWithIndicesBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub indices: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub dilation: [usize; 2], - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub indices: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub out: TensorDescription, } impl TensorOpsDescription { - /// Cleanup the remaining tensor handles that have not been used. - pub(crate) fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - TensorOpsDescription::BaseOpsFloat(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::BaseOpsInt(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::BaseOpsBool(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::NumericOpsFloat(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::NumericOpsInt(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::BoolOps(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::IntOps(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::FloatOps(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::ModuleOps(ops) => ops.cleanup_tensor(handles), - } - - // Cleanup tensor handles that were outputted, but ignored. - handles.cleanup_orphans(); + /// Cleanup the remaining tensor handles that have not been used. + pub(crate) fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + TensorOpsDescription::BaseOpsFloat(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::BaseOpsInt(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::BaseOpsBool(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::NumericOpsFloat(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::NumericOpsInt(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::BoolOps(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::IntOps(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::FloatOps(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::ModuleOps(ops) => ops.cleanup_tensor(handles), } - /// Execute the operation. - pub(crate) fn execute(&self, handles: &mut HandleContainer) { - match self { - TensorOpsDescription::BaseOpsFloat(ops) => ops.execute(handles), - TensorOpsDescription::BaseOpsInt(ops) => ops.execute(handles), - TensorOpsDescription::BaseOpsBool(ops) => ops.execute(handles), - TensorOpsDescription::NumericOpsFloat(ops) => ops.execute(handles), - TensorOpsDescription::NumericOpsInt(ops) => ops.execute(handles), - TensorOpsDescription::BoolOps(ops) => ops.execute(handles), - TensorOpsDescription::IntOps(ops) => ops.execute(handles), - TensorOpsDescription::FloatOps(ops) => ops.execute(handles), - TensorOpsDescription::ModuleOps(ops) => ops.execute(handles), - } + + // Cleanup tensor handles that were outputted, but ignored. + handles.cleanup_orphans(); + } + /// Execute the operation. + pub(crate) fn execute(&self, handles: &mut HandleContainer) { + match self { + TensorOpsDescription::BaseOpsFloat(ops) => ops.execute(handles), + TensorOpsDescription::BaseOpsInt(ops) => ops.execute(handles), + TensorOpsDescription::BaseOpsBool(ops) => ops.execute(handles), + TensorOpsDescription::NumericOpsFloat(ops) => ops.execute(handles), + TensorOpsDescription::NumericOpsInt(ops) => ops.execute(handles), + TensorOpsDescription::BoolOps(ops) => ops.execute(handles), + TensorOpsDescription::IntOps(ops) => ops.execute(handles), + TensorOpsDescription::FloatOps(ops) => ops.execute(handles), + TensorOpsDescription::ModuleOps(ops) => ops.execute(handles), } + } } impl BaseOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - BaseOpsDescription::ToDevice(_, _) => (), - BaseOpsDescription::Reshape(desc, _) => { - handles.cleanup(&desc.input); - } - BaseOpsDescription::SwapDims(desc, _) => { - handles.cleanup(&desc.input); - } - BaseOpsDescription::Slice(desc, _) => { - handles.cleanup(&desc.tensor); - } - BaseOpsDescription::SliceAssign(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.value); - } - BaseOpsDescription::Equal(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - BaseOpsDescription::Repeat(desc, _) => { - handles.cleanup(&desc.tensor); - } - BaseOpsDescription::Cat(desc, _) => { - for t in desc.tensors.iter() { - handles.cleanup(t); - } - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + BaseOpsDescription::ToDevice(_, _) => (), + BaseOpsDescription::Reshape(desc, _) => { + handles.cleanup(&desc.input); + } + BaseOpsDescription::SwapDims(desc, _) => { + handles.cleanup(&desc.input); + } + BaseOpsDescription::Slice(desc, _) => { + handles.cleanup(&desc.tensor); + } + BaseOpsDescription::SliceAssign(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.value); + } + BaseOpsDescription::Equal(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + BaseOpsDescription::Repeat(desc, _) => { + handles.cleanup(&desc.tensor); + } + BaseOpsDescription::Cat(desc, _) => { + for t in desc.tensors.iter() { + handles.cleanup(t); } + } } - fn execute(&self, handles: &mut HandleContainer) { - match self { - BaseOpsDescription::ToDevice(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Reshape(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::SwapDims(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Slice(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::SliceAssign(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Equal(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Repeat(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Cat(desc, ops) => ops.execute(desc, handles), - } + } + fn execute(&self, handles: &mut HandleContainer) { + match self { + BaseOpsDescription::ToDevice(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Reshape(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::SwapDims(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Slice(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::SliceAssign(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Equal(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Repeat(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Cat(desc, ops) => ops.execute(desc, handles), } + } } impl NumericOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - NumericOpsDescription::Add(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::AddScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Sub(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::SubScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Mul(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::MulScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Div(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::DivScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Ones(_, _) => {} - NumericOpsDescription::Gather(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - } - NumericOpsDescription::Scatter(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - handles.cleanup(&desc.value); - } - NumericOpsDescription::Select(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - } - NumericOpsDescription::SelectAssign(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - handles.cleanup(&desc.value); - } - NumericOpsDescription::MaskWhere(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.value); - handles.cleanup(&desc.mask); - } - NumericOpsDescription::MaskFill(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.mask); - } - NumericOpsDescription::EqualElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::GreaterElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::GreaterEqualElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::LowerElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::LowerEqualElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Greater(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::GreaterEqual(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::Lower(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::LowerEqual(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::ArgMax(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::ArgMin(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Clamp(desc, _) => { - handles.cleanup(&desc.tensor); - } - NumericOpsDescription::ClampMin(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::ClampMax(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Abs(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::Zeros(_, _) => {} - NumericOpsDescription::Full(_, _) => {} - NumericOpsDescription::MeanDim(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Mean(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::Sum(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::SumDim(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Max(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::MaxDimWithIndices(desc, _) => { - handles.cleanup(&desc.tensor); - } - NumericOpsDescription::MinDimWithIndices(desc, _) => { - handles.cleanup(&desc.tensor); - } - NumericOpsDescription::Min(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::MaxDim(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::MinDim(desc, _) => { - handles.cleanup(&desc.lhs); - } - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + NumericOpsDescription::Add(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::AddScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Sub(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::SubScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Mul(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::MulScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Div(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::DivScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Ones(_, _) => {} + NumericOpsDescription::Gather(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + } + NumericOpsDescription::Scatter(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + handles.cleanup(&desc.value); + } + NumericOpsDescription::Select(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + } + NumericOpsDescription::SelectAssign(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + handles.cleanup(&desc.value); + } + NumericOpsDescription::MaskWhere(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.value); + handles.cleanup(&desc.mask); + } + NumericOpsDescription::MaskFill(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.mask); + } + NumericOpsDescription::EqualElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::GreaterElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::GreaterEqualElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::LowerElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::LowerEqualElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Greater(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::GreaterEqual(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::Lower(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::LowerEqual(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::ArgMax(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::ArgMin(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Clamp(desc, _) => { + handles.cleanup(&desc.tensor); + } + NumericOpsDescription::ClampMin(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::ClampMax(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Abs(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::Zeros(_, _) => {} + NumericOpsDescription::Full(_, _) => {} + NumericOpsDescription::MeanDim(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Mean(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::Sum(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::SumDim(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Max(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::MaxDimWithIndices(desc, _) => { + handles.cleanup(&desc.tensor); + } + NumericOpsDescription::MinDimWithIndices(desc, _) => { + handles.cleanup(&desc.tensor); + } + NumericOpsDescription::Min(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::MaxDim(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::MinDim(desc, _) => { + handles.cleanup(&desc.lhs); + } } + } - fn execute(&self, handles: &mut HandleContainer) { - match self { - NumericOpsDescription::Add(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::AddScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Sub(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::SubScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Div(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::DivScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Mul(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MulScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Ones(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Gather(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Scatter(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Select(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::SelectAssign(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaskWhere(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaskFill(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::EqualElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Greater(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::GreaterElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::GreaterEqual(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::GreaterEqualElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Lower(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::LowerElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::LowerEqual(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::LowerEqualElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ArgMax(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ArgMin(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Clamp(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ClampMin(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ClampMax(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Abs(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Zeros(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Full(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MeanDim(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Mean(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Sum(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::SumDim(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Max(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaxDimWithIndices(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MinDimWithIndices(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Min(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaxDim(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MinDim(desc, ops) => ops.execute(desc, handles), - } + fn execute(&self, handles: &mut HandleContainer) { + match self { + NumericOpsDescription::Add(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::AddScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Sub(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::SubScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Div(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::DivScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Mul(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MulScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Ones(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Gather(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Scatter(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Select(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::SelectAssign(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaskWhere(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaskFill(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::EqualElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Greater(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::GreaterElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::GreaterEqual(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::GreaterEqualElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Lower(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::LowerElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::LowerEqual(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::LowerEqualElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ArgMax(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ArgMin(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Clamp(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ClampMin(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ClampMax(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Abs(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Zeros(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Full(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MeanDim(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Mean(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Sum(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::SumDim(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Max(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaxDimWithIndices(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MinDimWithIndices(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Min(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaxDim(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MinDim(desc, ops) => ops.execute(desc, handles), } + } } impl FloatOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - FloatOpsDescription::Matmul(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - FloatOpsDescription::Random(_, _) => {} - FloatOpsDescription::Exp(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs), - FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Sin(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Tanh(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::IntoInt(desc, _) => handles.cleanup(&desc.input), - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + FloatOpsDescription::Matmul(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + FloatOpsDescription::Random(_, _) => {} + FloatOpsDescription::Exp(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs), + FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Sin(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Tanh(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::IntoInt(desc, _) => handles.cleanup(&desc.input), } - fn execute(&self, handles: &mut HandleContainer) { - match self { - FloatOpsDescription::Matmul(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Random(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Exp(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Sin(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Tanh(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), - } + } + fn execute(&self, handles: &mut HandleContainer) { + match self { + FloatOpsDescription::Matmul(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Random(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Exp(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Sin(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Tanh(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), } + } } impl IntOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - IntOpsDescription::IntoFloat(desc, _) => { - handles.cleanup(&desc.input); - } - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + IntOpsDescription::IntoFloat(desc, _) => { + handles.cleanup(&desc.input); + } } - fn execute(&self, handles: &mut HandleContainer) { - match self { - IntOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), - } + } + fn execute(&self, handles: &mut HandleContainer) { + match self { + IntOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), } + } } impl BoolOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - BoolOpsDescription::IntoFloat(desc, _) => { - handles.cleanup(&desc.input); - } - BoolOpsDescription::IntoInt(desc, _) => { - handles.cleanup(&desc.input); - } - BoolOpsDescription::Not(desc, _) => { - handles.cleanup(&desc.input); - } - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + BoolOpsDescription::IntoFloat(desc, _) => { + handles.cleanup(&desc.input); + } + BoolOpsDescription::IntoInt(desc, _) => { + handles.cleanup(&desc.input); + } + BoolOpsDescription::Not(desc, _) => { + handles.cleanup(&desc.input); + } } - fn execute(&self, handles: &mut HandleContainer) { - match self { - BoolOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), - BoolOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), - BoolOpsDescription::Not(desc, ops) => ops.execute(desc, handles), - } + } + fn execute(&self, handles: &mut HandleContainer) { + match self { + BoolOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), + BoolOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), + BoolOpsDescription::Not(desc, ops) => ops.execute(desc, handles), } + } } impl ModuleOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - ModuleOpsDescription::Embedding(desc, _) => { - handles.cleanup(&desc.weights); - handles.cleanup(&desc.indices); - } - ModuleOpsDescription::EmbeddingBackward(desc, _) => { - handles.cleanup(&desc.weights); - handles.cleanup(&desc.out_grad); - handles.cleanup(&desc.indices); - } - ModuleOpsDescription::Conv1d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + ModuleOpsDescription::Embedding(desc, _) => { + handles.cleanup(&desc.weights); + handles.cleanup(&desc.indices); + } + ModuleOpsDescription::EmbeddingBackward(desc, _) => { + handles.cleanup(&desc.weights); + handles.cleanup(&desc.out_grad); + handles.cleanup(&desc.indices); + } + ModuleOpsDescription::Conv1d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); - } - } - ModuleOpsDescription::Conv2d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + if let Some(bias) = &desc.bias { + handles.cleanup(bias); + } + } + ModuleOpsDescription::Conv2d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); - } - } - ModuleOpsDescription::ConvTranspose1d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + if let Some(bias) = &desc.bias { + handles.cleanup(bias); + } + } + ModuleOpsDescription::ConvTranspose1d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); - } - } - ModuleOpsDescription::ConvTranspose2d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + if let Some(bias) = &desc.bias { + handles.cleanup(bias); + } + } + ModuleOpsDescription::ConvTranspose2d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); - } - } - ModuleOpsDescription::AvgPool1d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AvgPool2d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AvgPool1dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::AvgPool2dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::AdaptiveAvgPool1d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AdaptiveAvgPool2d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::MaxPool1d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool1dWithIndices(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - handles.cleanup(&desc.indices); - } - ModuleOpsDescription::MaxPool2d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool2dWithIndices(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - handles.cleanup(&desc.indices); - } + if let Some(bias) = &desc.bias { + handles.cleanup(bias); } + } + ModuleOpsDescription::AvgPool1d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AvgPool2d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AvgPool1dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::AvgPool2dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::AdaptiveAvgPool1d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AdaptiveAvgPool2d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::MaxPool1d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool1dWithIndices(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + handles.cleanup(&desc.indices); + } + ModuleOpsDescription::MaxPool2d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool2dWithIndices(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + handles.cleanup(&desc.indices); + } } - fn execute(&self, handles: &mut HandleContainer) { - match self { - ModuleOpsDescription::Embedding(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::EmbeddingBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::Conv1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::Conv2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::ConvTranspose1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::ConvTranspose2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool1dBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool2dBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AdaptiveAvgPool1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AdaptiveAvgPool2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, ops) => { - ops.execute(desc, handles) - } - ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, ops) => { - ops.execute(desc, handles) - } - ModuleOpsDescription::MaxPool1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool1dWithIndices(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, ops) => { - ops.execute(desc, handles) - } - ModuleOpsDescription::MaxPool2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool2dWithIndices(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, ops) => { - ops.execute(desc, handles) - } - } + } + fn execute(&self, handles: &mut HandleContainer) { + match self { + ModuleOpsDescription::Embedding(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::EmbeddingBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::Conv1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::Conv2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::ConvTranspose1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::ConvTranspose2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool1dBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool2dBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AdaptiveAvgPool1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AdaptiveAvgPool2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool1dWithIndices(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool2dWithIndices(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, ops) => ops.execute(desc, handles), } + } } diff --git a/burn-fusion/src/handle.rs b/burn-fusion/src/handle.rs index 10a6dbbef3..8b2d36eb85 100644 --- a/burn-fusion/src/handle.rs +++ b/burn-fusion/src/handle.rs @@ -6,132 +6,132 @@ use std::{collections::HashMap, sync::Arc}; /// are used optimally. #[derive(Default)] pub struct HandleContainer { - handles: HashMap>, - counter: u64, - pub(crate) handles_orphan: Vec, - /// The device on which all tensors are held. - pub device: B::Device, + handles: HashMap>, + counter: u64, + pub(crate) handles_orphan: Vec, + /// The device on which all tensors are held. + pub device: B::Device, } enum Handle { - NotInit, - Existing(B::Handle), + NotInit, + Existing(B::Handle), } impl HandleContainer { - pub(crate) fn new(device_handle: B::FusionDevice) -> Self { - Self { - handles: HashMap::new(), - handles_orphan: Vec::new(), - counter: 0, - device: device_handle.clone().into(), - } - } - - /// Register a handle for the given [tensor id](TensorId). - pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) { - self.handles.insert(id, Handle::Existing(handle)); + pub(crate) fn new(device_handle: B::FusionDevice) -> Self { + Self { + handles: HashMap::new(), + handles_orphan: Vec::new(), + counter: 0, + device: device_handle.clone().into(), } - - /// Get the handle for the given [tensor id](TensorId). - pub fn get_handle(&mut self, tensor: &TensorDescription) -> B::Handle { - let (id, handle) = self - .handles - .remove_entry(&tensor.id) - .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id)); - - match handle { - Handle::Existing(handle) => match tensor.status { - TensorStatus::ReadOnly => { - self.handles.insert(id, Handle::Existing(handle.clone())); - handle - } - TensorStatus::ReadWrite => handle, - TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."), - }, - Handle::NotInit => panic!("Cannot get uninitialized handle."), + } + + /// Register a handle for the given [tensor id](TensorId). + pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) { + self.handles.insert(id, Handle::Existing(handle)); + } + + /// Get the handle for the given [tensor id](TensorId). + pub fn get_handle(&mut self, tensor: &TensorDescription) -> B::Handle { + let (id, handle) = self + .handles + .remove_entry(&tensor.id) + .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id)); + + match handle { + Handle::Existing(handle) => match tensor.status { + TensorStatus::ReadOnly => { + self.handles.insert(id, Handle::Existing(handle.clone())); + handle } + TensorStatus::ReadWrite => handle, + TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."), + }, + Handle::NotInit => panic!("Cannot get uninitialized handle."), } - - /// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the - /// given [tensor description](TensorDescription). - pub fn get_float_tensor( - &mut self, - tensor: &TensorDescription, - ) -> B::TensorPrimitive { - B::float_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) - } - - /// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the - /// given [tensor description](TensorDescription). - pub fn get_int_tensor( - &mut self, - tensor: &TensorDescription, - ) -> B::IntTensorPrimitive { - B::int_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) - } - - /// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the - /// given [tensor description](TensorDescription). - pub fn get_bool_tensor( - &mut self, - tensor: &TensorDescription, - ) -> B::BoolTensorPrimitive { - B::bool_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) - } - - /// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_float_tensor( - &mut self, - id: &TensorId, - tensor: B::TensorPrimitive, - ) { - let handle = B::float_tensor_handle(tensor); - self.handles.insert(id.clone(), Handle::Existing(handle)); - } - - /// Register a new [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_int_tensor( - &mut self, - id: &TensorId, - tensor: B::IntTensorPrimitive, - ) { - let handle = B::int_tensor_handle(tensor); - self.handles.insert(id.clone(), Handle::Existing(handle)); - } - - /// Register a new [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_bool_tensor( - &mut self, - id: &TensorId, - tensor: B::BoolTensorPrimitive, - ) { - let handle = B::bool_tensor_handle(tensor); - self.handles.insert(id.clone(), Handle::Existing(handle)); + } + + /// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the + /// given [tensor description](TensorDescription). + pub fn get_float_tensor( + &mut self, + tensor: &TensorDescription, + ) -> B::TensorPrimitive { + B::float_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) + } + + /// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the + /// given [tensor description](TensorDescription). + pub fn get_int_tensor( + &mut self, + tensor: &TensorDescription, + ) -> B::IntTensorPrimitive { + B::int_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) + } + + /// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the + /// given [tensor description](TensorDescription). + pub fn get_bool_tensor( + &mut self, + tensor: &TensorDescription, + ) -> B::BoolTensorPrimitive { + B::bool_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) + } + + /// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_float_tensor( + &mut self, + id: &TensorId, + tensor: B::TensorPrimitive, + ) { + let handle = B::float_tensor_handle(tensor); + self.handles.insert(id.clone(), Handle::Existing(handle)); + } + + /// Register a new [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_int_tensor( + &mut self, + id: &TensorId, + tensor: B::IntTensorPrimitive, + ) { + let handle = B::int_tensor_handle(tensor); + self.handles.insert(id.clone(), Handle::Existing(handle)); + } + + /// Register a new [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_bool_tensor( + &mut self, + id: &TensorId, + tensor: B::BoolTensorPrimitive, + ) { + let handle = B::bool_tensor_handle(tensor); + self.handles.insert(id.clone(), Handle::Existing(handle)); + } + + /// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId). + pub fn create_tensor_uninit(&mut self) -> Arc { + let id = TensorId::new(self.counter); + self.counter += 1; + self.handles.insert(id.clone(), Handle::NotInit); + + Arc::new(id) + } + + pub(crate) fn cleanup(&mut self, tensor: &TensorDescription) { + match tensor.status { + TensorStatus::ReadOnly => (), + TensorStatus::NotInit => (), + TensorStatus::ReadWrite => { + self.handles.remove(&tensor.id); + } } + } - /// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId). - pub fn create_tensor_uninit(&mut self) -> Arc { - let id = TensorId::new(self.counter); - self.counter += 1; - self.handles.insert(id.clone(), Handle::NotInit); - - Arc::new(id) - } - - pub(crate) fn cleanup(&mut self, tensor: &TensorDescription) { - match tensor.status { - TensorStatus::ReadOnly => (), - TensorStatus::NotInit => (), - TensorStatus::ReadWrite => { - self.handles.remove(&tensor.id); - } - } - } - - pub(crate) fn cleanup_orphans(&mut self) { - for id in self.handles_orphan.drain(..) { - self.handles.remove(&id); - } + pub(crate) fn cleanup_orphans(&mut self) { + for id in self.handles_orphan.drain(..) { + self.handles.remove(&id); } + } } diff --git a/burn-fusion/src/ops/binary.rs b/burn-fusion/src/ops/binary.rs index 05d859252a..c3148eb3c4 100644 --- a/burn-fusion/src/ops/binary.rs +++ b/burn-fusion/src/ops/binary.rs @@ -1,101 +1,101 @@ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let rhs = handles.get_float_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let rhs = handles.get_float_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_float_tensor(&args.out.id, output); - } - } - }; + handles.register_float_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let rhs = handles.get_float_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let rhs = handles.get_float_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let rhs = handles.get_int_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let rhs = handles.get_int_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } pub(crate) fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec { - let mut shape_out = Vec::with_capacity(lhs.len()); + let mut shape_out = Vec::with_capacity(lhs.len()); - for (l, r) in lhs.iter().zip(rhs.iter()) { - shape_out.push(usize::max(*l, *r)); - } + for (l, r) in lhs.iter().zip(rhs.iter()) { + shape_out.push(usize::max(*l, *r)); + } - shape_out + shape_out } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let rhs = handles.get_int_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let rhs = handles.get_int_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } diff --git a/burn-fusion/src/ops/boolean.rs b/burn-fusion/src/ops/boolean.rs index 179db25d4d..43d1f63e72 100644 --- a/burn-fusion/src/ops/boolean.rs +++ b/burn-fusion/src/ops/boolean.rs @@ -1,402 +1,399 @@ use crate::{ - client::FusionClient, - get_client, - graph::{ - BaseOpsDescription, BinaryOpsDescription, BoolOpsDescription, CatOpsDescription, Ops, - ReshapeDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, - TensorOpsDescription, UnaryOpsDescription, - }, - ops::binary::binary_ops_shape, - Fusion, FusionBackend, + client::FusionClient, + get_client, + graph::{ + BaseOpsDescription, BinaryOpsDescription, BoolOpsDescription, CatOpsDescription, Ops, + ReshapeDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, + TensorOpsDescription, UnaryOpsDescription, + }, + ops::binary::binary_ops_shape, + Fusion, FusionBackend, }; use burn_tensor::{ - ops::{BoolTensor, BoolTensorOps}, - Device, Shape, + ops::{BoolTensor, BoolTensorOps}, + Device, Shape, }; impl BoolTensorOps for Fusion { - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::bool_empty(shape.clone(), device); - - client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::bool_empty(shape.clone(), device); + + client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + tensor.shape() + } + + fn bool_into_data( + tensor: BoolTensor, + ) -> burn_tensor::Reader> { + tensor.bool_into_data() + } + + fn bool_from_data( + data: burn_tensor::Data, + device: &Device, + ) -> BoolTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::bool_from_data(data, device); + let shape = B::bool_shape(&tensor); + + client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) + } + + fn bool_into_int( + tensor: BoolTensor, + ) -> burn_tensor::ops::IntTensor { + struct IntoIntOps; + + impl Ops for IntoIntOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_into_int(input); + handles.register_int_tensor(&args.out.id, output); + } } - fn bool_shape(tensor: &BoolTensor) -> Shape { - tensor.shape() + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoIntOps::), + ))); + + out + } + + fn bool_into_float( + tensor: BoolTensor, + ) -> burn_tensor::ops::FloatTensor { + struct IntoFloatOps; + + impl Ops for IntoFloatOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_into_float(input); + handles.register_float_tensor(&args.out.id, output); + } } - fn bool_into_data( - tensor: BoolTensor, - ) -> burn_tensor::Reader> { - tensor.bool_into_data() + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::BoolOps( + BoolOpsDescription::IntoFloat( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoFloatOps::), + ), + )); + + out + } + + fn bool_device(tensor: &BoolTensor) -> Device { + tensor.client.device().clone().into() + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + let device_original: &B::FusionDevice = tensor.client.device(); + let device_target: B::FusionDevice = device.clone().into(); + + if device_original == &device_target { + return tensor; } - fn bool_from_data( - data: burn_tensor::Data, - device: &Device, - ) -> BoolTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::bool_from_data(data, device); - let shape = B::bool_shape(&tensor); + let client_target = get_client::(&device_target); + let client_original = tensor.client.clone(); - client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) - } + client_original + .clone() + .change_client_bool::(tensor.into_description(), client_target) + } - fn bool_into_int( - tensor: BoolTensor, - ) -> burn_tensor::ops::IntTensor { - struct IntoIntOps; - - impl Ops for IntoIntOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_into_int(input); - handles.register_int_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client - .register(TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoIntOps::), - ))); - - out - } + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + struct ReshapeDimsOps; - fn bool_into_float( - tensor: BoolTensor, - ) -> burn_tensor::ops::FloatTensor { - struct IntoFloatOps; - - impl Ops for IntoFloatOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_into_float(input); - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::BoolOps( - BoolOpsDescription::IntoFloat( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoFloatOps::), - ), - )); - - out - } + impl Ops for ReshapeDimsOps { + type Args = ReshapeDescription; - fn bool_device(tensor: &BoolTensor) -> Device { - tensor.client.device().clone().into() + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_reshape::(input, Shape::from(&args.shape)); + handles.register_bool_tensor(&args.out.id, output); + } } - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); - - if device_original == &device_target { - return tensor; - } + let shape: Vec = shape.dims.into(); + let out = tensor.client.tensor_uninitialized(shape.clone()); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::Reshape( + ReshapeDescription { + input: tensor.into_description(), + shape, + out: out.to_description_out(), + }, + Box::new(ReshapeDimsOps::), + ), + )); + + out + } + + fn bool_slice( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + ) -> BoolTensor { + struct SliceOps; + + impl Ops for SliceOps { + type Args = SliceOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_bool_tensor::(&args.tensor); + + let output = B::bool_slice::(tensor, args.ranges.clone().try_into().unwrap()); + + handles.register_bool_tensor(&args.out.id, output); + } + } - let client_target = get_client::(&device_target); - let client_original = tensor.client.clone(); + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); - client_original - .clone() - .change_client_bool::(tensor.into_description(), client_target) + for i in shape.len()..D1 { + shape.push(tensor.shape[i]); } - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - struct ReshapeDimsOps; - - impl Ops for ReshapeDimsOps { - type Args = ReshapeDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_reshape::(input, Shape::from(&args.shape)); - handles.register_bool_tensor(&args.out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape.clone()); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::Reshape( - ReshapeDescription { - input: tensor.into_description(), - shape, - out: out.to_description_out(), - }, - Box::new(ReshapeDimsOps::), - ), - )); - - out + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::Slice( + SliceOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + out: out.to_description_out(), + }, + Box::new(SliceOps::), + ), + )); + + out + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + value: BoolTensor, + ) -> BoolTensor { + struct SliceAssignOps; + + impl Ops for SliceAssignOps { + type Args = SliceAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_bool_tensor::(&args.tensor); + let value = handles.get_bool_tensor::(&args.value); + + let output = + B::bool_slice_assign::(tensor, args.ranges.clone().try_into().unwrap(), value); + + handles.register_bool_tensor(&args.out.id, output); + } } - fn bool_slice( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - ) -> BoolTensor { - struct SliceOps; - - impl Ops for SliceOps { - type Args = SliceOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_bool_tensor::(&args.tensor); - - let output = - B::bool_slice::(tensor, args.ranges.clone().try_into().unwrap()); - - handles.register_bool_tensor(&args.out.id, output); - } - } - - let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); - - for i in shape.len()..D1 { - shape.push(tensor.shape[i]); - } - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::Slice( - SliceOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - out: out.to_description_out(), - }, - Box::new(SliceOps::), - ), - )); - - out + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::SliceAssign( + SliceAssignOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SliceAssignOps::), + ), + )); + + out + } + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> BoolTensor { + struct CatOps; + + impl Ops for CatOps { + type Args = CatOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensors = args + .tensors + .iter() + .map(|tensor| handles.get_bool_tensor(tensor)) + .collect(); + + let output = B::bool_cat::(tensors, args.dim); + + handles.register_bool_tensor(&args.out.id, output); + } } - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - value: BoolTensor, - ) -> BoolTensor { - struct SliceAssignOps; - - impl Ops for SliceAssignOps { - type Args = SliceAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_bool_tensor::(&args.tensor); - let value = handles.get_bool_tensor::(&args.value); - - let output = B::bool_slice_assign::( - tensor, - args.ranges.clone().try_into().unwrap(), - value, - ); - - handles.register_bool_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::SliceAssign( - SliceAssignOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SliceAssignOps::), - ), - )); - - out - } + let tensor_first = tensors.get(0).unwrap(); + let client = tensor_first.client.clone(); - fn bool_cat( - tensors: Vec>, - dim: usize, - ) -> BoolTensor { - struct CatOps; - - impl Ops for CatOps { - type Args = CatOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensors = args - .tensors - .iter() - .map(|tensor| handles.get_bool_tensor(tensor)) - .collect(); - - let output = B::bool_cat::(tensors, args.dim); - - handles.register_bool_tensor(&args.out.id, output); - } - } - - let tensor_first = tensors.get(0).unwrap(); - let client = tensor_first.client.clone(); - - // Calculate the output shape - let mut shape: Vec = tensor_first.shape.clone(); - shape[dim] = 0; - for tensor in tensors.iter() { - shape[dim] += tensor.shape[dim]; - } - - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat( - CatOpsDescription { - tensors: tensors.into_iter().map(|t| t.into_description()).collect(), - dim, - out: out.to_description_out(), - }, - Box::new(CatOps::), - ))); - - out + // Calculate the output shape + let mut shape: Vec = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; } - fn bool_equal( - lhs: BoolTensor, - rhs: BoolTensor, - ) -> BoolTensor { - struct EqualOps; - - impl Ops for EqualOps { - type Args = BinaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let lhs = handles.get_bool_tensor::(&args.lhs); - let rhs = handles.get_bool_tensor(&args.rhs); - let output = B::bool_equal(lhs, rhs); - handles.register_bool_tensor(&args.out.id, output); - } - } - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::Equal( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(EqualOps::), - ), - )); - - out + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat( + CatOpsDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }, + Box::new(CatOps::), + ))); + + out + } + + fn bool_equal( + lhs: BoolTensor, + rhs: BoolTensor, + ) -> BoolTensor { + struct EqualOps; + + impl Ops for EqualOps { + type Args = BinaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let lhs = handles.get_bool_tensor::(&args.lhs); + let rhs = handles.get_bool_tensor(&args.rhs); + let output = B::bool_equal(lhs, rhs); + handles.register_bool_tensor(&args.out.id, output); + } } - fn bool_not(tensor: BoolTensor) -> BoolTensor { - struct NotOps; - - impl Ops for NotOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_not(input); - handles.register_bool_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::BoolOps( - crate::graph::BoolOpsDescription::Not( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(NotOps::), - ), - )); - - out + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::Equal( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(EqualOps::), + ), + )); + + out + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + struct NotOps; + + impl Ops for NotOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_not(input); + handles.register_bool_tensor(&args.out.id, output); + } } - fn bool_swap_dims( - tensor: BoolTensor, - dim1: usize, - dim2: usize, - ) -> BoolTensor { - struct SwapDimsOps; - - impl Ops for SwapDimsOps { - type Args = SwapDimsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_swap_dims(input, args.dim1, args.dim2); - handles.register_bool_tensor(&args.out.id, output); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim1] = tensor.shape[dim2]; - shape[dim2] = tensor.shape[dim1]; - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::SwapDims( - SwapDimsDescription { - input: tensor.into_description(), - dim1, - dim2, - out: out.to_description_out(), - }, - Box::new(SwapDimsOps::), - ), - )); - - out + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::BoolOps( + crate::graph::BoolOpsDescription::Not( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(NotOps::), + ), + )); + + out + } + + fn bool_swap_dims( + tensor: BoolTensor, + dim1: usize, + dim2: usize, + ) -> BoolTensor { + struct SwapDimsOps; + + impl Ops for SwapDimsOps { + type Args = SwapDimsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_swap_dims(input, args.dim1, args.dim2); + handles.register_bool_tensor(&args.out.id, output); + } } + + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::SwapDims( + SwapDimsDescription { + input: tensor.into_description(), + dim1, + dim2, + out: out.to_description_out(), + }, + Box::new(SwapDimsOps::), + ), + )); + + out + } } diff --git a/burn-fusion/src/ops/float.rs b/burn-fusion/src/ops/float.rs index fc60d3ef4b..156fcb259d 100644 --- a/burn-fusion/src/ops/float.rs +++ b/burn-fusion/src/ops/float.rs @@ -1,1669 +1,1672 @@ use crate::{ - binary_float_cmp_ops, binary_float_ops, - client::FusionClient, - get_client, - graph::{ - BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, - FloatOpsDescription, GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, - NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription, - ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription, - SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, - TensorOpsDescription, UnaryOpsDescription, - }, - ops::binary::binary_ops_shape, - scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, unary_float_ops, Fusion, - FusionBackend, TensorDescription, + binary_float_cmp_ops, binary_float_ops, + client::FusionClient, + get_client, + graph::{ + BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, + FloatOpsDescription, GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, + NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription, + ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription, SelectOpsDescription, + SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, TensorOpsDescription, + UnaryOpsDescription, + }, + ops::binary::binary_ops_shape, + scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, unary_float_ops, Fusion, + FusionBackend, TensorDescription, }; use burn_tensor::{ - ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, - Data, Device, Distribution, Reader, Shape, + ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, + Data, Device, Distribution, Reader, Shape, }; use std::ops::Range; impl TensorOps for Fusion { - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::from_data(data, device); - let shape = B::shape(&tensor); - - client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) - } - - fn random( - shape: Shape, - distribution: Distribution>, - device: &Device, - ) -> FloatTensor { - struct RandomOps; - - impl Ops for RandomOps { - type Args = (TensorDescription, Distribution>); - - fn execute( - &self, - (out, distribution): &Self::Args, - handles: &mut crate::HandleContainer, - ) { - let shape = Shape::from(out.shape.clone()); - let output: B::TensorPrimitive = - B::random(shape, *distribution, &handles.device); - handles.register_float_tensor(&out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Random( - (out.to_description_out(), distribution), - Box::new(RandomOps::), - ))); - - out - } - - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - struct ZerosOps; - - impl Ops for ZerosOps { - type Args = TensorDescription; - - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::zeros::(shape, &handles.device); - handles.register_float_tensor(&out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), - )); - - out - } - - fn ones(shape: Shape, device: &Device) -> FloatTensor { - struct OnesOps; - - impl Ops for OnesOps { - type Args = TensorDescription; - - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::ones::(shape, &handles.device); - handles.register_float_tensor(&out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), - )); - - out - } - - fn full( - shape: Shape, - fill_value: FloatElem, - device: &Device, - ) -> FloatTensor { - struct FullOps; - - impl Ops for FullOps { - type Args = (TensorDescription, FloatElem); - - fn execute(&self, (out, value): &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output: B::TensorPrimitive = B::full(shape, *value, &handles.device); - handles.register_float_tensor(&out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Full( - (out.to_description_out(), fill_value), - Box::new(FullOps::), - ), - )); - - out - } - - fn shape(tensor: &FloatTensor) -> Shape { - tensor.shape() - } - - fn into_data(tensor: FloatTensor) -> Reader, D>> { - tensor.into_data() - } - - fn device(tensor: &FloatTensor) -> Device { - tensor.client.device().clone().into() - } - - fn to_device( - tensor: FloatTensor, - device: &Device, - ) -> FloatTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); - - if device_original == &device_target { - return tensor; - } - - let client_target = get_client::(&device_target); - let client_original = tensor.client.clone(); - - client_original - .clone() - .change_client_float::(tensor.into_description(), client_target) - } - - fn into_int(tensor: FloatTensor) -> IntTensor { - struct IntoIntOps; - - impl Ops for IntoIntOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = B::into_int(input); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::FloatOps( - FloatOpsDescription::IntoInt( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoIntOps::), - ), - )); - - out - } - - fn empty(shape: Shape, device: &Device) -> FloatTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::empty(shape.clone(), device); - - client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) - } - - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(AddOps, B::add); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Add( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(AddOps, B::add_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::AddScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out - } - - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(ClampMinOps, B::clamp_min); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ClampMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: min, - out: out.to_description_out(), - }, - Box::new(ClampMinOps::), - ), - )); - - out - } - - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(ClampMaxOps, B::clamp_max); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ClampMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: max, - out: out.to_description_out(), - }, - Box::new(ClampMaxOps::), - ), - )); - - out - } - - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - struct ClampOps; - - impl Ops for ClampOps { - type Args = ClampOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.tensor); - let output = B::clamp(input, args.min, args.max); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Clamp( - ClampOpsDescription { - tensor: tensor.into_description(), - min, - max, - out: out.to_description_out(), - }, - Box::new(ClampOps::), - ), - )); - - out - } - - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(SubOps, B::sub); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Sub( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(SubOps, B::sub_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::SubScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(MulOps, B::mul); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Mul( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(MulOps, B::mul_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MulScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(DivOps, B::div); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Div( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(DivOps, B::div_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::DivScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(MatmulOps, B::matmul); - - let mut shape = binary_ops_shape(&lhs.shape, &rhs.shape); - - shape[D - 2] = lhs.shape[D - 2]; - shape[D - 1] = rhs.shape[D - 1]; - - let out = lhs.client.tensor_uninitialized(shape); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Matmul( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(MatmulOps::), - ))); - - out - } - - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - struct SwapDimsOps; - - impl Ops for SwapDimsOps { - type Args = SwapDimsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = B::swap_dims(input, args.dim1, args.dim2); - handles.register_float_tensor(&args.out.id, output); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim1] = tensor.shape[dim2]; - shape[dim2] = tensor.shape[dim1]; - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::SwapDims( - SwapDimsDescription { - input: tensor.into_description(), - dim1, - dim2, - out: out.to_description_out(), - }, - Box::new(SwapDimsOps::), - ), - )); - - out - } - - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - struct ReshapeDimsOps; - - impl Ops for ReshapeDimsOps { - type Args = ReshapeDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = B::reshape::(input, Shape::from(&args.shape)); - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape.clone()); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::Reshape( - ReshapeDescription { - input: tensor.into_description(), - shape, - out: out.to_description_out(), - }, - Box::new(ReshapeDimsOps::), - ), - )); - - out - } - - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - struct GatherOps; - - impl Ops for GatherOps { - type Args = GatherOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::gather(args.dim, tensor, indices); - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Gather( - GatherOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(GatherOps::), - ), - )); - - out - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - struct ScatterOps; - - impl Ops for ScatterOps { - type Args = ScatterOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_float_tensor(&args.value); - - let output = B::scatter(args.dim, tensor, indices, value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Scatter( - ScatterOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(ScatterOps::), - ), - )); - - out - } - - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - struct SelectOps; - - impl Ops for SelectOps { - type Args = SelectOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::select(tensor, args.dim, indices); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let mut shape: Vec = tensor.shape.clone(); - shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Select( - SelectOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectOps::), - ), - )); - - out - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - struct SelectAssignOps; - - impl Ops for SelectAssignOps { - type Args = SelectAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_float_tensor(&args.value); - - let output = B::select_assign(tensor, args.dim, indices, value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::SelectAssign( - SelectAssignOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectAssignOps::), - ), - )); - - out - } - - fn slice( - tensor: FloatTensor, - ranges: [Range; D2], - ) -> FloatTensor { - struct SliceOps; - - impl Ops for SliceOps { - type Args = SliceOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - - let output = B::slice::(tensor, args.ranges.clone().try_into().unwrap()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); - - for i in shape.len()..D1 { - shape.push(tensor.shape[i]); - } - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::Slice( - SliceOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - out: out.to_description_out(), - }, - Box::new(SliceOps::), - ), - )); - - out - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [Range; D2], - value: FloatTensor, - ) -> FloatTensor { - struct SliceAssignOps; - - impl Ops for SliceAssignOps { - type Args = SliceAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let value = handles.get_float_tensor::(&args.value); - - let output = B::slice_assign::( - tensor, - args.ranges.clone().try_into().unwrap(), - value, - ); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::SliceAssign( - SliceAssignOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SliceAssignOps::), - ), - )); - - out - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor { - struct MaskWhereOps; - - impl Ops for MaskWhereOps { - type Args = MaskWhereOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let value = handles.get_float_tensor(&args.value); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::mask_where(tensor, mask, value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaskWhere( - MaskWhereOpsDescription { - tensor: tensor.into_description(), - value: value.into_description(), - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskWhereOps::), - ), - )); - - out - } - - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - struct MaskFillOps; - - impl Ops for MaskFillOps { - type Args = MaskFillOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::mask_fill(tensor, mask, args.value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaskFill( - MaskFillOpsDescription { - tensor: tensor.into_description(), - value, - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskFillOps::), - ), - )); - - out - } - - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(EqualOps, B::equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::Equal( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(EqualOps::), - ), - )); - - out - } - - fn equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(EqualElemOps, B::equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::EqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(EqualElemOps::), - ), - )); - - out - } - - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(GreaterOps, B::greater); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Greater( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterOps::), - ), - )); - - out - } - - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::GreaterElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterElemOps::), - ), - )); - - out - } - - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(GreaterEqualOps, B::greater_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::GreaterEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterEqualOps::), - ), - )); - - out - } - - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::GreaterEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterEqualElemOps::), - ), - )); - - out - } - - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(LowerOps, B::lower); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Lower( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerOps::), - ), - )); - - out - } - - fn lower_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(LowerElemOps, B::lower_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::LowerElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerElemOps::), - ), - )); - - out - } - - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(LowerEqualOps, B::lower_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::LowerEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerEqualOps::), - ), - )); - - out - } - - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::LowerEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerEqualElemOps::), - ), - )); - - out - } - - fn sum(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SumOps, B::sum); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Sum( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SumOps::), - ), - )); - - out - } - - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(SumDimOps, B::sum_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::SumDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(SumDimOps::), - ), - )); - - out - } - - fn mean(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MeanOps, B::mean); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Mean( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MeanOps::), - ), - )); - - out - } - - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(MeanDimOps, B::mean_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MeanDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MeanDimOps::), - ), - )); - - out - } - - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - tensor.clone() - } - - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - tensor - } - - fn exp(lhs: FloatTensor) -> FloatTensor { - unary_float_ops!(ExpOps, B::exp); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Exp( - UnaryOpsDescription { - input: lhs.into_description(), - out: out.to_description_out(), - }, - Box::new(ExpOps::), - ))); - - out - } - - fn log(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(LogOps, B::log); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::from_data(data, device); + let shape = B::shape(&tensor); - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(LogOps::), - ))); + client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) + } - out - } - - fn log1p(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(Log1pOps, B::log1p); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log1p( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(Log1pOps::), - ))); - - out - } - - fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { - scalar_float_ops!(PowfOps, B::powf, f32); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Powf( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(PowfOps::), - ))); - - out - } - - fn sqrt(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SqrtOps, B::sqrt); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sqrt( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SqrtOps::), - ))); - - out - } - - fn abs(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(AbsOps, B::abs); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Abs( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(AbsOps::), - ), - )); - - out - } - - fn cos(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(CosOps, B::cos); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Cos( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(CosOps::), - ))); + fn random( + shape: Shape, + distribution: Distribution>, + device: &Device, + ) -> FloatTensor { + struct RandomOps; - out - } - - fn sin(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SinOps, B::sin); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sin( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SinOps::), - ))); - - out - } + impl Ops for RandomOps { + type Args = (TensorDescription, Distribution>); - fn tanh(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(TanhOps, B::tanh); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Tanh( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(TanhOps::), - ))); - - out + fn execute(&self, (out, distribution): &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output: B::TensorPrimitive = B::random(shape, *distribution, &handles.device); + handles.register_float_tensor(&out.id, output); + } } - fn recip(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(Recip, B::recip); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(Recip::), - ))); - out - } - - fn erf(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(TanhOps, B::erf); + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + client.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Random( + (out.to_description_out(), distribution), + Box::new(RandomOps::), + ))); + + out + } + + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + struct ZerosOps; - out.client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Erf( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(TanhOps::), - ))); + impl Ops for ZerosOps { + type Args = TensorDescription; - out + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::zeros::(shape, &handles.device); + handles.register_float_tensor(&out.id, output); + } } - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - struct CatOps; - - impl Ops for CatOps { - type Args = CatOpsDescription; + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensors = args - .tensors - .iter() - .map(|tensor| handles.get_float_tensor(tensor)) - .collect(); + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), + )); - let output = B::cat::(tensors, args.dim); + out + } - handles.register_float_tensor(&args.out.id, output); - } - } + fn ones(shape: Shape, device: &Device) -> FloatTensor { + struct OnesOps; - let tensor_first = tensors.get(0).unwrap(); - let client = tensor_first.client.clone(); + impl Ops for OnesOps { + type Args = TensorDescription; - // Calculate the output shape - let mut shape: Vec = tensor_first.shape.clone(); - shape[dim] = 0; - for tensor in tensors.iter() { - shape[dim] += tensor.shape[dim]; - } - - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::BaseOpsFloat(BaseOpsDescription::Cat( - CatOpsDescription { - tensors: tensors.into_iter().map(|t| t.into_description()).collect(), - dim, - out: out.to_description_out(), - }, - Box::new(CatOps::), - ))); - - out - } - - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - scalar_float2int_ops!(ArgMaxOps, B::argmax, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ArgMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMaxOps::), - ), - )); - - out - } - - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - scalar_float2int_ops!(ArgMinOps, B::argmin, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ArgMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMinOps::), - ), - )); - - out + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::ones::(shape, &handles.device); + handles.register_float_tensor(&out.id, output); + } } - fn max(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MaxOps, B::max); + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); - let out = tensor.client.tensor_uninitialized(vec![1]); + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), + )); - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Max( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MaxOps::), - ), - )); + out + } - out - } + fn full( + shape: Shape, + fill_value: FloatElem, + device: &Device, + ) -> FloatTensor { + struct FullOps; - fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(MaxDimOps, B::max_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaxDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MaxDimOps::), - ), - )); - - out - } + impl Ops for FullOps { + type Args = (TensorDescription, FloatElem); - fn max_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - struct MaxDimWithIndicesOps; - - impl Ops for MaxDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let (output, indices) = B::max_dim_with_indices(tensor, args.dim); - - handles.register_float_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaxDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxDimWithIndicesOps::), - ), - )); - - (out, out_indices) + fn execute(&self, (out, value): &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output: B::TensorPrimitive = B::full(shape, *value, &handles.device); + handles.register_float_tensor(&out.id, output); + } } - fn min(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MinOps, B::min); + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); - let out = tensor.client.tensor_uninitialized(vec![1]); + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Full( + (out.to_description_out(), fill_value), + Box::new(FullOps::), + ), + )); - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Min( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MinOps::), - ), - )); + out + } - out - } + fn shape(tensor: &FloatTensor) -> Shape { + tensor.shape() + } - fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(MinDimOps, B::min_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MinDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MinDimOps::), - ), - )); - - out - } + fn into_data(tensor: FloatTensor) -> Reader, D>> { + tensor.into_data() + } - fn min_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - struct MinDimWithIndicesOps; - - impl Ops for MinDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let (output, indices) = B::min_dim_with_indices(tensor, args.dim); - - handles.register_float_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MinDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MinDimWithIndicesOps::), - ), - )); - - (out, out_indices) - } + fn device(tensor: &FloatTensor) -> Device { + tensor.client.device().clone().into() + } + + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor { + let device_original: &B::FusionDevice = tensor.client.device(); + let device_target: B::FusionDevice = device.clone().into(); + + if device_original == &device_target { + return tensor; + } + + let client_target = get_client::(&device_target); + let client_original = tensor.client.clone(); + + client_original + .clone() + .change_client_float::(tensor.into_description(), client_target) + } + + fn into_int(tensor: FloatTensor) -> IntTensor { + struct IntoIntOps; + + impl Ops for IntoIntOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = B::into_int(input); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::FloatOps( + FloatOpsDescription::IntoInt( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoIntOps::), + ), + )); + + out + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::empty(shape.clone(), device); + + client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(AddOps, B::add); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Add( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(AddOps, B::add_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::AddScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out + } + + fn clamp_min( + tensor: FloatTensor, + min: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(ClampMinOps, B::clamp_min); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ClampMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: min, + out: out.to_description_out(), + }, + Box::new(ClampMinOps::), + ), + )); + + out + } + + fn clamp_max( + tensor: FloatTensor, + max: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(ClampMaxOps, B::clamp_max); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ClampMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: max, + out: out.to_description_out(), + }, + Box::new(ClampMaxOps::), + ), + )); + + out + } + + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + struct ClampOps; + + impl Ops for ClampOps { + type Args = ClampOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.tensor); + let output = B::clamp(input, args.min, args.max); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Clamp( + ClampOpsDescription { + tensor: tensor.into_description(), + min, + max, + out: out.to_description_out(), + }, + Box::new(ClampOps::), + ), + )); + + out + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(SubOps, B::sub); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Sub( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(SubOps, B::sub_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::SubScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(MulOps, B::mul); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Mul( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(MulOps, B::mul_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MulScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(DivOps, B::div); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Div( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(DivOps, B::div_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::DivScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(MatmulOps, B::matmul); + + let mut shape = binary_ops_shape(&lhs.shape, &rhs.shape); + + shape[D - 2] = lhs.shape[D - 2]; + shape[D - 1] = rhs.shape[D - 1]; + + let out = lhs.client.tensor_uninitialized(shape); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Matmul( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(MatmulOps::), + ))); + + out + } + + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + struct SwapDimsOps; + + impl Ops for SwapDimsOps { + type Args = SwapDimsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = B::swap_dims(input, args.dim1, args.dim2); + handles.register_float_tensor(&args.out.id, output); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::SwapDims( + SwapDimsDescription { + input: tensor.into_description(), + dim1, + dim2, + out: out.to_description_out(), + }, + Box::new(SwapDimsOps::), + ), + )); + + out + } + + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + struct ReshapeDimsOps; + + impl Ops for ReshapeDimsOps { + type Args = ReshapeDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = B::reshape::(input, Shape::from(&args.shape)); + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let out = tensor.client.tensor_uninitialized(shape.clone()); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::Reshape( + ReshapeDescription { + input: tensor.into_description(), + shape, + out: out.to_description_out(), + }, + Box::new(ReshapeDimsOps::), + ), + )); + + out + } + + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + struct GatherOps; + + impl Ops for GatherOps { + type Args = GatherOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::gather(args.dim, tensor, indices); + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = indices.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Gather( + GatherOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(GatherOps::), + ), + )); + + out + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + struct ScatterOps; + + impl Ops for ScatterOps { + type Args = ScatterOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_float_tensor(&args.value); + + let output = B::scatter(args.dim, tensor, indices, value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Scatter( + ScatterOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(ScatterOps::), + ), + )); + + out + } + + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + struct SelectOps; + + impl Ops for SelectOps { + type Args = SelectOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::select(tensor, args.dim, indices); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let mut shape: Vec = tensor.shape.clone(); + shape[dim] = indices.shape[0]; + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Select( + SelectOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectOps::), + ), + )); + + out + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + struct SelectAssignOps; + + impl Ops for SelectAssignOps { + type Args = SelectAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_float_tensor(&args.value); + + let output = B::select_assign(tensor, args.dim, indices, value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::SelectAssign( + SelectAssignOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectAssignOps::), + ), + )); + + out + } + + fn slice( + tensor: FloatTensor, + ranges: [Range; D2], + ) -> FloatTensor { + struct SliceOps; + + impl Ops for SliceOps { + type Args = SliceOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + + let output = B::slice::(tensor, args.ranges.clone().try_into().unwrap()); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + + for i in shape.len()..D1 { + shape.push(tensor.shape[i]); + } + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::Slice( + SliceOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + out: out.to_description_out(), + }, + Box::new(SliceOps::), + ), + )); + + out + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [Range; D2], + value: FloatTensor, + ) -> FloatTensor { + struct SliceAssignOps; + + impl Ops for SliceAssignOps { + type Args = SliceAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let value = handles.get_float_tensor::(&args.value); + + let output = + B::slice_assign::(tensor, args.ranges.clone().try_into().unwrap(), value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::SliceAssign( + SliceAssignOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SliceAssignOps::), + ), + )); + + out + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor { + struct MaskWhereOps; + + impl Ops for MaskWhereOps { + type Args = MaskWhereOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let value = handles.get_float_tensor(&args.value); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::mask_where(tensor, mask, value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaskWhere( + MaskWhereOpsDescription { + tensor: tensor.into_description(), + value: value.into_description(), + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskWhereOps::), + ), + )); + + out + } + + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + struct MaskFillOps; + + impl Ops for MaskFillOps { + type Args = MaskFillOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::mask_fill(tensor, mask, args.value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaskFill( + MaskFillOpsDescription { + tensor: tensor.into_description(), + value, + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskFillOps::), + ), + )); + + out + } + + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(EqualOps, B::equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::Equal( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(EqualOps::), + ), + )); + + out + } + + fn equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(EqualElemOps, B::equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::EqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(EqualElemOps::), + ), + )); + + out + } + + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(GreaterOps, B::greater); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Greater( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterOps::), + ), + )); + + out + } + + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::GreaterElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterElemOps::), + ), + )); + + out + } + + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(GreaterEqualOps, B::greater_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::GreaterEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterEqualOps::), + ), + )); + + out + } + + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::GreaterEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterEqualElemOps::), + ), + )); + + out + } + + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(LowerOps, B::lower); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Lower( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerOps::), + ), + )); + + out + } + + fn lower_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(LowerElemOps, B::lower_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::LowerElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerElemOps::), + ), + )); + + out + } + + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(LowerEqualOps, B::lower_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::LowerEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerEqualOps::), + ), + )); + + out + } + + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::LowerEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerEqualElemOps::), + ), + )); + + out + } + + fn sum(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(SumOps, B::sum); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Sum( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SumOps::), + ), + )); + + out + } + + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(SumDimOps, B::sum_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::SumDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(SumDimOps::), + ), + )); + + out + } + + fn mean(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(MeanOps, B::mean); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Mean( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MeanOps::), + ), + )); + + out + } + + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(MeanDimOps, B::mean_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MeanDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MeanDimOps::), + ), + )); + + out + } + + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + tensor.clone() + } + + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + tensor + } + + fn exp(lhs: FloatTensor) -> FloatTensor { + unary_float_ops!(ExpOps, B::exp); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Exp( + UnaryOpsDescription { + input: lhs.into_description(), + out: out.to_description_out(), + }, + Box::new(ExpOps::), + ))); + + out + } + + fn log(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(LogOps, B::log); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(LogOps::), + ))); + + out + } + + fn log1p(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(Log1pOps, B::log1p); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log1p( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(Log1pOps::), + ))); + + out + } + + fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { + scalar_float_ops!(PowfOps, B::powf, f32); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Powf( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(PowfOps::), + ))); + + out + } + + fn sqrt(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(SqrtOps, B::sqrt); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sqrt( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SqrtOps::), + ))); + + out + } + + fn abs(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(AbsOps, B::abs); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Abs( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(AbsOps::), + ), + )); + + out + } + + fn cos(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(CosOps, B::cos); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Cos( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(CosOps::), + ))); + + out + } + + fn sin(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(SinOps, B::sin); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sin( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SinOps::), + ))); + + out + } + + fn tanh(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(TanhOps, B::tanh); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Tanh( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(TanhOps::), + ))); + + out + } + + fn recip(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(Recip, B::recip); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(Recip::), + ))); + out + } + + fn erf(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(TanhOps, B::erf); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out + .client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Erf( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(TanhOps::), + ))); + + out + } + + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + struct CatOps; + + impl Ops for CatOps { + type Args = CatOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensors = args + .tensors + .iter() + .map(|tensor| handles.get_float_tensor(tensor)) + .collect(); + + let output = B::cat::(tensors, args.dim); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let tensor_first = tensors.get(0).unwrap(); + let client = tensor_first.client.clone(); + + // Calculate the output shape + let mut shape: Vec = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; + } + + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::BaseOpsFloat(BaseOpsDescription::Cat( + CatOpsDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }, + Box::new(CatOps::), + ))); + + out + } + + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + scalar_float2int_ops!(ArgMaxOps, B::argmax, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ArgMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMaxOps::), + ), + )); + + out + } + + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + scalar_float2int_ops!(ArgMinOps, B::argmin, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ArgMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMinOps::), + ), + )); + + out + } + + fn max(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(MaxOps, B::max); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Max( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MaxOps::), + ), + )); + + out + } + + fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(MaxDimOps, B::max_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaxDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MaxDimOps::), + ), + )); + + out + } + + fn max_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + struct MaxDimWithIndicesOps; + + impl Ops for MaxDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let (output, indices) = B::max_dim_with_indices(tensor, args.dim); + + handles.register_float_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaxDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxDimWithIndicesOps::), + ), + )); + + (out, out_indices) + } + + fn min(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(MinOps, B::min); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Min( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MinOps::), + ), + )); + + out + } + + fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(MinDimOps, B::min_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MinDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MinDimOps::), + ), + )); + + out + } + + fn min_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + struct MinDimWithIndicesOps; + + impl Ops for MinDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let (output, indices) = B::min_dim_with_indices(tensor, args.dim); + + handles.register_float_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MinDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MinDimWithIndicesOps::), + ), + )); + + (out, out_indices) + } } diff --git a/burn-fusion/src/ops/int.rs b/burn-fusion/src/ops/int.rs index 32d2d6a547..8e7f989d8f 100644 --- a/burn-fusion/src/ops/int.rs +++ b/burn-fusion/src/ops/int.rs @@ -1,1401 +1,1406 @@ use crate::{ - binary_int_cmp_ops, binary_int_ops, - client::FusionClient, - get_client, - graph::{ - self, BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, - GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, - NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription, - ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription, - SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, - TensorOpsDescription, UnaryOpsDescription, - }, - ops::binary::binary_ops_shape, - scalar_int_cmp_ops, scalar_int_ops, unary_int_ops, Fusion, FusionBackend, TensorDescription, + binary_int_cmp_ops, binary_int_ops, + client::FusionClient, + get_client, + graph::{ + self, BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, + GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, NumericOpsDescription, + Ops, ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOpsDescription, + ScatterOpsDescription, SelectAssignOpsDescription, SelectOpsDescription, + SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, TensorOpsDescription, + UnaryOpsDescription, + }, + ops::binary::binary_ops_shape, + scalar_int_cmp_ops, scalar_int_ops, unary_int_ops, Fusion, FusionBackend, TensorDescription, }; use burn_tensor::{ - ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, - Data, Device, Reader, Shape, + ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, + Data, Device, Reader, Shape, }; use core::ops::Range; impl IntTensorOps for Fusion { - fn int_empty(shape: Shape, device: &Device) -> IntTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::int_empty(shape.clone(), device); - - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - tensor.shape() - } - - fn int_into_data(tensor: IntTensor) -> Reader, D>> { - tensor.int_into_data() - } - - fn int_from_data( - data: Data, D>, - device: &Device, - ) -> IntTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::int_from_data(data, device); - let shape = B::int_shape(&tensor); - - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::int_empty(shape.clone(), device); + + client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + tensor.shape() + } + + fn int_into_data(tensor: IntTensor) -> Reader, D>> { + tensor.int_into_data() + } + + fn int_from_data( + data: Data, D>, + device: &Device, + ) -> IntTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::int_from_data(data, device); + let shape = B::int_shape(&tensor); + + client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) + } + + fn int_device(tensor: &IntTensor) -> Device { + tensor.client.device().clone().into() + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + let device_original: &B::FusionDevice = tensor.client.device(); + let device_target: B::FusionDevice = device.clone().into(); + + if device_original == &device_target { + return tensor; } - fn int_device(tensor: &IntTensor) -> Device { - tensor.client.device().clone().into() - } + let client_target = get_client::(&device_target); + let client_original = tensor.client.clone(); - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); + client_original + .clone() + .change_client_int::(tensor.into_description(), client_target) + } - if device_original == &device_target { - return tensor; - } + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + struct ReshapeDimsOps; - let client_target = get_client::(&device_target); - let client_original = tensor.client.clone(); + impl Ops for ReshapeDimsOps { + type Args = ReshapeDescription; - client_original - .clone() - .change_client_int::(tensor.into_description(), client_target) + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = B::int_reshape::(input, Shape::from(&args.shape)); + handles.register_int_tensor(&args.out.id, output); + } } - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - struct ReshapeDimsOps; - - impl Ops for ReshapeDimsOps { - type Args = ReshapeDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = B::int_reshape::(input, Shape::from(&args.shape)); - handles.register_int_tensor(&args.out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape.clone()); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt( - BaseOpsDescription::Reshape( - ReshapeDescription { - input: tensor.into_description(), - shape, - out: out.to_description_out(), - }, - Box::new(ReshapeDimsOps::), - ), - )); - - out + let shape: Vec = shape.dims.into(); + let out = tensor.client.tensor_uninitialized(shape.clone()); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt( + BaseOpsDescription::Reshape( + ReshapeDescription { + input: tensor.into_description(), + shape, + out: out.to_description_out(), + }, + Box::new(ReshapeDimsOps::), + ), + )); + + out + } + + fn int_slice( + tensor: IntTensor, + ranges: [Range; D2], + ) -> IntTensor { + struct SliceOps; + + impl Ops for SliceOps { + type Args = SliceOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + + let output = B::int_slice::(tensor, args.ranges.clone().try_into().unwrap()); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_slice( - tensor: IntTensor, - ranges: [Range; D2], - ) -> IntTensor { - struct SliceOps; - - impl Ops for SliceOps { - type Args = SliceOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - - let output = - B::int_slice::(tensor, args.ranges.clone().try_into().unwrap()); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); - for i in shape.len()..D1 { - shape.push(tensor.shape[i]); - } - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Slice( - SliceOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - out: out.to_description_out(), - }, - Box::new(SliceOps::), - ))); - - out + for i in shape.len()..D1 { + shape.push(tensor.shape[i]); } - fn int_slice_assign( - tensor: IntTensor, - ranges: [Range; D2], - value: IntTensor, - ) -> IntTensor { - struct SliceAssignOps; - - impl Ops for SliceAssignOps { - type Args = SliceAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let value = handles.get_int_tensor::(&args.value); - - let output = B::int_slice_assign::( - tensor, - args.ranges.clone().try_into().unwrap(), - value, - ); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt( - BaseOpsDescription::SliceAssign( - SliceAssignOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SliceAssignOps::), - ), - )); - - out + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Slice( + SliceOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + out: out.to_description_out(), + }, + Box::new(SliceOps::), + ))); + + out + } + + fn int_slice_assign( + tensor: IntTensor, + ranges: [Range; D2], + value: IntTensor, + ) -> IntTensor { + struct SliceAssignOps; + + impl Ops for SliceAssignOps { + type Args = SliceAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let value = handles.get_int_tensor::(&args.value); + + let output = + B::int_slice_assign::(tensor, args.ranges.clone().try_into().unwrap(), value); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - value: IntTensor, - ) -> IntTensor { - struct MaskWhereOps; - - impl Ops for MaskWhereOps { - type Args = MaskWhereOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let value = handles.get_int_tensor(&args.value); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::int_mask_where(tensor, mask, value); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaskWhere( - MaskWhereOpsDescription { - tensor: tensor.into_description(), - value: value.into_description(), - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskWhereOps::), - ), - )); - - out + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt( + BaseOpsDescription::SliceAssign( + SliceAssignOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SliceAssignOps::), + ), + )); + + out + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + value: IntTensor, + ) -> IntTensor { + struct MaskWhereOps; + + impl Ops for MaskWhereOps { + type Args = MaskWhereOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let value = handles.get_int_tensor(&args.value); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::int_mask_where(tensor, mask, value); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor { - struct MaskFillOps; - - impl Ops for MaskFillOps { - type Args = MaskFillOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::int_mask_fill(tensor, mask, args.value); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaskFill( - MaskFillOpsDescription { - tensor: tensor.into_description(), - value, - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskFillOps::), - ), - )); - - out + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaskWhere( + MaskWhereOpsDescription { + tensor: tensor.into_description(), + value: value.into_description(), + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskWhereOps::), + ), + )); + + out + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor { + struct MaskFillOps; + + impl Ops for MaskFillOps { + type Args = MaskFillOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::int_mask_fill(tensor, mask, args.value); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - struct GatherOps; - - impl Ops for GatherOps { - type Args = GatherOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::int_gather(args.dim, tensor, indices); - handles.register_int_tensor(&args.out.id, output); - } - } - - let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Gather( - GatherOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(GatherOps::), - ), - )); - - out + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaskFill( + MaskFillOpsDescription { + tensor: tensor.into_description(), + value, + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskFillOps::), + ), + )); + + out + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + struct GatherOps; + + impl Ops for GatherOps { + type Args = GatherOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::int_gather(args.dim, tensor, indices); + handles.register_int_tensor(&args.out.id, output); + } } - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - struct ScatterOps; - - impl Ops for ScatterOps { - type Args = ScatterOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_int_tensor(&args.value); - - let output = B::int_scatter(args.dim, tensor, indices, value); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Scatter( - ScatterOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(ScatterOps::), - ), - )); - - out + let shape: Vec = indices.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Gather( + GatherOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(GatherOps::), + ), + )); + + out + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + struct ScatterOps; + + impl Ops for ScatterOps { + type Args = ScatterOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_int_tensor(&args.value); + + let output = B::int_scatter(args.dim, tensor, indices, value); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - struct SelectOps; - - impl Ops for SelectOps { - type Args = SelectOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::int_select(tensor, args.dim, indices); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let mut shape: Vec = tensor.shape.clone(); - shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Select( - SelectOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectOps::), - ), - )); - - out + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Scatter( + ScatterOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(ScatterOps::), + ), + )); + + out + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + struct SelectOps; + + impl Ops for SelectOps { + type Args = SelectOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::int_select(tensor, args.dim, indices); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - struct SelectAssignOps; - - impl Ops for SelectAssignOps { - type Args = SelectAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_int_tensor(&args.value); - - let output = B::int_select_assign(tensor, args.dim, indices, value); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::SelectAssign( - SelectAssignOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectAssignOps::), - ), - )); - - out + let mut shape: Vec = tensor.shape.clone(); + shape[dim] = indices.shape[0]; + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Select( + SelectOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectOps::), + ), + )); + + out + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + struct SelectAssignOps; + + impl Ops for SelectAssignOps { + type Args = SelectAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_int_tensor(&args.value); + + let output = B::int_select_assign(tensor, args.dim, indices, value); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - struct CatOps; - - impl Ops for CatOps { - type Args = CatOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensors = args - .tensors - .iter() - .map(|tensor| handles.get_int_tensor(tensor)) - .collect(); - - let output = B::int_cat::(tensors, args.dim); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let tensor_first = tensors.get(0).unwrap(); - let client = tensor_first.client.clone(); - - // Calculate the output shape - let mut shape: Vec = tensor_first.shape.clone(); - shape[dim] = 0; - for tensor in tensors.iter() { - shape[dim] += tensor.shape[dim]; - } - - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat( - CatOpsDescription { - tensors: tensors.into_iter().map(|t| t.into_description()).collect(), - dim, - out: out.to_description_out(), - }, - Box::new(CatOps::), - ))); - - out + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::SelectAssign( + SelectAssignOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectAssignOps::), + ), + )); + + out + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + struct CatOps; + + impl Ops for CatOps { + type Args = CatOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensors = args + .tensors + .iter() + .map(|tensor| handles.get_int_tensor(tensor)) + .collect(); + + let output = B::int_cat::(tensors, args.dim); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(EqualOps, B::int_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client - .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(EqualOps::), - ))); - - out - } + let tensor_first = tensors.get(0).unwrap(); + let client = tensor_first.client.clone(); - fn int_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::EqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(EqualElemOps::), - ), - )); - - out + // Calculate the output shape + let mut shape: Vec = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; } - fn int_greater( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(GreaterOps, B::int_greater); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Greater( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterOps::), - ), - )); - - out + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat( + CatOpsDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }, + Box::new(CatOps::), + ))); + + out + } + + fn int_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(EqualOps, B::int_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out + .client + .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(EqualOps::), + ))); + + out + } + + fn int_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::EqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(EqualElemOps::), + ), + )); + + out + } + + fn int_greater( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(GreaterOps, B::int_greater); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Greater( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterOps::), + ), + )); + + out + } + + fn int_greater_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::GreaterElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterElemOps::), + ), + )); + + out + } + + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::GreaterEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterEqualOps::), + ), + )); + + out + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::GreaterEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterEqualElemOps::), + ), + )); + + out + } + + fn int_lower( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(LowerOps, B::int_lower); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Lower( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerOps::), + ), + )); + + out + } + + fn int_lower_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::LowerElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerElemOps::), + ), + )); + + out + } + + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::LowerEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerEqualOps::), + ), + )); + + out + } + + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::LowerEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerEqualElemOps::), + ), + )); + + out + } + + fn int_add( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(AddOps, B::int_add); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Add( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out + } + + fn int_add_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(AddOps, B::int_add_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::AddScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out + } + + fn int_sub( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(SubOps, B::int_sub); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Sub( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out + } + + fn int_sub_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(SubOps, B::int_sub_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::SubScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out + } + + fn int_mul( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(MulOps, B::int_mul); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Mul( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out + } + + fn int_mul_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(MulOps, B::int_mul_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MulScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out + } + + fn int_div( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(DivOps, B::int_div); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Div( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out + } + + fn int_div_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(DivOps, B::int_div_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out + .client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::DivScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + struct ZerosOps; + + impl Ops for ZerosOps { + type Args = TensorDescription; + + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::int_zeros::(shape, &handles.device); + handles.register_int_tensor(&out.id, output); + } } - fn int_greater_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::GreaterElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterElemOps::), - ), - )); - - out - } + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::GreaterEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterEqualOps::), - ), - )); - - out - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::GreaterEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterEqualElemOps::), - ), - )); - - out - } - - fn int_lower( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(LowerOps, B::int_lower); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Lower( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerOps::), - ), - )); - - out - } - - fn int_lower_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::LowerElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerElemOps::), - ), - )); - - out - } + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), + )); - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::LowerEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerEqualOps::), - ), - )); - - out - } + out + } - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::LowerEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerEqualElemOps::), - ), - )); - - out - } + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + struct OnesOps; - fn int_add( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(AddOps, B::int_add); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Add( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out - } + impl Ops for OnesOps { + type Args = TensorDescription; - fn int_add_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(AddOps, B::int_add_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::AddScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::int_ones::(shape, &handles.device); + handles.register_int_tensor(&out.id, output); + } } - fn int_sub( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(SubOps, B::int_sub); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Sub( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), + )); + + out + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + unary_int_ops!(SumOps, B::int_sum); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Sum( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SumOps::), + ), + )); + + out + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(SumDimOps, B::int_sum_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::SumDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(SumDimOps::), + ), + )); + + out + } + + fn int_mean(tensor: IntTensor) -> IntTensor { + unary_int_ops!(MeanOps, B::int_mean); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Mean( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MeanOps::), + ), + )); + + out + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(MeanDimOps, B::int_mean_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MeanDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MeanDimOps::), + ), + )); + + out + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(ArgMaxOps, B::int_argmax, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ArgMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMaxOps::), + ), + )); + + out + } + + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(ArgMinOps, B::int_argmin, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ArgMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMinOps::), + ), + )); + + out + } + + fn int_clamp_min( + tensor: IntTensor, + min: IntElem, + ) -> IntTensor { + scalar_int_ops!(ClampMinOps, B::int_clamp_min); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ClampMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: min, + out: out.to_description_out(), + }, + Box::new(ClampMinOps::), + ), + )); + + out + } + + fn int_clamp_max( + tensor: IntTensor, + max: IntElem, + ) -> IntTensor { + scalar_int_ops!(ClampMaxOps, B::int_clamp_max); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ClampMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: max, + out: out.to_description_out(), + }, + Box::new(ClampMaxOps::), + ), + )); + + out + } + + fn int_clamp( + tensor: IntTensor, + min: IntElem, + max: IntElem, + ) -> IntTensor { + struct ClampOps; + + impl Ops for ClampOps { + type Args = ClampOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.tensor); + let output = B::int_clamp(input, args.min, args.max); + + handles.register_int_tensor(&args.out.id, output); + } } - fn int_sub_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(SubOps, B::int_sub_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::SubScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Clamp( + ClampOpsDescription { + tensor: tensor.into_description(), + min, + max, + out: out.to_description_out(), + }, + Box::new(ClampOps::), + ), + )); + + out + } + + fn int_abs(tensor: IntTensor) -> IntTensor { + unary_int_ops!(AbsOps, B::int_abs); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Abs( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(AbsOps::), + ), + )); + + out + } + + fn int_into_float(tensor: IntTensor) -> FloatTensor { + struct IntoFloatOps; + + impl Ops for IntoFloatOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = B::int_into_float(input); + handles.register_float_tensor(&args.out.id, output); + } } - fn int_mul( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(MulOps, B::int_mul); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Mul( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::IntOps( + graph::IntOpsDescription::IntoFloat( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoFloatOps::), + ), + )); + + out + } + + fn int_swap_dims( + tensor: IntTensor, + dim1: usize, + dim2: usize, + ) -> IntTensor { + struct SwapDimsOps; + + impl Ops for SwapDimsOps { + type Args = SwapDimsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = B::int_swap_dims(input, args.dim1, args.dim2); + handles.register_int_tensor(&args.out.id, output); + } } - fn int_mul_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(MulOps, B::int_mul_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MulScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt( + BaseOpsDescription::SwapDims( + SwapDimsDescription { + input: tensor.into_description(), + dim1, + dim2, + out: out.to_description_out(), + }, + Box::new(SwapDimsOps::), + ), + )); + + out + } + + fn int_max(tensor: IntTensor) -> IntTensor { + unary_int_ops!(MaxOps, B::int_max); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Max( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MaxOps::), + ), + )); + + out + } + + fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(MaxDimOps, B::int_max_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaxDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MaxDimOps::), + ), + )); + + out + } + + fn int_max_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + struct MaxDimWithIndicesOps; + + impl Ops for MaxDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let (output, indices) = B::int_max_dim_with_indices(tensor, args.dim); + + handles.register_int_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } } - fn int_div( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(DivOps, B::int_div); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Div( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaxDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxDimWithIndicesOps::), + ), + )); + + (out, out_indices) + } + + fn int_min(tensor: IntTensor) -> IntTensor { + unary_int_ops!(MinOps, B::int_min); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Min( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MinOps::), + ), + )); + + out + } + + fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(MinDimOps, B::int_min_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MinDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MinDimOps::), + ), + )); + + out + } + + fn int_min_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + struct MinDimWithIndicesOps; + + impl Ops for MinDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let (output, indices) = B::int_min_dim_with_indices(tensor, args.dim); + + handles.register_int_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } } - fn int_div_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(DivOps, B::int_div_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::DivScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - struct ZerosOps; - - impl Ops for ZerosOps { - type Args = TensorDescription; - - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::int_zeros::(shape, &handles.device); - handles.register_int_tensor(&out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), - )); - - out - } - - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - struct OnesOps; - - impl Ops for OnesOps { - type Args = TensorDescription; - - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::int_ones::(shape, &handles.device); - handles.register_int_tensor(&out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), - )); - - out - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - unary_int_ops!(SumOps, B::int_sum); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Sum( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SumOps::), - ), - )); - - out - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(SumDimOps, B::int_sum_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::SumDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(SumDimOps::), - ), - )); - - out - } - - fn int_mean(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MeanOps, B::int_mean); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Mean( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MeanOps::), - ), - )); - - out - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(MeanDimOps, B::int_mean_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MeanDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MeanDimOps::), - ), - )); - - out - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(ArgMaxOps, B::int_argmax, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ArgMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMaxOps::), - ), - )); - - out - } - - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(ArgMinOps, B::int_argmin, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ArgMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMinOps::), - ), - )); - - out - } - - fn int_clamp_min( - tensor: IntTensor, - min: IntElem, - ) -> IntTensor { - scalar_int_ops!(ClampMinOps, B::int_clamp_min); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ClampMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: min, - out: out.to_description_out(), - }, - Box::new(ClampMinOps::), - ), - )); - - out - } - - fn int_clamp_max( - tensor: IntTensor, - max: IntElem, - ) -> IntTensor { - scalar_int_ops!(ClampMaxOps, B::int_clamp_max); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ClampMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: max, - out: out.to_description_out(), - }, - Box::new(ClampMaxOps::), - ), - )); - - out - } - - fn int_clamp( - tensor: IntTensor, - min: IntElem, - max: IntElem, - ) -> IntTensor { - struct ClampOps; - - impl Ops for ClampOps { - type Args = ClampOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.tensor); - let output = B::int_clamp(input, args.min, args.max); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Clamp( - ClampOpsDescription { - tensor: tensor.into_description(), - min, - max, - out: out.to_description_out(), - }, - Box::new(ClampOps::), - ), - )); - - out - } - - fn int_abs(tensor: IntTensor) -> IntTensor { - unary_int_ops!(AbsOps, B::int_abs); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Abs( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(AbsOps::), - ), - )); - - out - } - - fn int_into_float(tensor: IntTensor) -> FloatTensor { - struct IntoFloatOps; - - impl Ops for IntoFloatOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = B::int_into_float(input); - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::IntOps( - graph::IntOpsDescription::IntoFloat( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoFloatOps::), - ), - )); - - out - } - - fn int_swap_dims( - tensor: IntTensor, - dim1: usize, - dim2: usize, - ) -> IntTensor { - struct SwapDimsOps; - - impl Ops for SwapDimsOps { - type Args = SwapDimsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = B::int_swap_dims(input, args.dim1, args.dim2); - handles.register_int_tensor(&args.out.id, output); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim1] = tensor.shape[dim2]; - shape[dim2] = tensor.shape[dim1]; - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt( - BaseOpsDescription::SwapDims( - SwapDimsDescription { - input: tensor.into_description(), - dim1, - dim2, - out: out.to_description_out(), - }, - Box::new(SwapDimsOps::), - ), - )); - - out - } - - fn int_max(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MaxOps, B::int_max); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Max( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MaxOps::), - ), - )); - - out - } - - fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(MaxDimOps, B::int_max_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaxDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MaxDimOps::), - ), - )); - - out - } - - fn int_max_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - struct MaxDimWithIndicesOps; - - impl Ops for MaxDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let (output, indices) = B::int_max_dim_with_indices(tensor, args.dim); - - handles.register_int_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaxDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxDimWithIndicesOps::), - ), - )); - - (out, out_indices) - } - - fn int_min(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MinOps, B::int_min); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Min( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MinOps::), - ), - )); - - out - } - - fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(MinDimOps, B::int_min_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MinDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MinDimOps::), - ), - )); - - out - } - - fn int_min_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - struct MinDimWithIndicesOps; - - impl Ops for MinDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let (output, indices) = B::int_min_dim_with_indices(tensor, args.dim); - - handles.register_int_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MinDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MinDimWithIndicesOps::), - ), - )); - - (out, out_indices) - } + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MinDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MinDimWithIndicesOps::), + ), + )); + + (out, out_indices) + } } diff --git a/burn-fusion/src/ops/module.rs b/burn-fusion/src/ops/module.rs index 2eef4be4b3..d20d8ccd87 100644 --- a/burn-fusion/src/ops/module.rs +++ b/burn-fusion/src/ops/module.rs @@ -1,907 +1,900 @@ use crate::{ - client::FusionClient, - graph::{ - AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, - AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, - AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, - AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, - ConvTranspose2dDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, - MaxPool1dWithIndicesDescription, MaxPool2dDescription, - MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Ops, - TensorOpsDescription, - }, - Fusion, FusionBackend, + client::FusionClient, + graph::{ + AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, + AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, + AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, + AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, + ConvTranspose2dDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, + MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, + MaxPool2dWithIndicesDescription, Ops, TensorOpsDescription, + }, + Fusion, FusionBackend, }; use burn_tensor::ops::{ - conv::{ - calculate_conv_output_size, calculate_conv_transpose_output_size, - calculate_pool_output_size, - }, - ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool1dBackward, - MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + conv::{ + calculate_conv_output_size, calculate_conv_transpose_output_size, calculate_pool_output_size, + }, + ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool1dBackward, + MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; impl ModuleOps> for Fusion { - fn conv1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> FloatTensor { - struct Conv1dOps; - - impl Ops for Conv1dOps { - type Args = Conv1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv1d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size = calculate_conv_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.dilation[0], - x.shape[2], - ); + fn conv1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> FloatTensor { + struct Conv1dOps; + + impl Ops for Conv1dOps { + type Args = Conv1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv1d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } + } - let shape = vec![x.shape[0], weight.shape[0], size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::Conv1d( - Conv1dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(Conv1dOps), - ), - )); - - out + let size = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + + let shape = vec![x.shape[0], weight.shape[0], size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::Conv1d( + Conv1dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(Conv1dOps), + ), + )); + + out + } + + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor { + struct Conv2dOps; + + impl Ops for Conv2dOps { + type Args = Conv2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv2d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } } - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor { - struct Conv2dOps; - - impl Ops for Conv2dOps { - type Args = Conv2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv2d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size_0 = calculate_conv_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.dilation[0], - x.shape[2], - ); - let size_1 = calculate_conv_output_size( - weight.shape[3], - options.stride[1], - options.padding[1], - options.dilation[1], - x.shape[3], - ); + let size_0 = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::Conv2d( + Conv2dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(Conv2dOps), + ), + )); + + out + } + + fn conv_transpose1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + struct ConvTranspose1dOps; + + impl Ops for ConvTranspose1dOps { + type Args = ConvTranspose1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv_transpose1d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } + } - let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::Conv2d( - Conv2dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(Conv2dOps), - ), - )); - - out + let size = calculate_conv_transpose_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.padding_out[0], + options.dilation[0], + x.shape[2], + ); + + let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::ConvTranspose1d( + ConvTranspose1dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(ConvTranspose1dOps), + ), + )); + + out + } + + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + struct ConvTranspose2dOps; + + impl Ops for ConvTranspose2dOps { + type Args = ConvTranspose2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv_transpose2d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } } - fn conv_transpose1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> FloatTensor { - struct ConvTranspose1dOps; - - impl Ops for ConvTranspose1dOps { - type Args = ConvTranspose1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv_transpose1d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size = calculate_conv_transpose_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.padding_out[0], - options.dilation[0], - x.shape[2], + let size_0 = calculate_conv_transpose_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.padding_out[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_transpose_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.padding_out[1], + options.dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::ConvTranspose2d( + ConvTranspose2dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(ConvTranspose2dOps), + ), + )); + + out + } + + fn avg_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool1dOps; + + impl Ops for AvgPool1dOps { + type Args = AvgPool1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::avg_pool1d( + x, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, ); - let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::ConvTranspose1d( - ConvTranspose1dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(ConvTranspose1dOps), - ), - )); - - out + handles.register_float_tensor(&args.out.id, output); + } } - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor { - struct ConvTranspose2dOps; - - impl Ops for ConvTranspose2dOps { - type Args = ConvTranspose2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv_transpose2d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size_0 = calculate_conv_transpose_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.padding_out[0], - options.dilation[0], - x.shape[2], - ); - let size_1 = calculate_conv_transpose_output_size( - weight.shape[3], - options.stride[1], - options.padding[1], - options.padding_out[1], - options.dilation[1], - x.shape[3], + let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); + let shape = vec![x.shape[0], x.shape[1], size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool1d( + AvgPool1dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool1dOps), + ), + )); + + out + } + + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool2dOps; + + impl Ops for AvgPool2dOps { + type Args = AvgPool2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::avg_pool2d( + x, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, ); - let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::ConvTranspose2d( - ConvTranspose2dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(ConvTranspose2dOps), - ), - )); - - out + handles.register_float_tensor(&args.out.id, output); + } } - fn avg_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool1dOps; - - impl Ops for AvgPool1dOps { - type Args = AvgPool1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::avg_pool1d( - x, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, - ); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); - let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool1d( - AvgPool1dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool1dOps), - ), - )); - - out - } + let size_0 = calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]); + let size_1 = calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]); + + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool2d( + AvgPool2dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool2dOps), + ), + )); + + out + } + + fn avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool1dBackwardOps; + + impl Ops for AvgPool1dBackwardOps { + type Args = AvgPool1dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::avg_pool1d_backward( + x, + grad, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, + ); - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool2dOps; - - impl Ops for AvgPool2dOps { - type Args = AvgPool2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::avg_pool2d( - x, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, - ); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size_0 = - calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]); - let size_1 = - calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]); - - let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool2d( - AvgPool2dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool2dOps), - ), - )); - - out + handles.register_float_tensor(&args.out.id, output); + } } - fn avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool1dBackwardOps; - - impl Ops for AvgPool1dBackwardOps { - type Args = AvgPool1dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::avg_pool1d_backward( - x, - grad, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, - ); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool1dBackward( - AvgPool1dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool1dBackwardOps), - ), - )); - - out - } + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool1dBackward( + AvgPool1dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool1dBackwardOps), + ), + )); + + out + } + + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool2dBackwardOps; + + impl Ops for AvgPool2dBackwardOps { + type Args = AvgPool2dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::avg_pool2d_backward( + x, + grad, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, + ); - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool2dBackwardOps; - - impl Ops for AvgPool2dBackwardOps { - type Args = AvgPool2dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::avg_pool2d_backward( - x, - grad, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, - ); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool2dBackward( - AvgPool2dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool2dBackwardOps), - ), - )); - - out + handles.register_float_tensor(&args.out.id, output); + } } - fn max_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> FloatTensor { - struct MaxPool1dOps; - - impl Ops for MaxPool1dOps { - type Args = MaxPool1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool1d( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - ); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); - - let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool1d( - MaxPool1dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool1dOps), - ), - )); - - out + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool2dBackward( + AvgPool2dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool2dBackwardOps), + ), + )); + + out + } + + fn max_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> FloatTensor { + struct MaxPool1dOps; + + impl Ops for MaxPool1dOps { + type Args = MaxPool1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool1d( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + ); + + handles.register_float_tensor(&args.out.id, output); + } } - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor { - struct MaxPool2dOps; - - impl Ops for MaxPool2dOps { - type Args = MaxPool2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool2d( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - ); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size_0 = calculate_pool_output_size( - kernel_size[0], - stride[0], - padding[0], - dilation[0], - x.shape[2], - ); - let size_1 = calculate_pool_output_size( - kernel_size[1], - stride[1], - padding[1], - dilation[1], - x.shape[3], + let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); + + let shape = vec![x.shape[0], x.shape[1], size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool1d( + MaxPool1dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool1dOps), + ), + )); + + out + } + + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor { + struct MaxPool2dOps; + + impl Ops for MaxPool2dOps { + type Args = MaxPool2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool2d( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, ); - let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool2d( - MaxPool2dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool2dOps), - ), - )); - - out + handles.register_float_tensor(&args.out.id, output); + } } - fn max_pool1d_with_indices( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices { - struct MaxPool1dWithIndicesOps; - - impl Ops for MaxPool1dWithIndicesOps { - type Args = MaxPool1dWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool1d_with_indices( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - ); - - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); - } - } - - let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); - let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool1dWithIndices( - MaxPool1dWithIndicesDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxPool1dWithIndicesOps), - ), - )); - - MaxPool1dWithIndices::new(out, out_indices) + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool2d( + MaxPool2dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool2dOps), + ), + )); + + out + } + + fn max_pool1d_with_indices( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices { + struct MaxPool1dWithIndicesOps; + + impl Ops for MaxPool1dWithIndicesOps { + type Args = MaxPool1dWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool1d_with_indices( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + ); + + handles.register_float_tensor(&args.out.id, output.output); + handles.register_int_tensor(&args.out_indices.id, output.indices); + } } - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices { - struct MaxPool2dWithIndicesOps; - - impl Ops for MaxPool2dWithIndicesOps { - type Args = MaxPool2dWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool2d_with_indices( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - ); - - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); - } - } - - let size_0 = calculate_pool_output_size( - kernel_size[0], - stride[0], - padding[0], - dilation[0], - x.shape[2], + let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); + let shape = vec![x.shape[0], x.shape[1], size]; + let out = x.client.tensor_uninitialized(shape.clone()); + let out_indices = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool1dWithIndices( + MaxPool1dWithIndicesDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxPool1dWithIndicesOps), + ), + )); + + MaxPool1dWithIndices::new(out, out_indices) + } + + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices { + struct MaxPool2dWithIndicesOps; + + impl Ops for MaxPool2dWithIndicesOps { + type Args = MaxPool2dWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool2d_with_indices( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, ); - let size_1 = calculate_pool_output_size( - kernel_size[1], - stride[1], - padding[1], - dilation[1], - x.shape[3], + + handles.register_float_tensor(&args.out.id, output.output); + handles.register_int_tensor(&args.out_indices.id, output.indices); + } + } + + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape.clone()); + let out_indices = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool2dWithIndices( + MaxPool2dWithIndicesDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxPool2dWithIndicesOps), + ), + )); + + MaxPool2dWithIndices::new(out, out_indices) + } + + fn max_pool1d_with_indices_backward( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool1dBackward { + struct MaxPool1dWithIndicesBackwardOps; + + impl Ops for MaxPool1dWithIndicesBackwardOps { + type Args = MaxPool1dWithIndicesBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let indices = handles.get_int_tensor(&args.indices); + let output = B::max_pool1d_with_indices_backward( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + grad, + indices, ); - let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool2dWithIndices( - MaxPool2dWithIndicesDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxPool2dWithIndicesOps), - ), - )); - - MaxPool2dWithIndices::new(out, out_indices) + handles.register_float_tensor(&args.out.id, output.x_grad); + } } - fn max_pool1d_with_indices_backward( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool1dBackward { - struct MaxPool1dWithIndicesBackwardOps; - - impl Ops for MaxPool1dWithIndicesBackwardOps { - type Args = MaxPool1dWithIndicesBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); - let output = B::max_pool1d_with_indices_backward( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - grad, - indices, - ); - - handles.register_float_tensor(&args.out.id, output.x_grad); - } - } - - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool1dWithIndicesBackward( - MaxPool1dWithIndicesBackwardDescription { - x: x.into_description(), - grad: output_grad.into_description(), - indices: indices.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool1dWithIndicesBackwardOps), - ), - )); - - MaxPool1dBackward::new(out) + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool1dWithIndicesBackward( + MaxPool1dWithIndicesBackwardDescription { + x: x.into_description(), + grad: output_grad.into_description(), + indices: indices.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool1dWithIndicesBackwardOps), + ), + )); + + MaxPool1dBackward::new(out) + } + + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward { + struct MaxPool2dWithIndicesBackwardOps; + + impl Ops for MaxPool2dWithIndicesBackwardOps { + type Args = MaxPool2dWithIndicesBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let indices = handles.get_int_tensor(&args.indices); + let output = B::max_pool2d_with_indices_backward( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + grad, + indices, + ); + + handles.register_float_tensor(&args.out.id, output.x_grad); + } } - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward { - struct MaxPool2dWithIndicesBackwardOps; - - impl Ops for MaxPool2dWithIndicesBackwardOps { - type Args = MaxPool2dWithIndicesBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); - let output = B::max_pool2d_with_indices_backward( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - grad, - indices, - ); - - handles.register_float_tensor(&args.out.id, output.x_grad); - } - } - - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool2dWithIndicesBackward( - MaxPool2dWithIndicesBackwardDescription { - x: x.into_description(), - grad: output_grad.into_description(), - indices: indices.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool2dWithIndicesBackwardOps), - ), - )); - - MaxPool2dBackward::new(out) + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool2dWithIndicesBackward( + MaxPool2dWithIndicesBackwardDescription { + x: x.into_description(), + grad: output_grad.into_description(), + indices: indices.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool2dWithIndicesBackwardOps), + ), + )); + + MaxPool2dBackward::new(out) + } + + fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { + struct AdaptiveAvgPool1dOps; + + impl Ops for AdaptiveAvgPool1dOps { + type Args = AdaptiveAvgPool1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::adaptive_avg_pool1d(x, args.output_size); + + handles.register_float_tensor(&args.out.id, output); + } } - fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { - struct AdaptiveAvgPool1dOps; + let shape = vec![x.shape[0], x.shape[1], output_size]; + let out = x.client.tensor_uninitialized(shape); - impl Ops for AdaptiveAvgPool1dOps { - type Args = AdaptiveAvgPool1dDescription; + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool1d( + AdaptiveAvgPool1dDescription { + x: x.into_description(), + output_size, + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool1dOps), + ), + )); - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::adaptive_avg_pool1d(x, args.output_size); + out + } - handles.register_float_tensor(&args.out.id, output); - } - } + fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { + struct AdaptiveAvgPool2dOps; - let shape = vec![x.shape[0], x.shape[1], output_size]; - let out = x.client.tensor_uninitialized(shape); + impl Ops for AdaptiveAvgPool2dOps { + type Args = AdaptiveAvgPool2dDescription; - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool1d( - AdaptiveAvgPool1dDescription { - x: x.into_description(), - output_size, - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool1dOps), - ), - )); + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::adaptive_avg_pool2d(x, args.output_size); - out + handles.register_float_tensor(&args.out.id, output); + } } - fn adaptive_avg_pool2d( - x: FloatTensor, - output_size: [usize; 2], - ) -> FloatTensor { - struct AdaptiveAvgPool2dOps; - - impl Ops for AdaptiveAvgPool2dOps { - type Args = AdaptiveAvgPool2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::adaptive_avg_pool2d(x, args.output_size); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool2d( - AdaptiveAvgPool2dDescription { - x: x.into_description(), - output_size, - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool2dOps), - ), - )); - - out + let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool2d( + AdaptiveAvgPool2dDescription { + x: x.into_description(), + output_size, + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool2dOps), + ), + )); + + out + } + + fn adaptive_avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + struct AdaptiveAvgPool1dBackwardOps; + + impl Ops for AdaptiveAvgPool1dBackwardOps { + type Args = AdaptiveAvgPool1dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::adaptive_avg_pool1d_backward(x, grad); + + handles.register_float_tensor(&args.out.id, output); + } } - fn adaptive_avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - struct AdaptiveAvgPool1dBackwardOps; - - impl Ops for AdaptiveAvgPool1dBackwardOps { - type Args = AdaptiveAvgPool1dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::adaptive_avg_pool1d_backward(x, grad); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool1dBackward( - AdaptiveAvgPool1dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool1dBackwardOps), - ), - )); - - out + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool1dBackward( + AdaptiveAvgPool1dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool1dBackwardOps), + ), + )); + + out + } + + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + struct AdaptiveAvgPool2dBackwardOps; + + impl Ops for AdaptiveAvgPool2dBackwardOps { + type Args = AdaptiveAvgPool2dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::adaptive_avg_pool2d_backward(x, grad); + + handles.register_float_tensor(&args.out.id, output); + } } - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - struct AdaptiveAvgPool2dBackwardOps; - - impl Ops for AdaptiveAvgPool2dBackwardOps { - type Args = AdaptiveAvgPool2dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::adaptive_avg_pool2d_backward(x, grad); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool2dBackward( - AdaptiveAvgPool2dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool2dBackwardOps), - ), - )); - - out - } + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool2dBackward( + AdaptiveAvgPool2dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool2dBackwardOps), + ), + )); + + out + } } diff --git a/burn-fusion/src/ops/unary.rs b/burn-fusion/src/ops/unary.rs index 84f6900b51..f35e00aef6 100644 --- a/burn-fusion/src/ops/unary.rs +++ b/burn-fusion/src/ops/unary.rs @@ -1,168 +1,168 @@ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_ops { - ( + ( $name:ident, $ops:expr ) => { - scalar_float_ops!($name, $ops, FloatElem); - }; - ( + scalar_float_ops!($name, $ops, FloatElem); + }; + ( $name:ident, $ops:expr, $elem:ty ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription<$elem>; + impl Ops for $name { + type Args = ScalarOpsDescription<$elem>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_float_tensor(&args.out.id, output); - } - } - }; + handles.register_float_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float2int_ops { - ( + ( $name:ident, $ops:expr, $elem:ty ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription<$elem>; + impl Ops for $name { + type Args = ScalarOpsDescription<$elem>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_float_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = UnaryOpsDescription; + impl Ops for $name { + type Args = UnaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = $ops(input); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = $ops(input); - handles.register_float_tensor(&args.out.id, output); - } - } - }; + handles.register_float_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_int_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = UnaryOpsDescription; + impl Ops for $name { + type Args = UnaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = $ops(input); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = $ops(input); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription>; + impl Ops for $name { + type Args = ScalarOpsDescription>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription>; + impl Ops for $name { + type Args = ScalarOpsDescription>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_ops { - ( + ( $name:ident, $ops:expr ) => { - scalar_int_ops!($name, $ops, IntElem); - }; - ( + scalar_int_ops!($name, $ops, IntElem); + }; + ( $name:ident, $ops:expr, $elem:ty ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription<$elem>; + impl Ops for $name { + type Args = ScalarOpsDescription<$elem>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } diff --git a/burn-fusion/src/server.rs b/burn-fusion/src/server.rs index 9b52d38295..f2dfb163d6 100644 --- a/burn-fusion/src/server.rs +++ b/burn-fusion/src/server.rs @@ -1,161 +1,162 @@ use crate::{ - graph::{Graph, GraphExecution, Optimization, TensorOpsDescription}, - FusionBackend, FusionProperties, FusionStatus, HandleContainer, TensorId, + graph::{Graph, GraphExecution, Optimization, TensorOpsDescription}, + FusionBackend, FusionProperties, FusionStatus, HandleContainer, TensorId, }; use burn_tensor::ops::{FloatElem, IntElem}; use std::sync::Arc; pub struct FusionServer where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - optimizations: Vec>, - graph: Graph, - pub(crate) handles: HandleContainer, - execution: G, - pub device: B::FusionDevice, + optimizations: Vec>, + graph: Graph, + pub(crate) handles: HandleContainer, + execution: G, + pub device: B::FusionDevice, } impl FusionServer where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - pub fn new(device: B::FusionDevice) -> Self { - let optimizations = B::operations(&device.clone().into()) - .into_iter() - .map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default()))) - .collect(); - - Self { - optimizations, - graph: Graph::new(), - handles: HandleContainer::new(device.clone()), - execution: G::default(), - device, - } + pub fn new(device: B::FusionDevice) -> Self { + let optimizations = B::operations(&device.clone().into()) + .into_iter() + .map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default()))) + .collect(); + + Self { + optimizations, + graph: Graph::new(), + handles: HandleContainer::new(device.clone()), + execution: G::default(), + device, } - - pub fn register(&mut self, ops: TensorOpsDescription) { - let ops = Arc::new(ops); - self.graph.add(ops.clone()); - - self.optimizations - .iter_mut() - .for_each(|optimization| optimization.register(&ops)); - - self.execution.maybe_execute( - &mut self.graph, - &mut self.handles, - &mut self.optimizations, - false, - ); + } + + pub fn register(&mut self, ops: TensorOpsDescription) { + let ops = Arc::new(ops); + self.graph.add(ops.clone()); + + self + .optimizations + .iter_mut() + .for_each(|optimization| optimization.register(&ops)); + + self.execution.maybe_execute( + &mut self.graph, + &mut self.handles, + &mut self.optimizations, + false, + ); + } + + pub fn drain_graph(&mut self) { + if self.graph.is_empty() { + return; } - pub fn drain_graph(&mut self) { - if self.graph.is_empty() { - return; - } - - self.execution.maybe_execute( - &mut self.graph, - &mut self.handles, - &mut self.optimizations, - true, - ); - } - - pub fn create_empty_handle(&mut self) -> Arc { - self.handles.create_tensor_uninit() - } - - pub fn read_float( - &mut self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> { - // Make sure all registered operations are executed. - // The underlying backend can still be async. - self.drain_graph(); - - let tensor = self.handles.get_float_tensor(&tensor); - B::into_data(tensor) - } - - pub fn read_int( - &mut self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> { - // Make sure all registered operations are executed. - // The underlying backend can still be async. - self.drain_graph(); - - let tensor = self.handles.get_int_tensor(&tensor); - B::int_into_data(tensor) - } - - pub fn read_bool( - &mut self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader> { - // Make sure all registered operations are executed. - // The underlying backend can still be async. - self.drain_graph(); - - let tensor = self.handles.get_bool_tensor(&tensor); - B::bool_into_data(tensor) - } - - pub fn change_server_float( - &mut self, - tensor: &crate::TensorDescription, - device: &B::Device, - server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_float_tensor::(tensor); - let tensor = B::to_device(tensor, device); - let id = server_device.create_empty_handle(); - - server_device - .handles - .register_float_tensor(&id, tensor.clone()); - - id - } - pub fn change_server_int( - &mut self, - tensor: &crate::TensorDescription, - device: &B::Device, - server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_int_tensor::(tensor); - let tensor = B::int_to_device(tensor, device); - let id = server_device.create_empty_handle(); - - server_device - .handles - .register_int_tensor(&id, tensor.clone()); - - id - } - pub fn change_server_bool( - &mut self, - tensor: &crate::TensorDescription, - device: &B::Device, - server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_bool_tensor::(tensor); - let tensor = B::bool_to_device(tensor, device); - let id = server_device.create_empty_handle(); - - server_device - .handles - .register_bool_tensor(&id, tensor.clone()); - - id - } - - pub fn drop_tensor_handle(&mut self, id: TensorId) { - self.handles.handles_orphan.push(id); - } + self.execution.maybe_execute( + &mut self.graph, + &mut self.handles, + &mut self.optimizations, + true, + ); + } + + pub fn create_empty_handle(&mut self) -> Arc { + self.handles.create_tensor_uninit() + } + + pub fn read_float( + &mut self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> { + // Make sure all registered operations are executed. + // The underlying backend can still be async. + self.drain_graph(); + + let tensor = self.handles.get_float_tensor(&tensor); + B::into_data(tensor) + } + + pub fn read_int( + &mut self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> { + // Make sure all registered operations are executed. + // The underlying backend can still be async. + self.drain_graph(); + + let tensor = self.handles.get_int_tensor(&tensor); + B::int_into_data(tensor) + } + + pub fn read_bool( + &mut self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader> { + // Make sure all registered operations are executed. + // The underlying backend can still be async. + self.drain_graph(); + + let tensor = self.handles.get_bool_tensor(&tensor); + B::bool_into_data(tensor) + } + + pub fn change_server_float( + &mut self, + tensor: &crate::TensorDescription, + device: &B::Device, + server_device: &mut Self, + ) -> Arc { + let tensor = self.handles.get_float_tensor::(tensor); + let tensor = B::to_device(tensor, device); + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_float_tensor(&id, tensor.clone()); + + id + } + pub fn change_server_int( + &mut self, + tensor: &crate::TensorDescription, + device: &B::Device, + server_device: &mut Self, + ) -> Arc { + let tensor = self.handles.get_int_tensor::(tensor); + let tensor = B::int_to_device(tensor, device); + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_int_tensor(&id, tensor.clone()); + + id + } + pub fn change_server_bool( + &mut self, + tensor: &crate::TensorDescription, + device: &B::Device, + server_device: &mut Self, + ) -> Arc { + let tensor = self.handles.get_bool_tensor::(tensor); + let tensor = B::bool_to_device(tensor, device); + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_bool_tensor(&id, tensor.clone()); + + id + } + + pub fn drop_tensor_handle(&mut self, id: TensorId) { + self.handles.handles_orphan.push(id); + } } diff --git a/burn-fusion/src/tensor.rs b/burn-fusion/src/tensor.rs index 70ffcf3937..7cac08af54 100644 --- a/burn-fusion/src/tensor.rs +++ b/burn-fusion/src/tensor.rs @@ -1,140 +1,140 @@ use crate::client::FusionClient; use burn_tensor::{ - backend::Backend, - ops::{FloatElem, IntElem}, - Data, Reader, Shape, + backend::Backend, + ops::{FloatElem, IntElem}, + Data, Reader, Shape, }; use std::sync::Arc; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. #[derive(Clone)] pub struct FusionTensor { - /// Tensor id. - pub id: Arc, - /// The shape of the tensor. - pub shape: Vec, - /// The [fusion client](FusionClient). - pub client: C, - // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. - // - // When a tensor is dropped and is still an orphan, we need to register it as such to avoid - // memory leak. Otherwise, the cleanup is going to happen during a graph execution. - pub(crate) is_orphan: bool, + /// Tensor id. + pub id: Arc, + /// The shape of the tensor. + pub shape: Vec, + /// The [fusion client](FusionClient). + pub client: C, + // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. + // + // When a tensor is dropped and is still an orphan, we need to register it as such to avoid + // memory leak. Otherwise, the cleanup is going to happen during a graph execution. + pub(crate) is_orphan: bool, } impl core::fmt::Debug for FusionTensor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str( - format!( - "{{ id: {:?}, shape: {:?}, should_drop: {:?}, backend: {:?}, device: {:?} }}", - self.id, - self.shape, - self.is_orphan, - ::name(), - self.client.device().clone().into(), - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + format!( + "{{ id: {:?}, shape: {:?}, should_drop: {:?}, backend: {:?}, device: {:?} }}", + self.id, + self.shape, + self.is_orphan, + ::name(), + self.client.device().clone().into(), + ) + .as_str(), + ) + } } impl FusionTensor { - pub(crate) fn new(id: Arc, shape: Vec, client: C) -> Self { - Self { - id, - shape, - client, - is_orphan: true, - } - } - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.shape.clone()) + pub(crate) fn new(id: Arc, shape: Vec, client: C) -> Self { + Self { + id, + shape, + client, + is_orphan: true, } - - fn status(&self) -> TensorStatus { - if Arc::strong_count(&self.id) <= 1 { - TensorStatus::ReadWrite - } else { - TensorStatus::ReadOnly - } - } - - /// Description to be used when using an uninitialized tensor as output. - pub(crate) fn to_description_out(&self) -> TensorDescription { - TensorDescription { - status: TensorStatus::NotInit, - shape: self.shape.clone(), - id: self.id.as_ref().clone(), - } + } + pub(crate) fn shape(&self) -> Shape { + Shape::from(self.shape.clone()) + } + + fn status(&self) -> TensorStatus { + if Arc::strong_count(&self.id) <= 1 { + TensorStatus::ReadWrite + } else { + TensorStatus::ReadOnly } - - /// Description to be used when using an initialized tensor used as input. - pub(crate) fn into_description(mut self) -> TensorDescription { - let status = self.status(); - let mut shape_out = Vec::new(); - core::mem::swap(&mut self.shape, &mut shape_out); - - if let TensorStatus::ReadWrite = status { - self.is_orphan = false; - } - - TensorDescription { - status, - shape: shape_out, - id: self.id.as_ref().clone(), - } + } + + /// Description to be used when using an uninitialized tensor as output. + pub(crate) fn to_description_out(&self) -> TensorDescription { + TensorDescription { + status: TensorStatus::NotInit, + shape: self.shape.clone(), + id: self.id.as_ref().clone(), } + } - pub(crate) fn into_data(self) -> Reader, D>> { - self.client - .clone() - .read_tensor_float(self.into_description()) - } + /// Description to be used when using an initialized tensor used as input. + pub(crate) fn into_description(mut self) -> TensorDescription { + let status = self.status(); + let mut shape_out = Vec::new(); + core::mem::swap(&mut self.shape, &mut shape_out); - pub(crate) fn int_into_data( - self, - ) -> Reader, D>> { - self.client.clone().read_tensor_int(self.into_description()) + if let TensorStatus::ReadWrite = status { + self.is_orphan = false; } - pub(crate) fn bool_into_data(self) -> Reader> { - self.client - .clone() - .read_tensor_bool(self.into_description()) + TensorDescription { + status, + shape: shape_out, + id: self.id.as_ref().clone(), } + } + + pub(crate) fn into_data(self) -> Reader, D>> { + self + .client + .clone() + .read_tensor_float(self.into_description()) + } + + pub(crate) fn int_into_data(self) -> Reader, D>> { + self.client.clone().read_tensor_int(self.into_description()) + } + + pub(crate) fn bool_into_data(self) -> Reader> { + self + .client + .clone() + .read_tensor_bool(self.into_description()) + } } impl Drop for FusionTensor { - fn drop(&mut self) { - if !self.is_orphan { - return; - } - - match self.status() { - TensorStatus::ReadWrite => { - self.client.register_orphan(&self.id); - } - TensorStatus::ReadOnly => {} - TensorStatus::NotInit => {} - } + fn drop(&mut self) { + if !self.is_orphan { + return; } + + match self.status() { + TensorStatus::ReadWrite => { + self.client.register_orphan(&self.id); + } + TensorStatus::ReadOnly => {} + TensorStatus::NotInit => {} + } + } } /// The tensor unique identifier. #[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct TensorId { - value: u64, + value: u64, } /// The status of the current tensor. #[derive(Hash, Clone, Debug, PartialEq, Eq)] pub enum TensorStatus { - /// The tensor can be read, but not written. - ReadOnly, - /// The tensor can be mutated inplace. - ReadWrite, - /// No handle exists for that tensor. - NotInit, + /// The tensor can be read, but not written. + ReadOnly, + /// The tensor can be mutated inplace. + ReadWrite, + /// No handle exists for that tensor. + NotInit, } /// A tensor definition represents a snapshot of a tensor when it was used. @@ -149,17 +149,17 @@ pub enum TensorStatus { /// 4. Status::ReadWrite #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct TensorDescription { - /// The [tensor id](TensorId). - pub id: TensorId, - /// The shape of the tensor. - pub shape: Vec, - /// The [status](TensorStatus) of the tensor when it was used. - pub status: TensorStatus, + /// The [tensor id](TensorId). + pub id: TensorId, + /// The shape of the tensor. + pub shape: Vec, + /// The [status](TensorStatus) of the tensor when it was used. + pub status: TensorStatus, } impl TensorId { - /// Create a new tensor id. - pub fn new(value: u64) -> Self { - Self { value } - } + /// Create a new tensor id. + pub fn new(value: u64) -> Self { + Self { value } + } } diff --git a/burn-import/build.rs b/burn-import/build.rs index f6cd681658..da847f00dc 100644 --- a/burn-import/build.rs +++ b/burn-import/build.rs @@ -1,11 +1,11 @@ fn main() { - if cfg!(feature = "onnx") { - // Generate the onnx protobuf files - protobuf_codegen::Codegen::new() - .pure() - .includes(["src"]) - .input("src/onnx/protos/onnx.proto") - .cargo_out_dir("onnx-protos") - .run_from_script(); - } + if cfg!(feature = "onnx") { + // Generate the onnx protobuf files + protobuf_codegen::Codegen::new() + .pure() + .includes(["src"]) + .input("src/onnx/protos/onnx.proto") + .cargo_out_dir("onnx-protos") + .run_from_script(); + } } diff --git a/burn-import/onnx-tests/build.rs b/burn-import/onnx-tests/build.rs index dc1576099e..62afd67f7d 100644 --- a/burn-import/onnx-tests/build.rs +++ b/burn-import/onnx-tests/build.rs @@ -1,114 +1,114 @@ use burn_import::onnx::{ModelGen, RecordType}; fn main() { - // Re-run this build script if the onnx-tests directory changes. - println!("cargo:rerun-if-changed=tests"); + // Re-run this build script if the onnx-tests directory changes. + println!("cargo:rerun-if-changed=tests"); - // Add onnx models. - ModelGen::new() - .input("tests/add/add_int.onnx") - .input("tests/add/add.onnx") - .input("tests/avg_pool2d/avg_pool2d.onnx") - .input("tests/batch_norm/batch_norm.onnx") - .input("tests/clip/clip_opset16.onnx") - .input("tests/clip/clip_opset7.onnx") - .input("tests/concat/concat.onnx") - .input("tests/conv1d/conv1d.onnx") - .input("tests/conv2d/conv2d.onnx") - .input("tests/div/div.onnx") - .input("tests/dropout/dropout_opset16.onnx") - .input("tests/dropout/dropout_opset7.onnx") - .input("tests/equal/equal.onnx") - .input("tests/erf/erf.onnx") - .input("tests/flatten/flatten.onnx") - .input("tests/gather/gather.onnx") - .input("tests/global_avr_pool/global_avr_pool.onnx") - .input("tests/linear/linear.onnx") - .input("tests/log_softmax/log_softmax.onnx") - .input("tests/maxpool2d/maxpool2d.onnx") - .input("tests/mul/mul.onnx") - .input("tests/recip/recip.onnx") - .input("tests/relu/relu.onnx") - .input("tests/reshape/reshape.onnx") - .input("tests/sigmoid/sigmoid.onnx") - .input("tests/softmax/softmax.onnx") - .input("tests/sub/sub_int.onnx") - .input("tests/sub/sub.onnx") - .input("tests/tanh/tanh.onnx") - .input("tests/transpose/transpose.onnx") - .out_dir("model/") - .run_from_script(); + // Add onnx models. + ModelGen::new() + .input("tests/add/add_int.onnx") + .input("tests/add/add.onnx") + .input("tests/avg_pool2d/avg_pool2d.onnx") + .input("tests/batch_norm/batch_norm.onnx") + .input("tests/clip/clip_opset16.onnx") + .input("tests/clip/clip_opset7.onnx") + .input("tests/concat/concat.onnx") + .input("tests/conv1d/conv1d.onnx") + .input("tests/conv2d/conv2d.onnx") + .input("tests/div/div.onnx") + .input("tests/dropout/dropout_opset16.onnx") + .input("tests/dropout/dropout_opset7.onnx") + .input("tests/equal/equal.onnx") + .input("tests/erf/erf.onnx") + .input("tests/flatten/flatten.onnx") + .input("tests/gather/gather.onnx") + .input("tests/global_avr_pool/global_avr_pool.onnx") + .input("tests/linear/linear.onnx") + .input("tests/log_softmax/log_softmax.onnx") + .input("tests/maxpool2d/maxpool2d.onnx") + .input("tests/mul/mul.onnx") + .input("tests/recip/recip.onnx") + .input("tests/relu/relu.onnx") + .input("tests/reshape/reshape.onnx") + .input("tests/sigmoid/sigmoid.onnx") + .input("tests/softmax/softmax.onnx") + .input("tests/sub/sub_int.onnx") + .input("tests/sub/sub.onnx") + .input("tests/tanh/tanh.onnx") + .input("tests/transpose/transpose.onnx") + .out_dir("model/") + .run_from_script(); - // The following tests are used to generate the model with different record types. - // (e.g. bincode, pretty_json, etc.) Do not need to add new tests here, just use the default - // record type to the ModelGen::new() call above. + // The following tests are used to generate the model with different record types. + // (e.g. bincode, pretty_json, etc.) Do not need to add new tests here, just use the default + // record type to the ModelGen::new() call above. - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk/") - .record_type(RecordType::NamedMpk) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk/") + .record_type(RecordType::NamedMpk) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk_half/") - .record_type(RecordType::NamedMpk) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk_half/") + .record_type(RecordType::NamedMpk) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/pretty_json/") - .record_type(RecordType::PrettyJson) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/pretty_json/") + .record_type(RecordType::PrettyJson) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/pretty_json_half/") - .record_type(RecordType::PrettyJson) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/pretty_json_half/") + .record_type(RecordType::PrettyJson) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk_gz/") - .record_type(RecordType::NamedMpkGz) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk_gz/") + .record_type(RecordType::NamedMpkGz) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk_gz_half/") - .record_type(RecordType::NamedMpkGz) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk_gz_half/") + .record_type(RecordType::NamedMpkGz) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode/") - .record_type(RecordType::Bincode) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode/") + .record_type(RecordType::Bincode) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode_half/") - .record_type(RecordType::Bincode) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode_half/") + .record_type(RecordType::Bincode) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode_embedded/") - .embed_states(true) - .record_type(RecordType::Bincode) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode_embedded/") + .embed_states(true) + .record_type(RecordType::Bincode) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode_embedded_half/") - .embed_states(true) - .half_precision(true) - .record_type(RecordType::Bincode) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode_embedded_half/") + .embed_states(true) + .half_precision(true) + .record_type(RecordType::Bincode) + .run_from_script(); - // panic!("Purposefully failing build to output logs."); + // panic!("Purposefully failing build to output logs."); } diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs index b10c45849c..3863bc9204 100644 --- a/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/burn-import/onnx-tests/tests/onnx_tests.rs @@ -13,599 +13,599 @@ macro_rules! include_models { // ATTENTION: Modify this macro to include all models in the `model` directory. include_models!( - add_int, - add, - avg_pool2d, - batch_norm, - clip_opset16, - clip_opset7, - concat, - conv1d, - conv2d, - div, - dropout_opset16, - dropout_opset7, - equal, - erf, - flatten, - gather, - global_avr_pool, - linear, - log_softmax, - maxpool2d, - mul, - recip, - relu, - reshape, - sigmoid, - softmax, - sub_int, - sub, - tanh, - transpose + add_int, + add, + avg_pool2d, + batch_norm, + clip_opset16, + clip_opset7, + concat, + conv1d, + conv2d, + div, + dropout_opset16, + dropout_opset7, + equal, + erf, + flatten, + gather, + global_avr_pool, + linear, + log_softmax, + maxpool2d, + mul, + recip, + relu, + reshape, + sigmoid, + softmax, + sub_int, + sub, + tanh, + transpose ); #[cfg(test)] mod tests { - use core::f64::consts; - - use super::*; - - use burn::tensor::{Data, Int, Shape, Tensor}; + use core::f64::consts; + + use super::*; + + use burn::tensor::{Data, Int, Shape, Tensor}; - use float_cmp::ApproxEq; - - type Backend = burn_ndarray::NdArray; + use float_cmp::ApproxEq; + + type Backend = burn_ndarray::NdArray; - #[test] - fn add_scalar_to_tensor_and_tensor_to_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: add::Model = add::Model::default(); + #[test] + fn add_scalar_to_tensor_and_tensor_to_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: add::Model = add::Model::default(); - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let scalar = 2f64; - let output = model.forward(input, scalar); - let expected = Data::from([[[[9., 10., 11., 12.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 2f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[9., 10., 11., 12.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn add_scalar_to_int_tensor_and_int_tensor_to_int_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: add_int::Model = add_int::Model::default(); - - // Run the model - let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); - let scalar = 2; - let output = model.forward(input, scalar); - let expected = Data::from([[[[9, 11, 13, 15]]]]); + #[test] + fn add_scalar_to_int_tensor_and_int_tensor_to_int_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: add_int::Model = add_int::Model::default(); + + // Run the model + let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); + let scalar = 2; + let output = model.forward(input, scalar); + let expected = Data::from([[[[9, 11, 13, 15]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn sub_scalar_from_tensor_and_tensor_from_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: sub::Model = sub::Model::default(); + #[test] + fn sub_scalar_from_tensor_and_tensor_from_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: sub::Model = sub::Model::default(); - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let scalar = 3.0f64; - let output = model.forward(input, scalar); - let expected = Data::from([[[[6., 7., 8., 9.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 3.0f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[6., 7., 8., 9.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn sub_scalar_from_int_tensor_and_int_tensor_from_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: sub_int::Model = sub_int::Model::default(); + #[test] + fn sub_scalar_from_int_tensor_and_int_tensor_from_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: sub_int::Model = sub_int::Model::default(); - // Run the model - let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); - let scalar = 3; - let output = model.forward(input, scalar); - let expected = Data::from([[[[6, 6, 6, 6]]]]); + // Run the model + let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); + let scalar = 3; + let output = model.forward(input, scalar); + let expected = Data::from([[[[6, 6, 6, 6]]]]); - assert_eq!(output.to_data(), expected); - } - #[test] - fn mul_scalar_with_tensor_and_tensor_with_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: mul::Model = mul::Model::default(); + assert_eq!(output.to_data(), expected); + } + #[test] + fn mul_scalar_with_tensor_and_tensor_with_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: mul::Model = mul::Model::default(); - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let scalar = 6.0f64; - let output = model.forward(input, scalar); - let expected = Data::from([[[[126., 252., 378., 504.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 6.0f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[126., 252., 378., 504.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn div_tensor_by_scalar_and_tensor_by_tensor() { - // Initialize the model without weights (because the exported file does not contain them) - let model: div::Model = div::Model::new(); + #[test] + fn div_tensor_by_scalar_and_tensor_by_tensor() { + // Initialize the model without weights (because the exported file does not contain them) + let model: div::Model = div::Model::new(); - // Run the model - let input = Tensor::::from_floats([[[[3., 6., 6., 9.]]]]); - let scalar1 = 9.0f64; - let scalar2 = 3.0f64; - let output = model.forward(input, scalar1, scalar2); - let expected = Data::from([[[[1., 2., 2., 3.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[3., 6., 6., 9.]]]]); + let scalar1 = 9.0f64; + let scalar2 = 3.0f64; + let output = model.forward(input, scalar1, scalar2); + let expected = Data::from([[[[1., 2., 2., 3.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn concat_tensors() { - // Initialize the model - let model: concat::Model = concat::Model::new(); + #[test] + fn concat_tensors() { + // Initialize the model + let model: concat::Model = concat::Model::new(); - // Run the model - let input = Tensor::::zeros([1, 2, 3, 5]); + // Run the model + let input = Tensor::::zeros([1, 2, 3, 5]); - let output = model.forward(input); + let output = model.forward(input); - let expected = Shape::from([1, 18, 3, 5]); + let expected = Shape::from([1, 18, 3, 5]); - assert_eq!(output.shape(), expected); - } + assert_eq!(output.shape(), expected); + } - #[test] - fn conv1d() { - // Initialize the model with weights (loaded from the exported file) - let model: conv1d::Model = conv1d::Model::default(); + #[test] + fn conv1d() { + // Initialize the model with weights (loaded from the exported file) + let model: conv1d::Model = conv1d::Model::default(); - // Run the model with pi as input for easier testing - let input = Tensor::::full([6, 4, 10], consts::PI); + // Run the model with pi as input for easier testing + let input = Tensor::::full([6, 4, 10], consts::PI); - let output = model.forward(input); + let output = model.forward(input); - // test the output shape - let expected_shape: Shape<3> = Shape::from([6, 2, 7]); - assert_eq!(output.shape(), expected_shape); + // test the output shape + let expected_shape: Shape<3> = Shape::from([6, 2, 7]); + assert_eq!(output.shape(), expected_shape); - // We are using the sum of the output tensor to test the correctness of the conv1d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum = output.sum().into_scalar(); - let expected_sum = -54.549_243; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + // We are using the sum of the output tensor to test the correctness of the conv1d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); + let expected_sum = -54.549_243; // from pytorch + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn conv2d() { - // Initialize the model with weights (loaded from the exported file) - let model: conv2d::Model = conv2d::Model::default(); + #[test] + fn conv2d() { + // Initialize the model with weights (loaded from the exported file) + let model: conv2d::Model = conv2d::Model::default(); - // Run the model with ones as input for easier testing - let input = Tensor::::ones([2, 4, 10, 15]); + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15]); - let output = model.forward(input); + let output = model.forward(input); - let expected_shape = Shape::from([2, 6, 6, 15]); - assert_eq!(output.shape(), expected_shape); + let expected_shape = Shape::from([2, 6, 6, 15]); + assert_eq!(output.shape(), expected_shape); - // We are using the sum of the output tensor to test the correctness of the conv2d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum = output.sum().into_scalar(); + // We are using the sum of the output tensor to test the correctness of the conv2d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); - let expected_sum = -113.869_99; // from pytorch + let expected_sum = -113.869_99; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn dropout_opset16() { - let model: dropout_opset16::Model = dropout_opset16::Model::default(); + #[test] + fn dropout_opset16() { + let model: dropout_opset16::Model = dropout_opset16::Model::default(); - // Run the model with ones as input for easier testing - let input = Tensor::::ones([2, 4, 10, 15]); + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15]); - let output = model.forward(input); + let output = model.forward(input); - let expected_shape = Shape::from([2, 4, 10, 15]); - assert_eq!(output.shape(), expected_shape); + let expected_shape = Shape::from([2, 4, 10, 15]); + assert_eq!(output.shape(), expected_shape); - let output_sum = output.sum().into_scalar(); + let output_sum = output.sum().into_scalar(); - let expected_sum = 1200.0; // from pytorch + let expected_sum = 1200.0; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn dropout_opset7() { - let model: dropout_opset7::Model = dropout_opset7::Model::default(); + #[test] + fn dropout_opset7() { + let model: dropout_opset7::Model = dropout_opset7::Model::default(); - // Run the model with ones as input for easier testing - let input = Tensor::::ones([2, 4, 10, 15]); + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15]); - let output = model.forward(input); + let output = model.forward(input); - let expected_shape = Shape::from([2, 4, 10, 15]); - assert_eq!(output.shape(), expected_shape); + let expected_shape = Shape::from([2, 4, 10, 15]); + assert_eq!(output.shape(), expected_shape); - let output_sum = output.sum().into_scalar(); + let output_sum = output.sum().into_scalar(); - let expected_sum = 1200.0; // from pytorch + let expected_sum = 1200.0; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn erf() { - let model: erf::Model = erf::Model::default(); + #[test] + fn erf() { + let model: erf::Model = erf::Model::default(); - let input = Tensor::::from_data([[[[1.0, 2.0, 3.0, 4.0]]]]); - let output = model.forward(input); - let expected = Tensor::::from_data([[[[0.8427, 0.9953, 1.0000, 1.0000]]]]); + let input = Tensor::::from_data([[[[1.0, 2.0, 3.0, 4.0]]]]); + let output = model.forward(input); + let expected = Tensor::::from_data([[[[0.8427, 0.9953, 1.0000, 1.0000]]]]); - output.to_data().assert_approx_eq(&expected.to_data(), 4); - } + output.to_data().assert_approx_eq(&expected.to_data(), 4); + } - #[test] - fn gather() { - // Initialize the model with weights (loaded from the exported file) - let model: gather::Model = gather::Model::default(); + #[test] + fn gather() { + // Initialize the model with weights (loaded from the exported file) + let model: gather::Model = gather::Model::default(); - // Run the model - let input = Tensor::::from_floats([[1., 2.], [3., 4.]]); - let index = Tensor::::from_ints([[0, 0], [1, 0]]); - let output = model.forward(input, index); - let expected = Data::from([[1., 1.], [4., 3.]]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn globalavrpool_1d_2d() { - // The model contains 1d and 2d global average pooling nodes - let model: global_avr_pool::Model = global_avr_pool::Model::default(); - - // Run the model with ones as input for easier testing - let input_1d = Tensor::::ones([2, 4, 10]); - let input_2d = Tensor::::ones([3, 10, 3, 15]); - - let (output_1d, output_2d) = model.forward(input_1d, input_2d); - - let expected_shape_1d = Shape::from([2, 4, 1]); - let expected_shape_2d = Shape::from([3, 10, 1, 1]); - assert_eq!(output_1d.shape(), expected_shape_1d); - assert_eq!(output_2d.shape(), expected_shape_2d); - - let output_sum_1d = output_1d.sum().into_scalar(); - let output_sum_2d = output_2d.sum().into_scalar(); - - let expected_sum_1d = 8.0; // from pytorch - let expected_sum_2d = 30.0; // from pytorch - - assert!(expected_sum_1d.approx_eq(output_sum_1d, (1.0e-4, 2))); - assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2))); - } - - #[test] - fn softmax() { - // Initialize the model without weights (because the exported file does not contain them) - let model: softmax::Model = softmax::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.36830685, 0.29917702, 0.33251613], - [0.521_469_2, 0.13475533, 0.343_775_5], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn log_softmax() { - // Initialize the model without weights (because the exported file does not contain them) - let model: log_softmax::Model = log_softmax::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [-0.998_838_9, -1.206_719_9, -1.101_067], - [-0.651_105_1, -2.004_294_6, -1.067_766_4], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn maxpool2d() { - // Initialize the model without weights (because the exported file does not contain them) - let model: maxpool2d::Model = maxpool2d::Model::new(); - - // Run the model - let input = Tensor::::from_floats([[[ - [1.927, 1.487, 0.901, -2.106, 0.678], - [-1.235, -0.043, -1.605, -0.752, -0.687], - [-0.493, 0.241, -1.111, 0.092, -2.317], - [-0.217, -1.385, -0.396, 0.803, -0.622], - [-0.592, -0.063, -0.829, 0.331, -1.558], - ]]]); - let output = model.forward(input); - let expected = Data::from([[[ - [0.901, 1.927, 1.487, 0.901], - [0.901, 1.927, 1.487, 0.901], - [-0.396, 0.803, 0.241, -0.396], - ]]]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn avg_pool2d() { - // Initialize the model without weights (because the exported file does not contain them) - let model: avg_pool2d::Model = avg_pool2d::Model::new(); - - // Run the model - let input = Tensor::::from_floats([[[ - [-0.077, 0.360, -0.782, 0.072, 0.665], - [-0.287, 1.621, -1.597, -0.052, 0.611], - [0.760, -0.034, -0.345, 0.494, -0.078], - [-1.805, -0.476, 0.205, 0.338, 1.353], - [0.374, 0.013, 0.774, -0.109, -0.271], - ]]]); - let output = model.forward(input); - let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]); - - output.to_data().assert_approx_eq(&expected, 3); - } - - #[test] - fn reshape() { - // Initialize the model without weights (because the exported file does not contain them) - let model: reshape::Model = reshape::Model::new(); - - // Run the model - let input = Tensor::::from_floats([0., 1., 2., 3.]); - let output = model.forward(input); - let expected = Data::from([[0., 1., 2., 3.]]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn flatten() { - // Initialize the model without weights (because the exported file does not contain them) - let model: flatten::Model = flatten::Model::new(); - - // Run the model - let input = Tensor::::ones([1, 5, 15]); - let output = model.forward(input); - - let expected_shape = Shape::from([1, 75]); - assert_eq!(expected_shape, output.shape()); - } - - #[test] - fn batch_norm() { - let model: batch_norm::Model = batch_norm::Model::default(); - - // Run the model with ones as input for easier testing - let input = Tensor::::ones([1, 20, 1]); - let output = model.forward(input); - - let expected_shape = Shape::from([1, 5, 2, 2]); - assert_eq!(output.shape(), expected_shape); - - let output_sum = output.sum().into_scalar(); - let expected_sum = 19.999_802; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2))); - } - - #[test] - fn relu() { - // Initialize the model without weights (because the exported file does not contain them) - let model: relu::Model = relu::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, 0.00000000, 0.00000000], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn sigmoid() { - // Initialize the model without weights (because the exported file does not contain them) - let model: sigmoid::Model = sigmoid::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.58338636, 0.532_157_9, 0.55834854], - [0.557_33, 0.24548186, 0.45355222], - ]); - - output.to_data().assert_approx_eq(&expected, 7); - } - - #[test] - fn transpose() { - // Initialize the model without weights (because the exported file does not contain them) - let model: transpose::Model = transpose::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.33669037, 0.23033303], - [0.128_809_4, -1.122_856_4], - [0.23446237, -0.18632829], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn equal_scalar_to_scalar_and_tensor_to_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: equal::Model = equal::Model::default(); - - // Run the model - let input = Tensor::::from_floats([[[[1., 1., 1., 1.]]]]); - - let scalar = 2f64; - let (tensor_out, scalar_out) = model.forward(input, scalar); - let expected_tensor = Data::from([[[[true, true, true, true]]]]); - let expected_scalar = false; - - assert_eq!(tensor_out.to_data(), expected_tensor); - assert_eq!(scalar_out, expected_scalar); - } - - #[test] - fn clip_opset16() { - // Initialize the model without weights (because the exported file does not contain them) - let model: clip_opset16::Model = clip_opset16::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let (output1, output2, output3) = model.forward(input); - let expected1 = Data::from([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); - let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); - - assert_eq!(output1.to_data(), expected1); - assert_eq!(output2.to_data(), expected2); - assert_eq!(output3.to_data(), expected3); - } - - #[test] - fn clip_opset7() { - // Initialize the model without weights (because the exported file does not contain them) - let model: clip_opset7::Model = clip_opset7::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let (output1, output2, output3) = model.forward(input); - let expected1 = Data::from([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); - let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); - - assert_eq!(output1.to_data(), expected1); - assert_eq!(output2.to_data(), expected2); - assert_eq!(output3.to_data(), expected3); - } - - #[test] - fn linear() { - // Initialize the model with weights (loaded from the exported file) - let model: linear::Model = linear::Model::default(); - #[allow(clippy::approx_constant)] - let input1 = Tensor::::full([4, 3], 3.14); - #[allow(clippy::approx_constant)] - let input2 = Tensor::::full([2, 5], 3.14); - #[allow(clippy::approx_constant)] - let input3 = Tensor::::full([3, 2, 7], 3.14); - - let (output1, output2, output3) = model.forward(input1, input2, input3); - - // test the output shape - let expected_shape1: Shape<2> = Shape::from([4, 4]); - let expected_shape2: Shape<2> = Shape::from([2, 6]); - let expected_shape3: Shape<3> = Shape::from([3, 2, 8]); - assert_eq!(output1.shape(), expected_shape1); - assert_eq!(output2.shape(), expected_shape2); - assert_eq!(output3.shape(), expected_shape3); - - // We are using the sum of the output tensor to test the correctness of the conv1d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum1 = output1.sum().into_scalar(); - let output_sum2 = output2.sum().into_scalar(); - let output_sum3 = output3.sum().into_scalar(); - - let expected_sum1 = -9.655_477; // from pytorch - let expected_sum2 = -8.053_822; // from pytorch - let expected_sum3 = 27.575_281; // from pytorch - - assert!(expected_sum1.approx_eq(output_sum1, (1.0e-6, 2))); - assert!(expected_sum2.approx_eq(output_sum2, (1.0e-6, 2))); - assert!(expected_sum3.approx_eq(output_sum3, (1.0e-6, 2))); - } - - #[test] - fn tanh() { - // Initialize the model - let model = tanh::Model::::new(); - - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let output = model.forward(input); - // data from pyTorch - let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]); - output.to_data().assert_approx_eq(&expected, 4); - } - - #[test] - fn recip() { - // Initialize the model - let model = recip::Model::::new(); - - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let output = model.forward(input); - // data from pyTorch - let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]); - output.to_data().assert_approx_eq(&expected, 4); - } + // Run the model + let input = Tensor::::from_floats([[1., 2.], [3., 4.]]); + let index = Tensor::::from_ints([[0, 0], [1, 0]]); + let output = model.forward(input, index); + let expected = Data::from([[1., 1.], [4., 3.]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn globalavrpool_1d_2d() { + // The model contains 1d and 2d global average pooling nodes + let model: global_avr_pool::Model = global_avr_pool::Model::default(); + + // Run the model with ones as input for easier testing + let input_1d = Tensor::::ones([2, 4, 10]); + let input_2d = Tensor::::ones([3, 10, 3, 15]); + + let (output_1d, output_2d) = model.forward(input_1d, input_2d); + + let expected_shape_1d = Shape::from([2, 4, 1]); + let expected_shape_2d = Shape::from([3, 10, 1, 1]); + assert_eq!(output_1d.shape(), expected_shape_1d); + assert_eq!(output_2d.shape(), expected_shape_2d); + + let output_sum_1d = output_1d.sum().into_scalar(); + let output_sum_2d = output_2d.sum().into_scalar(); + + let expected_sum_1d = 8.0; // from pytorch + let expected_sum_2d = 30.0; // from pytorch + + assert!(expected_sum_1d.approx_eq(output_sum_1d, (1.0e-4, 2))); + assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2))); + } + + #[test] + fn softmax() { + // Initialize the model without weights (because the exported file does not contain them) + let model: softmax::Model = softmax::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.36830685, 0.29917702, 0.33251613], + [0.521_469_2, 0.13475533, 0.343_775_5], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn log_softmax() { + // Initialize the model without weights (because the exported file does not contain them) + let model: log_softmax::Model = log_softmax::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [-0.998_838_9, -1.206_719_9, -1.101_067], + [-0.651_105_1, -2.004_294_6, -1.067_766_4], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn maxpool2d() { + // Initialize the model without weights (because the exported file does not contain them) + let model: maxpool2d::Model = maxpool2d::Model::new(); + + // Run the model + let input = Tensor::::from_floats([[[ + [1.927, 1.487, 0.901, -2.106, 0.678], + [-1.235, -0.043, -1.605, -0.752, -0.687], + [-0.493, 0.241, -1.111, 0.092, -2.317], + [-0.217, -1.385, -0.396, 0.803, -0.622], + [-0.592, -0.063, -0.829, 0.331, -1.558], + ]]]); + let output = model.forward(input); + let expected = Data::from([[[ + [0.901, 1.927, 1.487, 0.901], + [0.901, 1.927, 1.487, 0.901], + [-0.396, 0.803, 0.241, -0.396], + ]]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn avg_pool2d() { + // Initialize the model without weights (because the exported file does not contain them) + let model: avg_pool2d::Model = avg_pool2d::Model::new(); + + // Run the model + let input = Tensor::::from_floats([[[ + [-0.077, 0.360, -0.782, 0.072, 0.665], + [-0.287, 1.621, -1.597, -0.052, 0.611], + [0.760, -0.034, -0.345, 0.494, -0.078], + [-1.805, -0.476, 0.205, 0.338, 1.353], + [0.374, 0.013, 0.774, -0.109, -0.271], + ]]]); + let output = model.forward(input); + let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]); + + output.to_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn reshape() { + // Initialize the model without weights (because the exported file does not contain them) + let model: reshape::Model = reshape::Model::new(); + + // Run the model + let input = Tensor::::from_floats([0., 1., 2., 3.]); + let output = model.forward(input); + let expected = Data::from([[0., 1., 2., 3.]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn flatten() { + // Initialize the model without weights (because the exported file does not contain them) + let model: flatten::Model = flatten::Model::new(); + + // Run the model + let input = Tensor::::ones([1, 5, 15]); + let output = model.forward(input); + + let expected_shape = Shape::from([1, 75]); + assert_eq!(expected_shape, output.shape()); + } + + #[test] + fn batch_norm() { + let model: batch_norm::Model = batch_norm::Model::default(); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([1, 20, 1]); + let output = model.forward(input); + + let expected_shape = Shape::from([1, 5, 2, 2]); + assert_eq!(output.shape(), expected_shape); + + let output_sum = output.sum().into_scalar(); + let expected_sum = 19.999_802; // from pytorch + assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2))); + } + + #[test] + fn relu() { + // Initialize the model without weights (because the exported file does not contain them) + let model: relu::Model = relu::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, 0.00000000, 0.00000000], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn sigmoid() { + // Initialize the model without weights (because the exported file does not contain them) + let model: sigmoid::Model = sigmoid::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.58338636, 0.532_157_9, 0.55834854], + [0.557_33, 0.24548186, 0.45355222], + ]); + + output.to_data().assert_approx_eq(&expected, 7); + } + + #[test] + fn transpose() { + // Initialize the model without weights (because the exported file does not contain them) + let model: transpose::Model = transpose::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.33669037, 0.23033303], + [0.128_809_4, -1.122_856_4], + [0.23446237, -0.18632829], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn equal_scalar_to_scalar_and_tensor_to_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: equal::Model = equal::Model::default(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 1., 1., 1.]]]]); + + let scalar = 2f64; + let (tensor_out, scalar_out) = model.forward(input, scalar); + let expected_tensor = Data::from([[[[true, true, true, true]]]]); + let expected_scalar = false; + + assert_eq!(tensor_out.to_data(), expected_tensor); + assert_eq!(scalar_out, expected_scalar); + } + + #[test] + fn clip_opset16() { + // Initialize the model without weights (because the exported file does not contain them) + let model: clip_opset16::Model = clip_opset16::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let (output1, output2, output3) = model.forward(input); + let expected1 = Data::from([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); + let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); + + assert_eq!(output1.to_data(), expected1); + assert_eq!(output2.to_data(), expected2); + assert_eq!(output3.to_data(), expected3); + } + + #[test] + fn clip_opset7() { + // Initialize the model without weights (because the exported file does not contain them) + let model: clip_opset7::Model = clip_opset7::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let (output1, output2, output3) = model.forward(input); + let expected1 = Data::from([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); + let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); + + assert_eq!(output1.to_data(), expected1); + assert_eq!(output2.to_data(), expected2); + assert_eq!(output3.to_data(), expected3); + } + + #[test] + fn linear() { + // Initialize the model with weights (loaded from the exported file) + let model: linear::Model = linear::Model::default(); + #[allow(clippy::approx_constant)] + let input1 = Tensor::::full([4, 3], 3.14); + #[allow(clippy::approx_constant)] + let input2 = Tensor::::full([2, 5], 3.14); + #[allow(clippy::approx_constant)] + let input3 = Tensor::::full([3, 2, 7], 3.14); + + let (output1, output2, output3) = model.forward(input1, input2, input3); + + // test the output shape + let expected_shape1: Shape<2> = Shape::from([4, 4]); + let expected_shape2: Shape<2> = Shape::from([2, 6]); + let expected_shape3: Shape<3> = Shape::from([3, 2, 8]); + assert_eq!(output1.shape(), expected_shape1); + assert_eq!(output2.shape(), expected_shape2); + assert_eq!(output3.shape(), expected_shape3); + + // We are using the sum of the output tensor to test the correctness of the conv1d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum1 = output1.sum().into_scalar(); + let output_sum2 = output2.sum().into_scalar(); + let output_sum3 = output3.sum().into_scalar(); + + let expected_sum1 = -9.655_477; // from pytorch + let expected_sum2 = -8.053_822; // from pytorch + let expected_sum3 = 27.575_281; // from pytorch + + assert!(expected_sum1.approx_eq(output_sum1, (1.0e-6, 2))); + assert!(expected_sum2.approx_eq(output_sum2, (1.0e-6, 2))); + assert!(expected_sum3.approx_eq(output_sum3, (1.0e-6, 2))); + } + + #[test] + fn tanh() { + // Initialize the model + let model = tanh::Model::::new(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let output = model.forward(input); + // data from pyTorch + let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]); + output.to_data().assert_approx_eq(&expected, 4); + } + + #[test] + fn recip() { + // Initialize the model + let model = recip::Model::::new(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let output = model.forward(input); + // data from pyTorch + let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]); + output.to_data().assert_approx_eq(&expected, 4); + } } diff --git a/burn-import/onnx-tests/tests/record_type_tests.rs b/burn-import/onnx-tests/tests/record_type_tests.rs index 98558b5019..d532672125 100644 --- a/burn-import/onnx-tests/tests/record_type_tests.rs +++ b/burn-import/onnx-tests/tests/record_type_tests.rs @@ -6,58 +6,58 @@ // different. macro_rules! test_model { - ($mod_name:ident) => { - test_model!($mod_name, 1.0e-4); // Default tolerance - }; - ($mod_name:ident, $tolerance:expr) => { - pub mod $mod_name { - include!(concat!( - env!("OUT_DIR"), - "/model/", - stringify!($mod_name), - "/conv1d.rs" - )); - } - - #[test] - fn $mod_name() { - // Initialize the model with weights (loaded from the exported file) - let model: $mod_name::Model = $mod_name::Model::default(); - - // Run the model with pi as input for easier testing - let input = Tensor::::full([6, 4, 10], consts::PI); - - let output = model.forward(input); - - // test the output shape - let expected_shape: Shape<3> = Shape::from([6, 2, 7]); - assert_eq!(output.shape(), expected_shape); - - // We are using the sum of the output tensor to test the correctness of the conv1d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum = output.sum().into_scalar(); - let expected_sum = -54.549_243; // from pytorch - assert!(expected_sum.approx_eq(output_sum, ($tolerance, 2))); - } - }; + ($mod_name:ident) => { + test_model!($mod_name, 1.0e-4); // Default tolerance + }; + ($mod_name:ident, $tolerance:expr) => { + pub mod $mod_name { + include!(concat!( + env!("OUT_DIR"), + "/model/", + stringify!($mod_name), + "/conv1d.rs" + )); + } + + #[test] + fn $mod_name() { + // Initialize the model with weights (loaded from the exported file) + let model: $mod_name::Model = $mod_name::Model::default(); + + // Run the model with pi as input for easier testing + let input = Tensor::::full([6, 4, 10], consts::PI); + + let output = model.forward(input); + + // test the output shape + let expected_shape: Shape<3> = Shape::from([6, 2, 7]); + assert_eq!(output.shape(), expected_shape); + + // We are using the sum of the output tensor to test the correctness of the conv1d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); + let expected_sum = -54.549_243; // from pytorch + assert!(expected_sum.approx_eq(output_sum, ($tolerance, 2))); + } + }; } #[cfg(test)] mod tests { - use burn::tensor::{Shape, Tensor}; - use float_cmp::ApproxEq; - use std::f64::consts; - - type Backend = burn_ndarray::NdArray; - - test_model!(named_mpk); - test_model!(named_mpk_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(pretty_json); - test_model!(pretty_json_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(named_mpk_gz); - test_model!(named_mpk_gz_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(bincode); - test_model!(bincode_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(bincode_embedded); - test_model!(bincode_embedded_half, 1.0e-2); // Reduce tolerance for half precision + use burn::tensor::{Shape, Tensor}; + use float_cmp::ApproxEq; + use std::f64::consts; + + type Backend = burn_ndarray::NdArray; + + test_model!(named_mpk); + test_model!(named_mpk_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(pretty_json); + test_model!(pretty_json_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(named_mpk_gz); + test_model!(named_mpk_gz_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(bincode); + test_model!(bincode_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(bincode_embedded); + test_model!(bincode_embedded_half, 1.0e-2); // Reduce tolerance for half precision } diff --git a/burn-import/src/burn/codegen.rs b/burn-import/src/burn/codegen.rs index ed617e8086..2f61292b2b 100644 --- a/burn-import/src/burn/codegen.rs +++ b/burn-import/src/burn/codegen.rs @@ -5,89 +5,89 @@ use burn::nn::PaddingConfig1d; use burn::nn::PaddingConfig2d; fn convert_primitive(primitive: T) -> TokenStream { - let value = primitive.to_string(); - value.parse().unwrap() + let value = primitive.to_string(); + value.parse().unwrap() } fn convert_to_array<'a, I, T: ToTokens>(list: I) -> TokenStream where - I: Iterator, - T: 'a, + I: Iterator, + T: 'a, { - let mut body = quote! {}; + let mut body = quote! {}; - list.for_each(|item| { - let elem = item.to_tokens(); - body.extend(quote! {#elem,}); - }); + list.for_each(|item| { + let elem = item.to_tokens(); + body.extend(quote! {#elem,}); + }); - quote! { - [#body] - } + quote! { + [#body] + } } pub trait ToTokens { - fn to_tokens(&self) -> TokenStream; + fn to_tokens(&self) -> TokenStream; } impl ToTokens for [T; N] { - fn to_tokens(&self) -> TokenStream { - convert_to_array(self.iter()) - } + fn to_tokens(&self) -> TokenStream { + convert_to_array(self.iter()) + } } impl ToTokens for Vec { - fn to_tokens(&self) -> TokenStream { - convert_to_array(self.iter()) - } + fn to_tokens(&self) -> TokenStream { + convert_to_array(self.iter()) + } } /// Prettier output for `usize` impl ToTokens for usize { - fn to_tokens(&self) -> TokenStream { - convert_primitive(self) - } + fn to_tokens(&self) -> TokenStream { + convert_primitive(self) + } } /// Prettier output for `i64` impl ToTokens for i64 { - fn to_tokens(&self) -> TokenStream { - convert_primitive(self) - } + fn to_tokens(&self) -> TokenStream { + convert_primitive(self) + } } /// Prettier output for `f64` impl ToTokens for f64 { - fn to_tokens(&self) -> TokenStream { - convert_primitive(self) - } + fn to_tokens(&self) -> TokenStream { + convert_primitive(self) + } } /// Padding configuration impl ToTokens for PaddingConfig1d { - fn to_tokens(&self) -> TokenStream { - match self { - Self::Same => quote! { PaddingConfig1d::Same }, - Self::Valid => quote! { PaddingConfig1d::Valid }, - Self::Explicit(padding) => { - let padding = padding.to_tokens(); - quote! { PaddingConfig1d::Explicit(#padding) } - } - } + fn to_tokens(&self) -> TokenStream { + match self { + Self::Same => quote! { PaddingConfig1d::Same }, + Self::Valid => quote! { PaddingConfig1d::Valid }, + Self::Explicit(padding) => { + let padding = padding.to_tokens(); + quote! { PaddingConfig1d::Explicit(#padding) } + } } + } } /// Padding configuration impl ToTokens for PaddingConfig2d { - fn to_tokens(&self) -> TokenStream { - match self { - Self::Same => quote! { PaddingConfig2d::Same }, - Self::Valid => quote! { PaddingConfig2d::Valid }, - Self::Explicit(padding1, padding2) => { - let padding1 = padding1.to_tokens(); - let padding2 = padding2.to_tokens(); - quote! { PaddingConfig2d::Explicit(#padding1, #padding2) } - } - } + fn to_tokens(&self) -> TokenStream { + match self { + Self::Same => quote! { PaddingConfig2d::Same }, + Self::Valid => quote! { PaddingConfig2d::Valid }, + Self::Explicit(padding1, padding2) => { + let padding1 = padding1.to_tokens(); + let padding2 = padding2.to_tokens(); + quote! { PaddingConfig2d::Explicit(#padding1, #padding2) } + } } + } } diff --git a/burn-import/src/burn/graph.rs b/burn-import/src/burn/graph.rs index 98437c3877..ea0c2dfc9c 100644 --- a/burn-import/src/burn/graph.rs +++ b/burn-import/src/burn/graph.rs @@ -1,606 +1,618 @@ use super::{BurnImports, Scope, Type}; use crate::burn::{ - node::{Node, NodeCodegen}, - TensorKind, TensorType, + node::{Node, NodeCodegen}, + TensorKind, TensorType, }; use burn::record::{ - BinFileRecorder, BurnRecord, FileRecorder, NamedMpkFileRecorder, NamedMpkGzFileRecorder, - PrecisionSettings, PrettyJsonFileRecorder, Recorder, + BinFileRecorder, BurnRecord, FileRecorder, NamedMpkFileRecorder, NamedMpkGzFileRecorder, + PrecisionSettings, PrettyJsonFileRecorder, Recorder, }; use proc_macro2::TokenStream; use quote::quote; use serde::{ - ser::{SerializeMap, SerializeTuple}, - Serialize, + ser::{SerializeMap, SerializeTuple}, + Serialize, }; use std::{any::type_name, collections::HashMap, path::PathBuf}; /// Type of the record to be saved. #[derive(Debug, Clone, Default, Copy)] pub enum RecordType { - /// Pretty JSON format (useful for debugging). - PrettyJson, + /// Pretty JSON format (useful for debugging). + PrettyJson, - #[default] - /// Compressed Named MessagePack. - NamedMpkGz, + #[default] + /// Compressed Named MessagePack. + NamedMpkGz, - /// Uncompressed Named MessagePack. - NamedMpk, + /// Uncompressed Named MessagePack. + NamedMpk, - /// Bincode format (useful for embedding and for no-std support). - Bincode, + /// Bincode format (useful for embedding and for no-std support). + Bincode, } /// Burn graph intermediate representation of modules and tensor operations. #[derive(Default, Debug)] pub struct BurnGraph { - nodes: Vec>, - scope: Scope, - imports: BurnImports, - top_comment: Option, - default: Option, - blank_spaces: bool, - gen_new_fn: bool, - graph_input_types: Vec, - graph_output_types: Vec, + nodes: Vec>, + scope: Scope, + imports: BurnImports, + top_comment: Option, + default: Option, + blank_spaces: bool, + gen_new_fn: bool, + graph_input_types: Vec, + graph_output_types: Vec, } impl BurnGraph { - /// Register a new operation node into the graph. - /// - /// # Notes - /// - /// The node must be registered in the same order they will be executed in the forward pass. - pub fn register + 'static>(&mut self, node: N) { - let node = node.into_node(); - log::debug!("Registering node => '{}'", node.name()); - self.nodes.push(node); - } - - /// Generate a function `Model::new()` without any argument when `gen_new_fn` is `true`. - /// - /// This is useful if you intend to train the model generated. - pub fn with_new_fn(mut self, gen_new_fn: bool) -> Self { - self.gen_new_fn = gen_new_fn; - self - } - - /// Save the state of each node in a record file. - /// - /// The `Default` trait will be implemented for the generated model, which will load the record - /// saved at the provided path. In case of `embed_states` is true, the record will be embedded - /// in the generated code (useful for no-std support). - /// - /// # Arguments - /// - /// * `out_file` - The path to the record file. - /// * `record_type` - The type of the record to be saved. - /// * `embed_states` - Embed the record in the generated code. - /// - /// # Panics - /// - /// Panics if the record type is not `RecordType::Bincode` and `embed_states` is `true`. - pub fn with_record( - mut self, - out_file: PathBuf, - record_type: RecordType, - embed_states: bool, - ) -> Self { - let precision_ty_str = extract_type_name_by_type::(); - self.imports - .register(format!("burn::record::{precision_ty_str}")); - - match record_type { - RecordType::PrettyJson => { - PrettyJsonFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructMap( - BurnGraphState::new(&self.nodes), - )), - out_file.clone(), - ) - .unwrap(); - - assert!( - !embed_states, - "Embedding states is not supported for PrettyJsonFileRecorder." - ); - - self.register_record_file( - out_file, - &format!("burn::record::PrettyJsonFileRecorder::<{precision_ty_str}>"), - ); - } - RecordType::NamedMpkGz => { - NamedMpkGzFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructMap( - BurnGraphState::new(&self.nodes), - )), - out_file.clone(), - ) - .unwrap(); - - assert!( - !embed_states, - "Embedding states is not supported for NamedMpkGzFileRecorder." - ); - self.register_record_file( - out_file, - &format!("burn::record::NamedMpkGzFileRecorder::<{precision_ty_str}>"), - ); - } + /// Register a new operation node into the graph. + /// + /// # Notes + /// + /// The node must be registered in the same order they will be executed in the forward pass. + pub fn register + 'static>(&mut self, node: N) { + let node = node.into_node(); + log::debug!("Registering node => '{}'", node.name()); + self.nodes.push(node); + } + + /// Generate a function `Model::new()` without any argument when `gen_new_fn` is `true`. + /// + /// This is useful if you intend to train the model generated. + pub fn with_new_fn(mut self, gen_new_fn: bool) -> Self { + self.gen_new_fn = gen_new_fn; + self + } + + /// Save the state of each node in a record file. + /// + /// The `Default` trait will be implemented for the generated model, which will load the record + /// saved at the provided path. In case of `embed_states` is true, the record will be embedded + /// in the generated code (useful for no-std support). + /// + /// # Arguments + /// + /// * `out_file` - The path to the record file. + /// * `record_type` - The type of the record to be saved. + /// * `embed_states` - Embed the record in the generated code. + /// + /// # Panics + /// + /// Panics if the record type is not `RecordType::Bincode` and `embed_states` is `true`. + pub fn with_record( + mut self, + out_file: PathBuf, + record_type: RecordType, + embed_states: bool, + ) -> Self { + let precision_ty_str = extract_type_name_by_type::(); + self + .imports + .register(format!("burn::record::{precision_ty_str}")); + + match record_type { + RecordType::PrettyJson => { + PrettyJsonFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructMap(BurnGraphState::new( + &self.nodes, + ))), + out_file.clone(), + ) + .unwrap(); - RecordType::NamedMpk => { - NamedMpkFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructMap( - BurnGraphState::new(&self.nodes), - )), - out_file.clone(), - ) - .unwrap(); - - assert!( - !embed_states, - "Embedding states is not supported for NamedMpkFileRecorder." - ); - - self.register_record_file( - out_file, - &format!("burn::record::NamedMpkFileRecorder::<{precision_ty_str}>"), - ); - } + assert!( + !embed_states, + "Embedding states is not supported for PrettyJsonFileRecorder." + ); - RecordType::Bincode => { - BinFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructTuple(BurnGraphState::new( - &self.nodes, - ))), - out_file.clone(), - ) - .unwrap(); - - if embed_states { - self.register_record_embed(out_file); - } else { - self.register_record_file( - out_file, - &format!("burn::record::BinFileRecorder::<{precision_ty_str}>"), - ); - } - } - } + self.register_record_file( + out_file, + &format!("burn::record::PrettyJsonFileRecorder::<{precision_ty_str}>"), + ); + } + RecordType::NamedMpkGz => { + NamedMpkGzFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructMap(BurnGraphState::new( + &self.nodes, + ))), + out_file.clone(), + ) + .unwrap(); - self - } + assert!( + !embed_states, + "Embedding states is not supported for NamedMpkGzFileRecorder." + ); + self.register_record_file( + out_file, + &format!("burn::record::NamedMpkGzFileRecorder::<{precision_ty_str}>"), + ); + } + + RecordType::NamedMpk => { + NamedMpkFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructMap(BurnGraphState::new( + &self.nodes, + ))), + out_file.clone(), + ) + .unwrap(); - /// Add blank spaces in some places - /// - /// # Notes - /// - /// It can be problematic when testing. - pub fn with_blank_space(mut self, blank_spaces: bool) -> Self { - self.blank_spaces = blank_spaces; - self - } + assert!( + !embed_states, + "Embedding states is not supported for NamedMpkFileRecorder." + ); - /// Add a comment at the top of the generated file. - pub fn with_top_comment(mut self, top_comment: Option) -> Self { - self.top_comment = top_comment; - self + self.register_record_file( + out_file, + &format!("burn::record::NamedMpkFileRecorder::<{precision_ty_str}>"), + ); + } + + RecordType::Bincode => { + BinFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructTuple(BurnGraphState::new(&self.nodes))), + out_file.clone(), + ) + .unwrap(); + + if embed_states { + self.register_record_embed(out_file); + } else { + self.register_record_file( + out_file, + &format!("burn::record::BinFileRecorder::<{precision_ty_str}>"), + ); + } + } } - /// Generate tokens reprensenting the graph with Burn modules and tensor operations. - pub fn codegen(mut self) -> TokenStream { - self.build_scope(); - - self.register_imports(); - - let codegen_imports = self.imports.codegen(); - let codegen_struct = self.codegen_struct(); - let codegen_new_record = self.codegen_new_record(); - let codegen_forward = self.codegen_forward(); - - let maybe_blank = match self.blank_spaces { - true => quote! { - _blank_!(); - }, - false => quote! {}, - }; - let codegen_new = match self.gen_new_fn { - true => { - let new_fn = self.codegen_new(); - quote! { - #new_fn - #maybe_blank - } - } - false => quote! {}, - }; - let codegen_default = match self.default { - Some(default) => quote! { - #default - #maybe_blank - }, - None => quote! {}, - }; - - let maybe_top_file_comment = match self.top_comment { - Some(comment) => quote! { - _comment_!(#comment); - }, - None => quote! {}, - }; - + self + } + + /// Add blank spaces in some places + /// + /// # Notes + /// + /// It can be problematic when testing. + pub fn with_blank_space(mut self, blank_spaces: bool) -> Self { + self.blank_spaces = blank_spaces; + self + } + + /// Add a comment at the top of the generated file. + pub fn with_top_comment(mut self, top_comment: Option) -> Self { + self.top_comment = top_comment; + self + } + + /// Generate tokens reprensenting the graph with Burn modules and tensor operations. + pub fn codegen(mut self) -> TokenStream { + self.build_scope(); + + self.register_imports(); + + let codegen_imports = self.imports.codegen(); + let codegen_struct = self.codegen_struct(); + let codegen_new_record = self.codegen_new_record(); + let codegen_forward = self.codegen_forward(); + + let maybe_blank = match self.blank_spaces { + true => quote! { + _blank_!(); + }, + false => quote! {}, + }; + let codegen_new = match self.gen_new_fn { + true => { + let new_fn = self.codegen_new(); quote! { - #maybe_top_file_comment - #codegen_imports + #new_fn #maybe_blank + } + } + false => quote! {}, + }; + let codegen_default = match self.default { + Some(default) => quote! { + #default + #maybe_blank + }, + None => quote! {}, + }; + + let maybe_top_file_comment = match self.top_comment { + Some(comment) => quote! { + _comment_!(#comment); + }, + None => quote! {}, + }; + + quote! { + #maybe_top_file_comment + #codegen_imports + #maybe_blank + #maybe_blank + + #codegen_struct + #maybe_blank + + #codegen_default + + impl Model { + #codegen_new_record #maybe_blank - #codegen_struct - #maybe_blank - - #codegen_default - - impl Model { - #codegen_new_record - #maybe_blank - - #codegen_new - #codegen_forward - } + #codegen_new + #codegen_forward } } - - fn register_imports(&mut self) { - // Register imports from nodes - self.nodes - .iter() - .for_each(|node| node.register_imports(&mut self.imports)); - - // Combine input and output types into a single vector - let all_types = self - .graph_input_types - .iter() - .chain(&self.graph_output_types); - - // Register imports for bool and int tensors - for ty in all_types { - match ty { - Type::Tensor(TensorType { - kind: TensorKind::Bool, - .. - }) => { - self.imports.register("burn::tensor::Bool"); - } - Type::Tensor(TensorType { - kind: TensorKind::Int, - .. - }) => { - self.imports.register("burn::tensor::Int"); - } - _ => {} - } + } + + fn register_imports(&mut self) { + // Register imports from nodes + self + .nodes + .iter() + .for_each(|node| node.register_imports(&mut self.imports)); + + // Combine input and output types into a single vector + let all_types = self + .graph_input_types + .iter() + .chain(&self.graph_output_types); + + // Register imports for bool and int tensors + for ty in all_types { + match ty { + Type::Tensor(TensorType { + kind: TensorKind::Bool, + .. + }) => { + self.imports.register("burn::tensor::Bool"); } - } - /// Build the scope state to make sure tensor clones are added where needed. - fn build_scope(&mut self) { - log::debug!("Building the scope nodes len => '{}'", self.nodes.len()); - - fn to_tensor(ty: Type) -> Option { - match ty { - Type::Tensor(tensor) => Some(tensor), - Type::Scalar(_) => None, - Type::Other(_) => None, - } + Type::Tensor(TensorType { + kind: TensorKind::Int, + .. + }) => { + self.imports.register("burn::tensor::Int"); } - - // Register graph tensor input with 0 as node position - self.graph_input_types - .clone() - .into_iter() - .flat_map(to_tensor) - .for_each(|tensor| { - self.scope.tensor_register_variable(&tensor, 0); - }); - - self.nodes - .iter() - .enumerate() - .for_each(|(node_position, node)| { - node.output_types() - .into_iter() - .flat_map(to_tensor) - .for_each(|tensor| { - self.scope - .tensor_register_variable(&tensor, node_position + 1) - }) - }); - - self.nodes - .iter() - .enumerate() - .for_each(|(node_position, node)| { - node.input_types() - .into_iter() - .flat_map(to_tensor) - .for_each(|tensor| { - self.scope - .tensor_register_future_use(&tensor, node_position) - }) - }); + _ => {} + } + } + } + /// Build the scope state to make sure tensor clones are added where needed. + fn build_scope(&mut self) { + log::debug!("Building the scope nodes len => '{}'", self.nodes.len()); + + fn to_tensor(ty: Type) -> Option { + match ty { + Type::Tensor(tensor) => Some(tensor), + Type::Scalar(_) => None, + Type::Other(_) => None, + } } - fn register_record_file(&mut self, file: PathBuf, recorder_str: &str) { - self.imports.register("burn::record::Recorder"); - - let recorder_ty = syn::parse_str::(recorder_str).unwrap(); - - // Add default implementation - let file = file.to_str().unwrap(); - self.default = Some(quote! { - _blank_!(); - impl Default for Model { - fn default() -> Self { - Self::from_file(#file) - } + // Register graph tensor input with 0 as node position + self + .graph_input_types + .clone() + .into_iter() + .flat_map(to_tensor) + .for_each(|tensor| { + self.scope.tensor_register_variable(&tensor, 0); + }); + + self + .nodes + .iter() + .enumerate() + .for_each(|(node_position, node)| { + node + .output_types() + .into_iter() + .flat_map(to_tensor) + .for_each(|tensor| { + self + .scope + .tensor_register_variable(&tensor, node_position + 1) + }) + }); + + self + .nodes + .iter() + .enumerate() + .for_each(|(node_position, node)| { + node + .input_types() + .into_iter() + .flat_map(to_tensor) + .for_each(|tensor| { + self + .scope + .tensor_register_future_use(&tensor, node_position) + }) + }); + } + + fn register_record_file(&mut self, file: PathBuf, recorder_str: &str) { + self.imports.register("burn::record::Recorder"); + + let recorder_ty = syn::parse_str::(recorder_str).unwrap(); + + // Add default implementation + let file = file.to_str().unwrap(); + self.default = Some(quote! { + _blank_!(); + impl Default for Model { + fn default() -> Self { + Self::from_file(#file) } - _blank_!(); - impl Model { - pub fn from_file(file: &str) -> Self { - let record = #recorder_ty::new() - .load(file.into()) - .expect("Record file to exist."); - Self::new_with(record) - } + } + _blank_!(); + impl Model { + pub fn from_file(file: &str) -> Self { + let record = #recorder_ty::new() + .load(file.into()) + .expect("Record file to exist."); + Self::new_with(record) } - }); - } - - fn register_record_embed(&mut self, file: PathBuf) { - self.imports.register("burn::record::Recorder"); - - // NOTE: Bincode format is used for embedding states for now. - let precision = extract_type_name_by_type::(); - let precision_ty = syn::parse_str::(&precision).unwrap(); - self.imports.register("burn::record::BinBytesRecorder"); - - let mut file = file; - file.set_extension(BinFileRecorder::::file_extension()); - let file = file.to_str().unwrap(); - self.default = Some(quote! { - _blank_!(); - static EMBEDDED_STATES: &[u8] = include_bytes!(#file); - _blank_!(); - impl Default for Model { - fn default() -> Self { - Self::from_embedded() - } + } + }); + } + + fn register_record_embed(&mut self, file: PathBuf) { + self.imports.register("burn::record::Recorder"); + + // NOTE: Bincode format is used for embedding states for now. + let precision = extract_type_name_by_type::(); + let precision_ty = syn::parse_str::(&precision).unwrap(); + self.imports.register("burn::record::BinBytesRecorder"); + + let mut file = file; + file.set_extension(BinFileRecorder::::file_extension()); + let file = file.to_str().unwrap(); + self.default = Some(quote! { + _blank_!(); + static EMBEDDED_STATES: &[u8] = include_bytes!(#file); + _blank_!(); + impl Default for Model { + fn default() -> Self { + Self::from_embedded() } - _blank_!(); - impl Model { - pub fn from_embedded() -> Self { - let record = BinBytesRecorder::<#precision_ty>::default() - .load(EMBEDDED_STATES.to_vec()) - .expect("Failed to decode state"); - - Self::new_with(record) - } + } + _blank_!(); + impl Model { + pub fn from_embedded() -> Self { + let record = BinBytesRecorder::<#precision_ty>::default() + .load(EMBEDDED_STATES.to_vec()) + .expect("Failed to decode state"); + + Self::new_with(record) } + } - }); + }); + } + + fn codegen_struct(&self) -> TokenStream { + let mut body = quote! {}; + self + .nodes + .iter() + .filter_map(|node| node.field_type()) + .map(|field| { + let name = field.name(); + let ty = field.ty(); + + if matches!(&field, Type::Tensor(_)) { + quote! { + #name: burn::module::Param<#ty>, + } + } else { + quote! { + #name: #ty, + } + } + }) + .for_each(|code| body.extend(code)); + + // Extend with phantom data to avoid unused generic type. + body.extend(quote! { + phantom: core::marker::PhantomData, + }); + + quote! { + #[derive(Module, Debug)] + pub struct Model { + #body + } } - - fn codegen_struct(&self) -> TokenStream { - let mut body = quote! {}; - self.nodes - .iter() - .filter_map(|node| node.field_type()) - .map(|field| { - let name = field.name(); - let ty = field.ty(); - - if matches!(&field, Type::Tensor(_)) { - quote! { - #name: burn::module::Param<#ty>, - } - } else { - quote! { - #name: #ty, - } - } - }) - .for_each(|code| body.extend(code)); - - // Extend with phantom data to avoid unused generic type. - body.extend(quote! { - phantom: core::marker::PhantomData, - }); - - quote! { - #[derive(Module, Debug)] - pub struct Model { - #body + } + + fn codegen_new(&self) -> TokenStream { + let mut body = quote! {}; + + self + .nodes + .iter() + .map(|node| node.field_init(false)) + .for_each(|code| body.extend(code)); + + let fields = self + .nodes + .iter() + .flat_map(|node| node.field_type()) + .map(|field| field.name().clone()) + .collect::>(); + + quote! { + #[allow(dead_code)] + pub fn new() -> Self { + #body + + Self { + #(#fields,)* + phantom: core::marker::PhantomData, } } } - - fn codegen_new(&self) -> TokenStream { - let mut body = quote! {}; - - self.nodes - .iter() - .map(|node| node.field_init(false)) - .for_each(|code| body.extend(code)); - - let fields = self - .nodes - .iter() - .flat_map(|node| node.field_type()) - .map(|field| field.name().clone()) - .collect::>(); - - quote! { - #[allow(dead_code)] - pub fn new() -> Self { - #body - - Self { - #(#fields,)* - phantom: core::marker::PhantomData, - } + } + fn codegen_new_record(&self) -> TokenStream { + let mut body = quote! {}; + + self + .nodes + .iter() + .map(|node| node.field_init(true)) + .for_each(|code| body.extend(code)); + + let fields = self + .nodes + .iter() + .flat_map(|node| node.field_type()) + .map(|field| field.name().clone()) + .collect::>(); + + quote! { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + #body + + Self { + #(#fields,)* + phantom: core::marker::PhantomData, } } } - fn codegen_new_record(&self) -> TokenStream { - let mut body = quote! {}; + } - self.nodes - .iter() - .map(|node| node.field_init(true)) - .for_each(|code| body.extend(code)); + fn codegen_forward(&mut self) -> TokenStream { + let mut input_def = quote! {}; + let mut output_type_def = quote! {}; + let mut output_return_def = quote! {}; - let fields = self - .nodes - .iter() - .flat_map(|node| node.field_type()) - .map(|field| field.name().clone()) - .collect::>(); + self.graph_input_types.iter().for_each(|input| { + let name = input.name().clone(); + let ty = input.ty(); - quote! { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - #body - - Self { - #(#fields,)* - phantom: core::marker::PhantomData, - } - } - } - } + input_def.extend(quote! { + #name: #ty, - fn codegen_forward(&mut self) -> TokenStream { - let mut input_def = quote! {}; - let mut output_type_def = quote! {}; - let mut output_return_def = quote! {}; + }) + }); - self.graph_input_types.iter().for_each(|input| { - let name = input.name().clone(); - let ty = input.ty(); + let multiple_output = self.graph_output_types.len() > 1; - input_def.extend(quote! { - #name: #ty, + self.graph_output_types.iter().for_each(|output| { + let name = output.name(); + let ty = output.ty(); - }) + if multiple_output { + output_type_def.extend(quote! { + #ty, }); - - let multiple_output = self.graph_output_types.len() > 1; - - self.graph_output_types.iter().for_each(|output| { - let name = output.name(); - let ty = output.ty(); - - if multiple_output { - output_type_def.extend(quote! { - #ty, - }); - output_return_def.extend(quote! { - #name, - }); - } else { - output_type_def.extend(quote! { - #ty - }); - output_return_def.extend(quote! { - #name - }); - } + output_return_def.extend(quote! { + #name, }); - - if multiple_output { - output_return_def = quote! { - (#output_return_def) - }; - output_type_def = quote! { - (#output_type_def) - }; - } - - let mut body = quote! {}; - self.nodes - .iter() - .enumerate() - .map(|(index, node)| node.forward(&mut self.scope, index)) - .for_each(|code| body.extend(code)); - - // TODO Return the result without a `let` binding from a block, - // otherwise let_and_return error will be triggered by clippy. - // For now, we just disable the warning. - quote! { - #[allow(clippy::let_and_return)] - pub fn forward(&self, #input_def) -> #output_type_def { - #body - - #output_return_def - } - } + } else { + output_type_def.extend(quote! { + #ty + }); + output_return_def.extend(quote! { + #name + }); + } + }); + + if multiple_output { + output_return_def = quote! { + (#output_return_def) + }; + output_type_def = quote! { + (#output_type_def) + }; } - /// Register the input and output types of the graph using the passed in names. - /// The names must be unique and match the names of the inputs and outputs of the nodes. - /// The order will be preserved. - /// - /// # Arguments - /// - /// * `input_names` - The names of the inputs of the graph. - /// * `output_names` - The names of the outputs of the graph. - /// - /// # Panics - /// - /// Panics if the graph is empty. - pub fn register_input_output(&mut self, input_names: Vec, output_names: Vec) { - assert!( - !self.nodes.is_empty(), - "Cannot register input and output types for an empty graph." - ); - - // Get the unique names of each input of the nodes - let mut inputs = HashMap::new(); - let mut outputs = HashMap::new(); - for node in self.nodes.iter() { - for input in node.input_types() { - inputs.insert(input.name().to_string(), input); - } - for output in node.output_types() { - outputs.insert(output.name().to_string(), output); - } + let mut body = quote! {}; + self + .nodes + .iter() + .enumerate() + .map(|(index, node)| node.forward(&mut self.scope, index)) + .for_each(|code| body.extend(code)); + + // TODO Return the result without a `let` binding from a block, + // otherwise let_and_return error will be triggered by clippy. + // For now, we just disable the warning. + quote! { + #[allow(clippy::let_and_return)] + pub fn forward(&self, #input_def) -> #output_type_def { + #body + + #output_return_def } - - // Get the input and output types of the graph using passed in names - input_names.iter().for_each(|input| { - self.graph_input_types - .push(inputs.get(input).unwrap().clone()); - }); - - output_names.iter().for_each(|output| { - self.graph_output_types.push( - outputs - .get(output) - .unwrap_or_else(|| panic!("Output type is not found for {output}")) - .clone(), - ); - }); } + } + + /// Register the input and output types of the graph using the passed in names. + /// The names must be unique and match the names of the inputs and outputs of the nodes. + /// The order will be preserved. + /// + /// # Arguments + /// + /// * `input_names` - The names of the inputs of the graph. + /// * `output_names` - The names of the outputs of the graph. + /// + /// # Panics + /// + /// Panics if the graph is empty. + pub fn register_input_output(&mut self, input_names: Vec, output_names: Vec) { + assert!( + !self.nodes.is_empty(), + "Cannot register input and output types for an empty graph." + ); + + // Get the unique names of each input of the nodes + let mut inputs = HashMap::new(); + let mut outputs = HashMap::new(); + for node in self.nodes.iter() { + for input in node.input_types() { + inputs.insert(input.name().to_string(), input); + } + for output in node.output_types() { + outputs.insert(output.name().to_string(), output); + } + } + + // Get the input and output types of the graph using passed in names + input_names.iter().for_each(|input| { + self + .graph_input_types + .push(inputs.get(input).unwrap().clone()); + }); + + output_names.iter().for_each(|output| { + self.graph_output_types.push( + outputs + .get(output) + .unwrap_or_else(|| panic!("Output type is not found for {output}")) + .clone(), + ); + }); + } } #[derive(new, Debug)] struct BurnGraphState<'a, PS: PrecisionSettings> { - nodes: &'a Vec>, + nodes: &'a Vec>, } /// Represents a custom serialization strategy for the graph state in the module struct. @@ -618,24 +630,24 @@ struct BurnGraphState<'a, PS: PrecisionSettings> { struct StructMap<'a, PS: PrecisionSettings>(BurnGraphState<'a, PS>); impl<'a, PS: PrecisionSettings> Serialize for StructMap<'a, PS> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let nodes_with_names = self - .0 - .nodes - .iter() - .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) - .collect::>(); - let mut map = serializer.serialize_map(Some(nodes_with_names.len()))?; - - for (node, name) in nodes_with_names.iter() { - map.serialize_entry(&name.to_string(), &node)?; - } - - map.end() + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let nodes_with_names = self + .0 + .nodes + .iter() + .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) + .collect::>(); + let mut map = serializer.serialize_map(Some(nodes_with_names.len()))?; + + for (node, name) in nodes_with_names.iter() { + map.serialize_entry(&name.to_string(), &node)?; } + + map.end() + } } /// Represents a custom serialization strategy for the graph state in the module struct. @@ -652,31 +664,31 @@ impl<'a, PS: PrecisionSettings> Serialize for StructMap<'a, PS> { struct StructTuple<'a, PS: PrecisionSettings>(BurnGraphState<'a, PS>); impl<'a, PS: PrecisionSettings> Serialize for StructTuple<'a, PS> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let nodes_with_names = self - .0 - .nodes - .iter() - .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) - .collect::>(); - let mut map = serializer.serialize_tuple(nodes_with_names.len())?; - - for (node, _name) in nodes_with_names.iter() { - map.serialize_element(&node)?; - } - - map.end() + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let nodes_with_names = self + .0 + .nodes + .iter() + .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) + .collect::>(); + let mut map = serializer.serialize_tuple(nodes_with_names.len())?; + + for (node, _name) in nodes_with_names.iter() { + map.serialize_element(&node)?; } + + map.end() + } } fn extract_type_name_by_type() -> String { - let full_type_name = type_name::(); - full_type_name - .rsplit("::") - .next() - .unwrap_or(full_type_name) - .to_string() + let full_type_name = type_name::(); + full_type_name + .rsplit("::") + .next() + .unwrap_or(full_type_name) + .to_string() } diff --git a/burn-import/src/burn/imports.rs b/burn-import/src/burn/imports.rs index 540a45529b..c091eb0d24 100644 --- a/burn-import/src/burn/imports.rs +++ b/burn-import/src/burn/imports.rs @@ -5,38 +5,37 @@ use std::collections::HashSet; /// Keep track of imported modules. #[derive(Debug, Default)] pub struct BurnImports { - imports: HashSet, + imports: HashSet, } impl BurnImports { - /// Register an import type. - /// - /// # Notes - /// - /// Each import statement will be generated just once no matter how many times it was - /// registered. - pub fn register>(&mut self, import: S) { - self.imports.insert(import.into()); - } + /// Register an import type. + /// + /// # Notes + /// + /// Each import statement will be generated just once no matter how many times it was + /// registered. + pub fn register>(&mut self, import: S) { + self.imports.insert(import.into()); + } - /// Generate the import tokens. - pub fn codegen(&self) -> TokenStream { - let mut import_tokens = vec![]; + /// Generate the import tokens. + pub fn codegen(&self) -> TokenStream { + let mut import_tokens = vec![]; - for import in self.imports.iter() { - let path: syn::Path = - syn::parse_str(import).expect("Unable to parse input string as a path"); + for import in self.imports.iter() { + let path: syn::Path = syn::parse_str(import).expect("Unable to parse input string as a path"); - import_tokens.push(quote! { #path }); - } + import_tokens.push(quote! { #path }); + } - quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #(use #import_tokens;)* - } + #(use #import_tokens;)* } + } } diff --git a/burn-import/src/burn/node/avg_pool2d.rs b/burn-import/src/burn/node/avg_pool2d.rs index 3f12457f20..b35cfdcdc0 100644 --- a/burn-import/src/burn/node/avg_pool2d.rs +++ b/burn-import/src/burn/node/avg_pool2d.rs @@ -8,151 +8,151 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; #[derive(Debug, Clone)] pub struct AvgPool2dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub config: AvgPool2dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub config: AvgPool2dConfig, } impl AvgPool2dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - config: AvgPool2dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - AvgPool2d - }, - ), - input, - output, - config, - } + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + config: AvgPool2dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + AvgPool2d + }, + ), + input, + output, + config, } + } } impl NodeCodegen for AvgPool2dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; + let kernel_size = self.config.kernel_size.to_tokens(); + let strides = self.config.strides.to_tokens(); + let padding = self.config.padding.to_tokens(); + + let init_line = quote! { + init(); + }; - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - let kernel_size = self.config.kernel_size.to_tokens(); - let strides = self.config.strides.to_tokens(); - let padding = self.config.padding.to_tokens(); + let tokens = quote! { + let #name = AvgPool2dConfig::new(#kernel_size) + .with_strides(#strides) + .with_padding(#padding) + .#init_line + }; - let init_line = quote! { - init(); - }; + Some(tokens) + } - let tokens = quote! { - let #name = AvgPool2dConfig::new(#kernel_size) - .with_strides(#strides) - .with_padding(#padding) - .#init_line - }; + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - Some(tokens) + quote! { + let #output = self.#field.forward(#input); } + } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig2d"); + imports.register("burn::nn::pool::AvgPool2d"); + imports.register("burn::nn::pool::AvgPool2dConfig"); + } - quote! { - let #output = self.#field.forward(#input); - } - } - - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig2d"); - imports.register("burn::nn::pool::AvgPool2d"); - imports.register("burn::nn::pool::AvgPool2dConfig"); - } + fn into_node(self) -> Node { + Node::AvgPool2d(self) + } - fn into_node(self) -> Node { - Node::AvgPool2d(self) - } - - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) - } + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{avg_pool2d::AvgPool2dNode, test::assert_tokens}, - TensorType, - }; - use burn::{nn::pool::AvgPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(AvgPool2dNode::new( - "avg_pool2d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - AvgPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig2d; - use burn::nn::pool::AvgPool2d; - use burn::nn::pool::AvgPool2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - avg_pool2d: AvgPool2d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{avg_pool2d::AvgPool2dNode, test::assert_tokens}, + TensorType, + }; + use burn::{nn::pool::AvgPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(AvgPool2dNode::new( + "avg_pool2d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + AvgPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::pool::AvgPool2d; + use burn::nn::pool::AvgPool2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + avg_pool2d: AvgPool2d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let avg_pool2d = AvgPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .init(); - - Self { - avg_pool2d, - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let avg_pool2d = AvgPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .init(); + + Self { + avg_pool2d, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.avg_pool2d.forward(input); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.avg_pool2d.forward(input); - output - } + output } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/base.rs b/burn-import/src/burn/node/base.rs index 944156306b..f4262da460 100644 --- a/burn-import/src/burn/node/base.rs +++ b/burn-import/src/burn/node/base.rs @@ -1,9 +1,8 @@ use super::{ - avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, - concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, - dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, - linear::LinearNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode, - unary::UnaryNode, + avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, + concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, + dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, linear::LinearNode, + matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::record::PrecisionSettings; @@ -16,371 +15,371 @@ pub type SerializationBackend = NdArray; /// Codegen trait that should be implemented by all [node](Node) entries. pub trait NodeCodegen: std::fmt::Debug { - /// All types that are used as inputs during the forward pass. - /// - /// # Notes - /// The vec should not include types that are accessible with `self`. - /// See [field type](NodeCodegen::field_type). - fn input_types(&self) -> Vec; - - /// All types that are produced during the forward pass. - fn output_types(&self) -> Vec; - - /// The forward pass implementation of the node. - /// - /// # Notes - /// - /// The [Scope](Scope) struct should be used for [input tensor type](Type::Tensor) access. - /// The method [use_owned_tensor](Scope::use_owned_tensor) keeps track of tensor reference - /// count and insert `clone` with necessary. - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream; - - /// Convert the node implementation into a [node entry](Node). - fn into_node(self) -> Node; - - /// Register the necessary imports. - fn register_imports(&self, _imports: &mut BurnImports) {} - - /// (Optional) Declare the type of the field - /// - /// # Notes - /// - /// This should be implemented when the node has some parameters. - /// Just one field per type is possible, if the node has multiple types for its parameters, a - /// tuple can be used. - /// - /// Other field functions should be implemented when this one returns something other than None. - /// * [field_init](NodeCodegen::field_init) to initialize parameters. - /// * [field_serialize](NodeCodegen::field_serialize) to create the model record. - fn field_type(&self) -> Option { - None - } - - /// (Optional) Declare how the parameters are initialized with and without a record. - /// - /// The function should be implemented along [field_type](NodeCodegen::field_type). - fn field_init(&self, _with_record: bool) -> Option { - None - } - - /// (Optional) Declare how the parameters are serialized in a record. - /// - /// The function should be implemented along [field_type](NodeCodegen::field_type). - fn field_serialize(&self, _serializer: S) -> Result { - panic!("Serialization should be implemented when field_type is not None."); - } + /// All types that are used as inputs during the forward pass. + /// + /// # Notes + /// The vec should not include types that are accessible with `self`. + /// See [field type](NodeCodegen::field_type). + fn input_types(&self) -> Vec; + + /// All types that are produced during the forward pass. + fn output_types(&self) -> Vec; + + /// The forward pass implementation of the node. + /// + /// # Notes + /// + /// The [Scope](Scope) struct should be used for [input tensor type](Type::Tensor) access. + /// The method [use_owned_tensor](Scope::use_owned_tensor) keeps track of tensor reference + /// count and insert `clone` with necessary. + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream; + + /// Convert the node implementation into a [node entry](Node). + fn into_node(self) -> Node; + + /// Register the necessary imports. + fn register_imports(&self, _imports: &mut BurnImports) {} + + /// (Optional) Declare the type of the field + /// + /// # Notes + /// + /// This should be implemented when the node has some parameters. + /// Just one field per type is possible, if the node has multiple types for its parameters, a + /// tuple can be used. + /// + /// Other field functions should be implemented when this one returns something other than None. + /// * [field_init](NodeCodegen::field_init) to initialize parameters. + /// * [field_serialize](NodeCodegen::field_serialize) to create the model record. + fn field_type(&self) -> Option { + None + } + + /// (Optional) Declare how the parameters are initialized with and without a record. + /// + /// The function should be implemented along [field_type](NodeCodegen::field_type). + fn field_init(&self, _with_record: bool) -> Option { + None + } + + /// (Optional) Declare how the parameters are serialized in a record. + /// + /// The function should be implemented along [field_type](NodeCodegen::field_type). + fn field_serialize(&self, _serializer: S) -> Result { + panic!("Serialization should be implemented when field_type is not None."); + } } #[derive(Debug, Clone)] pub enum Node { - AvgPool2d(AvgPool2dNode), - BatchNorm(BatchNormNode), - Binary(BinaryNode), - Clip(ClipNode), - Concat(ConcatNode), - Constant(ConstantNode), - Conv1d(Conv1dNode), - Conv2d(Conv2dNode), - Dropout(DropoutNode), - Gather(GatherNode), - GlobalAvgPool(GlobalAvgPoolNode), - Linear(LinearNode), - Matmul(MatmulNode), - MaxPool2d(MaxPool2dNode), - Reshape(ReshapeNode), - Unary(UnaryNode), + AvgPool2d(AvgPool2dNode), + BatchNorm(BatchNormNode), + Binary(BinaryNode), + Clip(ClipNode), + Concat(ConcatNode), + Constant(ConstantNode), + Conv1d(Conv1dNode), + Conv2d(Conv2dNode), + Dropout(DropoutNode), + Gather(GatherNode), + GlobalAvgPool(GlobalAvgPoolNode), + Linear(LinearNode), + Matmul(MatmulNode), + MaxPool2d(MaxPool2dNode), + Reshape(ReshapeNode), + Unary(UnaryNode), } macro_rules! match_all { - ($self:expr, $func:expr) => {{ - #[allow(clippy::redundant_closure_call)] - match $self { - Node::AvgPool2d(node) => $func(node), - Node::BatchNorm(node) => $func(node), - Node::Binary(node) => $func(node), - Node::Clip(node) => $func(node), - Node::Concat(node) => $func(node), - Node::Constant(node) => $func(node), - Node::Conv1d(node) => $func(node), - Node::Conv2d(node) => $func(node), - Node::Dropout(node) => $func(node), - Node::Gather(node) => $func(node), - Node::GlobalAvgPool(node) => $func(node), - Node::Linear(node) => $func(node), - Node::Matmul(node) => $func(node), - Node::MaxPool2d(node) => $func(node), - Node::Reshape(node) => $func(node), - Node::Unary(node) => $func(node), - } - }}; + ($self:expr, $func:expr) => {{ + #[allow(clippy::redundant_closure_call)] + match $self { + Node::AvgPool2d(node) => $func(node), + Node::BatchNorm(node) => $func(node), + Node::Binary(node) => $func(node), + Node::Clip(node) => $func(node), + Node::Concat(node) => $func(node), + Node::Constant(node) => $func(node), + Node::Conv1d(node) => $func(node), + Node::Conv2d(node) => $func(node), + Node::Dropout(node) => $func(node), + Node::Gather(node) => $func(node), + Node::GlobalAvgPool(node) => $func(node), + Node::Linear(node) => $func(node), + Node::Matmul(node) => $func(node), + Node::MaxPool2d(node) => $func(node), + Node::Reshape(node) => $func(node), + Node::Unary(node) => $func(node), + } + }}; } impl Serialize for Node { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.field_serialize(serializer) - } + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.field_serialize(serializer) + } } impl Node { - pub fn name(&self) -> &str { - match self { - Node::AvgPool2d(_) => "avg_pool2d", - Node::BatchNorm(_) => "batch_norm", - Node::Binary(binary) => binary.binary_type.as_str(), - Node::Concat(_) => "concat", - Node::Clip(_) => "clip", - Node::Constant(_) => "constant", - Node::Conv1d(_) => "conv1d", - Node::Conv2d(_) => "conv2d", - Node::Dropout(_) => "dropout", - Node::Gather(_) => "gather", - Node::GlobalAvgPool(_) => "global_avg_pool", - Node::Linear(_) => "linear", - Node::Matmul(_) => "matmul", - Node::MaxPool2d(_) => "max_pool2d", - Node::Reshape(_) => "reshape", - Node::Unary(unary) => unary.kind.as_str(), - } + pub fn name(&self) -> &str { + match self { + Node::AvgPool2d(_) => "avg_pool2d", + Node::BatchNorm(_) => "batch_norm", + Node::Binary(binary) => binary.binary_type.as_str(), + Node::Concat(_) => "concat", + Node::Clip(_) => "clip", + Node::Constant(_) => "constant", + Node::Conv1d(_) => "conv1d", + Node::Conv2d(_) => "conv2d", + Node::Dropout(_) => "dropout", + Node::Gather(_) => "gather", + Node::GlobalAvgPool(_) => "global_avg_pool", + Node::Linear(_) => "linear", + Node::Matmul(_) => "matmul", + Node::MaxPool2d(_) => "max_pool2d", + Node::Reshape(_) => "reshape", + Node::Unary(unary) => unary.kind.as_str(), } + } } impl NodeCodegen for Node { - fn output_types(&self) -> Vec { - match_all!(self, NodeCodegen::::output_types) - } - - fn input_types(&self) -> Vec { - match_all!(self, NodeCodegen::::input_types) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - match_all!(self, |node| NodeCodegen::::forward( - node, - scope, - node_position - )) - } - - fn field_type(&self) -> Option { - match_all!(self, NodeCodegen::::field_type) - } - - fn field_init(&self, with_record: bool) -> Option { - match_all!(self, |node| NodeCodegen::::field_init( - node, - with_record - )) - } - - fn register_imports(&self, imports: &mut BurnImports) { - match_all!(self, |node| NodeCodegen::::register_imports( - node, imports - )) - } - - fn into_node(self) -> Node { - self - } - - fn field_serialize(&self, serializer: S) -> Result { - match_all!(self, |node| NodeCodegen::::field_serialize( - node, serializer - )) - } + fn output_types(&self) -> Vec { + match_all!(self, NodeCodegen::::output_types) + } + + fn input_types(&self) -> Vec { + match_all!(self, NodeCodegen::::input_types) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + match_all!(self, |node| NodeCodegen::::forward( + node, + scope, + node_position + )) + } + + fn field_type(&self) -> Option { + match_all!(self, NodeCodegen::::field_type) + } + + fn field_init(&self, with_record: bool) -> Option { + match_all!(self, |node| NodeCodegen::::field_init( + node, + with_record + )) + } + + fn register_imports(&self, imports: &mut BurnImports) { + match_all!(self, |node| NodeCodegen::::register_imports( + node, imports + )) + } + + fn into_node(self) -> Node { + self + } + + fn field_serialize(&self, serializer: S) -> Result { + match_all!(self, |node| NodeCodegen::::field_serialize( + node, serializer + )) + } } #[cfg(test)] pub(crate) mod tests { - use crate::burn::{ - graph::BurnGraph, - node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen}, - TensorType, - }; - use burn::{ - nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, - }; - use proc_macro2::TokenStream; - use quote::quote; - - pub(crate) fn one_node_graph + 'static>( - node_gen: T, - forward: TokenStream, - input_names: Vec, - output_names: Vec, - ) { - let mut graph = BurnGraph::::default(); - - graph.register(node_gen); - - graph.register_input_output(input_names, output_names); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } - - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } - } - - #[allow(clippy::let_and_return)] - #forward - } + use crate::burn::{ + graph::BurnGraph, + node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen}, + TensorType, + }; + use burn::{ + nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, + }; + use proc_macro2::TokenStream; + use quote::quote; + + pub(crate) fn one_node_graph + 'static>( + node_gen: T, + forward: TokenStream, + input_names: Vec, + output_names: Vec, + ) { + let mut graph = BurnGraph::::default(); + + graph.register(node_gen); + + graph.register_input_output(input_names, output_names); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, }; - assert_tokens(graph.codegen(), expected); - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - #[test] - fn test_codegen_two_nodes() { - let mut graph = BurnGraph::::default(); - - graph.register(MatmulNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor3", 4), - )); - graph.register(Conv2dNode::new( - "conv2d", - TensorType::new_float("tensor3", 4), - TensorType::new_float("tensor4", 4), - Data::from([2.]).serialize(), - None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor4".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::conv::Conv2dConfig; - use burn::nn::conv::Conv2d; - use burn::nn::PaddingConfig2d; - - #[derive(Module, Debug)] - pub struct Model { - conv2d: Conv2d, - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv2d = Conv2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .with_groups(1) - .with_bias(true) - .init_with(record.conv2d); - - Self { - conv2d, - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.matmul(tensor2); - let tensor4 = self.conv2d.forward(tensor3); + #[allow(clippy::let_and_return)] + #forward + } + }; - tensor4 - } - } + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_two_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor3", 4), + )); + graph.register(Conv2dNode::new( + "conv2d", + TensorType::new_float("tensor3", 4), + TensorType::new_float("tensor4", 4), + Data::from([2.]).serialize(), + None, + Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor4".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, }; + use burn::nn::conv::Conv2dConfig; + use burn::nn::conv::Conv2d; + use burn::nn::PaddingConfig2d; + + #[derive(Module, Debug)] + pub struct Model { + conv2d: Conv2d, + phantom: core::marker::PhantomData, + } - assert_tokens(graph.codegen(), expected); - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv2d = Conv2dConfig::new([3, 3], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .with_groups(1) + .with_bias(true) + .init_with(record.conv2d); + + Self { + conv2d, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.matmul(tensor2); + let tensor4 = self.conv2d.forward(tensor3); - #[test] - fn test_codegen_clone_tensor() { - let mut graph = BurnGraph::::default(); - - graph.register(MatmulNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor3", 4), - )); - graph.register(Conv2dNode::new( - "conv2d", - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor4", 4), - Data::from([2.]).serialize(), - None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), - )); - graph.register(MatmulNode::new( - TensorType::new_float("tensor3", 4), - TensorType::new_float("tensor4", 4), - TensorType::new_float("output", 4), - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["output".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig2d; - use burn::nn::conv::Conv2d; - use burn::nn::conv::Conv2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - conv2d: Conv2d, - phantom: core::marker::PhantomData, + tensor4 } + } + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv2d = Conv2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .with_groups(1) - .with_bias(true) - .init_with(record.conv2d); - - Self { - conv2d, - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.matmul(tensor2.clone()); - let tensor4 = self.conv2d.forward(tensor2); - let output = tensor3.matmul(tensor4); + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_clone_tensor() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor3", 4), + )); + graph.register(Conv2dNode::new( + "conv2d", + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor4", 4), + Data::from([2.]).serialize(), + None, + Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + )); + graph.register(MatmulNode::new( + TensorType::new_float("tensor3", 4), + TensorType::new_float("tensor4", 4), + TensorType::new_float("output", 4), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["output".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::conv::Conv2d; + use burn::nn::conv::Conv2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + conv2d: Conv2d, + phantom: core::marker::PhantomData, + } - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv2d = Conv2dConfig::new([3, 3], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .with_groups(1) + .with_bias(true) + .init_with(record.conv2d); + + Self { + conv2d, + phantom: core::marker::PhantomData, } } - }; + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.matmul(tensor2.clone()); + let tensor4 = self.conv2d.forward(tensor2); + let output = tensor3.matmul(tensor4); - assert_tokens(graph.codegen(), expected); - } + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/batch_norm.rs b/burn-import/src/burn/node/batch_norm.rs index b706c47cc2..6f2174bcc4 100644 --- a/burn-import/src/burn/node/batch_norm.rs +++ b/burn-import/src/burn/node/batch_norm.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{ConstantRecord, Param, ParamId}, - nn::{BatchNormConfig, BatchNormRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{ConstantRecord, Param, ParamId}, + nn::{BatchNormConfig, BatchNormRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,49 +12,49 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct BatchNormNode { - pub dim: usize, - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub gamma: DataSerialize, - pub beta: DataSerialize, - pub running_mean: DataSerialize, - pub running_var: DataSerialize, - pub config: BatchNormConfig, + pub dim: usize, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub gamma: DataSerialize, + pub beta: DataSerialize, + pub running_mean: DataSerialize, + pub running_var: DataSerialize, + pub config: BatchNormConfig, } impl BatchNormNode { - #[allow(clippy::too_many_arguments)] - pub fn new>( - dim: usize, - name: S, - input: TensorType, - output: TensorType, - gamma: DataSerialize, - beta: DataSerialize, - running_mean: DataSerialize, - running_var: DataSerialize, - config: BatchNormConfig, - ) -> Self { - let dim_tokens = dim.to_tokens(); - - Self { - dim, - field: OtherType::new( - name, - quote! { - BatchNorm - }, - ), - input, - output, - gamma, - beta, - running_mean, - running_var, - config, - } + #[allow(clippy::too_many_arguments)] + pub fn new>( + dim: usize, + name: S, + input: TensorType, + output: TensorType, + gamma: DataSerialize, + beta: DataSerialize, + running_mean: DataSerialize, + running_var: DataSerialize, + config: BatchNormConfig, + ) -> Self { + let dim_tokens = dim.to_tokens(); + + Self { + dim, + field: OtherType::new( + name, + quote! { + BatchNorm + }, + ), + input, + output, + gamma, + beta, + running_mean, + running_var, + config, } + } } macro_rules! batch_norm_serialize { @@ -101,124 +101,124 @@ macro_rules! batch_norm_serialize { } impl NodeCodegen for BatchNormNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let num_features = self.config.num_features.to_tokens(); - let epsilon = self.config.epsilon; - let momentum = self.config.momentum; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; - - let tokens = quote! { - let #name = BatchNormConfig::new(#num_features) - .with_epsilon(#epsilon) - .with_momentum(#momentum) - .#init_line - }; - - Some(tokens) - } - - fn field_serialize(&self, serializer: S) -> Result { - batch_norm_serialize!(self, serializer) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; - - quote! { - let #output = self.#field.forward(#input); - } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::BatchNorm"); - imports.register("burn::nn::BatchNormConfig"); - } - - fn into_node(self) -> Node { - Node::BatchNorm(self) + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let num_features = self.config.num_features.to_tokens(); + let epsilon = self.config.epsilon; + let momentum = self.config.momentum; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, + }; + + let tokens = quote! { + let #name = BatchNormConfig::new(#num_features) + .with_epsilon(#epsilon) + .with_momentum(#momentum) + .#init_line + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + batch_norm_serialize!(self, serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::BatchNorm"); + imports.register("burn::nn::BatchNormConfig"); + } + + fn into_node(self) -> Node { + Node::BatchNorm(self) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - use burn::{record::FullPrecisionSettings, tensor::Data}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(BatchNormNode::new( - 2, // Batch norm 2d - "norm", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - Data::from([2.]).serialize(), - Data::from([2.]).serialize(), - Data::from([2.]).serialize(), - BatchNormConfig::new(128), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::BatchNorm; - use burn::nn::BatchNormConfig; - - #[derive(Module, Debug)] - pub struct Model { - norm: BatchNorm, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{record::FullPrecisionSettings, tensor::Data}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(BatchNormNode::new( + 2, // Batch norm 2d + "norm", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + Data::from([2.]).serialize(), + Data::from([2.]).serialize(), + Data::from([2.]).serialize(), + BatchNormConfig::new(128), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::BatchNorm; + use burn::nn::BatchNormConfig; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let norm = BatchNormConfig::new(128) - .with_epsilon(0.00001f64) - .with_momentum(0.1f64) - .init_with(record.norm); - - Self { - norm, - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.norm.forward(input); + #[derive(Module, Debug)] + pub struct Model { + norm: BatchNorm, + phantom: core::marker::PhantomData, + } - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let norm = BatchNormConfig::new(128) + .with_epsilon(0.00001f64) + .with_momentum(0.1f64) + .init_with(record.norm); + + Self { + norm, + phantom: core::marker::PhantomData, } } - }; + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.norm.forward(input); - assert_tokens(graph.codegen(), expected); - } + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/binary.rs b/burn-import/src/burn/node/binary.rs index 2e237339a2..e01c0d302f 100644 --- a/burn-import/src/burn/node/binary.rs +++ b/burn-import/src/burn/node/binary.rs @@ -7,23 +7,23 @@ use std::sync::Arc; #[derive(Clone)] pub enum BinaryType { - Add, - Sub, - Mul, - Div, - Equal, + Add, + Sub, + Mul, + Div, + Equal, } impl BinaryType { - pub(crate) fn as_str(&self) -> &str { - match self { - BinaryType::Add => "add", - BinaryType::Sub => "sub", - BinaryType::Mul => "mul", - BinaryType::Div => "div", - BinaryType::Equal => "equal", - } + pub(crate) fn as_str(&self) -> &str { + match self { + BinaryType::Add => "add", + BinaryType::Sub => "sub", + BinaryType::Mul => "mul", + BinaryType::Div => "div", + BinaryType::Equal => "equal", } + } } // Simple fn pointer that receive input as a token stream and return function call. @@ -32,313 +32,313 @@ type FnPointer = Arc TokenStream>; /// Node for all binary operators. #[derive(Clone, new)] pub struct BinaryNode { - pub lhs: Type, - pub rhs: Type, - pub output: Type, - pub binary_type: BinaryType, - function: FnPointer, + pub lhs: Type, + pub rhs: Type, + pub output: Type, + pub binary_type: BinaryType, + function: FnPointer, } impl std::fmt::Debug for BinaryNode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str( - format!( - "BinaryNode {{ lhs: {:?}, rhs: {:?}, output: {:?}, name: {:?} }}", - self.lhs, - self.rhs, - self.output, - self.binary_type.as_str() - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + format!( + "BinaryNode {{ lhs: {:?}, rhs: {:?}, output: {:?}, name: {:?} }}", + self.lhs, + self.rhs, + self.output, + self.binary_type.as_str() + ) + .as_str(), + ) + } } impl NodeCodegen for BinaryNode { - fn output_types(&self) -> Vec { - vec![self.output.clone()] + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + vec![self.lhs.clone(), self.rhs.clone()] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + // Get the lhs name in the form of token stream. + let lhs = match &self.lhs { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("lhs must be a tensor or scalar"), + }; + + // Get the rhs name in the form of token stream + let rhs = match &self.rhs { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("rhs must be a tensor or scalar"), + }; + + let output = &self.output.name(); + let function = (self.function)(lhs, rhs); + + quote! { + let #output = #function; } + } - fn input_types(&self) -> Vec { - vec![self.lhs.clone(), self.rhs.clone()] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - // Get the lhs name in the form of token stream. - let lhs = match &self.lhs { - Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), - Type::Scalar(scalar) => { - let name = scalar.name.clone(); - quote! { #name } - } - _ => panic!("lhs must be a tensor or scalar"), - }; - - // Get the rhs name in the form of token stream - let rhs = match &self.rhs { - Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), - Type::Scalar(scalar) => { - let name = scalar.name.clone(); - quote! { #name } - } - _ => panic!("rhs must be a tensor or scalar"), - }; - - let output = &self.output.name(); - let function = (self.function)(lhs, rhs); - - quote! { - let #output = #function; - } - } - - fn into_node(self) -> Node { - Node::Binary(self) - } + fn into_node(self) -> Node { + Node::Binary(self) + } } impl BinaryNode { - pub(crate) fn add(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.add(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.add_scalar(#rhs) }, - (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.add_scalar(#lhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs + #rhs }, - _ => panic!("Addition is supported for tensor and scalar only"), - }; - - Self::new(lhs, rhs, output, BinaryType::Add, Arc::new(function)) - } - - pub(crate) fn sub(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, - _ => panic!("Subtraction is supported for tensor and scalar only"), - }; - - Self::new(lhs, rhs, output, BinaryType::Sub, Arc::new(function)) - } - - pub(crate) fn mul(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.mul(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.mul_scalar(#rhs) }, - (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.mul_scalar(#lhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs * #rhs }, - _ => panic!("Multiplication is supported for tensor and scalar only"), - }; - - Self::new(lhs, rhs, output, BinaryType::Mul, Arc::new(function)) - } - - pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs }, - _ => panic!("Division is supported for tensor and scalar only"), - }; - - Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function)) - } - - pub(crate) fn equal(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.equal(#rhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs == #rhs }, - _ => panic!("Comparison is supported for tensor to tensor and scalar to scalar only"), - }; - - Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function)) - } + pub(crate) fn add(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.add(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.add_scalar(#rhs) }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.add_scalar(#lhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs + #rhs }, + _ => panic!("Addition is supported for tensor and scalar only"), + }; + + Self::new(lhs, rhs, output, BinaryType::Add, Arc::new(function)) + } + + pub(crate) fn sub(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, + _ => panic!("Subtraction is supported for tensor and scalar only"), + }; + + Self::new(lhs, rhs, output, BinaryType::Sub, Arc::new(function)) + } + + pub(crate) fn mul(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.mul(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.mul_scalar(#rhs) }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.mul_scalar(#lhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs * #rhs }, + _ => panic!("Multiplication is supported for tensor and scalar only"), + }; + + Self::new(lhs, rhs, output, BinaryType::Mul, Arc::new(function)) + } + + pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs }, + _ => panic!("Division is supported for tensor and scalar only"), + }; + + Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function)) + } + + pub(crate) fn equal(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.equal(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs == #rhs }, + _ => panic!("Comparison is supported for tensor to tensor and scalar to scalar only"), + }; + + Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function)) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::graph::BurnGraph; - use crate::burn::node::test::assert_tokens; - use crate::burn::node::tests::one_node_graph; - use crate::burn::{ScalarKind, ScalarType, TensorType}; - - macro_rules! test_binary_operator_on_tensors { - ($operator:ident) => {{ - one_node_graph( - BinaryNode::$operator( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - Type::Tensor(TensorType::new_float("tensor3", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.$operator(tensor2); - - tensor3 - } - }, - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - }}; - } - - macro_rules! test_binary_operator_on_tensor_and_scalar { - ($operator:ident, $burn_operator:ident) => {{ - one_node_graph( - BinaryNode::$operator( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), - Type::Tensor(TensorType::new_float("tensor3", 4)), - ), - quote! { - pub fn forward(&self, scalar1: f32, tensor1: Tensor) -> Tensor { - let tensor3 = tensor1.$burn_operator(scalar1); - - tensor3 - } - }, - vec!["scalar1".to_string(), "tensor1".to_string()], - vec!["tensor3".to_string()], - ); - }}; - } - - macro_rules! test_binary_operator_on_scalar_and_scalar { - ($operator:ident, $scalar_operator:tt) => {{ - one_node_graph( - BinaryNode::$operator( - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), - Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), - Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float32)), - ), - quote! { - pub fn forward(&self, scalar1: f32, scalar2: f32) -> f32 { - let scalar3 = scalar1 $scalar_operator scalar2; - - scalar3 - } - }, - vec!["scalar1".to_string(), "scalar2".to_string()], - vec!["scalar3".to_string()], - ); - }}; - } - - #[test] - fn test_binary_codegen_add() { - test_binary_operator_on_tensors!(add); - } - - #[test] - fn test_binary_codegen_add_scalar() { - test_binary_operator_on_tensor_and_scalar!(add, add_scalar); - } - - #[test] - fn test_binary_codegen_add_scalars() { - test_binary_operator_on_scalar_and_scalar!(add, +); - } - - #[test] - fn test_binary_codegen_sub() { - test_binary_operator_on_tensors!(sub); - } - - #[test] - fn test_binary_codegen_sub_scalar() { - test_binary_operator_on_tensor_and_scalar!(sub, sub_scalar); - } - - #[test] - fn test_binary_codegen_sub_scalars() { - test_binary_operator_on_scalar_and_scalar!(sub, -); - } - - #[test] - fn test_binary_codegen_mul() { - test_binary_operator_on_tensors!(mul); - } - - #[test] - fn test_binary_codegen_mul_scalar() { - test_binary_operator_on_tensor_and_scalar!(mul, mul_scalar); - } - - #[test] - fn test_binary_codegen_mul_scalars() { - test_binary_operator_on_scalar_and_scalar!(mul, *); - } - - #[test] - fn test_binary_codegen_div() { - test_binary_operator_on_tensors!(div); - } + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::graph::BurnGraph; + use crate::burn::node::test::assert_tokens; + use crate::burn::node::tests::one_node_graph; + use crate::burn::{ScalarKind, ScalarType, TensorType}; + + macro_rules! test_binary_operator_on_tensors { + ($operator:ident) => {{ + one_node_graph( + BinaryNode::$operator( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Type::Tensor(TensorType::new_float("tensor3", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.$operator(tensor2); - #[test] - fn test_binary_codegen_div_scalar() { - test_binary_operator_on_tensor_and_scalar!(div, div_scalar); - } + tensor3 + } + }, + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + }}; + } + + macro_rules! test_binary_operator_on_tensor_and_scalar { + ($operator:ident, $burn_operator:ident) => {{ + one_node_graph( + BinaryNode::$operator( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Tensor(TensorType::new_float("tensor3", 4)), + ), + quote! { + pub fn forward(&self, scalar1: f32, tensor1: Tensor) -> Tensor { + let tensor3 = tensor1.$burn_operator(scalar1); - #[test] - fn test_binary_codegen_div_scalars() { - test_binary_operator_on_scalar_and_scalar!(div, /); - } + tensor3 + } + }, + vec!["scalar1".to_string(), "tensor1".to_string()], + vec!["tensor3".to_string()], + ); + }}; + } + + macro_rules! test_binary_operator_on_scalar_and_scalar { + ($operator:ident, $scalar_operator:tt) => {{ + one_node_graph( + BinaryNode::$operator( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), + Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float32)), + ), + quote! { + pub fn forward(&self, scalar1: f32, scalar2: f32) -> f32 { + let scalar3 = scalar1 $scalar_operator scalar2; - #[test] - fn test_binary_codegen_equal_tensors() { - let mut graph = BurnGraph::::default(); - let node_gen = BinaryNode::equal( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - Type::Tensor(TensorType::new_bool("tensor3", 4)), - ); - - graph.register(node_gen); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::tensor::Bool; - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, + scalar3 } + }, + vec!["scalar1".to_string(), "scalar2".to_string()], + vec!["scalar3".to_string()], + ); + }}; + } + + #[test] + fn test_binary_codegen_add() { + test_binary_operator_on_tensors!(add); + } + + #[test] + fn test_binary_codegen_add_scalar() { + test_binary_operator_on_tensor_and_scalar!(add, add_scalar); + } + + #[test] + fn test_binary_codegen_add_scalars() { + test_binary_operator_on_scalar_and_scalar!(add, +); + } + + #[test] + fn test_binary_codegen_sub() { + test_binary_operator_on_tensors!(sub); + } + + #[test] + fn test_binary_codegen_sub_scalar() { + test_binary_operator_on_tensor_and_scalar!(sub, sub_scalar); + } + + #[test] + fn test_binary_codegen_sub_scalars() { + test_binary_operator_on_scalar_and_scalar!(sub, -); + } + + #[test] + fn test_binary_codegen_mul() { + test_binary_operator_on_tensors!(mul); + } + + #[test] + fn test_binary_codegen_mul_scalar() { + test_binary_operator_on_tensor_and_scalar!(mul, mul_scalar); + } + + #[test] + fn test_binary_codegen_mul_scalars() { + test_binary_operator_on_scalar_and_scalar!(mul, *); + } + + #[test] + fn test_binary_codegen_div() { + test_binary_operator_on_tensors!(div); + } + + #[test] + fn test_binary_codegen_div_scalar() { + test_binary_operator_on_tensor_and_scalar!(div, div_scalar); + } + + #[test] + fn test_binary_codegen_div_scalars() { + test_binary_operator_on_scalar_and_scalar!(div, /); + } + + #[test] + fn test_binary_codegen_equal_tensors() { + let mut graph = BurnGraph::::default(); + let node_gen = BinaryNode::equal( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Type::Tensor(TensorType::new_bool("tensor3", 4)), + ); + + graph.register(node_gen); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Bool; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } + } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.equal(tensor2); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.equal(tensor2); - tensor3 - } + tensor3 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } - #[test] - fn test_binary_codegen_equal_scalars() { - test_binary_operator_on_scalar_and_scalar!(equal, ==); - } + #[test] + fn test_binary_codegen_equal_scalars() { + test_binary_operator_on_scalar_and_scalar!(equal, ==); + } } diff --git a/burn-import/src/burn/node/clip.rs b/burn-import/src/burn/node/clip.rs index 3156ab6d73..69254a3d3a 100644 --- a/burn-import/src/burn/node/clip.rs +++ b/burn-import/src/burn/node/clip.rs @@ -6,182 +6,182 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct ClipNode { - pub input: TensorType, - pub output: TensorType, - pub min: Option, // Should be elem Type - pub max: Option, + pub input: TensorType, + pub output: TensorType, + pub min: Option, // Should be elem Type + pub max: Option, } impl NodeCodegen for ClipNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - - if let Some(min) = self.min { - if let Some(max) = self.max { - quote! { - let #output = #input.clamp(#min, #max); - } - } else { - quote! { - let #output = #input.clamp_min(#min); - } - } - } else if let Some(max) = self.max { - return quote! { - let #output = #input.clamp_max(#max); - }; - } else { - panic!("Clip node must have at least one min or max value"); + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + if let Some(min) = self.min { + if let Some(max) = self.max { + quote! { + let #output = #input.clamp(#min, #max); + } + } else { + quote! { + let #output = #input.clamp_min(#min); } + } + } else if let Some(max) = self.max { + return quote! { + let #output = #input.clamp_max(#max); + }; + } else { + panic!("Clip node must have at least one min or max value"); } + } - fn into_node(self) -> Node { - Node::Clip(self) - } + fn into_node(self) -> Node { + Node::Clip(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; + use burn::record::FullPrecisionSettings; - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - #[test] - fn codegen_nodes_min_max() { - let mut graph = BurnGraph::::default(); + #[test] + fn codegen_nodes_min_max() { + let mut graph = BurnGraph::::default(); - graph.register(ClipNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - Some(0.0), - Some(1.0), - )); + graph.register(ClipNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + Some(0.0), + Some(1.0), + )); - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.clamp(0f64, 1f64); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.clamp(0f64, 1f64); - tensor2 - } + tensor2 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } - #[test] - fn codegen_nodes_min() { - let mut graph = BurnGraph::::default(); + #[test] + fn codegen_nodes_min() { + let mut graph = BurnGraph::::default(); - graph.register(ClipNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - Some(0.0), - None, - )); + graph.register(ClipNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + Some(0.0), + None, + )); - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.clamp_min(0f64); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.clamp_min(0f64); - tensor2 - } + tensor2 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } - #[test] - fn codegen_nodes_max() { - let mut graph = BurnGraph::::default(); + #[test] + fn codegen_nodes_max() { + let mut graph = BurnGraph::::default(); - graph.register(ClipNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - None, - Some(1.0), - )); + graph.register(ClipNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + None, + Some(1.0), + )); - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.clamp_max(1f64); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.clamp_max(1f64); - tensor2 - } + tensor2 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/concat.rs b/burn-import/src/burn/node/concat.rs index a0cb6e1893..20ecc2b0fe 100644 --- a/burn-import/src/burn/node/concat.rs +++ b/burn-import/src/burn/node/concat.rs @@ -7,100 +7,101 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct ConcatNode { - pub inputs: Vec, - pub output: TensorType, - pub dim: usize, + pub inputs: Vec, + pub output: TensorType, + pub dim: usize, } impl NodeCodegen for ConcatNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + self + .inputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let dim = self.dim.to_tokens(); + let inputs = self + .inputs + .iter() + .map(|t| scope.tensor_use_owned(t, node_position)); + + let output = &self.output.name; + + quote! { + let #output = burn::tensor::Tensor::cat([#(#inputs),*].into(), #dim); } + } - fn input_types(&self) -> Vec { - self.inputs - .iter() - .map(|t| Type::Tensor(t.clone())) - .collect() - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let dim = self.dim.to_tokens(); - let inputs = self - .inputs - .iter() - .map(|t| scope.tensor_use_owned(t, node_position)); - - let output = &self.output.name; - - quote! { - let #output = burn::tensor::Tensor::cat([#(#inputs),*].into(), #dim); - } - } - - fn into_node(self) -> Node { - Node::Concat(self) - } + fn into_node(self) -> Node { + Node::Concat(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{concat::ConcatNode, test::assert_tokens}, - TensorType, - }; + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{concat::ConcatNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_concat() { + let mut graph = BurnGraph::::default(); + + graph.register(ConcatNode::new( + vec![ + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + ], + TensorType::new_float("tensor3", 4), + 1, + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[test] - fn test_codegen_concat() { - let mut graph = BurnGraph::::default(); - - graph.register(ConcatNode::new( - vec![ - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - ], - TensorType::new_float("tensor3", 4), - 1, - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } + } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1); - tensor3 - } + tensor3 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/constant.rs b/burn-import/src/burn/node/constant.rs index d091529e1e..886b6a85f9 100644 --- a/burn-import/src/burn/node/constant.rs +++ b/burn-import/src/burn/node/constant.rs @@ -1,9 +1,9 @@ use super::{Node, NodeCodegen}; use crate::burn::{ScalarKind, ScalarType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::ParamId, - record::{ParamSerde, PrecisionSettings}, - tensor::DataSerialize, + module::ParamId, + record::{ParamSerde, PrecisionSettings}, + tensor::DataSerialize, }; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -11,178 +11,178 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct ConstantNode { - pub name: String, - pub value: ConstantValue, - pub output: Type, + pub name: String, + pub value: ConstantValue, + pub output: Type, } #[derive(Debug, Clone)] pub enum TensorValue { - Float(DataSerialize), - Int(DataSerialize), - // TODO Support bool serialization (@antimora 8/26/2023) + Float(DataSerialize), + Int(DataSerialize), + // TODO Support bool serialization (@antimora 8/26/2023) } #[derive(Debug, Clone, new)] pub enum ConstantValue { - /// Float constant. - Float32(f32), - Float64(f64), + /// Float constant. + Float32(f32), + Float64(f64), - /// Integer constant. - Int32(i32), - Int64(i64), + /// Integer constant. + Int32(i32), + Int64(i64), - // Boolean constant. - Bool(bool), + // Boolean constant. + Bool(bool), - /// Tensor constant. - Tensor(TensorType, TensorValue), + /// Tensor constant. + Tensor(TensorType, TensorValue), } impl ConstantValue { - pub fn ty_tokens(&self) -> TokenStream { - match self { - ConstantValue::Float32(_) => quote! { f32 }, - ConstantValue::Float64(_) => quote! { f64 }, - ConstantValue::Int32(_) => quote! { i32 }, - ConstantValue::Int64(_) => quote! { i64 }, - ConstantValue::Bool(_) => quote! { bool }, - ConstantValue::Tensor(tensor_type, _) => { - let ty = tensor_type.ty(); - quote! { burn::module::Param<#ty>} - } - } + pub fn ty_tokens(&self) -> TokenStream { + match self { + ConstantValue::Float32(_) => quote! { f32 }, + ConstantValue::Float64(_) => quote! { f64 }, + ConstantValue::Int32(_) => quote! { i32 }, + ConstantValue::Int64(_) => quote! { i64 }, + ConstantValue::Bool(_) => quote! { bool }, + ConstantValue::Tensor(tensor_type, _) => { + let ty = tensor_type.ty(); + quote! { burn::module::Param<#ty>} + } } - pub fn val_tokens(&self) -> TokenStream { - match self { - ConstantValue::Float32(val) => quote! { #val }, - ConstantValue::Float64(val) => quote! { #val }, - ConstantValue::Int32(val) => quote! { #val }, - ConstantValue::Int64(val) => quote! { #val }, - ConstantValue::Bool(val) => quote! { #val }, - ConstantValue::Tensor(_, _) => { - panic!("Tensor constant is not assignable.") - } - } + } + pub fn val_tokens(&self) -> TokenStream { + match self { + ConstantValue::Float32(val) => quote! { #val }, + ConstantValue::Float64(val) => quote! { #val }, + ConstantValue::Int32(val) => quote! { #val }, + ConstantValue::Int64(val) => quote! { #val }, + ConstantValue::Bool(val) => quote! { #val }, + ConstantValue::Tensor(_, _) => { + panic!("Tensor constant is not assignable.") + } } + } } impl ConstantNode { - pub fn new(name: String, value: ConstantValue, output: Type) -> Self { - Self { - name, - value, - output, - } + pub fn new(name: String, value: ConstantValue, output: Type) -> Self { + Self { + name, + value, + output, } - pub fn constant_value_into_type(&self) -> Type { - let name = Ident::new(self.name.as_str(), Span::call_site()); - match &self.value { - ConstantValue::Float32(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Float32, - }), - ConstantValue::Float64(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Float64, - }), - ConstantValue::Int32(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Int32, - }), - ConstantValue::Int64(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Int64, - }), - ConstantValue::Bool(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Bool, - }), - - ConstantValue::Tensor(tensor_type, _) => Type::Tensor(tensor_type.clone()), - } + } + pub fn constant_value_into_type(&self) -> Type { + let name = Ident::new(self.name.as_str(), Span::call_site()); + match &self.value { + ConstantValue::Float32(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Float32, + }), + ConstantValue::Float64(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Float64, + }), + ConstantValue::Int32(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Int32, + }), + ConstantValue::Int64(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Int64, + }), + ConstantValue::Bool(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Bool, + }), + + ConstantValue::Tensor(tensor_type, _) => Type::Tensor(tensor_type.clone()), } + } } impl NodeCodegen for ConstantNode { - fn output_types(&self) -> Vec { - vec![self.output.clone()] + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + vec![] + } + + fn field_type(&self) -> Option { + match &self.value { + ConstantValue::Tensor(tensor_type, _) => Some(Type::Tensor(tensor_type.clone())), + _ => None, } + } - fn input_types(&self) -> Vec { - vec![] - } - - fn field_type(&self) -> Option { - match &self.value { - ConstantValue::Tensor(tensor_type, _) => Some(Type::Tensor(tensor_type.clone())), - _ => None, + fn field_init(&self, with_record: bool) -> Option { + match &self.value { + ConstantValue::Tensor(tensor_type, _) => { + let ty = tensor_type.ty(); + let name = Ident::new(self.name.as_ref(), Span::call_site()); + let shape = tensor_type.clone().shape.unwrap().to_tokens(); + let dim = tensor_type.clone().dim.to_tokens(); + + if with_record { + Some(quote! { + let #name = record.#name.map(|tensor| tensor.set_require_grad(false)); + }) + } else { + Some(quote! { + let #name: burn::module::Param<#ty> = burn::module::Param::new( + burn::module::ParamId::new(), + Tensor::::zeros(#shape).set_require_grad(false), + ); + }) } + } + _ => None, } + } - fn field_init(&self, with_record: bool) -> Option { - match &self.value { - ConstantValue::Tensor(tensor_type, _) => { - let ty = tensor_type.ty(); - let name = Ident::new(self.name.as_ref(), Span::call_site()); - let shape = tensor_type.clone().shape.unwrap().to_tokens(); - let dim = tensor_type.clone().dim.to_tokens(); - - if with_record { - Some(quote! { - let #name = record.#name.map(|tensor| tensor.set_require_grad(false)); - }) - } else { - Some(quote! { - let #name: burn::module::Param<#ty> = burn::module::Param::new( - burn::module::ParamId::new(), - Tensor::::zeros(#shape).set_require_grad(false), - ); - }) - } - } - _ => None, + fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { + let name = Ident::new(self.name.as_ref(), Span::call_site()); + let output = self.output.name(); + + match &self.value { + ConstantValue::Tensor(_, _) => { + quote! { + let #output = self.#name.val(); } - } + } + _ => { + let val = self.value.val_tokens(); + let ty = self.value.ty_tokens(); - fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { - let name = Ident::new(self.name.as_ref(), Span::call_site()); - let output = self.output.name(); - - match &self.value { - ConstantValue::Tensor(_, _) => { - quote! { - let #output = self.#name.val(); - } - } - _ => { - let val = self.value.val_tokens(); - let ty = self.value.ty_tokens(); - - quote! { - let #output: #ty = #val; - } - } + quote! { + let #output: #ty = #val; } + } } - - fn into_node(self) -> Node { - Node::Constant(self) + } + + fn into_node(self) -> Node { + Node::Constant(self) + } + + fn field_serialize(&self, serializer: S) -> Result { + if let ConstantValue::Tensor(_, ds) = &self.value { + let data: DataSerialize = match ds { + TensorValue::Float(data) => data.clone().convert(), + TensorValue::Int(data) => data.clone().convert(), + }; + let data = ParamSerde::new(ParamId::new().into_string(), data); + return data.serialize(serializer); } - fn field_serialize(&self, serializer: S) -> Result { - if let ConstantValue::Tensor(_, ds) = &self.value { - let data: DataSerialize = match ds { - TensorValue::Float(data) => data.clone().convert(), - TensorValue::Int(data) => data.clone().convert(), - }; - let data = ParamSerde::new(ParamId::new().into_string(), data); - return data.serialize(serializer); - } - - S::serialize_none(serializer) - } + S::serialize_none(serializer) + } } // TODO add test missing for constant node (@antimora 8/2/2023) diff --git a/burn-import/src/burn/node/conv1d.rs b/burn-import/src/burn/node/conv1d.rs index 0c3e568880..e3d5b38bc2 100644 --- a/burn-import/src/burn/node/conv1d.rs +++ b/burn-import/src/burn/node/conv1d.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv1dConfig, Conv1dRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{ConstantRecord, Param, ParamId}, + nn::conv::{Conv1dConfig, Conv1dRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,191 +12,191 @@ use serde::Serialize; #[derive(Clone, Debug)] pub struct Conv1dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub data_weights: DataSerialize, - pub data_bias: Option>, - pub config: Conv1dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub data_weights: DataSerialize, + pub data_bias: Option>, + pub config: Conv1dConfig, } impl Conv1dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - data_weights: DataSerialize, - data_bias: Option>, - config: Conv1dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Conv1d - }, - ), - input, - output, - data_weights, - data_bias, - config, - } + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + data_weights: DataSerialize, + data_bias: Option>, + config: Conv1dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Conv1d + }, + ), + input, + output, + data_weights, + data_bias, + config, } + } } impl NodeCodegen for Conv1dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let channels_in = self.config.channels_in.to_tokens(); - let channels_out = self.config.channels_out.to_tokens(); - let kernel_size = self.config.kernel_size.to_tokens(); - let stride = self.config.stride.to_tokens(); - let dilation = self.config.dilation.to_tokens(); - let groups = self.config.groups.to_tokens(); - let padding = self.config.padding.to_tokens(); - let bias = self.config.bias; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; - - let tokens = quote! { - let #name = Conv1dConfig::new(#channels_in, #channels_out, #kernel_size) - .with_stride(#stride) - .with_padding(#padding) - .with_dilation(#dilation) - .with_groups(#groups) - .with_bias(#bias) - .#init_line - }; - - Some(tokens) - } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let channels_in = self.config.channels_in.to_tokens(); + let channels_out = self.config.channels_out.to_tokens(); + let kernel_size = self.config.kernel_size.to_tokens(); + let stride = self.config.stride.to_tokens(); + let dilation = self.config.dilation.to_tokens(); + let groups = self.config.groups.to_tokens(); + let padding = self.config.padding.to_tokens(); + let bias = self.config.bias; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, + }; - fn field_serialize(&self, serializer: S) -> Result { - let record = Conv1dRecord:: { - weight: Param::new( - ParamId::new(), - Tensor::from_data(self.data_weights.clone().convert()), - ), - bias: self - .data_bias - .as_ref() - .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), - stride: ConstantRecord::new(), - kernel_size: ConstantRecord::new(), - dilation: ConstantRecord::new(), - groups: ConstantRecord::new(), - padding: ConstantRecord::new(), - }; + let tokens = quote! { + let #name = Conv1dConfig::new(#channels_in, #channels_out, #kernel_size) + .with_stride(#stride) + .with_padding(#padding) + .with_dilation(#dilation) + .with_groups(#groups) + .with_bias(#bias) + .#init_line + }; - let item = Record::into_item::(record); - item.serialize(serializer) - } + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let record = Conv1dRecord:: { + weight: Param::new( + ParamId::new(), + Tensor::from_data(self.data_weights.clone().convert()), + ), + bias: self + .data_bias + .as_ref() + .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), + stride: ConstantRecord::new(), + kernel_size: ConstantRecord::new(), + dilation: ConstantRecord::new(), + groups: ConstantRecord::new(), + padding: ConstantRecord::new(), + }; - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + let item = Record::into_item::(record); + item.serialize(serializer) + } - quote! { - let #output = self.#field.forward(#input); - } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig1d"); - imports.register("burn::nn::conv::Conv1d"); - imports.register("burn::nn::conv::Conv1dConfig"); - } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - fn into_node(self) -> Node { - Node::Conv1d(self) + quote! { + let #output = self.#field.forward(#input); } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig1d"); + imports.register("burn::nn::conv::Conv1d"); + imports.register("burn::nn::conv::Conv1dConfig"); + } + + fn into_node(self) -> Node { + Node::Conv1d(self) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{conv1d::Conv1dNode, test::assert_tokens}, - TensorType, - }; - use burn::{ - nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, - }; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(Conv1dNode::new( - "conv1d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - None, - Conv1dConfig::new(3, 3, 3).with_padding(PaddingConfig1d::Valid), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig1d; - use burn::nn::conv::Conv1d; - use burn::nn::conv::Conv1dConfig; - - #[derive(Module, Debug)] - pub struct Model { - conv1d: Conv1d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{conv1d::Conv1dNode, test::assert_tokens}, + TensorType, + }; + use burn::{ + nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, + }; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(Conv1dNode::new( + "conv1d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + None, + Conv1dConfig::new(3, 3, 3).with_padding(PaddingConfig1d::Valid), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig1d; + use burn::nn::conv::Conv1d; + use burn::nn::conv::Conv1dConfig; + + #[derive(Module, Debug)] + pub struct Model { + conv1d: Conv1d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv1d = Conv1dConfig::new(3, 3, 3) - .with_stride(1) - .with_padding(PaddingConfig1d::Valid) - .with_dilation(1) - .with_groups(1) - .with_bias(true) - .init_with(record.conv1d); - - Self { - conv1d, - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv1d = Conv1dConfig::new(3, 3, 3) + .with_stride(1) + .with_padding(PaddingConfig1d::Valid) + .with_dilation(1) + .with_groups(1) + .with_bias(true) + .init_with(record.conv1d); + + Self { + conv1d, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.conv1d.forward(input); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.conv1d.forward(input); - output - } + output } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/conv2d.rs b/burn-import/src/burn/node/conv2d.rs index 9b3c9f4408..d6b0059794 100644 --- a/burn-import/src/burn/node/conv2d.rs +++ b/burn-import/src/burn/node/conv2d.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv2dConfig, Conv2dRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{ConstantRecord, Param, ParamId}, + nn::conv::{Conv2dConfig, Conv2dRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,190 +12,190 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct Conv2dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub data_weights: DataSerialize, - pub data_bias: Option>, - pub config: Conv2dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub data_weights: DataSerialize, + pub data_bias: Option>, + pub config: Conv2dConfig, } impl Conv2dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - data_weights: DataSerialize, - data_bias: Option>, - config: Conv2dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Conv2d - }, - ), - input, - output, - data_weights, - data_bias, - config, - } + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + data_weights: DataSerialize, + data_bias: Option>, + config: Conv2dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Conv2d + }, + ), + input, + output, + data_weights, + data_bias, + config, } + } } impl NodeCodegen for Conv2dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let channels = self.config.channels.to_tokens(); - let kernel_size = self.config.kernel_size.to_tokens(); - let stride = self.config.stride.to_tokens(); - let dilation = self.config.dilation.to_tokens(); - let groups = self.config.groups.to_tokens(); - let padding = self.config.padding.to_tokens(); - let bias = self.config.bias; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; - - let tokens = quote! { - let #name = Conv2dConfig::new(#channels, #kernel_size) - .with_stride(#stride) - .with_padding(#padding) - .with_dilation(#dilation) - .with_groups(#groups) - .with_bias(#bias) - .#init_line - }; - - Some(tokens) - } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let channels = self.config.channels.to_tokens(); + let kernel_size = self.config.kernel_size.to_tokens(); + let stride = self.config.stride.to_tokens(); + let dilation = self.config.dilation.to_tokens(); + let groups = self.config.groups.to_tokens(); + let padding = self.config.padding.to_tokens(); + let bias = self.config.bias; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, + }; - fn field_serialize(&self, serializer: S) -> Result { - let record = Conv2dRecord:: { - weight: Param::new( - ParamId::new(), - Tensor::from_data(self.data_weights.clone().convert()), - ), - bias: self - .data_bias - .as_ref() - .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), - stride: [ConstantRecord::new(); 2], - kernel_size: [ConstantRecord::new(); 2], - dilation: [ConstantRecord::new(); 2], - groups: ConstantRecord::new(), - padding: ConstantRecord::new(), - }; + let tokens = quote! { + let #name = Conv2dConfig::new(#channels, #kernel_size) + .with_stride(#stride) + .with_padding(#padding) + .with_dilation(#dilation) + .with_groups(#groups) + .with_bias(#bias) + .#init_line + }; - let item = Record::into_item::(record); - item.serialize(serializer) - } + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let record = Conv2dRecord:: { + weight: Param::new( + ParamId::new(), + Tensor::from_data(self.data_weights.clone().convert()), + ), + bias: self + .data_bias + .as_ref() + .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), + stride: [ConstantRecord::new(); 2], + kernel_size: [ConstantRecord::new(); 2], + dilation: [ConstantRecord::new(); 2], + groups: ConstantRecord::new(), + padding: ConstantRecord::new(), + }; - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + let item = Record::into_item::(record); + item.serialize(serializer) + } - quote! { - let #output = self.#field.forward(#input); - } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig2d"); - imports.register("burn::nn::conv::Conv2d"); - imports.register("burn::nn::conv::Conv2dConfig"); - } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - fn into_node(self) -> Node { - Node::Conv2d(self) + quote! { + let #output = self.#field.forward(#input); } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig2d"); + imports.register("burn::nn::conv::Conv2d"); + imports.register("burn::nn::conv::Conv2dConfig"); + } + + fn into_node(self) -> Node { + Node::Conv2d(self) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{conv2d::Conv2dNode, test::assert_tokens}, - TensorType, - }; - use burn::{ - nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, - }; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(Conv2dNode::new( - "conv2d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig2d; - use burn::nn::conv::Conv2d; - use burn::nn::conv::Conv2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - conv2d: Conv2d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{conv2d::Conv2dNode, test::assert_tokens}, + TensorType, + }; + use burn::{ + nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, + }; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(Conv2dNode::new( + "conv2d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + None, + Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::conv::Conv2d; + use burn::nn::conv::Conv2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + conv2d: Conv2d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv2d = Conv2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .with_groups(1) - .with_bias(true) - .init_with(record.conv2d); - - Self { - conv2d, - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv2d = Conv2dConfig::new([3, 3], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .with_groups(1) + .with_bias(true) + .init_with(record.conv2d); + + Self { + conv2d, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.conv2d.forward(input); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.conv2d.forward(input); - output - } + output } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/dropout.rs b/burn-import/src/burn/node/dropout.rs index de61653e13..5c8efa3045 100644 --- a/burn-import/src/burn/node/dropout.rs +++ b/burn-import/src/burn/node/dropout.rs @@ -8,138 +8,138 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; #[derive(Debug, Clone)] pub struct DropoutNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub config: DropoutConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub config: DropoutConfig, } impl DropoutNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - config: DropoutConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Dropout - }, - ), - input, - output, - config, - } + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + config: DropoutConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Dropout + }, + ), + input, + output, + config, } + } } impl NodeCodegen for DropoutNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - - let prob = self.config.prob.to_tokens(); - - let init_line = quote! { - init(); - }; - - let tokens = quote! { - let #name = DropoutConfig::new(#prob) - .#init_line - }; - - Some(tokens) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; - - quote! { - let #output = self.#field.forward(#input); - } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::Dropout"); - imports.register("burn::nn::DropoutConfig"); - } - - fn into_node(self) -> Node { - Node::Dropout(self) - } - - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; + + let prob = self.config.prob.to_tokens(); + + let init_line = quote! { + init(); + }; + + let tokens = quote! { + let #name = DropoutConfig::new(#prob) + .#init_line + }; + + Some(tokens) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::Dropout"); + imports.register("burn::nn::DropoutConfig"); + } + + fn into_node(self) -> Node { + Node::Dropout(self) + } + + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - use burn::{nn::DropoutConfig, record::FullPrecisionSettings}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(DropoutNode::new( - "dropout", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - DropoutConfig::new(0.5), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::Dropout; - use burn::nn::DropoutConfig; - - #[derive(Module, Debug)] - pub struct Model { - dropout: Dropout, - phantom: core::marker::PhantomData, + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{nn::DropoutConfig, record::FullPrecisionSettings}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(DropoutNode::new( + "dropout", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + DropoutConfig::new(0.5), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::Dropout; + use burn::nn::DropoutConfig; - } + #[derive(Module, Debug)] + pub struct Model { + dropout: Dropout, + phantom: core::marker::PhantomData, - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let dropout = DropoutConfig::new(0.5) - .init(); + } - Self { - dropout, - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.dropout.forward(input); + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let dropout = DropoutConfig::new(0.5) + .init(); - output + Self { + dropout, + phantom: core::marker::PhantomData, } } - }; + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.dropout.forward(input); - assert_tokens(graph.codegen(), expected); - } + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/gather.rs b/burn-import/src/burn/node/gather.rs index 2c0e6bc9ea..933b76f94d 100644 --- a/burn-import/src/burn/node/gather.rs +++ b/burn-import/src/burn/node/gather.rs @@ -6,101 +6,101 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct GatherNode { - pub input: TensorType, - pub index: TensorType, - pub output: TensorType, - pub dim: usize, + pub input: TensorType, + pub index: TensorType, + pub output: TensorType, + pub dim: usize, } impl NodeCodegen for GatherNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![ + Type::Tensor(self.input.clone()), + Type::Tensor(self.index.clone()), + ] + } + + fn forward( + &self, + scope: &mut crate::burn::Scope, + node_position: usize, + ) -> proc_macro2::TokenStream { + let dim = self.dim.to_tokens(); + let input = scope.tensor_use_owned(&self.input, node_position); + let index = scope.tensor_use_owned(&self.index, node_position); + let output = &self.output.name; + + quote! { + let #output = #input.gather(#dim, #index); } + } - fn input_types(&self) -> Vec { - vec![ - Type::Tensor(self.input.clone()), - Type::Tensor(self.index.clone()), - ] - } - - fn forward( - &self, - scope: &mut crate::burn::Scope, - node_position: usize, - ) -> proc_macro2::TokenStream { - let dim = self.dim.to_tokens(); - let input = scope.tensor_use_owned(&self.input, node_position); - let index = scope.tensor_use_owned(&self.index, node_position); - let output = &self.output.name; - - quote! { - let #output = #input.gather(#dim, #index); - } - } - - fn into_node(self) -> super::Node { - Node::Gather(self) - } + fn into_node(self) -> super::Node { + Node::Gather(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{gather::GatherNode, test::assert_tokens}, - TensorType, - }; + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{gather::GatherNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_gather() { + let mut graph = BurnGraph::::default(); + + graph.register(GatherNode::new( + TensorType::new_float("tensor1", 2), + TensorType::new_int("tensor2", 2), + TensorType::new_float("tensor3", 2), + 1, + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[test] - fn test_codegen_gather() { - let mut graph = BurnGraph::::default(); - - graph.register(GatherNode::new( - TensorType::new_float("tensor1", 2), - TensorType::new_int("tensor2", 2), - TensorType::new_float("tensor3", 2), - 1, - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::tensor::Int; - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } + } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.gather(1, tensor2); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.gather(1, tensor2); - tensor3 - } + tensor3 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/global_avg_pool.rs b/burn-import/src/burn/node/global_avg_pool.rs index 80d6f3cec2..76c18992ec 100644 --- a/burn-import/src/burn/node/global_avg_pool.rs +++ b/burn-import/src/burn/node/global_avg_pool.rs @@ -13,203 +13,203 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type}; /// is equivalent to global average pooling. #[derive(Debug, Clone)] pub struct GlobalAvgPoolNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, } impl GlobalAvgPoolNode { - pub fn new>(name: S, input: TensorType, output: TensorType) -> Self { - // Depending on the input dimension, we need to use a different type nn module - let field_type = match input.dim { - 3 => quote! { - AdaptiveAvgPool1d - }, - 4 => quote! { - AdaptiveAvgPool2d - }, - dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), - }; + pub fn new>(name: S, input: TensorType, output: TensorType) -> Self { + // Depending on the input dimension, we need to use a different type nn module + let field_type = match input.dim { + 3 => quote! { + AdaptiveAvgPool1d + }, + 4 => quote! { + AdaptiveAvgPool2d + }, + dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), + }; - Self { - field: OtherType::new(name, field_type), - input, - output, - } + Self { + field: OtherType::new(name, field_type), + input, + output, } + } } impl NodeCodegen for GlobalAvgPoolNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - - let tokens = match self.input.dim { - 3 => { - quote! { - let #name = AdaptiveAvgPool1dConfig::new(1) - .init(); - } - } - 4 => { - quote! { - let #name = AdaptiveAvgPool2dConfig::new([1,1]) - .init(); - } - } - dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), - }; + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; + + let tokens = match self.input.dim { + 3 => { + quote! { + let #name = AdaptiveAvgPool1dConfig::new(1) + .init(); + } + } + 4 => { + quote! { + let #name = AdaptiveAvgPool2dConfig::new([1,1]) + .init(); + } + } + dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), + }; - Some(tokens) - } + Some(tokens) + } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - quote! { - let #output = self.#field.forward(#input); - } + quote! { + let #output = self.#field.forward(#input); } - - fn register_imports(&self, imports: &mut BurnImports) { - match self.input.dim { - 3 => { - imports.register("burn::nn::pool::AdaptiveAvgPool1d"); - imports.register("burn::nn::pool::AdaptiveAvgPool1dConfig"); - } - 4 => { - imports.register("burn::nn::pool::AdaptiveAvgPool2d"); - imports.register("burn::nn::pool::AdaptiveAvgPool2dConfig"); - } - dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), - } + } + + fn register_imports(&self, imports: &mut BurnImports) { + match self.input.dim { + 3 => { + imports.register("burn::nn::pool::AdaptiveAvgPool1d"); + imports.register("burn::nn::pool::AdaptiveAvgPool1dConfig"); + } + 4 => { + imports.register("burn::nn::pool::AdaptiveAvgPool2d"); + imports.register("burn::nn::pool::AdaptiveAvgPool2dConfig"); + } + dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), } + } - fn into_node(self) -> Node { - Node::GlobalAvgPool(self) - } + fn into_node(self) -> Node { + Node::GlobalAvgPool(self) + } - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) - } + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{global_avg_pool::GlobalAvgPoolNode, test::assert_tokens}, - TensorType, - }; - use burn::record::FullPrecisionSettings; - - #[test] - fn test_codegen_2d() { - let mut graph = BurnGraph::::default(); - - graph.register(GlobalAvgPoolNode::new( - "global_avg_pool1", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::pool::AdaptiveAvgPool2d; - use burn::nn::pool::AdaptiveAvgPool2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - global_avg_pool1: AdaptiveAvgPool2d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{global_avg_pool::GlobalAvgPoolNode, test::assert_tokens}, + TensorType, + }; + use burn::record::FullPrecisionSettings; + + #[test] + fn test_codegen_2d() { + let mut graph = BurnGraph::::default(); + + graph.register(GlobalAvgPoolNode::new( + "global_avg_pool1", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::pool::AdaptiveAvgPool2d; + use burn::nn::pool::AdaptiveAvgPool2dConfig; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let global_avg_pool1 = AdaptiveAvgPool2dConfig::new([1, 1]) - .init(); + #[derive(Module, Debug)] + pub struct Model { + global_avg_pool1: AdaptiveAvgPool2d, + phantom: core::marker::PhantomData, + } - Self { - global_avg_pool1, - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.global_avg_pool1.forward(input); + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let global_avg_pool1 = AdaptiveAvgPool2dConfig::new([1, 1]) + .init(); - output + Self { + global_avg_pool1, + phantom: core::marker::PhantomData, } } - }; + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.global_avg_pool1.forward(input); - assert_tokens(graph.codegen(), expected); - } - - #[test] - fn test_codegen_1d() { - let mut graph = BurnGraph::::default(); - - graph.register(GlobalAvgPoolNode::new( - "global_avg_pool1", - TensorType::new_float("input", 3), - TensorType::new_float("output", 3), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::pool::AdaptiveAvgPool1d; - use burn::nn::pool::AdaptiveAvgPool1dConfig; - - #[derive(Module, Debug)] - pub struct Model { - global_avg_pool1: AdaptiveAvgPool1d, - phantom: core::marker::PhantomData, + output } + } + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let global_avg_pool1 = AdaptiveAvgPool1dConfig::new(1) - .init(); + assert_tokens(graph.codegen(), expected); + } - Self { - global_avg_pool1, - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.global_avg_pool1.forward(input); + #[test] + fn test_codegen_1d() { + let mut graph = BurnGraph::::default(); + + graph.register(GlobalAvgPoolNode::new( + "global_avg_pool1", + TensorType::new_float("input", 3), + TensorType::new_float("output", 3), + )); - output + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::pool::AdaptiveAvgPool1d; + use burn::nn::pool::AdaptiveAvgPool1dConfig; + + #[derive(Module, Debug)] + pub struct Model { + global_avg_pool1: AdaptiveAvgPool1d, + phantom: core::marker::PhantomData, + } + + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let global_avg_pool1 = AdaptiveAvgPool1dConfig::new(1) + .init(); + + Self { + global_avg_pool1, + phantom: core::marker::PhantomData, } } - }; + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.global_avg_pool1.forward(input); - assert_tokens(graph.codegen(), expected); - } + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/linear.rs b/burn-import/src/burn/node/linear.rs index b413c2c4a0..1e9bd5a1c2 100644 --- a/burn-import/src/burn/node/linear.rs +++ b/burn-import/src/burn/node/linear.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{Param, ParamId}, - nn::{LinearConfig, LinearRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{Param, ParamId}, + nn::{LinearConfig, LinearRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,167 +12,167 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct LinearNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub data_weights: DataSerialize, - pub data_bias: Option>, - pub config: LinearConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub data_weights: DataSerialize, + pub data_bias: Option>, + pub config: LinearConfig, } impl LinearNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - data_weights: DataSerialize, - data_bias: Option>, - config: LinearConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Linear - }, - ), - input, - output, - data_weights, - data_bias, - config, - } + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + data_weights: DataSerialize, + data_bias: Option>, + config: LinearConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Linear + }, + ), + input, + output, + data_weights, + data_bias, + config, } + } } impl NodeCodegen for LinearNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let d_input = self.config.d_input.to_tokens(); + let d_output = self.config.d_output.to_tokens(); + let bias = self.config.bias; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, + }; + + let tokens = quote! { + let #name = LinearConfig::new(#d_input, #d_output) + .with_bias(#bias) + .#init_line + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let record = LinearRecord:: { + weight: Param::new( + ParamId::new(), + Tensor::from_data(self.data_weights.clone().convert()), + ), + bias: self + .data_bias + .as_ref() + .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), + }; + + let item = Record::into_item::(record); + item.serialize(serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); } + } - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let d_input = self.config.d_input.to_tokens(); - let d_output = self.config.d_output.to_tokens(); - let bias = self.config.bias; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; - - let tokens = quote! { - let #name = LinearConfig::new(#d_input, #d_output) - .with_bias(#bias) - .#init_line - }; + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::Linear"); + imports.register("burn::nn::LinearConfig"); + } - Some(tokens) - } + fn into_node(self) -> Node { + Node::Linear(self) + } +} - fn field_serialize(&self, serializer: S) -> Result { - let record = LinearRecord:: { - weight: Param::new( - ParamId::new(), - Tensor::from_data(self.data_weights.clone().convert()), - ), - bias: self - .data_bias - .as_ref() - .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{record::FullPrecisionSettings, tensor::Data}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(LinearNode::new( + "linear", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + None, + LinearConfig::new(128, 128), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, }; + use burn::nn::Linear; + use burn::nn::LinearConfig; - let item = Record::into_item::(record); - item.serialize(serializer) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; - - quote! { - let #output = self.#field.forward(#input); + #[derive(Module, Debug)] + pub struct Model { + linear: Linear, + phantom: core::marker::PhantomData, } - } - - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::Linear"); - imports.register("burn::nn::LinearConfig"); - } - fn into_node(self) -> Node { - Node::Linear(self) - } -} + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let linear = LinearConfig::new(128, 128) + .with_bias(true) + .init_with(record.linear); -#[cfg(test)] -mod tests { - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - use burn::{record::FullPrecisionSettings, tensor::Data}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(LinearNode::new( - "linear", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - None, - LinearConfig::new(128, 128), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::Linear; - use burn::nn::LinearConfig; - - #[derive(Module, Debug)] - pub struct Model { - linear: Linear, - phantom: core::marker::PhantomData, - } - - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let linear = LinearConfig::new(128, 128) - .with_bias(true) - .init_with(record.linear); - - Self { - linear, - phantom: core::marker::PhantomData, - } + Self { + linear, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.linear.forward(input); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.linear.forward(input); - output - } + output } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/matmul.rs b/burn-import/src/burn/node/matmul.rs index b7b1eea97b..077eb272ef 100644 --- a/burn-import/src/burn/node/matmul.rs +++ b/burn-import/src/burn/node/matmul.rs @@ -6,93 +6,93 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct MatmulNode { - pub lhs: TensorType, - pub rhs: TensorType, - pub output: TensorType, + pub lhs: TensorType, + pub rhs: TensorType, + pub output: TensorType, } impl NodeCodegen for MatmulNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![ + Type::Tensor(self.lhs.clone()), + Type::Tensor(self.rhs.clone()), + ] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let lhs = scope.tensor_use_owned(&self.lhs, node_position); + let rhs = scope.tensor_use_owned(&self.rhs, node_position); + let output = &self.output.name; + + quote! { + let #output = #lhs.matmul(#rhs); } + } - fn input_types(&self) -> Vec { - vec![ - Type::Tensor(self.lhs.clone()), - Type::Tensor(self.rhs.clone()), - ] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let lhs = scope.tensor_use_owned(&self.lhs, node_position); - let rhs = scope.tensor_use_owned(&self.rhs, node_position); - let output = &self.output.name; - - quote! { - let #output = #lhs.matmul(#rhs); - } - } - - fn into_node(self) -> Node { - Node::Matmul(self) - } + fn into_node(self) -> Node { + Node::Matmul(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{matmul::MatmulNode, test::assert_tokens}, - TensorType, - }; + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{matmul::MatmulNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_two_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor3", 4), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[test] - fn test_codegen_two_nodes() { - let mut graph = BurnGraph::::default(); - - graph.register(MatmulNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor3", 4), - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } + } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.matmul(tensor2); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.matmul(tensor2); - tensor3 - } + tensor3 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/max_pool2d.rs b/burn-import/src/burn/node/max_pool2d.rs index 2bbf859bb9..61ba4845db 100644 --- a/burn-import/src/burn/node/max_pool2d.rs +++ b/burn-import/src/burn/node/max_pool2d.rs @@ -8,155 +8,155 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; #[derive(Debug, Clone)] pub struct MaxPool2dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub config: MaxPool2dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub config: MaxPool2dConfig, } impl MaxPool2dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - config: MaxPool2dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - MaxPool2d - }, - ), - input, - output, - config, - } + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + config: MaxPool2dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + MaxPool2d + }, + ), + input, + output, + config, } + } } impl NodeCodegen for MaxPool2dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; + let kernel_size = self.config.kernel_size.to_tokens(); + let strides = self.config.strides.to_tokens(); + let padding = self.config.padding.to_tokens(); + let dilation = self.config.dilation.to_tokens(); + + let init_line = quote! { + init(); + }; - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - let kernel_size = self.config.kernel_size.to_tokens(); - let strides = self.config.strides.to_tokens(); - let padding = self.config.padding.to_tokens(); - let dilation = self.config.dilation.to_tokens(); + let tokens = quote! { + let #name = MaxPool2dConfig::new(#kernel_size) + .with_strides(#strides) + .with_padding(#padding) + .with_dilation(#dilation) + .#init_line + }; - let init_line = quote! { - init(); - }; + Some(tokens) + } - let tokens = quote! { - let #name = MaxPool2dConfig::new(#kernel_size) - .with_strides(#strides) - .with_padding(#padding) - .with_dilation(#dilation) - .#init_line - }; + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - Some(tokens) + quote! { + let #output = self.#field.forward(#input); } + } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig2d"); + imports.register("burn::nn::pool::MaxPool2d"); + imports.register("burn::nn::pool::MaxPool2dConfig"); + } - quote! { - let #output = self.#field.forward(#input); - } - } - - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig2d"); - imports.register("burn::nn::pool::MaxPool2d"); - imports.register("burn::nn::pool::MaxPool2dConfig"); - } + fn into_node(self) -> Node { + Node::MaxPool2d(self) + } - fn into_node(self) -> Node { - Node::MaxPool2d(self) - } - - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) - } + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{max_pool2d::MaxPool2dNode, test::assert_tokens}, - TensorType, - }; - use burn::{nn::pool::MaxPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(MaxPool2dNode::new( - "max_pool2d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - MaxPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig2d; - use burn::nn::pool::MaxPool2d; - use burn::nn::pool::MaxPool2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - max_pool2d: MaxPool2d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{max_pool2d::MaxPool2dNode, test::assert_tokens}, + TensorType, + }; + use burn::{nn::pool::MaxPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(MaxPool2dNode::new( + "max_pool2d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + MaxPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::pool::MaxPool2d; + use burn::nn::pool::MaxPool2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + max_pool2d: MaxPool2d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let max_pool2d = MaxPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .init(); - - Self { - max_pool2d, - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let max_pool2d = MaxPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .init(); + + Self { + max_pool2d, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.max_pool2d.forward(input); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.max_pool2d.forward(input); - output - } + output } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/reshape.rs b/burn-import/src/burn/node/reshape.rs index df8959e90b..032f1a3c33 100644 --- a/burn-import/src/burn/node/reshape.rs +++ b/burn-import/src/burn/node/reshape.rs @@ -6,85 +6,85 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct ReshapeNode { - pub input: TensorType, - pub output: TensorType, - pub shape: Vec, + pub input: TensorType, + pub output: TensorType, + pub shape: Vec, } impl NodeCodegen for ReshapeNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let shape_values = &self.shape.to_tokens(); + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let shape_values = &self.shape.to_tokens(); - quote! { - let #output = #input.reshape(#shape_values); - } + quote! { + let #output = #input.reshape(#shape_values); } + } - fn into_node(self) -> Node { - Node::Reshape(self) - } + fn into_node(self) -> Node { + Node::Reshape(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{reshape::ReshapeNode, test::assert_tokens}, - TensorType, - }; - - #[test] - fn test_codegen_nodes() { - let mut graph = BurnGraph::::default(); - - graph.register(ReshapeNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - [4, 4, 4, 4].into(), - )); - - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{reshape::ReshapeNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(ReshapeNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + [4, 4, 4, 4].into(), + )); + + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.reshape([4, 4, 4, 4]); + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.reshape([4, 4, 4, 4]); - tensor2 - } + tensor2 } - }; + } + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/test.rs b/burn-import/src/burn/node/test.rs index 903248ec98..33fb5dd530 100644 --- a/burn-import/src/burn/node/test.rs +++ b/burn-import/src/burn/node/test.rs @@ -3,8 +3,8 @@ use proc_macro2::TokenStream; #[track_caller] pub fn assert_tokens(tokens1: TokenStream, tokens2: TokenStream) { - let tokens1 = format_tokens(tokens1); - let tokens2 = format_tokens(tokens2); + let tokens1 = format_tokens(tokens1); + let tokens2 = format_tokens(tokens2); - pretty_assertions::assert_eq!(tokens1, tokens2); + pretty_assertions::assert_eq!(tokens1, tokens2); } diff --git a/burn-import/src/burn/node/unary.rs b/burn-import/src/burn/node/unary.rs index 7f557b785c..d7fc3da925 100644 --- a/burn-import/src/burn/node/unary.rs +++ b/burn-import/src/burn/node/unary.rs @@ -11,386 +11,386 @@ type FnPointer = Rc TokenStream>; /// Node for all unary operators. #[derive(Clone, new)] pub struct UnaryNode { - pub input: Type, - pub output: Type, - pub kind: UnaryNodeKind, - function: FnPointer, + pub input: Type, + pub output: Type, + pub kind: UnaryNodeKind, + function: FnPointer, } /// Type of unary node. #[derive(Clone)] pub enum UnaryNodeKind { - Cast, - Erf, - Flatten, - LogSoftmax, - Softmax, - Relu, - Reciprocal, - Sigmoid, - Tanh, - Transpose, + Cast, + Erf, + Flatten, + LogSoftmax, + Softmax, + Relu, + Reciprocal, + Sigmoid, + Tanh, + Transpose, } impl UnaryNodeKind { - pub fn as_str(&self) -> &str { - match self { - Self::Cast => "cast", - Self::Erf => "erf", - Self::Flatten => "flatten", - Self::LogSoftmax => "log_softmax", - Self::Softmax => "softmax", - Self::Relu => "relu", - Self::Reciprocal => "reciprocal", - Self::Sigmoid => "sigmoid", - Self::Tanh => "tanh", - Self::Transpose => "transpose", - } + pub fn as_str(&self) -> &str { + match self { + Self::Cast => "cast", + Self::Erf => "erf", + Self::Flatten => "flatten", + Self::LogSoftmax => "log_softmax", + Self::Softmax => "softmax", + Self::Relu => "relu", + Self::Reciprocal => "reciprocal", + Self::Sigmoid => "sigmoid", + Self::Tanh => "tanh", + Self::Transpose => "transpose", } + } } impl std::fmt::Debug for UnaryNode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str( - format!( - "UnaryNode {{ input: {:?}, output: {:?}, name: {} }}", - self.input, - self.output, - self.kind.as_str() - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + format!( + "UnaryNode {{ input: {:?}, output: {:?}, name: {} }}", + self.input, + self.output, + self.kind.as_str() + ) + .as_str(), + ) + } } impl NodeCodegen for UnaryNode { - fn output_types(&self) -> Vec { - vec![self.output.clone()] - } - - fn input_types(&self) -> Vec { - vec![self.input.clone()] + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + vec![self.input.clone()] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + // Get the lhs name in the form of token stream. + let input = match &self.input { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("lhs must be a tensor or scalar"), + }; + + // let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name(); + let function = (self.function)(input); + + quote! { + let #output = #function; } + } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - // Get the lhs name in the form of token stream. - let input = match &self.input { - Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), - Type::Scalar(scalar) => { - let name = scalar.name.clone(); - quote! { #name } - } - _ => panic!("lhs must be a tensor or scalar"), - }; - - // let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name(); - let function = (self.function)(input); - - quote! { - let #output = #function; - } - } - - fn into_node(self) -> Node { - Node::Unary(self) - } + fn into_node(self) -> Node { + Node::Unary(self) + } } impl UnaryNode { - pub(crate) fn erf(input: Type, output: Type) -> Self { - let function = move |input| quote! { #input.erf() }; - Self::new(input, output, UnaryNodeKind::Erf, Rc::new(function)) - } - - pub(crate) fn flatten(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self { - let start_dim = start_dim.to_tokens(); - let end_dim = end_dim.to_tokens(); - let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) }; - - Self::new(input, output, UnaryNodeKind::Flatten, Rc::new(function)) - } - - pub(crate) fn relu(input: Type, output: Type) -> Self { - let function = move |input| quote! { burn::tensor::activation::relu(#input) }; - Self::new(input, output, UnaryNodeKind::Relu, Rc::new(function)) - } - - pub(crate) fn sigmoid(input: Type, output: Type) -> Self { - let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) }; - Self::new(input, output, UnaryNodeKind::Sigmoid, Rc::new(function)) - } - - pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self { - let dim = dim.to_tokens(); - let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) }; - Self::new(input, output, UnaryNodeKind::LogSoftmax, Rc::new(function)) - } - - pub(crate) fn softmax(input: Type, output: Type, dim: usize) -> Self { - let dim = dim.to_tokens(); - let function = move |input| quote! { burn::tensor::activation::softmax(#input, #dim) }; - Self::new(input, output, UnaryNodeKind::Softmax, Rc::new(function)) - } - - pub(crate) fn tanh(input: Type, output: Type) -> Self { - let function = move |input| quote! { burn::tensor::activation::tanh(#input)}; - Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function)) - } - - pub(crate) fn transpose(input: Type, output: Type) -> Self { - let function = move |input| quote! { #input.transpose() }; - Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function)) - } - - pub(crate) fn reciprocal(input: Type, output: Type) -> Self { - let function = move |input| quote! { #input.recip() }; - Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function)) - } - - /// Casts the input to the output type. - /// - /// Currently this function only supports the following conversions: - /// 1) scalar -> scalar - /// - /// TODO: Implement the following conversions: - /// 2) tensor int -> tensor float - /// 3) tensor float -> tensor int - /// 4) tensor -> scalar - /// 5) scalar -> tensor - pub(crate) fn cast(input: Type, output: Type) -> Self { - let function = match output.clone() { - Type::Scalar(scalar) => { - let ty = scalar.ty(); - move |input| quote! { #input as #ty } - } - Type::Tensor(_tensor) => { - // TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023) - // TODO: If the input is scalar and the output type is a tensor, - // we should generate another code block. (@antimora 8/4/2023) - // Tensor::from_data(Data::from([#input]).convert()).unsqueeze(); - todo!() - } - - _ => panic!("output must be a tensor"), - }; - - Self::new(input, output, UnaryNodeKind::Cast, Rc::new(function)) - } + pub(crate) fn erf(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.erf() }; + Self::new(input, output, UnaryNodeKind::Erf, Rc::new(function)) + } + + pub(crate) fn flatten(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self { + let start_dim = start_dim.to_tokens(); + let end_dim = end_dim.to_tokens(); + let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) }; + + Self::new(input, output, UnaryNodeKind::Flatten, Rc::new(function)) + } + + pub(crate) fn relu(input: Type, output: Type) -> Self { + let function = move |input| quote! { burn::tensor::activation::relu(#input) }; + Self::new(input, output, UnaryNodeKind::Relu, Rc::new(function)) + } + + pub(crate) fn sigmoid(input: Type, output: Type) -> Self { + let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) }; + Self::new(input, output, UnaryNodeKind::Sigmoid, Rc::new(function)) + } + + pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self { + let dim = dim.to_tokens(); + let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) }; + Self::new(input, output, UnaryNodeKind::LogSoftmax, Rc::new(function)) + } + + pub(crate) fn softmax(input: Type, output: Type, dim: usize) -> Self { + let dim = dim.to_tokens(); + let function = move |input| quote! { burn::tensor::activation::softmax(#input, #dim) }; + Self::new(input, output, UnaryNodeKind::Softmax, Rc::new(function)) + } + + pub(crate) fn tanh(input: Type, output: Type) -> Self { + let function = move |input| quote! { burn::tensor::activation::tanh(#input)}; + Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function)) + } + + pub(crate) fn transpose(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.transpose() }; + Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function)) + } + + pub(crate) fn reciprocal(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.recip() }; + Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function)) + } + + /// Casts the input to the output type. + /// + /// Currently this function only supports the following conversions: + /// 1) scalar -> scalar + /// + /// TODO: Implement the following conversions: + /// 2) tensor int -> tensor float + /// 3) tensor float -> tensor int + /// 4) tensor -> scalar + /// 5) scalar -> tensor + pub(crate) fn cast(input: Type, output: Type) -> Self { + let function = match output.clone() { + Type::Scalar(scalar) => { + let ty = scalar.ty(); + move |input| quote! { #input as #ty } + } + Type::Tensor(_tensor) => { + // TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023) + // TODO: If the input is scalar and the output type is a tensor, + // we should generate another code block. (@antimora 8/4/2023) + // Tensor::from_data(Data::from([#input]).convert()).unsqueeze(); + todo!() + } + + _ => panic!("output must be a tensor"), + }; + + Self::new(input, output, UnaryNodeKind::Cast, Rc::new(function)) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::node::tests::one_node_graph; - use crate::burn::{ScalarKind, ScalarType, TensorType}; - - #[test] - fn test_unary_codegen_flatten() { - one_node_graph( - UnaryNode::flatten( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - 1, - 2, - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.flatten(1, 2); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_erf() { - one_node_graph( - UnaryNode::erf( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.erf(); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_relu() { - one_node_graph( - UnaryNode::relu( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::relu(tensor1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_sigmoid() { - one_node_graph( - UnaryNode::sigmoid( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::sigmoid(tensor1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_log_softmax() { - one_node_graph( - UnaryNode::log_softmax( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - 1, - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_softmax() { - one_node_graph( - UnaryNode::softmax( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - 1, - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::softmax(tensor1, 1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_tanh() { - one_node_graph( - UnaryNode::tanh( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::tanh(tensor1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_transpose() { - one_node_graph( - UnaryNode::transpose( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.transpose(); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_reciprocal() { - one_node_graph( - UnaryNode::reciprocal( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.recip(); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_cast() { - one_node_graph( - UnaryNode::cast( - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)), - Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), - ), - quote! { - pub fn forward(&self, scalar1: f64) -> f32 { - let scalar2 = scalar1 as f32; - - scalar2 - } - }, - vec!["scalar1".to_string()], - vec!["scalar2".to_string()], - ); - one_node_graph( - UnaryNode::cast( - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), - Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), - ), - quote! { - pub fn forward(&self, scalar1: f32) -> f64 { - let scalar2 = scalar1 as f64; - - scalar2 - } - }, - vec!["scalar1".to_string()], - vec!["scalar2".to_string()], - ); - } + use super::*; + use crate::burn::node::tests::one_node_graph; + use crate::burn::{ScalarKind, ScalarType, TensorType}; + + #[test] + fn test_unary_codegen_flatten() { + one_node_graph( + UnaryNode::flatten( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + 1, + 2, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.flatten(1, 2); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_erf() { + one_node_graph( + UnaryNode::erf( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.erf(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_relu() { + one_node_graph( + UnaryNode::relu( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::relu(tensor1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_sigmoid() { + one_node_graph( + UnaryNode::sigmoid( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::sigmoid(tensor1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_log_softmax() { + one_node_graph( + UnaryNode::log_softmax( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + 1, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_softmax() { + one_node_graph( + UnaryNode::softmax( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + 1, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::softmax(tensor1, 1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_tanh() { + one_node_graph( + UnaryNode::tanh( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::tanh(tensor1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_transpose() { + one_node_graph( + UnaryNode::transpose( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.transpose(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_reciprocal() { + one_node_graph( + UnaryNode::reciprocal( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.recip(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_cast() { + one_node_graph( + UnaryNode::cast( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), + ), + quote! { + pub fn forward(&self, scalar1: f64) -> f32 { + let scalar2 = scalar1 as f32; + + scalar2 + } + }, + vec!["scalar1".to_string()], + vec!["scalar2".to_string()], + ); + one_node_graph( + UnaryNode::cast( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), + ), + quote! { + pub fn forward(&self, scalar1: f32) -> f64 { + let scalar2 = scalar1 as f64; + + scalar2 + } + }, + vec!["scalar1".to_string()], + vec!["scalar2".to_string()], + ); + } } diff --git a/burn-import/src/burn/scope.rs b/burn-import/src/burn/scope.rs index 9e497c09aa..1ceaa0b972 100644 --- a/burn-import/src/burn/scope.rs +++ b/burn-import/src/burn/scope.rs @@ -7,78 +7,78 @@ use std::collections::HashMap; /// The scope struct ensures that ownership rules are respected during the forward pass. #[derive(Clone, Debug, Default)] pub struct Scope { - variables: HashMap>, + variables: HashMap>, } #[derive(Clone, Debug, new)] struct TensorVariable { - references: usize, - node_position: usize, + references: usize, + node_position: usize, } impl Scope { - /// Declare a new tensor variable. - pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - for variable in variables.iter_mut() { - if variable.node_position == node_position { - variable.references += 1; - return; - } - } - - variables.push(TensorVariable::new(0, node_position)); - } else { - self.variables.insert( - tensor.name.clone(), - vec![TensorVariable::new(0, node_position)], - ); + /// Declare a new tensor variable. + pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) { + if let Some(variables) = self.variables.get_mut(&tensor.name) { + for variable in variables.iter_mut() { + if variable.node_position == node_position { + variable.references += 1; + return; } + } + + variables.push(TensorVariable::new(0, node_position)); + } else { + self.variables.insert( + tensor.name.clone(), + vec![TensorVariable::new(0, node_position)], + ); } + } - /// Register a future use of a tensor variable. - /// - /// # Notes - /// - /// We need to know all futures use of a variable in advance. - pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - for variable in variables.iter_mut().rev() { - if node_position >= variable.node_position { - variable.references += 1; - break; - } - } - } else { - panic!("No variable with name {}", tensor.name); + /// Register a future use of a tensor variable. + /// + /// # Notes + /// + /// We need to know all futures use of a variable in advance. + pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) { + if let Some(variables) = self.variables.get_mut(&tensor.name) { + for variable in variables.iter_mut().rev() { + if node_position >= variable.node_position { + variable.references += 1; + break; } + } + } else { + panic!("No variable with name {}", tensor.name); } + } - /// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward. - pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - let mut count = 0; - let name = &tensor.name; + /// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward. + pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream { + if let Some(variables) = self.variables.get_mut(&tensor.name) { + let mut count = 0; + let name = &tensor.name; - for variable in variables.iter_mut().rev() { - if node_position >= variable.node_position { - variable.references -= 1; - count = variable.references; - break; - } - } + for variable in variables.iter_mut().rev() { + if node_position >= variable.node_position { + variable.references -= 1; + count = variable.references; + break; + } + } - if count > 0 { - quote! { - #name.clone() - } - } else { - quote! { - #name - } - } - } else { - panic!("No variable with name {}", &tensor.name); + if count > 0 { + quote! { + #name.clone() + } + } else { + quote! { + #name } + } + } else { + panic!("No variable with name {}", &tensor.name); } + } } diff --git a/burn-import/src/burn/ty.rs b/burn-import/src/burn/ty.rs index 292523ecc7..fb710f872f 100644 --- a/burn-import/src/burn/ty.rs +++ b/burn-import/src/burn/ty.rs @@ -7,146 +7,146 @@ use crate::burn::ToTokens; #[derive(Debug, Clone)] pub struct TensorType { - pub name: Ident, - pub dim: usize, - pub kind: TensorKind, - pub shape: Option>, + pub name: Ident, + pub dim: usize, + pub kind: TensorKind, + pub shape: Option>, } #[derive(Debug, Clone, Copy)] pub enum TensorKind { - Int, - Float, - Bool, + Int, + Float, + Bool, } #[derive(Debug, Clone)] pub enum ScalarKind { - Int32, - Int64, - Float32, - Float64, - Bool, + Int32, + Int64, + Float32, + Float64, + Bool, } #[derive(Debug, Clone)] pub struct ScalarType { - pub name: Ident, - pub kind: ScalarKind, + pub name: Ident, + pub kind: ScalarKind, } #[derive(Debug, Clone)] pub struct OtherType { - pub name: Ident, - pub ty: TokenStream, + pub name: Ident, + pub ty: TokenStream, } #[derive(Debug, Clone)] pub enum Type { - /// Tensor type. - Tensor(TensorType), + /// Tensor type. + Tensor(TensorType), - /// Scalar type. - Scalar(ScalarType), + /// Scalar type. + Scalar(ScalarType), - // Other type (more flexible type). - Other(OtherType), + // Other type (more flexible type). + Other(OtherType), } impl Type { - pub fn name(&self) -> &Ident { - match self { - Type::Tensor(tensor) => &tensor.name, - Type::Scalar(scalar) => &scalar.name, - Type::Other(other) => &other.name, - } + pub fn name(&self) -> &Ident { + match self { + Type::Tensor(tensor) => &tensor.name, + Type::Scalar(scalar) => &scalar.name, + Type::Other(other) => &other.name, } - pub fn ty(&self) -> TokenStream { - match self { - Type::Tensor(tensor) => tensor.ty(), - Type::Scalar(scalar) => scalar.ty(), - Type::Other(other) => other.ty(), - } + } + pub fn ty(&self) -> TokenStream { + match self { + Type::Tensor(tensor) => tensor.ty(), + Type::Scalar(scalar) => scalar.ty(), + Type::Other(other) => other.ty(), } + } } impl ScalarType { - pub fn new>(name: S, kind: ScalarKind) -> Self { - Self { - name: Ident::new(name.as_ref(), Span::call_site()), - kind, - } + pub fn new>(name: S, kind: ScalarKind) -> Self { + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + kind, } - pub fn ty(&self) -> TokenStream { - match self.kind { - ScalarKind::Int32 => quote! { i32 }, - ScalarKind::Int64 => quote! { i64 }, - ScalarKind::Float32 => quote! { f32 }, - ScalarKind::Float64 => quote! { f64 }, - ScalarKind::Bool => quote! { bool }, - } + } + pub fn ty(&self) -> TokenStream { + match self.kind { + ScalarKind::Int32 => quote! { i32 }, + ScalarKind::Int64 => quote! { i64 }, + ScalarKind::Float32 => quote! { f32 }, + ScalarKind::Float64 => quote! { f64 }, + ScalarKind::Bool => quote! { bool }, } + } } impl TensorType { - pub fn new>( - name: S, - dim: usize, - kind: TensorKind, - shape: Option>, - ) -> Self { - Self { - name: Ident::new(name.as_ref(), Span::call_site()), - dim, - kind, - shape, - } + pub fn new>( + name: S, + dim: usize, + kind: TensorKind, + shape: Option>, + ) -> Self { + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + dim, + kind, + shape, } - pub fn new_float>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Float, None) - } - - pub fn new_int>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Int, None) - } - - pub fn new_bool>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Bool, None) - } - - pub fn ty(&self) -> TokenStream { - let dim = self.dim.to_tokens(); - match self { - TensorType { - kind: TensorKind::Float, - .. - } => quote! { - Tensor - }, - TensorType { - kind: TensorKind::Int, - .. - } => quote! { - Tensor - }, - TensorType { - kind: TensorKind::Bool, - .. - } => quote! { - Tensor - }, - } + } + pub fn new_float>(name: S, dim: usize) -> Self { + Self::new(name, dim, TensorKind::Float, None) + } + + pub fn new_int>(name: S, dim: usize) -> Self { + Self::new(name, dim, TensorKind::Int, None) + } + + pub fn new_bool>(name: S, dim: usize) -> Self { + Self::new(name, dim, TensorKind::Bool, None) + } + + pub fn ty(&self) -> TokenStream { + let dim = self.dim.to_tokens(); + match self { + TensorType { + kind: TensorKind::Float, + .. + } => quote! { + Tensor + }, + TensorType { + kind: TensorKind::Int, + .. + } => quote! { + Tensor + }, + TensorType { + kind: TensorKind::Bool, + .. + } => quote! { + Tensor + }, } + } } impl OtherType { - pub fn new>(name: S, tokens: TokenStream) -> Self { - Self { - name: Ident::new(name.as_ref(), Span::call_site()), - ty: tokens, - } - } - pub fn ty(&self) -> TokenStream { - self.ty.clone() + pub fn new>(name: S, tokens: TokenStream) -> Self { + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + ty: tokens, } + } + pub fn ty(&self) -> TokenStream { + self.ty.clone() + } } diff --git a/burn-import/src/formatter.rs b/burn-import/src/formatter.rs index adc8545c38..bba4cdaeff 100644 --- a/burn-import/src/formatter.rs +++ b/burn-import/src/formatter.rs @@ -3,15 +3,15 @@ use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt}; /// Formats a token stream into a string. pub fn format_tokens(tokens: TokenStream) -> String { - let fmt = code_formatter(); + let fmt = code_formatter(); - fmt.format_tokens(tokens).expect("Valid token tree") + fmt.format_tokens(tokens).expect("Valid token tree") } fn code_formatter() -> RustFmt { - let config = Config::new_str() - .post_proc(PostProcess::ReplaceMarkersAndDocBlocks) - .edition(Edition::Rust2021); + let config = Config::new_str() + .post_proc(PostProcess::ReplaceMarkersAndDocBlocks) + .edition(Edition::Rust2021); - RustFmt::from_config(config) + RustFmt::from_config(config) } diff --git a/burn-import/src/logger.rs b/burn-import/src/logger.rs index 3378f17401..c5a279ef99 100644 --- a/burn-import/src/logger.rs +++ b/burn-import/src/logger.rs @@ -2,22 +2,22 @@ use std::error::Error; use tracing_core::LevelFilter; pub fn init_log() -> Result<(), Box> { - let result = tracing_subscriber::fmt() - .with_max_level(LevelFilter::DEBUG) - .without_time() - .try_init(); + let result = tracing_subscriber::fmt() + .with_max_level(LevelFilter::DEBUG) + .without_time() + .try_init(); - if result.is_ok() { - update_panic_hook(); - } - result + if result.is_ok() { + update_panic_hook(); + } + result } fn update_panic_hook() { - let hook = std::panic::take_hook(); + let hook = std::panic::take_hook(); - std::panic::set_hook(Box::new(move |info| { - log::error!("PANIC => {}", info.to_string()); - hook(info); - })); + std::panic::set_hook(Box::new(move |info| { + log::error!("PANIC => {}", info.to_string()); + hook(info); + })); } diff --git a/burn-import/src/main.rs b/burn-import/src/main.rs index 2601568250..4590e41b78 100644 --- a/burn-import/src/main.rs +++ b/burn-import/src/main.rs @@ -2,16 +2,16 @@ use burn_import::onnx::{ModelGen, RecordType}; /// Takes an ONNX file and generates a model from it fn main() { - let onnx_file = std::env::args().nth(1).expect("No input file provided"); - let output_dir = std::env::args() - .nth(2) - .expect("No output directory provided"); + let onnx_file = std::env::args().nth(1).expect("No input file provided"); + let output_dir = std::env::args() + .nth(2) + .expect("No output directory provided"); - // Generate the model code from the ONNX file. - ModelGen::new() - .input(onnx_file.as_str()) - .development(true) - .record_type(RecordType::PrettyJson) - .out_dir(output_dir.as_str()) - .run_from_cli(); + // Generate the model code from the ONNX file. + ModelGen::new() + .input(onnx_file.as_str()) + .development(true) + .record_type(RecordType::PrettyJson) + .out_dir(output_dir.as_str()) + .run_from_cli(); } diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 623d7584f2..fe08b05d79 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -5,110 +5,110 @@ use crate::onnx::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. pub fn coalesce(nodes: &mut Vec) { - let mut iter_mut = nodes.iter_mut().peekable(); - let mut nodes_to_remove: Vec = vec![]; - while let Some(node) = iter_mut.next() { - match node.node_type { - NodeType::Gemm => convert_gemm_to_linear(node), - NodeType::MatMul => { - convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove); - } - _ => {} - } + let mut iter_mut = nodes.iter_mut().peekable(); + let mut nodes_to_remove: Vec = vec![]; + while let Some(node) = iter_mut.next() { + match node.node_type { + NodeType::Gemm => convert_gemm_to_linear(node), + NodeType::MatMul => { + convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove); + } + _ => {} } + } - // Remove nodes instructed by conversation functions - for node_to_remove in nodes_to_remove { - nodes.retain(|n| n.name != node_to_remove); - } + // Remove nodes instructed by conversation functions + for node_to_remove in nodes_to_remove { + nodes.retain(|n| n.name != node_to_remove); + } } /// This function converts a Gemm node into a Linear node /// /// PyTorch and other frameworks use Gemm node to represent Linear layer. fn convert_gemm_to_linear(node: &mut Node) { - if node.outputs.len() != 1 { - panic!("Gemm node must have 1 output"); - } - let straight_linear = match ( - node.attrs.get("alpha"), - node.attrs.get("beta"), - node.attrs.get("transB"), - ) { - ( - Some(AttributeValue::Float32(alpha)), - Some(AttributeValue::Float32(beta)), - Some(AttributeValue::Int64(trans_b)), - ) => *alpha == 1.0 && *beta == 1.0 && *trans_b == 1, - _ => false, - }; - - if straight_linear { - node.node_type = NodeType::Linear; - node.attrs.remove("alpha"); - node.attrs.remove("beta"); - node.attrs.remove("transB"); - - // Transpose the weights - transpose_linear_node_weights(node); - } else { - panic!("Full Gemm node not supported yet."); - } + if node.outputs.len() != 1 { + panic!("Gemm node must have 1 output"); + } + let straight_linear = match ( + node.attrs.get("alpha"), + node.attrs.get("beta"), + node.attrs.get("transB"), + ) { + ( + Some(AttributeValue::Float32(alpha)), + Some(AttributeValue::Float32(beta)), + Some(AttributeValue::Int64(trans_b)), + ) => *alpha == 1.0 && *beta == 1.0 && *trans_b == 1, + _ => false, + }; + + if straight_linear { + node.node_type = NodeType::Linear; + node.attrs.remove("alpha"); + node.attrs.remove("beta"); + node.attrs.remove("transB"); + + // Transpose the weights + transpose_linear_node_weights(node); + } else { + panic!("Full Gemm node not supported yet."); + } } // Transpose linear weights (required for Gemm -> Linear conversion) fn transpose_linear_node_weights(node: &mut Node) { - assert!( - node.inputs.len() > 1, - "Linear node must have at least 2 input" - ); - - assert!(node.inputs[1].value.is_some(), "Input must have a value"); - - let weight = node.inputs[1] - .clone() - .into_tensor() - .expect("Tensor input is expected"); - - assert_eq!(weight.dim, 2, "Weight must be a 2D tensor"); - - let shape = weight.shape.unwrap(); - - match weight.data.expect("Tensor must have data") { - Data::Float32s(data) => { - let data_t = transpose_flattened(data, shape[0], shape[1]); - node.inputs[1].value = Some(Data::Float32s(data_t)); - } - Data::Float64s(data) => { - let data_t = transpose_flattened(data, shape[0], shape[1]); - node.inputs[1].value = Some(Data::Float64s(data_t)); - } - Data::Float16s(data) => { - let data_t = transpose_flattened(data, shape[0], shape[1]); - node.inputs[1].value = Some(Data::Float16s(data_t)); - } - _ => panic!("Only float types are supported for Linear node"), + assert!( + node.inputs.len() > 1, + "Linear node must have at least 2 input" + ); + + assert!(node.inputs[1].value.is_some(), "Input must have a value"); + + let weight = node.inputs[1] + .clone() + .into_tensor() + .expect("Tensor input is expected"); + + assert_eq!(weight.dim, 2, "Weight must be a 2D tensor"); + + let shape = weight.shape.unwrap(); + + match weight.data.expect("Tensor must have data") { + Data::Float32s(data) => { + let data_t = transpose_flattened(data, shape[0], shape[1]); + node.inputs[1].value = Some(Data::Float32s(data_t)); + } + Data::Float64s(data) => { + let data_t = transpose_flattened(data, shape[0], shape[1]); + node.inputs[1].value = Some(Data::Float64s(data_t)); } - let shape = Some(vec![shape[1], shape[0]]); // Transpose the shape - node.inputs[1].ty = ArgType::Tensor(TensorType { - shape, - elem_type: weight.elem_type, - dim: 2, - }); + Data::Float16s(data) => { + let data_t = transpose_flattened(data, shape[0], shape[1]); + node.inputs[1].value = Some(Data::Float16s(data_t)); + } + _ => panic!("Only float types are supported for Linear node"), + } + let shape = Some(vec![shape[1], shape[0]]); // Transpose the shape + node.inputs[1].ty = ArgType::Tensor(TensorType { + shape, + elem_type: weight.elem_type, + dim: 2, + }); } fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec { - assert_eq!(matrix.len(), rows * cols, "Matrix must be flattened"); + assert_eq!(matrix.len(), rows * cols, "Matrix must be flattened"); - let mut transposed: Vec = vec![matrix[0]; matrix.len()]; + let mut transposed: Vec = vec![matrix[0]; matrix.len()]; - for i in 0..rows { - for j in 0..cols { - transposed[j * rows + i] = matrix[i * cols + j]; - } + for i in 0..rows { + for j in 0..cols { + transposed[j * rows + i] = matrix[i * cols + j]; } + } - transposed + transposed } /// This function converts a MatMul node into a Linear node if possible. @@ -118,65 +118,65 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec /// This function also converts the following Add node into a Linear node if possible. /// Add node is used to represent bias in PyTorch. fn convert_matmul_to_linear( - node: &mut Node, - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, + node: &mut Node, + iter_mut: &mut Peekable>, + nodes_to_remove: &mut Vec, ) { - if node.inputs.len() != 2 { - panic!("MatMul node must have 2 inputs"); - } - - // if the second input does not have a value, it is not a weight, then proceed to the next node - if node.inputs[1].value.is_none() { - return; - } - - // Check if the second input is a 2D tensor - if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { - assert_eq!(tensor_type.dim, 2, "Weight must be a 2D tensor"); - } else { - panic!("Tensor input is expected"); - } - - // Convert the node to Linear - node.node_type = NodeType::Linear; - - // Check the next node for potential conversion - if let Some(peek_node) = iter_mut.peek() { - if is_add_node_with_bias(peek_node, node) { - convert_and_remove_add_node(iter_mut, nodes_to_remove, node); - } + if node.inputs.len() != 2 { + panic!("MatMul node must have 2 inputs"); + } + + // if the second input does not have a value, it is not a weight, then proceed to the next node + if node.inputs[1].value.is_none() { + return; + } + + // Check if the second input is a 2D tensor + if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { + assert_eq!(tensor_type.dim, 2, "Weight must be a 2D tensor"); + } else { + panic!("Tensor input is expected"); + } + + // Convert the node to Linear + node.node_type = NodeType::Linear; + + // Check the next node for potential conversion + if let Some(peek_node) = iter_mut.peek() { + if is_add_node_with_bias(peek_node, node) { + convert_and_remove_add_node(iter_mut, nodes_to_remove, node); } + } } /// Helper function to check if the peeked node is an Add node with bias fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { - peek_node.node_type == NodeType::Add - && peek_node.inputs.len() == 2 - && ((peek_node.inputs[0].name == current_node.outputs[0].name - && peek_node.inputs[1].value.is_some()) - || (peek_node.inputs[1].name == current_node.outputs[0].name - && peek_node.inputs[0].value.is_some())) + peek_node.node_type == NodeType::Add + && peek_node.inputs.len() == 2 + && ((peek_node.inputs[0].name == current_node.outputs[0].name + && peek_node.inputs[1].value.is_some()) + || (peek_node.inputs[1].name == current_node.outputs[0].name + && peek_node.inputs[0].value.is_some())) } /// Helper function to convert and remove the Add node fn convert_and_remove_add_node( - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, - current_node: &mut Node, + iter_mut: &mut Peekable>, + nodes_to_remove: &mut Vec, + current_node: &mut Node, ) { - let bias_node = iter_mut.next().unwrap(); + let bias_node = iter_mut.next().unwrap(); - let bias_input = if bias_node.inputs[0].value.is_some() { - bias_node.inputs[0].clone() - } else { - bias_node.inputs[1].clone() - }; + let bias_input = if bias_node.inputs[0].value.is_some() { + bias_node.inputs[0].clone() + } else { + bias_node.inputs[1].clone() + }; - // Push the bias input and update the output name - current_node.inputs.push(bias_input); - current_node.outputs[0].name = bias_node.outputs[0].name.clone(); + // Push the bias input and update the output name + current_node.inputs.push(bias_input); + current_node.outputs[0].name = bias_node.outputs[0].name.clone(); - // Remove the Add node - nodes_to_remove.push(bias_node.name.clone()); + // Remove the Add node + nodes_to_remove.push(bias_node.name.clone()); } diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index ce2a9ce4f3..a168a9a28f 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -3,279 +3,280 @@ use std::collections::HashMap; use protobuf::Enum; use super::{ - ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, - op_configuration::flatten_config, - protos::tensor_proto::DataType, + ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, + op_configuration::flatten_config, + protos::tensor_proto::DataType, }; struct TensorDimUpdater { - arguments: HashMap, + arguments: HashMap, } impl TensorDimUpdater { - fn new(inputs: &[Argument]) -> Self { - let mut arguments: HashMap = HashMap::with_capacity(inputs.len()); + fn new(inputs: &[Argument]) -> Self { + let mut arguments: HashMap = HashMap::with_capacity(inputs.len()); - inputs.iter().for_each(|input| { - arguments.insert(input.name.clone(), input.clone()); - }); - - Self { arguments } - } - /// Update tensor inputs from the registered arguments and returns the number of input - /// updated. - fn update_tensor_inputs(&self, node: &mut Node) -> usize { - self.update_arguments(&mut node.inputs) - } - - /// Update the arguments struct from the node output tensors and return the number of output - /// updated. - fn update_tensor_outputs(&mut self, node: &Node) -> usize { - node.outputs - .iter() - .map(|arg| { - self.arguments.insert(arg.name.clone(), arg.clone()); - }) - .count() - } + inputs.iter().for_each(|input| { + arguments.insert(input.name.clone(), input.clone()); + }); - fn update_arguments(&self, arguments: &mut [Argument]) -> usize { - arguments - .iter_mut() - .filter_map(|input| self.arguments.get(&input.name).map(|arg| (arg, input))) - .map(|(arg, input)| { - input.ty = arg.ty.clone(); - }) - .count() - } + Self { arguments } + } + /// Update tensor inputs from the registered arguments and returns the number of input + /// updated. + fn update_tensor_inputs(&self, node: &mut Node) -> usize { + self.update_arguments(&mut node.inputs) + } + + /// Update the arguments struct from the node output tensors and return the number of output + /// updated. + fn update_tensor_outputs(&mut self, node: &Node) -> usize { + node + .outputs + .iter() + .map(|arg| { + self.arguments.insert(arg.name.clone(), arg.clone()); + }) + .count() + } + + fn update_arguments(&self, arguments: &mut [Argument]) -> usize { + arguments + .iter_mut() + .filter_map(|input| self.arguments.get(&input.name).map(|arg| (arg, input))) + .map(|(arg, input)| { + input.ty = arg.ty.clone(); + }) + .count() + } } /// Infer the dimension of each output tensor and update them. pub fn dim_inference( - nodes: &mut Vec, - graph_inputs: &Vec, - graph_outputs: &mut Vec, + nodes: &mut Vec, + graph_inputs: &Vec, + graph_outputs: &mut Vec, ) { - let mut updater = TensorDimUpdater::new(graph_inputs); - - for node in nodes.iter_mut() { - updater.update_tensor_inputs(node); - - match node.node_type { - NodeType::Conv1d => conv1d_update_outputs(node), - NodeType::Conv2d => conv2d_update_outputs(node), - NodeType::MaxPool2d => same_as_input(node), - NodeType::Linear => linear_update_outputs(node), - NodeType::Flatten => flatten_update_outputs(node), - NodeType::GatherElements => same_as_input(node), - NodeType::Relu => same_as_input(node), - NodeType::LogSoftmax => same_as_input(node), - NodeType::BatchNormalization => same_as_input(node), - NodeType::Add => same_as_input(node), - NodeType::Sub => same_as_input(node), - NodeType::Mul => same_as_input(node), - NodeType::Cast => cast_update_outputs(node), - NodeType::Div => same_as_input(node), - NodeType::Erf => same_as_input(node), - NodeType::Sqrt => same_as_input(node), - NodeType::Tanh => same_as_input(node), - NodeType::Reciprocal => same_as_input(node), - NodeType::Softmax => same_as_input(node), - NodeType::ReduceMean => mean_update_outputs(node), - NodeType::Constant => constant_update_outputs(node), - NodeType::Equal => equal_update_outputs(node), - NodeType::Shape => shape_update_outputs(node), - NodeType::Unsqueeze => unsqueeze_update_outputs(node), - NodeType::Sigmoid => same_as_input(node), - NodeType::Transpose => same_as_input(node), - NodeType::Concat => concat_update_outputs(node), - NodeType::Reshape => reshape_update_outputs(node), - NodeType::Dropout => same_as_input(node), - NodeType::GlobalAveragePool => same_as_input(node), - NodeType::AveragePool2d => same_as_input(node), - NodeType::Clip => same_as_input(node), - // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. - _ => temporary_pass_through_stub(node), - } - - updater.update_tensor_outputs(node); + let mut updater = TensorDimUpdater::new(graph_inputs); + + for node in nodes.iter_mut() { + updater.update_tensor_inputs(node); + + match node.node_type { + NodeType::Conv1d => conv1d_update_outputs(node), + NodeType::Conv2d => conv2d_update_outputs(node), + NodeType::MaxPool2d => same_as_input(node), + NodeType::Linear => linear_update_outputs(node), + NodeType::Flatten => flatten_update_outputs(node), + NodeType::GatherElements => same_as_input(node), + NodeType::Relu => same_as_input(node), + NodeType::LogSoftmax => same_as_input(node), + NodeType::BatchNormalization => same_as_input(node), + NodeType::Add => same_as_input(node), + NodeType::Sub => same_as_input(node), + NodeType::Mul => same_as_input(node), + NodeType::Cast => cast_update_outputs(node), + NodeType::Div => same_as_input(node), + NodeType::Erf => same_as_input(node), + NodeType::Sqrt => same_as_input(node), + NodeType::Tanh => same_as_input(node), + NodeType::Reciprocal => same_as_input(node), + NodeType::Softmax => same_as_input(node), + NodeType::ReduceMean => mean_update_outputs(node), + NodeType::Constant => constant_update_outputs(node), + NodeType::Equal => equal_update_outputs(node), + NodeType::Shape => shape_update_outputs(node), + NodeType::Unsqueeze => unsqueeze_update_outputs(node), + NodeType::Sigmoid => same_as_input(node), + NodeType::Transpose => same_as_input(node), + NodeType::Concat => concat_update_outputs(node), + NodeType::Reshape => reshape_update_outputs(node), + NodeType::Dropout => same_as_input(node), + NodeType::GlobalAveragePool => same_as_input(node), + NodeType::AveragePool2d => same_as_input(node), + NodeType::Clip => same_as_input(node), + // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. + _ => temporary_pass_through_stub(node), } - updater.update_arguments(graph_outputs); + updater.update_tensor_outputs(node); + } + + updater.update_arguments(graph_outputs); } fn constant_update_outputs(node: &mut Node) { - // Fix the tensor dimension of the output when the value is tensor - - let keys = [ - "value", - "value_float", - "value_floats", - "value_int", - "value_ints", - "value_string", - "value_strings", - "sparse_value", - ]; - - let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); - - node.outputs[0].ty = match matched_value { - Some(value) => match &value { - // The value is stored in an attribute - AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - dim: tensor.dim, - shape: tensor.shape.clone(), - }), - AttributeValue::Float32(_) => ArgType::Scalar(ElementType::Float32), - AttributeValue::Float32s(value) => ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - dim: 1, - shape: Some(vec![value.len()]), - }), - AttributeValue::Int64(_) => ArgType::Scalar(ElementType::Int64), - AttributeValue::Int64s(value) => ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - dim: 1, - shape: Some(vec![value.len()]), - }), - ty => panic!("Constant value of {:?} is not supported", ty), - }, - None => panic!("Constant node must have a value attribute"), - }; + // Fix the tensor dimension of the output when the value is tensor + + let keys = [ + "value", + "value_float", + "value_floats", + "value_int", + "value_ints", + "value_string", + "value_strings", + "sparse_value", + ]; + + let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); + + node.outputs[0].ty = match matched_value { + Some(value) => match &value { + // The value is stored in an attribute + AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + dim: tensor.dim, + shape: tensor.shape.clone(), + }), + AttributeValue::Float32(_) => ArgType::Scalar(ElementType::Float32), + AttributeValue::Float32s(value) => ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + dim: 1, + shape: Some(vec![value.len()]), + }), + AttributeValue::Int64(_) => ArgType::Scalar(ElementType::Int64), + AttributeValue::Int64s(value) => ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + dim: 1, + shape: Some(vec![value.len()]), + }), + ty => panic!("Constant value of {:?} is not supported", ty), + }, + None => panic!("Constant node must have a value attribute"), + }; } /// Infer the shape of the output tensor of a Conv2d node fn linear_update_outputs(node: &mut Node) { - // Extract the configuration of the linear layer (inputs are known) - let node_input = &node.inputs[0]; - let weight = &node.inputs[1]; - - // Calculate the output shape. Usually we do not use shapes, but since the input shape is - // known, we can calculate the output shape. - if let ArgType::Tensor(tensor) = node_input.clone().ty { - let mut tensor = tensor.clone(); - let mut shape = tensor.shape.clone().unwrap(); - - if let ArgType::Tensor(weight_tensor) = weight.clone().ty { - let last = shape.last_mut().unwrap(); - *last = *weight_tensor.shape.unwrap().first().unwrap(); - } else { - panic!("Weight must be a tensor"); - } - - tensor.shape = Some(shape); - - // Update the output tensor - node.outputs[0].ty = ArgType::Tensor(tensor); + // Extract the configuration of the linear layer (inputs are known) + let node_input = &node.inputs[0]; + let weight = &node.inputs[1]; + + // Calculate the output shape. Usually we do not use shapes, but since the input shape is + // known, we can calculate the output shape. + if let ArgType::Tensor(tensor) = node_input.clone().ty { + let mut tensor = tensor.clone(); + let mut shape = tensor.shape.clone().unwrap(); + + if let ArgType::Tensor(weight_tensor) = weight.clone().ty { + let last = shape.last_mut().unwrap(); + *last = *weight_tensor.shape.unwrap().first().unwrap(); } else { - panic!("Only tensor input is valid"); + panic!("Weight must be a tensor"); } + + tensor.shape = Some(shape); + + // Update the output tensor + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + panic!("Only tensor input is valid"); + } } /// Update the output type using "to" attribute fn cast_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Cast: multiple inputs are not supported"); + if node.inputs.len() != 1 { + panic!("Cast: multiple inputs are not supported"); + } + let output = &mut node.outputs[0]; + + // Extract cast type and update the output tensor + let elem_type = match node.attrs.get("to") { + Some(value) => match &value { + AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("Cast: unsupported type"), + }, + _ => panic!("'to' attribute must be an Int64"), + }, + None => panic!("Constant node must have a value attribute"), + }; + + match output.ty.clone() { + ArgType::Tensor(tensor) => { + if tensor.dim == 0 { + // treat 0-dim tensor as scalar + output.ty = ArgType::Scalar(elem_type); + } else { + todo!("Cast: support casting from different tensor types"); + } } - let output = &mut node.outputs[0]; - - // Extract cast type and update the output tensor - let elem_type = match node.attrs.get("to") { - Some(value) => match &value { - AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - _ => panic!("Cast: unsupported type"), - }, - _ => panic!("'to' attribute must be an Int64"), - }, - None => panic!("Constant node must have a value attribute"), - }; - - match output.ty.clone() { - ArgType::Tensor(tensor) => { - if tensor.dim == 0 { - // treat 0-dim tensor as scalar - output.ty = ArgType::Scalar(elem_type); - } else { - todo!("Cast: support casting from different tensor types"); - } - } - ArgType::Scalar(_scalar) => { - output.ty = ArgType::Scalar(elem_type); - } - _ => panic!("Cast: only scalar input is valid"), + ArgType::Scalar(_scalar) => { + output.ty = ArgType::Scalar(elem_type); } + _ => panic!("Cast: only scalar input is valid"), + } } fn concat_update_outputs(node: &mut Node) { - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor), - _ => None, - }) - .unwrap(); - - node.outputs[0].ty = ArgType::Tensor(tensor.clone()); + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor), + _ => None, + }) + .unwrap(); + + node.outputs[0].ty = ArgType::Tensor(tensor.clone()); } fn reshape_update_outputs(node: &mut Node) { - assert_eq!(node.inputs.len(), 2); - - let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value { - shape - } else { - panic!("Reshape: int64s shape is expected per ONNX spec"); - }; - - // The output dimension is the same as the shape length - let dim = shape.len(); - let elem_type = match node.inputs[0].ty.clone() { - ArgType::Tensor(tensor) => tensor.elem_type, - _ => panic!("Reshape: invalid input type"), - }; - - let shape = shape.iter().map(|&dim| dim as usize).collect(); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - dim, - shape: Some(shape), - }); + assert_eq!(node.inputs.len(), 2); + + let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value { + shape + } else { + panic!("Reshape: int64s shape is expected per ONNX spec"); + }; + + // The output dimension is the same as the shape length + let dim = shape.len(); + let elem_type = match node.inputs[0].ty.clone() { + ArgType::Tensor(tensor) => tensor.elem_type, + _ => panic!("Reshape: invalid input type"), + }; + + let shape = shape.iter().map(|&dim| dim as usize).collect(); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + dim, + shape: Some(shape), + }); } fn mean_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Mean: multiple inputs are not supported"); - } - - // Extract the configuration of the linear layer (inputs are known) - let node_input = &mut node.inputs[0]; - let tensor = match node_input.clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - if dim_only { - node.outputs[0].ty = ArgType::Tensor(tensor); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor }); - } + if node.inputs.len() != 1 { + panic!("Mean: multiple inputs are not supported"); + } + + // Extract the configuration of the linear layer (inputs are known) + let node_input = &mut node.inputs[0]; + let tensor = match node_input.clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + if dim_only { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor }); + } } /// Infers the shape of a Unsqueeze node and replaces the shape of the output tensor. /// @@ -283,116 +284,116 @@ fn mean_update_outputs(node: &mut Node) { /// /// Unsqueeze is not implemented fully. This is left WIP from the past. fn unsqueeze_update_outputs(node: &mut Node) { - let node_input = node - .inputs - .first_mut() - .expect("Unsqueeze: an input is required"); - - if let ArgType::Tensor(tensor) = &mut node_input.ty { - tensor.dim += 1; - - // add a new dimension to the input tensor by extending the shape - // TODO: support unsqueezing configurations - if let Some(shape) = &mut tensor.shape { - shape.insert(0, 1); - } else { - todo!("Unsqueeze: support unsqueezing a tensor without shape"); - } - - node.outputs[0].ty = ArgType::Tensor(tensor.clone()); + let node_input = node + .inputs + .first_mut() + .expect("Unsqueeze: an input is required"); + + if let ArgType::Tensor(tensor) = &mut node_input.ty { + tensor.dim += 1; + + // add a new dimension to the input tensor by extending the shape + // TODO: support unsqueezing configurations + if let Some(shape) = &mut tensor.shape { + shape.insert(0, 1); } else { - panic!("Only tensor input is valid"); + todo!("Unsqueeze: support unsqueezing a tensor without shape"); } + + node.outputs[0].ty = ArgType::Tensor(tensor.clone()); + } else { + panic!("Only tensor input is valid"); + } } fn same_as_input(node: &mut Node) { - node.outputs[0].ty = node.inputs[0].ty.clone(); + node.outputs[0].ty = node.inputs[0].ty.clone(); } /// Temporary pass-through stub for dimension inference so that we can export the IR model. fn temporary_pass_through_stub(node: &mut Node) { - log::warn!( - "Must implement dimension inference for {:?}", - node.node_type - ); + log::warn!( + "Must implement dimension inference for {:?}", + node.node_type + ); } fn equal_update_outputs(node: &mut Node) { - let input1_type = node.inputs[0].ty.clone(); - - match input1_type { - ArgType::Tensor(tensor) => { - // if the input is a tensor, the output is a tensor of bool - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - ..tensor - }); - } - ArgType::Scalar(_) => { - node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); - } - _ => panic!("Only tensor input is valid"), + let input1_type = node.inputs[0].ty.clone(); + + match input1_type { + ArgType::Tensor(tensor) => { + // if the input is a tensor, the output is a tensor of bool + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + ..tensor + }); } + ArgType::Scalar(_) => { + node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); + } + _ => panic!("Only tensor input is valid"), + } } fn shape_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Gather: multiple inputs are not supported: {:?}", node); - } - - // Extract the configuration of the linear layer (inputs are known) - let node_input = &mut node.inputs[0]; - if let ArgType::Tensor(tensor) = node_input.clone().ty { - // Update the output tensor - node.outputs[0].ty = ArgType::Shape(tensor.dim); - } else { - panic!("Only tensor input is valid"); - } + if node.inputs.len() != 1 { + panic!("Gather: multiple inputs are not supported: {:?}", node); + } + + // Extract the configuration of the linear layer (inputs are known) + let node_input = &mut node.inputs[0]; + if let ArgType::Tensor(tensor) = node_input.clone().ty { + // Update the output tensor + node.outputs[0].ty = ArgType::Shape(tensor.dim); + } else { + panic!("Only tensor input is valid"); + } } /// Infers the shape of a Flatten node and replaces the shape of the output tensor. fn flatten_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Flatten: multiple inputs are not supported"); - } - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor), - _ => None, - }) - .unwrap(); - - let input_dim = tensor.dim; - - let (start_dim, end_dim) = flatten_config(node); - - let collapsed_dims = end_dim - start_dim; - let output_dim = input_dim - collapsed_dims; - - node.outputs[0].ty = ArgType::Tensor(TensorType { - dim: output_dim, - ..tensor.clone() - }); + if node.inputs.len() != 1 { + panic!("Flatten: multiple inputs are not supported"); + } + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor), + _ => None, + }) + .unwrap(); + + let input_dim = tensor.dim; + + let (start_dim, end_dim) = flatten_config(node); + + let collapsed_dims = end_dim - start_dim; + let output_dim = input_dim - collapsed_dims; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim: output_dim, + ..tensor.clone() + }); } /// Infers the shape of a Conv1d node and replaces the shape of the output tensor. fn conv1d_update_outputs(node: &mut Node) { - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { - node.outputs[0].ty = ArgType::Tensor(tensor); - } else { - panic!("Only tensor input is valid"); - } + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + panic!("Only tensor input is valid"); + } } /// Infers the shape of a Conv2d node and replaces the shape of the output tensor. fn conv2d_update_outputs(node: &mut Node) { - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { - node.outputs[0].ty = ArgType::Tensor(tensor); - } else { - panic!("Only tensor input is valid"); - } + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + panic!("Only tensor input is valid"); + } } diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 6b9db5b01b..d8a2354839 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -1,12 +1,12 @@ use std::{ - collections::{HashMap, HashSet}, - fs::File, - path::Path, + collections::{HashMap, HashSet}, + fs::File, + path::Path, }; use crate::onnx::{ - coalesce::coalesce, ir::TensorType, node_remap::remap_node_type, - proto_conversion::convert_node_proto, + coalesce::coalesce, ir::TensorType, node_remap::remap_node_type, + proto_conversion::convert_node_proto, }; use super::dim_inference::dim_inference; @@ -16,12 +16,12 @@ use super::protos::{ModelProto, TensorProto}; use protobuf::Message; const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 6] = [ - NodeType::BatchNormalization, - NodeType::Clip, - NodeType::Conv1d, - NodeType::Conv2d, - NodeType::Dropout, - NodeType::Reshape, + NodeType::BatchNormalization, + NodeType::Clip, + NodeType::Conv1d, + NodeType::Conv2d, + NodeType::Dropout, + NodeType::Reshape, ]; /// Open an onnx file and convert it to a Graph (intermediate representation) @@ -40,83 +40,83 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 6] = [ /// * If the file cannot be parsed /// * If the nodes are not topologically sorted pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { - log::info!("Parsing ONNX file: {}", onnx_path.display()); - - // Open the file - let mut file = File::open(onnx_path).expect("Unable to open file"); - let onnx_model: ModelProto = - Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); - - log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); - log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); - - log::debug!( - "Number of initializers: {:?}", - onnx_model.graph.initializer.len() - ); - - log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); - - // Convert the nodes - let mut nodes: Vec = vec![]; - for onnx_node in onnx_model.graph.node.iter() { - let mut node = convert_node_proto(onnx_node); - remap_node_type(&mut node); - nodes.push(node); - } - - // ONNX nodes must be topologically sorted per spec: - // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs - assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - - // Move inputs with initializers to states - move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); - - // Handle Identity nodes (expects inputs to be moved to states) - handle_identity(&mut nodes); - - // Lift constants to initializers (expects inputs to be moved to states) - lift_constants(&mut nodes); - - // Coalesce and transform nodes - coalesce(&mut nodes); - - // Rename nodes and inputs, save the mapping for later - let old_node_names = rename_nodes(&mut nodes); - - // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. - let mut inputs = onnx_model - .graph - .input - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); - - // Map each output in the model's graph to an Argument and collect them into a vector. - let mut outputs = onnx_model - .graph - .output - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); - - let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); - - // Infer shapes and update the inputs and outputs - dim_inference(&mut nodes, &inputs, &mut outputs); - - // Remove the graph inputs/output that are not used by any node - remove_unused_graph_inputs(&mut inputs, &mut outputs, &nodes); - - log::info!("Finished parsing ONNX file: {}", onnx_path.display()); - - ONNXGraph { - nodes, - inputs, - outputs, - old_node_names, - old_input_names, - } + log::info!("Parsing ONNX file: {}", onnx_path.display()); + + // Open the file + let mut file = File::open(onnx_path).expect("Unable to open file"); + let onnx_model: ModelProto = + Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); + + log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); + log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); + + log::debug!( + "Number of initializers: {:?}", + onnx_model.graph.initializer.len() + ); + + log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); + + // Convert the nodes + let mut nodes: Vec = vec![]; + for onnx_node in onnx_model.graph.node.iter() { + let mut node = convert_node_proto(onnx_node); + remap_node_type(&mut node); + nodes.push(node); + } + + // ONNX nodes must be topologically sorted per spec: + // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs + assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); + + // Move inputs with initializers to states + move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); + + // Handle Identity nodes (expects inputs to be moved to states) + handle_identity(&mut nodes); + + // Lift constants to initializers (expects inputs to be moved to states) + lift_constants(&mut nodes); + + // Coalesce and transform nodes + coalesce(&mut nodes); + + // Rename nodes and inputs, save the mapping for later + let old_node_names = rename_nodes(&mut nodes); + + // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. + let mut inputs = onnx_model + .graph + .input + .iter() + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect(); + + // Map each output in the model's graph to an Argument and collect them into a vector. + let mut outputs = onnx_model + .graph + .output + .iter() + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect(); + + let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); + + // Infer shapes and update the inputs and outputs + dim_inference(&mut nodes, &inputs, &mut outputs); + + // Remove the graph inputs/output that are not used by any node + remove_unused_graph_inputs(&mut inputs, &mut outputs, &nodes); + + log::info!("Finished parsing ONNX file: {}", onnx_path.display()); + + ONNXGraph { + nodes, + inputs, + outputs, + old_node_names, + old_input_names, + } } /// This function moves inputs that are also present @@ -128,49 +128,49 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { /// * `nodes` - A mutable reference to a vector of nodes /// * `initializers` - A vector of TensorProto fn move_inputs_to_state(nodes: &mut Vec, initializers: &[TensorProto]) { - // Convert initializers to hashmap for faster lookup - let initializers = initializers - .iter() - .map(|x| (x.name.clone(), x.clone())) - .collect::>(); - - // Iterate over each node in the graph - nodes.iter_mut().for_each(|node| { - for input in node.inputs.iter_mut() { - // If there is a corresponding initializer for the input, then move the data to the input value - if let Some(initializer) = initializers.get(&input.name) { - move_initializer_data(initializer, input); - } - } - }); + // Convert initializers to hashmap for faster lookup + let initializers = initializers + .iter() + .map(|x| (x.name.clone(), x.clone())) + .collect::>(); + + // Iterate over each node in the graph + nodes.iter_mut().for_each(|node| { + for input in node.inputs.iter_mut() { + // If there is a corresponding initializer for the input, then move the data to the input value + if let Some(initializer) = initializers.get(&input.name) { + move_initializer_data(initializer, input); + } + } + }); } fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { - // If the input name matches the tensor name in the initializer - // Convert the initializer to a tensor - let tensor = Tensor::try_from(initializer.clone()).expect("Invalid tensor"); - - if tensor.dim == 0 { - // Convert zero dim tensor to scalar - if let Some(data) = tensor.data { - input.value = Some(data.into_scalar()); - } else { - input.value = None; - } - - // Update the input type - input.ty = ArgType::Scalar(tensor.elem_type); + // If the input name matches the tensor name in the initializer + // Convert the initializer to a tensor + let tensor = Tensor::try_from(initializer.clone()).expect("Invalid tensor"); + + if tensor.dim == 0 { + // Convert zero dim tensor to scalar + if let Some(data) = tensor.data { + input.value = Some(data.into_scalar()); } else { - // Move the tensor data to the input value - input.value = tensor.data.clone(); - - // Update the input type - input.ty = ArgType::Tensor(TensorType { - dim: tensor.dim, - elem_type: tensor.elem_type, - shape: tensor.shape, - }); + input.value = None; } + + // Update the input type + input.ty = ArgType::Scalar(tensor.elem_type); + } else { + // Move the tensor data to the input value + input.value = tensor.data.clone(); + + // Update the input type + input.ty = ArgType::Tensor(TensorType { + dim: tensor.dim, + elem_type: tensor.elem_type, + shape: tensor.shape, + }); + } } /// Lift constants from the graph into the states vector for known node types. @@ -196,117 +196,116 @@ fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { /// /// Panics if the node's output is not a constant. fn lift_constants(nodes: &mut Vec) { - log::info!("Lifting constants into the states"); - - // create a set to hold the node types to process - let node_types_to_process: HashSet = - LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - - // create a new vector to hold the graph's constants (index by the node's name) - let constants = nodes - .iter() - .filter(|node| node.node_type == NodeType::Constant || node.node_type == NodeType::Identity) - .map(|node| (node.outputs[0].name.clone(), node.clone())) - .collect::>(); - - // create a set to hold the IDs of constants to be removed - let mut constant_to_removed = HashSet::::new(); + log::info!("Lifting constants into the states"); + + // create a set to hold the node types to process + let node_types_to_process: HashSet = + LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); + + // create a new vector to hold the graph's constants (index by the node's name) + let constants = nodes + .iter() + .filter(|node| node.node_type == NodeType::Constant || node.node_type == NodeType::Identity) + .map(|node| (node.outputs[0].name.clone(), node.clone())) + .collect::>(); + + // create a set to hold the IDs of constants to be removed + let mut constant_to_removed = HashSet::::new(); + + for node in nodes.iter_mut() { + // Skip the node if it is not in the set of node types to process + if !node_types_to_process.contains(&node.node_type) { + continue; + } - for node in nodes.iter_mut() { - // Skip the node if it is not in the set of node types to process - if !node_types_to_process.contains(&node.node_type) { - continue; + // Skip the first input because it is the node's true input and not a constant/state + node + .inputs + .iter_mut() + .skip(1) // TODO make configurable + .for_each(|input| { + if let Some(constant) = constants.get(&input.name) { + if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { + // The value comes from Identity inputs + if let Some(constant_input) = constant.inputs.first() { + input.ty = constant_input.ty.clone(); + input.value = constant_input.value.clone(); + } + } else { + // The value comes from an attribute + let arg = convert_constant_value(constant); // get the value of the constant + + input.value = arg.value; // set the input's value to the constant's value + input.ty = arg.ty; // set the input's type to the constant's type + // remove the constant from the graph + } + constant_to_removed.insert(constant.name.clone()); } + }); + } - // Skip the first input because it is the node's true input and not a constant/state - node.inputs - .iter_mut() - .skip(1) // TODO make configurable - .for_each(|input| { - if let Some(constant) = constants.get(&input.name) { - if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { - // The value comes from Identity inputs - if let Some(constant_input) = constant.inputs.first() { - input.ty = constant_input.ty.clone(); - input.value = constant_input.value.clone(); - } - } else { - // The value comes from an attribute - let arg = convert_constant_value(constant); // get the value of the constant - - input.value = arg.value; // set the input's value to the constant's value - input.ty = arg.ty; // set the input's type to the constant's type - // remove the constant from the graph - } - constant_to_removed.insert(constant.name.clone()); - } - }); - } + // remove the constants that were moved to the states vector + nodes.retain(|node| !constant_to_removed.contains(&node.name)); - // remove the constants that were moved to the states vector - nodes.retain(|node| !constant_to_removed.contains(&node.name)); - - log::debug!( - "The number of constants lifted: {}", - constant_to_removed.len() - ); + log::debug!( + "The number of constants lifted: {}", + constant_to_removed.len() + ); } fn handle_identity(nodes: &mut Vec) { - log::info!("Handling identity nodes"); - - let mut nodes_to_remove = HashSet::new(); - - let identity_nodes = nodes - .iter() - .filter(|node| node.node_type == NodeType::Identity) - .cloned() - .collect::>(); - - // Handle pass-through nodes. - for identity_node in identity_nodes { - if identity_node.node_type == NodeType::Identity && identity_node.inputs[0].value.is_none() - { - let input_name = &identity_node.inputs[0].name; - let output_name = &identity_node.outputs[0].name; - - // Replace the identity node's output with its input in the connected nodes. - for node in nodes.iter_mut() { - if let Some(matched_input) = node.inputs.iter_mut().find(|x| x.name == *output_name) - { - matched_input.name = input_name.clone(); - } - } - - nodes_to_remove.insert(identity_node); + log::info!("Handling identity nodes"); + + let mut nodes_to_remove = HashSet::new(); + + let identity_nodes = nodes + .iter() + .filter(|node| node.node_type == NodeType::Identity) + .cloned() + .collect::>(); + + // Handle pass-through nodes. + for identity_node in identity_nodes { + if identity_node.node_type == NodeType::Identity && identity_node.inputs[0].value.is_none() { + let input_name = &identity_node.inputs[0].name; + let output_name = &identity_node.outputs[0].name; + + // Replace the identity node's output with its input in the connected nodes. + for node in nodes.iter_mut() { + if let Some(matched_input) = node.inputs.iter_mut().find(|x| x.name == *output_name) { + matched_input.name = input_name.clone(); } + } + + nodes_to_remove.insert(identity_node); } + } - // Remove the identity nodes. - nodes.retain(|node| !nodes_to_remove.contains(node)); + // Remove the identity nodes. + nodes.retain(|node| !nodes_to_remove.contains(node)); } /// Rename the nodes in the graph to be unique and return a map of the old names to the new names. fn rename_nodes(nodes: &mut Vec) -> HashMap { - let mut old_names = HashMap::new(); - let mut counter: HashMap = HashMap::new(); + let mut old_names = HashMap::new(); + let mut counter: HashMap = HashMap::new(); - for node in nodes.iter_mut() { - // keep track of the number of nodes of each type - counter - .entry(node.node_type.clone()) - .and_modify(|e| *e += 1) - .or_insert(1); + for node in nodes.iter_mut() { + // keep track of the number of nodes of each type + counter + .entry(node.node_type.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); - let old_name = node.name.clone(); - let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase(); + let old_name = node.name.clone(); + let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase(); - node.name = new_name.clone(); + node.name = new_name.clone(); - old_names.insert(old_name, new_name); - } + old_names.insert(old_name, new_name); + } - old_names + old_names } /// Rename the inputs and output in the graph and return a map of @@ -316,60 +315,60 @@ fn rename_nodes(nodes: &mut Vec) -> HashMap { /// conv2_in1, conv2_in2, etc. This is done to be consistent with /// the naming convention of the nodes and allow to be used as rust identifiers. fn rename_inputs( - nodes: &mut Vec, - inputs: &mut Vec, - outputs: &mut Vec, + nodes: &mut Vec, + inputs: &mut Vec, + outputs: &mut Vec, ) -> HashMap { - let mut old_names = HashMap::new(); - - // rename all graph input names to follow input1, input2, input3, etc. - // (assumes the input names are already unique) + let mut old_names = HashMap::new(); + + // rename all graph input names to follow input1, input2, input3, etc. + // (assumes the input names are already unique) + let mut counter = 1; + for input in inputs.iter_mut() { + let old_name = input.name.clone(); + let new_name = format!("input{}", counter); + input.name = new_name.clone(); + old_names.insert(old_name, new_name); + counter += 1; + } + + for node in nodes.iter_mut() { let mut counter = 1; - for input in inputs.iter_mut() { - let old_name = input.name.clone(); - let new_name = format!("input{}", counter); - input.name = new_name.clone(); - old_names.insert(old_name, new_name); - counter += 1; - } - for node in nodes.iter_mut() { - let mut counter = 1; - - // loop through node outputs and rename them and store the new name <-> old name mapping - for output in node.outputs.iter_mut() { - let old_name = output.name.clone(); - let new_name = format!("{}_out{}", node.name, counter); - output.name = new_name.clone(); - old_names.insert(old_name, new_name); - counter += 1; - } + // loop through node outputs and rename them and store the new name <-> old name mapping + for output in node.outputs.iter_mut() { + let old_name = output.name.clone(); + let new_name = format!("{}_out{}", node.name, counter); + output.name = new_name.clone(); + old_names.insert(old_name, new_name); + counter += 1; } + } - for node in nodes.iter_mut() { - // loop through node inputs and rename them with previously replaced names - // and mark them as passed if they are in the old_names map (i.e. they are node outputs) - for input in node.inputs.iter_mut() { - if let Some(new_name) = old_names.get(&input.name) { - input.name = new_name.clone(); - input.passed = true; - } else { - input.name = "".to_string(); // Rename to a placeholder - input.passed = false; - } - } + for node in nodes.iter_mut() { + // loop through node inputs and rename them with previously replaced names + // and mark them as passed if they are in the old_names map (i.e. they are node outputs) + for input in node.inputs.iter_mut() { + if let Some(new_name) = old_names.get(&input.name) { + input.name = new_name.clone(); + input.passed = true; + } else { + input.name = "".to_string(); // Rename to a placeholder + input.passed = false; + } } + } - // Rename the graph outputs - for output in outputs.iter_mut() { - if let Some(new_name) = old_names.get(&output.name) { - output.name = new_name.clone(); - } else { - log::warn!("Output {:?} not found in old_names", output.name); - } + // Rename the graph outputs + for output in outputs.iter_mut() { + if let Some(new_name) = old_names.get(&output.name) { + output.name = new_name.clone(); + } else { + log::warn!("Output {:?} not found in old_names", output.name); } + } - old_names + old_names } /// Removes the graph inputs/output that are not used by any node. @@ -382,90 +381,90 @@ fn rename_inputs( /// Generally, it's a good idea to remove unused inputs/outputs because it makes the /// generated code cleaner and easier to read. fn remove_unused_graph_inputs( - inputs: &mut Vec, - outputs: &mut Vec, - nodes: &Vec, + inputs: &mut Vec, + outputs: &mut Vec, + nodes: &Vec, ) { - // Remove inputs that are not used by any node - inputs.retain(|input| { - for node in nodes.iter() { - if node - .inputs - .iter() - .any(|x| x.name == input.name && x.value.is_none()) - { - return true; - } - } - false - }); - - // Remove outputs that are not used by any node - outputs.retain(|output| { - for node in nodes.iter() { - if node.outputs.iter().any(|x| x.name == output.name) { - return true; - } - } - false - }); + // Remove inputs that are not used by any node + inputs.retain(|input| { + for node in nodes.iter() { + if node + .inputs + .iter() + .any(|x| x.name == input.name && x.value.is_none()) + { + return true; + } + } + false + }); + + // Remove outputs that are not used by any node + outputs.retain(|output| { + for node in nodes.iter() { + if node.outputs.iter().any(|x| x.name == output.name) { + return true; + } + } + false + }); } // Define a trait for topological sorting trait TopologicalSortable { - fn is_top_sorted(&self) -> bool; + fn is_top_sorted(&self) -> bool; } impl TopologicalSortable for Vec { - fn is_top_sorted(&self) -> bool { - // Create a hashmap to store the position of each node in the vector - let position: HashMap = self - .iter() - .enumerate() - .map(|(idx, node)| (node.name.clone(), idx)) - .collect(); - - // Iterate over each node in the vector - for node in self { - // Iterate over each output of the node - for output in &node.outputs { - // Iterate over each other node in the vector - for other_node in self { - // If the other node has an input that matches the current output - if other_node.inputs.contains(output) { - // If the position of the current node is greater than the position of the other node - if position[&node.name] > position[&other_node.name] { - // The vector is not topologically sorted - return false; - } - } - } + fn is_top_sorted(&self) -> bool { + // Create a hashmap to store the position of each node in the vector + let position: HashMap = self + .iter() + .enumerate() + .map(|(idx, node)| (node.name.clone(), idx)) + .collect(); + + // Iterate over each node in the vector + for node in self { + // Iterate over each output of the node + for output in &node.outputs { + // Iterate over each other node in the vector + for other_node in self { + // If the other node has an input that matches the current output + if other_node.inputs.contains(output) { + // If the position of the current node is greater than the position of the other node + if position[&node.name] > position[&other_node.name] { + // The vector is not topologically sorted + return false; } + } } - - // The vector is topologically sorted - true + } } + + // The vector is topologically sorted + true + } } /// Get the value of a constant node from its attributes pub(crate) fn convert_constant_value(node: &Node) -> Argument { - // A value can be stored in any of these attributes - let keys = [ - "value", - "value_float", - "value_floats", - "value_int", - "value_ints", - "value_string", - "value_strings", - "sparse_value", - ]; - - let value = keys - .iter() - .find_map(|&key| node.attrs.get(key).cloned()) - .expect("Constant should have a value"); - - Argument::from(value) + // A value can be stored in any of these attributes + let keys = [ + "value", + "value_float", + "value_floats", + "value_int", + "value_ints", + "value_string", + "value_strings", + "sparse_value", + ]; + + let value = keys + .iter() + .find_map(|&key| node.attrs.get(key).cloned()) + .expect("Constant should have a value"); + + Argument::from(value) } diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index d4a2086ad5..29f9f857da 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -9,39 +9,39 @@ pub type Shape = Vec; /// A node input or output. #[derive(Debug, Clone)] pub struct Argument { - /// The name of the node input. - pub name: String, + /// The name of the node input. + pub name: String, - /// The type of the argument. - pub ty: ArgType, + /// The type of the argument. + pub ty: ArgType, - /// The data of the argument. - pub value: Option, + /// The data of the argument. + pub value: Option, - /// True if the argument is passed to node, false otherwise. We use it mainly for informational purposes. - /// The argument should contain a value if passed is false. - pub passed: bool, + /// True if the argument is passed to node, false otherwise. We use it mainly for informational purposes. + /// The argument should contain a value if passed is false. + pub passed: bool, } /// The type of an argument. #[derive(Debug, Clone)] pub enum ArgType { - Scalar(ElementType), - Shape(Dim), - Tensor(TensorType), + Scalar(ElementType), + Shape(Dim), + Tensor(TensorType), } /// The type of an attribute. #[derive(Debug, Clone)] pub enum AttributeValue { - Float32(f32), - Float32s(Vec), - Int64(i64), - Int64s(Vec), - String(String), - Strings(Vec), - Tensor(Tensor), - Tensors(Vec), + Float32(f32), + Float32s(Vec), + Int64(i64), + Int64s(Vec), + String(String), + Strings(Vec), + Tensor(Tensor), + Tensors(Vec), } pub type Attributes = HashMap; @@ -49,126 +49,126 @@ pub type Attributes = HashMap; /// The type of an element. #[derive(Debug, Clone)] pub enum ElementType { - Float32, - Float64, - Int32, - Int64, - String, - Float16, - Bool, + Float32, + Float64, + Int32, + Int64, + String, + Float16, + Bool, } #[derive(Debug, Clone, Default)] pub struct TensorType { - /// The type of the tensor. - pub elem_type: ElementType, + /// The type of the tensor. + pub elem_type: ElementType, - /// The dimension of the tensor. - pub dim: Dim, + /// The dimension of the tensor. + pub dim: Dim, - /// The shape of the tensor. - pub shape: Option, + /// The shape of the tensor. + pub shape: Option, } impl Default for ElementType { - fn default() -> Self { - Self::Float32 - } + fn default() -> Self { + Self::Float32 + } } impl Default for ArgType { - fn default() -> Self { - Self::Tensor(TensorType::default()) - } + fn default() -> Self { + Self::Tensor(TensorType::default()) + } } impl Argument { - pub fn new(name: String) -> Self { - Self { - name, - ty: ArgType::default(), - value: None, - passed: false, - } - } + pub fn new(name: String) -> Self { + Self { + name, + ty: ArgType::default(), + value: None, + passed: false, + } + } } #[derive(Debug, Clone, Default)] pub struct Tensor { - /// The type of the tensor. - pub elem_type: ElementType, + /// The type of the tensor. + pub elem_type: ElementType, - /// The dimension of the tensor. - pub dim: Dim, + /// The dimension of the tensor. + pub dim: Dim, - /// The data of the tensor. - pub data: Option, + /// The data of the tensor. + pub data: Option, - /// The shape of the tensor. - pub shape: Option, + /// The shape of the tensor. + pub shape: Option, } /// Container to hold data for tensors and arguments #[derive(Clone)] pub enum Data { - Bool(bool), - Bools(Vec), - Float16(f16), - Float16s(Vec), - Float32(f32), - Float32s(Vec), - Float64(f64), - Float64s(Vec), - Int32(i32), - Int32s(Vec), - Int64(i64), - Int64s(Vec), - String(String), - Strings(Vec), + Bool(bool), + Bools(Vec), + Float16(f16), + Float16s(Vec), + Float32(f32), + Float32s(Vec), + Float64(f64), + Float64s(Vec), + Int32(i32), + Int32s(Vec), + Int64(i64), + Int64s(Vec), + String(String), + Strings(Vec), } /// ONNX graph representation #[derive(Debug, Clone)] pub struct ONNXGraph { - /// The nodes of the graph. - pub nodes: Vec, + /// The nodes of the graph. + pub nodes: Vec, - /// The inputs of the graph. - pub inputs: Vec, + /// The inputs of the graph. + pub inputs: Vec, - /// The outputs of the graph. - pub outputs: Vec, + /// The outputs of the graph. + pub outputs: Vec, - /// The original node names. - pub old_node_names: HashMap, + /// The original node names. + pub old_node_names: HashMap, - /// The original input names. - pub old_input_names: HashMap, + /// The original input names. + pub old_input_names: HashMap, } #[derive(Debug, Clone)] pub struct Node { - /// The type of the node. - pub node_type: NodeType, + /// The type of the node. + pub node_type: NodeType, - /// The name of the node. - pub name: String, + /// The name of the node. + pub name: String, - /// The inputs of the node. - pub inputs: Vec, + /// The inputs of the node. + pub inputs: Vec, - /// The outputs of the node. - pub outputs: Vec, + /// The outputs of the node. + pub outputs: Vec, - /// The attributes of the node. - pub attrs: Attributes, + /// The attributes of the node. + pub attrs: Attributes, } // Required by topological sort impl PartialEq for Node { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.node_type == other.node_type - } + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.node_type == other.node_type + } } // Required by topological sort @@ -176,596 +176,596 @@ impl Eq for Node {} // Required by topological sort impl core::hash::Hash for Node { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.node_type.hash(state); - self.inputs.hash(state); - self.outputs.hash(state); - } + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.node_type.hash(state); + self.inputs.hash(state); + self.outputs.hash(state); + } } // Required by topological sort impl core::hash::Hash for Argument { - fn hash(&self, state: &mut H) { - self.name.hash(state); - } + fn hash(&self, state: &mut H) { + self.name.hash(state); + } } impl Eq for Argument {} // Required by HashSet impl PartialEq for Argument { - fn eq(&self, other: &Self) -> bool { - self.name == other.name - } + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } } /// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops) #[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)] pub enum NodeType { - Abs, - Acos, - Acosh, - Add, - And, - ArgMax, - ArgMin, - Asin, - Asinh, - Atan, - Atanh, - AveragePool, - AveragePool1d, - AveragePool2d, - BatchNormalization, - Bernoulli, - BitShift, - BitwiseAnd, - BitwiseNot, - BitwiseOr, - BitwiseXor, - BlackmanWindow, - Cast, - CastLike, - Ceil, - Celu, - CenterCropPad, - Clip, - Col, - Compress, - Concat, - ConcatFromSequence, - Constant, - ConstantOfShape, - Conv, - Conv1d, - Conv2d, - ConvInteger, - ConvTranspose, - Cos, - Cosh, - CumSum, - DepthToSpace, - DequantizeLinear, - Det, - DFT, - Div, - Dropout, - DynamicQuantizeLinear, - Einsum, - Elu, - Equal, - Erf, - Exp, - Expand, - EyeLike, - Flatten, - Floor, - Gather, - GatherElements, - GatherND, - Gelu, - Gemm, - GlobalAveragePool, - GlobalLpPool, - GlobalMaxPool, - Greater, - GreaterOrEqual, - GridSample, - GroupNormalization, - GRU, - HammingWindow, - HannWindow, - Hardmax, - HardSigmoid, - HardSwish, - Identity, - If, - Im, - InstanceNormalization, - IsInf, - IsNaN, - LayerNormalization, - LeakyRelu, - Less, - LessOrEqual, - Linear, - Log, - LogSoftmax, - Loop, - LpNormalization, - LpPool, - LRN, - LSTM, - MatMul, - MatMulInteger, - Max, - MaxPool, - MaxPool1d, - MaxPool2d, - MaxRoiPool, - MaxUnpool, - Mean, - MeanVarianceNormalization, - MelWeightMatrix, - Min, - Mish, - Mod, - Mul, - Multinomial, - Neg, - NegativeLogLikelihoodLoss, - NonMaxSuppression, - NonZero, - Not, - OneHot, - Optional, - OptionalGetElement, - OptionalHasElement, - Or, - Pad, - Pow, - PRelu, - QLinearConv, - QLinearMatMul, - QuantizeLinear, - RandomNormal, - RandomNormalLike, - RandomUniform, - RandomUniformLike, - Range, - Reciprocal, - ReduceL, - ReduceLogSum, - ReduceLogSumExp, - ReduceMax, - ReduceMean, - ReduceMin, - ReduceProd, - ReduceSum, - ReduceSumSquare, - Relu, - Reshape, - Resize, - ReverseSequence, - RNN, - RoiAlign, - Round, - Scan, - Scatter, - ScatterElements, - ScatterND, - Selu, - SequenceAt, - SequenceConstruct, - SequenceEmpty, - SequenceErase, - SequenceInsert, - SequenceLength, - SequenceMap, - Shape, - Shrink, - Sigmoid, - Sign, - Sin, - Sinh, - Size, - Slice, - Softmax, - SoftmaxCrossEntropyLoss, - Softplus, - Softsign, - SpaceToDepth, - Split, - SplitToSequence, - Sqrt, - Squeeze, - STFT, - StringNormalizer, - Sub, - Sum, - Tan, - Tanh, - TfIdfVectorizer, - ThresholdedRelu, - Tile, - TopK, - Transpose, - Trilu, - Unique, - Unsqueeze, - Upsample, - Where, - Xor, + Abs, + Acos, + Acosh, + Add, + And, + ArgMax, + ArgMin, + Asin, + Asinh, + Atan, + Atanh, + AveragePool, + AveragePool1d, + AveragePool2d, + BatchNormalization, + Bernoulli, + BitShift, + BitwiseAnd, + BitwiseNot, + BitwiseOr, + BitwiseXor, + BlackmanWindow, + Cast, + CastLike, + Ceil, + Celu, + CenterCropPad, + Clip, + Col, + Compress, + Concat, + ConcatFromSequence, + Constant, + ConstantOfShape, + Conv, + Conv1d, + Conv2d, + ConvInteger, + ConvTranspose, + Cos, + Cosh, + CumSum, + DepthToSpace, + DequantizeLinear, + Det, + DFT, + Div, + Dropout, + DynamicQuantizeLinear, + Einsum, + Elu, + Equal, + Erf, + Exp, + Expand, + EyeLike, + Flatten, + Floor, + Gather, + GatherElements, + GatherND, + Gelu, + Gemm, + GlobalAveragePool, + GlobalLpPool, + GlobalMaxPool, + Greater, + GreaterOrEqual, + GridSample, + GroupNormalization, + GRU, + HammingWindow, + HannWindow, + Hardmax, + HardSigmoid, + HardSwish, + Identity, + If, + Im, + InstanceNormalization, + IsInf, + IsNaN, + LayerNormalization, + LeakyRelu, + Less, + LessOrEqual, + Linear, + Log, + LogSoftmax, + Loop, + LpNormalization, + LpPool, + LRN, + LSTM, + MatMul, + MatMulInteger, + Max, + MaxPool, + MaxPool1d, + MaxPool2d, + MaxRoiPool, + MaxUnpool, + Mean, + MeanVarianceNormalization, + MelWeightMatrix, + Min, + Mish, + Mod, + Mul, + Multinomial, + Neg, + NegativeLogLikelihoodLoss, + NonMaxSuppression, + NonZero, + Not, + OneHot, + Optional, + OptionalGetElement, + OptionalHasElement, + Or, + Pad, + Pow, + PRelu, + QLinearConv, + QLinearMatMul, + QuantizeLinear, + RandomNormal, + RandomNormalLike, + RandomUniform, + RandomUniformLike, + Range, + Reciprocal, + ReduceL, + ReduceLogSum, + ReduceLogSumExp, + ReduceMax, + ReduceMean, + ReduceMin, + ReduceProd, + ReduceSum, + ReduceSumSquare, + Relu, + Reshape, + Resize, + ReverseSequence, + RNN, + RoiAlign, + Round, + Scan, + Scatter, + ScatterElements, + ScatterND, + Selu, + SequenceAt, + SequenceConstruct, + SequenceEmpty, + SequenceErase, + SequenceInsert, + SequenceLength, + SequenceMap, + Shape, + Shrink, + Sigmoid, + Sign, + Sin, + Sinh, + Size, + Slice, + Softmax, + SoftmaxCrossEntropyLoss, + Softplus, + Softsign, + SpaceToDepth, + Split, + SplitToSequence, + Sqrt, + Squeeze, + STFT, + StringNormalizer, + Sub, + Sum, + Tan, + Tanh, + TfIdfVectorizer, + ThresholdedRelu, + Tile, + TopK, + Transpose, + Trilu, + Unique, + Unsqueeze, + Upsample, + Where, + Xor, } /// Truncate the vector display for debug display fn trunc(v: &Vec) -> String { - const BEGIN_INDEX: usize = 0; - const MAX_LEN: usize = 5; - let mut s = String::new(); - s.push('['); - for (i, item) in v.iter().enumerate() { - if i > BEGIN_INDEX { - s.push_str(", "); - } - s.push_str(&format!("{}", item)); - if i > MAX_LEN { - s.push_str(", ..."); - break; - } - } - s.push(']'); - s + const BEGIN_INDEX: usize = 0; + const MAX_LEN: usize = 5; + let mut s = String::new(); + s.push('['); + for (i, item) in v.iter().enumerate() { + if i > BEGIN_INDEX { + s.push_str(", "); + } + s.push_str(&format!("{}", item)); + if i > MAX_LEN { + s.push_str(", ..."); + break; + } + } + s.push(']'); + s } /// Shorten the tensor data for debug display impl fmt::Debug for Data { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Data::Float16s(v) => write!(f, "Float16s({})", trunc(v)), - Data::Float32s(v) => write!(f, "Float32s({})", trunc(v)), - Data::Float64s(v) => write!(f, "Float64s({})", trunc(v)), - Data::Int32s(v) => write!(f, "Int32s({})", trunc(v)), - Data::Int64s(v) => write!(f, "Int64s({})", trunc(v)), - Data::Strings(v) => write!(f, "Strings({})", trunc(v)), - Data::Bools(v) => write!(f, "Bools({})", trunc(v)), - Data::Float16(v) => write!(f, "Float16({})", v), - Data::Float32(v) => write!(f, "Float32({})", v), - Data::Float64(v) => write!(f, "Float64({})", v), - Data::Int32(v) => write!(f, "Int32({})", v), - Data::Int64(v) => write!(f, "Int64({})", v), - Data::String(v) => write!(f, "String({})", v), - Data::Bool(v) => write!(f, "Bool({})", v), - } - } + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Data::Float16s(v) => write!(f, "Float16s({})", trunc(v)), + Data::Float32s(v) => write!(f, "Float32s({})", trunc(v)), + Data::Float64s(v) => write!(f, "Float64s({})", trunc(v)), + Data::Int32s(v) => write!(f, "Int32s({})", trunc(v)), + Data::Int64s(v) => write!(f, "Int64s({})", trunc(v)), + Data::Strings(v) => write!(f, "Strings({})", trunc(v)), + Data::Bools(v) => write!(f, "Bools({})", trunc(v)), + Data::Float16(v) => write!(f, "Float16({})", v), + Data::Float32(v) => write!(f, "Float32({})", v), + Data::Float64(v) => write!(f, "Float64({})", v), + Data::Int32(v) => write!(f, "Int32({})", v), + Data::Int64(v) => write!(f, "Int64({})", v), + Data::String(v) => write!(f, "String({})", v), + Data::Bool(v) => write!(f, "Bool({})", v), + } + } } impl Data { - pub fn into_scalar(self) -> Self { - match self { - Data::Float16s(data) => { - assert_eq!(data.len(), 1); - Data::Float16(data[0]) - } - Data::Float32s(data) => { - assert_eq!(data.len(), 1); - Data::Float32(data[0]) - } - Data::Float64s(data) => { - assert_eq!(data.len(), 1); - Data::Float64(data[0]) - } - Data::Int32s(data) => { - assert_eq!(data.len(), 1); - Data::Int32(data[0]) - } - Data::Int64s(data) => { - assert_eq!(data.len(), 1); - Data::Int64(data[0]) - } - Data::Bools(data) => { - assert_eq!(data.len(), 1); - Data::Bool(data[0]) - } - Data::Strings(data) => { - assert_eq!(data.len(), 1); - Data::String(data[0].clone()) - } - _ => self, - } - } - - pub fn into_f16(self) -> f16 { - if let Data::Float16(elem) = self { - elem - } else { - panic!("Expected Float16, got {:?}", self); - } - } - - pub fn into_f32(self) -> f32 { - if let Data::Float32(elem) = self { - elem - } else { - panic!("Expected Float32, got {:?}", self); - } - } - - pub fn into_f64(self) -> f64 { - if let Data::Float64(elem) = self { - elem - } else { - panic!("Expected Float64, got {:?}", self); - } - } - - pub fn into_i32(self) -> i32 { - if let Data::Int32(elem) = self { - elem - } else { - panic!("Expected Int32, got {:?}", self); - } - } - - pub fn into_i64(self) -> i64 { - if let Data::Int64(elem) = self { - elem - } else { - panic!("Expected Int64, got {:?}", self); - } - } - - pub fn into_bool(self) -> bool { - if let Data::Bool(elem) = self { - elem - } else { - panic!("Expected Bool, got {:?}", self); - } - } - - pub fn into_string(self) -> String { - if let Data::String(elem) = self { - elem - } else { - panic!("Expected String, got {:?}", self); - } + pub fn into_scalar(self) -> Self { + match self { + Data::Float16s(data) => { + assert_eq!(data.len(), 1); + Data::Float16(data[0]) + } + Data::Float32s(data) => { + assert_eq!(data.len(), 1); + Data::Float32(data[0]) + } + Data::Float64s(data) => { + assert_eq!(data.len(), 1); + Data::Float64(data[0]) + } + Data::Int32s(data) => { + assert_eq!(data.len(), 1); + Data::Int32(data[0]) + } + Data::Int64s(data) => { + assert_eq!(data.len(), 1); + Data::Int64(data[0]) + } + Data::Bools(data) => { + assert_eq!(data.len(), 1); + Data::Bool(data[0]) + } + Data::Strings(data) => { + assert_eq!(data.len(), 1); + Data::String(data[0].clone()) + } + _ => self, + } + } + + pub fn into_f16(self) -> f16 { + if let Data::Float16(elem) = self { + elem + } else { + panic!("Expected Float16, got {:?}", self); + } + } + + pub fn into_f32(self) -> f32 { + if let Data::Float32(elem) = self { + elem + } else { + panic!("Expected Float32, got {:?}", self); + } + } + + pub fn into_f64(self) -> f64 { + if let Data::Float64(elem) = self { + elem + } else { + panic!("Expected Float64, got {:?}", self); + } + } + + pub fn into_i32(self) -> i32 { + if let Data::Int32(elem) = self { + elem + } else { + panic!("Expected Int32, got {:?}", self); + } + } + + pub fn into_i64(self) -> i64 { + if let Data::Int64(elem) = self { + elem + } else { + panic!("Expected Int64, got {:?}", self); + } + } + + pub fn into_bool(self) -> bool { + if let Data::Bool(elem) = self { + elem + } else { + panic!("Expected Bool, got {:?}", self); + } + } + + pub fn into_string(self) -> String { + if let Data::String(elem) = self { + elem + } else { + panic!("Expected String, got {:?}", self); + } + } + + pub fn into_f16s(self) -> Vec { + if let Data::Float16s(elem) = self { + elem + } else { + panic!("Expected Float16s, got {:?}", self); + } + } + + pub fn into_f32s(self) -> Vec { + if let Data::Float32s(elem) = self { + elem + } else { + panic!("Expected Float32s, got {:?}", self); + } + } + + pub fn into_f64s(self) -> Vec { + if let Data::Float64s(elem) = self { + elem + } else { + panic!("Expected Float64s, got {:?}", self); + } + } + + pub fn into_i32s(self) -> Vec { + if let Data::Int32s(elem) = self { + elem + } else { + panic!("Expected Int32s, got {:?}", self); + } + } + + pub fn into_i64s(self) -> Vec { + if let Data::Int64s(elem) = self { + elem + } else { + panic!("Expected Int64s, got {:?}", self); } + } - pub fn into_f16s(self) -> Vec { - if let Data::Float16s(elem) = self { - elem - } else { - panic!("Expected Float16s, got {:?}", self); - } - } - - pub fn into_f32s(self) -> Vec { - if let Data::Float32s(elem) = self { - elem - } else { - panic!("Expected Float32s, got {:?}", self); - } + pub fn into_bools(self) -> Vec { + if let Data::Bools(elem) = self { + elem + } else { + panic!("Expected Bools, got {:?}", self); } + } - pub fn into_f64s(self) -> Vec { - if let Data::Float64s(elem) = self { - elem - } else { - panic!("Expected Float64s, got {:?}", self); - } - } - - pub fn into_i32s(self) -> Vec { - if let Data::Int32s(elem) = self { - elem - } else { - panic!("Expected Int32s, got {:?}", self); - } - } - - pub fn into_i64s(self) -> Vec { - if let Data::Int64s(elem) = self { - elem - } else { - panic!("Expected Int64s, got {:?}", self); - } - } - - pub fn into_bools(self) -> Vec { - if let Data::Bools(elem) = self { - elem - } else { - panic!("Expected Bools, got {:?}", self); - } - } - - pub fn into_strings(self) -> Vec { - if let Data::Strings(elem) = self { - elem - } else { - panic!("Expected Strings, got {:?}", self); - } + pub fn into_strings(self) -> Vec { + if let Data::Strings(elem) = self { + elem + } else { + panic!("Expected Strings, got {:?}", self); } + } } impl AttributeValue { - pub fn into_f32(self) -> f32 { - if let AttributeValue::Float32(elem) = self { - elem - } else { - panic!("Expected Float32, got {:?}", self); - } + pub fn into_f32(self) -> f32 { + if let AttributeValue::Float32(elem) = self { + elem + } else { + panic!("Expected Float32, got {:?}", self); } + } - pub fn into_i32(self) -> i32 { - if let AttributeValue::Int64(elem) = self { - elem as i32 - } else { - panic!("Expected Int32, got {:?}", self); - } + pub fn into_i32(self) -> i32 { + if let AttributeValue::Int64(elem) = self { + elem as i32 + } else { + panic!("Expected Int32, got {:?}", self); } + } - pub fn into_i64(self) -> i64 { - if let AttributeValue::Int64(elem) = self { - elem - } else { - panic!("Expected Int64, got {:?}", self); - } + pub fn into_i64(self) -> i64 { + if let AttributeValue::Int64(elem) = self { + elem + } else { + panic!("Expected Int64, got {:?}", self); } + } - pub fn into_string(self) -> String { - if let AttributeValue::String(elem) = self { - elem - } else { - panic!("Expected String, got {:?}", self); - } + pub fn into_string(self) -> String { + if let AttributeValue::String(elem) = self { + elem + } else { + panic!("Expected String, got {:?}", self); } + } - pub fn into_tensor(self) -> Tensor { - if let AttributeValue::Tensor(elem) = self { - elem - } else { - panic!("Expected Tensor, got {:?}", self); - } + pub fn into_tensor(self) -> Tensor { + if let AttributeValue::Tensor(elem) = self { + elem + } else { + panic!("Expected Tensor, got {:?}", self); } + } - pub fn into_f32s(self) -> Vec { - if let AttributeValue::Float32s(elem) = self { - elem - } else { - panic!("Expected Float32s, got {:?}", self); - } + pub fn into_f32s(self) -> Vec { + if let AttributeValue::Float32s(elem) = self { + elem + } else { + panic!("Expected Float32s, got {:?}", self); } + } - pub fn into_i64s(self) -> Vec { - if let AttributeValue::Int64s(elem) = self { - elem - } else { - panic!("Expected Int64s, got {:?}", self); - } + pub fn into_i64s(self) -> Vec { + if let AttributeValue::Int64s(elem) = self { + elem + } else { + panic!("Expected Int64s, got {:?}", self); } + } - pub fn into_strings(self) -> Vec { - if let AttributeValue::Strings(elem) = self { - elem - } else { - panic!("Expected Strings, got {:?}", self); - } + pub fn into_strings(self) -> Vec { + if let AttributeValue::Strings(elem) = self { + elem + } else { + panic!("Expected Strings, got {:?}", self); } + } - pub fn into_tensors(self) -> Vec { - if let AttributeValue::Tensors(elem) = self { - elem - } else { - panic!("Expected Tensors, got {:?}", self); - } + pub fn into_tensors(self) -> Vec { + if let AttributeValue::Tensors(elem) = self { + elem + } else { + panic!("Expected Tensors, got {:?}", self); } + } } /// Convert AttributeValue to an Argument impl From for Argument { - fn from(attr: AttributeValue) -> Argument { - // "" is used as a placeholder for the name - let name = "".to_string(); - - match attr { - AttributeValue::Float32(value) => Argument { - ty: ArgType::Scalar(ElementType::Float32), - name, - value: Some(Data::Float32(value)), - passed: false, - }, - AttributeValue::Float32s(values) => Argument { - ty: ArgType::Tensor(TensorType { - dim: 1, - elem_type: ElementType::Float32, - shape: Some(vec![values.len()]), - }), - name, - value: Some(Data::Float32s(values)), - passed: false, - }, - AttributeValue::Int64(value) => Argument { - ty: ArgType::Scalar(ElementType::Int64), - name, - value: Some(Data::Int64(value)), - passed: false, - }, - AttributeValue::Int64s(values) => Argument { - ty: ArgType::Tensor(TensorType { - dim: 1, - elem_type: ElementType::Int64, - shape: Some(vec![values.len()]), - }), - name, - value: Some(Data::Int64s(values)), - passed: false, - }, - AttributeValue::String(value) => Argument { - ty: ArgType::Scalar(ElementType::String), - name, - value: Some(Data::String(value)), - passed: false, - }, - AttributeValue::Strings(values) => Argument { - ty: ArgType::Tensor(TensorType { - dim: 1, - elem_type: ElementType::String, - shape: Some(vec![values.len()]), - }), - name, - value: Some(Data::Strings(values)), - passed: false, - }, - AttributeValue::Tensor(tensor) => { - if tensor.dim == 0 { - // Convert zero dim tensor to scalar - if let Some(data) = tensor.data { - Argument { - ty: ArgType::Scalar(tensor.elem_type), - name, - value: Some(data.into_scalar()), - passed: false, - } - } else { - Argument { - ty: ArgType::Scalar(tensor.elem_type), - name, - value: None, - passed: false, - } - } - } else { - // Convert tensor to argument - Argument { - ty: ArgType::Tensor(TensorType { - dim: tensor.dim, - elem_type: tensor.elem_type, - shape: tensor.shape, - }), - name, - value: tensor.data, - passed: false, - } - } + fn from(attr: AttributeValue) -> Argument { + // "" is used as a placeholder for the name + let name = "".to_string(); + + match attr { + AttributeValue::Float32(value) => Argument { + ty: ArgType::Scalar(ElementType::Float32), + name, + value: Some(Data::Float32(value)), + passed: false, + }, + AttributeValue::Float32s(values) => Argument { + ty: ArgType::Tensor(TensorType { + dim: 1, + elem_type: ElementType::Float32, + shape: Some(vec![values.len()]), + }), + name, + value: Some(Data::Float32s(values)), + passed: false, + }, + AttributeValue::Int64(value) => Argument { + ty: ArgType::Scalar(ElementType::Int64), + name, + value: Some(Data::Int64(value)), + passed: false, + }, + AttributeValue::Int64s(values) => Argument { + ty: ArgType::Tensor(TensorType { + dim: 1, + elem_type: ElementType::Int64, + shape: Some(vec![values.len()]), + }), + name, + value: Some(Data::Int64s(values)), + passed: false, + }, + AttributeValue::String(value) => Argument { + ty: ArgType::Scalar(ElementType::String), + name, + value: Some(Data::String(value)), + passed: false, + }, + AttributeValue::Strings(values) => Argument { + ty: ArgType::Tensor(TensorType { + dim: 1, + elem_type: ElementType::String, + shape: Some(vec![values.len()]), + }), + name, + value: Some(Data::Strings(values)), + passed: false, + }, + AttributeValue::Tensor(tensor) => { + if tensor.dim == 0 { + // Convert zero dim tensor to scalar + if let Some(data) = tensor.data { + Argument { + ty: ArgType::Scalar(tensor.elem_type), + name, + value: Some(data.into_scalar()), + passed: false, + } + } else { + Argument { + ty: ArgType::Scalar(tensor.elem_type), + name, + value: None, + passed: false, } - _ => panic!("Unsupported attribute type"), + } + } else { + // Convert tensor to argument + Argument { + ty: ArgType::Tensor(TensorType { + dim: tensor.dim, + elem_type: tensor.elem_type, + shape: tensor.shape, + }), + name, + value: tensor.data, + passed: false, + } } + } + _ => panic!("Unsupported attribute type"), } + } } impl Argument { - pub fn into_tensor(self) -> Option { - if let ArgType::Tensor(tensor_type) = self.ty { - Some(Tensor { - elem_type: tensor_type.elem_type, - dim: tensor_type.dim, - data: self.value, - shape: tensor_type.shape, - }) - } else { - None - } - } + pub fn into_tensor(self) -> Option { + if let ArgType::Tensor(tensor_type) = self.ty { + Some(Tensor { + elem_type: tensor_type.elem_type, + dim: tensor_type.dim, + data: self.value, + shape: tensor_type.shape, + }) + } else { + None + } + } } diff --git a/burn-import/src/onnx/node_remap.rs b/burn-import/src/onnx/node_remap.rs index c87059280c..7b773f02aa 100644 --- a/burn-import/src/onnx/node_remap.rs +++ b/burn-import/src/onnx/node_remap.rs @@ -3,33 +3,33 @@ use super::ir::{AttributeValue, Node, NodeType}; /// Remap node type using kernel shape pub fn remap_node_with_kernel_shape(node: &mut Node, new_node_type: F) where - F: FnOnce(&Vec) -> NodeType, + F: FnOnce(&Vec) -> NodeType, { - if let AttributeValue::Int64s(ints) = node.attrs.get("kernel_shape").unwrap() { - node.node_type = new_node_type(ints); - } else { - panic!("kernel_shape is not an int64s"); - } + if let AttributeValue::Int64s(ints) = node.attrs.get("kernel_shape").unwrap() { + node.node_type = new_node_type(ints); + } else { + panic!("kernel_shape is not an int64s"); + } } /// Remap node type to a more specific one pub fn remap_node_type(node: &mut Node) { - match node.node_type { - NodeType::Conv => remap_node_with_kernel_shape(node, |ints| match ints.len() { - 1 => NodeType::Conv1d, - 2 => NodeType::Conv2d, - _ => panic!("Only conv 1d and 2d are supported"), - }), - NodeType::MaxPool => remap_node_with_kernel_shape(node, |ints| match ints.len() { - 1 => NodeType::MaxPool1d, - 2 => NodeType::MaxPool2d, - _ => panic!("Only max_pool 1d and 2d are supported"), - }), - NodeType::AveragePool => remap_node_with_kernel_shape(node, |ints| match ints.len() { - 1 => NodeType::AveragePool1d, - 2 => NodeType::AveragePool2d, - _ => panic!("Only avg_pool 1d and 2d are supported"), - }), - _ => (), - } + match node.node_type { + NodeType::Conv => remap_node_with_kernel_shape(node, |ints| match ints.len() { + 1 => NodeType::Conv1d, + 2 => NodeType::Conv2d, + _ => panic!("Only conv 1d and 2d are supported"), + }), + NodeType::MaxPool => remap_node_with_kernel_shape(node, |ints| match ints.len() { + 1 => NodeType::MaxPool1d, + 2 => NodeType::MaxPool2d, + _ => panic!("Only max_pool 1d and 2d are supported"), + }), + NodeType::AveragePool => remap_node_with_kernel_shape(node, |ints| match ints.len() { + 1 => NodeType::AveragePool1d, + 2 => NodeType::AveragePool2d, + _ => panic!("Only avg_pool 1d and 2d are supported"), + }), + _ => (), + } } diff --git a/burn-import/src/onnx/op_configuration.rs b/burn-import/src/onnx/op_configuration.rs index e9bf018781..cfcb69ec6d 100644 --- a/burn-import/src/onnx/op_configuration.rs +++ b/burn-import/src/onnx/op_configuration.rs @@ -1,8 +1,8 @@ use burn::nn::{ - conv::Conv1dConfig, - conv::Conv2dConfig, - pool::{AvgPool2dConfig, MaxPool2dConfig}, - BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d, + conv::Conv1dConfig, + conv::Conv2dConfig, + pool::{AvgPool2dConfig, MaxPool2dConfig}, + BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d, }; use crate::onnx::ir::Data; @@ -11,404 +11,404 @@ use super::ir::{ArgType, Node}; /// Create a Conv1dConfig from the attributes of the node pub fn conv1d_config(curr: &Node) -> Conv1dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1]; - let mut pads = vec![0, 0]; - let mut dilations = vec![1]; - let mut group: i64 = 1; - - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { - weight - } else { - panic!("Conv1d: weight tensor must be present"); - }; - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let shape = weight.shape.clone().unwrap(); - let channels_in = shape[1]; - let channels_out = shape[0]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64(), - _ => {} - } - } - - let padding = padding_config_1d(&pads); - - Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize) - .with_stride(strides[0] as usize) - .with_dilation(dilations[0] as usize) - .with_groups(group as usize) - .with_bias(bias) - .with_padding(padding) + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1]; + let mut pads = vec![0, 0]; + let mut dilations = vec![1]; + let mut group: i64 = 1; + + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { + weight + } else { + panic!("Conv1d: weight tensor must be present"); + }; + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let shape = weight.shape.clone().unwrap(); + let channels_in = shape[1]; + let channels_out = shape[0]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64(), + _ => {} + } + } + + let padding = padding_config_1d(&pads); + + Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize) + .with_stride(strides[0] as usize) + .with_dilation(dilations[0] as usize) + .with_groups(group as usize) + .with_bias(bias) + .with_padding(padding) } /// Create a Conv2dConfig from the attributes of the node pub fn conv2d_config(curr: &Node) -> Conv2dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - let mut group: i64 = 1; - - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { - weight - } else { - panic!("Conv1d: weight tensor must be present"); - }; - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let shape = weight.shape.clone().unwrap(); - let channels: [usize; 2] = [shape[1], shape[0]]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64(), - _ => {} - } - } - - let padding = padding_config(&pads); - - Conv2dConfig::new( - channels, - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([strides[0] as usize, strides[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group as usize) - .with_bias(bias) - .with_padding(padding) + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + let mut group: i64 = 1; + + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { + weight + } else { + panic!("Conv1d: weight tensor must be present"); + }; + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let shape = weight.shape.clone().unwrap(); + let channels: [usize; 2] = [shape[1], shape[0]]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64(), + _ => {} + } + } + + let padding = padding_config(&pads); + + Conv2dConfig::new( + channels, + [kernel_shape[0] as usize, kernel_shape[1] as usize], + ) + .with_stride([strides[0] as usize, strides[1] as usize]) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) + .with_groups(group as usize) + .with_bias(bias) + .with_padding(padding) } /// Create a MaxPool2dConfig from the attributes of the node pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - _ => {} - } + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + _ => {} } + } - let padding = padding_config(&pads); + let padding = padding_config(&pads); - MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) + MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) } /// Create a AvgPool2dConfig from the attributes of the node pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut count_include_pad: i64 = 0; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "count_include_pad" => count_include_pad = value.clone().into_i64(), - _ => {} - } + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut count_include_pad: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "count_include_pad" => count_include_pad = value.clone().into_i64(), + _ => {} } + } - let padding = padding_config(&pads); + let padding = padding_config(&pads); - if count_include_pad == 1 && padding != PaddingConfig2d::Valid { - todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636"); - } + if count_include_pad == 1 && padding != PaddingConfig2d::Valid { + todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636"); + } - AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) + AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) } /// Create a FlattenConfig from the attributes of the node pub fn flatten_config(curr: &Node) -> (usize, usize) { - // the begin dimension is the first dimension (Default: 1 per ONNX spec) - let mut start_dim: i64 = 1; - - // check if the node has only one input - if curr.inputs.len() != 1 { - panic!( - "Flatten: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match curr.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // check if the input tensor has at least 2 dimensions - if tensor.dim < 2 { - panic!( - "Flatten: input tensor must have at least 2 dimensions (got {:?})", - tensor.dim - ); - } - - // the end dimension is the last dimension - let end_dim = tensor.dim - 1; - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => start_dim = value.clone().into_i64(), - _ => {} - } - } - - // if beg_dim is negative, it is counted from the end - if start_dim < 0 { - start_dim += tensor.dim as i64; - } - - (start_dim as usize, end_dim) + // the begin dimension is the first dimension (Default: 1 per ONNX spec) + let mut start_dim: i64 = 1; + + // check if the node has only one input + if curr.inputs.len() != 1 { + panic!( + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // check if the input tensor has at least 2 dimensions + if tensor.dim < 2 { + panic!( + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.dim + ); + } + + // the end dimension is the last dimension + let end_dim = tensor.dim - 1; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "axis" => start_dim = value.clone().into_i64(), + _ => {} + } + } + + // if beg_dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.dim as i64; + } + + (start_dim as usize, end_dim) } /// Create a GatherConfig from the attributes of the node pub fn gather_config(curr: &Node) -> usize { - // Default: 0 per ONNX spec - let mut dim: i64 = 0; - - // check if the node has only one input - if curr.inputs.len() != 2 { - panic!("Gather: index tensor must be present"); - } - - // extract the shape of the input tensor - let tensor = match curr.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => dim = value.clone().into_i64(), - _ => {} - } - } - - // if dim is negative, it is counted from the end - if dim < 0 { - dim += tensor.dim as i64; - } - - dim as usize + // Default: 0 per ONNX spec + let mut dim: i64 = 0; + + // check if the node has only one input + if curr.inputs.len() != 2 { + panic!("Gather: index tensor must be present"); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "axis" => dim = value.clone().into_i64(), + _ => {} + } + } + + // if dim is negative, it is counted from the end + if dim < 0 { + dim += tensor.dim as i64; + } + + dim as usize } /// Create a LinearConfig from the attributes of the node pub fn linear_config(node: &Node) -> LinearConfig { - if node.inputs.len() < 2 { - panic!("Linear: missing weight tensor"); - } - - // extract the shape of the weight tensor - let weight = if let ArgType::Tensor(ref weight) = node.inputs[1].ty { - weight - } else { - panic!("Linear: weight tensor must be present"); - }; - - // check if the weight tensor has at least 2 dimensions - if weight.dim < 2 { - panic!( - "Linear: weight tensor must have at least 2 dimensions (got {:?})", - weight.dim - ); - } - - let shape = weight.shape.clone().unwrap(); - let (in_size, out_size) = (shape[0], shape[1]); - - // check if the bias is present - let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); - - LinearConfig::new(in_size, out_size).with_bias(bias) + if node.inputs.len() < 2 { + panic!("Linear: missing weight tensor"); + } + + // extract the shape of the weight tensor + let weight = if let ArgType::Tensor(ref weight) = node.inputs[1].ty { + weight + } else { + panic!("Linear: weight tensor must be present"); + }; + + // check if the weight tensor has at least 2 dimensions + if weight.dim < 2 { + panic!( + "Linear: weight tensor must have at least 2 dimensions (got {:?})", + weight.dim + ); + } + + let shape = weight.shape.clone().unwrap(); + let (in_size, out_size) = (shape[0], shape[1]); + + // check if the bias is present + let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); + + LinearConfig::new(in_size, out_size).with_bias(bias) } /// Create a DropoutConfig from an attribute and state of the node pub fn dropout_config(node: &Node) -> DropoutConfig { - // Opset 7 and older store probability as an attribute - if node.attrs.contains_key("ratio") { - let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); - return DropoutConfig::new(prob as f64); - } - - if node.inputs.len() < 2 { - panic!("Dropout configuration must have at least 2 inputs"); - } - - let ratio = node.inputs[1] - .value - .clone() - .expect("Dropout ratio must be passed in the second input") - .into_scalar(); - - let prob = match ratio { - Data::Float16(ratio) => f64::from(f32::from(ratio)), - Data::Float32(ratio) => ratio as f64, - Data::Float64(ratio) => ratio, - _ => panic!("Dropout ratio must be a float"), - }; - - DropoutConfig::new(prob) + // Opset 7 and older store probability as an attribute + if node.attrs.contains_key("ratio") { + let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); + return DropoutConfig::new(prob as f64); + } + + if node.inputs.len() < 2 { + panic!("Dropout configuration must have at least 2 inputs"); + } + + let ratio = node.inputs[1] + .value + .clone() + .expect("Dropout ratio must be passed in the second input") + .into_scalar(); + + let prob = match ratio { + Data::Float16(ratio) => f64::from(f32::from(ratio)), + Data::Float32(ratio) => ratio as f64, + Data::Float64(ratio) => ratio, + _ => panic!("Dropout ratio must be a float"), + }; + + DropoutConfig::new(prob) } /// Create log_softmax config from the attributes of the node pub fn log_softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "LogSoftmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.dim as i64; - } - - axis as usize + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "LogSoftmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.dim as i64; + } + + axis as usize } /// Create softmax config from the attributes of the node pub fn softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Softmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.dim as i64; - } - - axis as usize + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Softmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.dim as i64; + } + + axis as usize } /// Create concat config from the attributes of the node pub fn concat_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // extract the shape of the input tensor + let tensor = match node.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.dim as i64; + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} } + } - axis as usize + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.dim as i64; + } + + axis as usize } /// Create a BatchNormConfig from the attributes of the node pub fn batch_norm_config(node: &Node) -> BatchNormConfig { - // extract the shape of the weight tensor - let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { - tensor_type - } else { - panic!("BatchNorm: weight tensor must be present"); - }; - - let num_features: usize = tensor_type.shape.clone().unwrap()[0]; - - let mut epsilon = 0f32; - let mut momentum = 0f32; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "momentum" => momentum = value.clone().into_f32(), - "epsilon" => epsilon = value.clone().into_f32(), - _ => {} - } + // extract the shape of the weight tensor + let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { + tensor_type + } else { + panic!("BatchNorm: weight tensor must be present"); + }; + + let num_features: usize = tensor_type.shape.clone().unwrap()[0]; + + let mut epsilon = 0f32; + let mut momentum = 0f32; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "momentum" => momentum = value.clone().into_f32(), + "epsilon" => epsilon = value.clone().into_f32(), + _ => {} } + } - BatchNormConfig::new(num_features) - .with_epsilon(epsilon as f64) - .with_momentum(momentum as f64) + BatchNormConfig::new(num_features) + .with_epsilon(epsilon as f64) + .with_momentum(momentum as f64) } /// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. @@ -431,110 +431,110 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig { /// This function is used when the padding is specified as a list of integers, /// and not used when the padding is specified as a string, e.g. "SAME_UPPER". fn padding_config(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == top && top == right && right == bottom && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } else if left == top && top == right && right == bottom && bottom == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig2d::Valid + } else if left == right && top == bottom { + // i.e [2, 3, 2, 3] + PaddingConfig2d::Explicit(left as usize, top as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } } pub fn reshape_config(node: &Node) -> Vec { - let mut allowzero = 0; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "allowzero" => allowzero = value.clone().into_i64(), - _ => {} - } - } - - // Burn does not support zero size shape (0 means false in ONNX) - // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) - if allowzero != 0 { - panic!("Zero shape size is not supported"); - } - - if node.inputs.len() != 2 || node.inputs[1].value.is_none() { - panic!("Reshape: shape tensor must be present"); - } - - let input_value = &node.inputs[1].value; - match &node.inputs[1].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.dim, 1, "Reshape: shape tensor must be 1D"); - - if let Some(Data::Int64s(shape)) = input_value.as_ref() { - shape.clone() - } else { - panic!("Tensor data type must be int64") - } - } - _ => panic!("Only tensor input is valid for shape"), - } + let mut allowzero = 0; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "allowzero" => allowzero = value.clone().into_i64(), + _ => {} + } + } + + // Burn does not support zero size shape (0 means false in ONNX) + // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) + if allowzero != 0 { + panic!("Zero shape size is not supported"); + } + + if node.inputs.len() != 2 || node.inputs[1].value.is_none() { + panic!("Reshape: shape tensor must be present"); + } + + let input_value = &node.inputs[1].value; + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.dim, 1, "Reshape: shape tensor must be 1D"); + + if let Some(Data::Int64s(shape)) = input_value.as_ref() { + shape.clone() + } else { + panic!("Tensor data type must be int64") + } + } + _ => panic!("Only tensor input is valid for shape"), + } } pub fn clip_config(node: &Node) -> (Option, Option) { - let mut min_result: Option = None; - let mut max_result: Option = None; - - // For Clip Opset 6+ , the min and max values are attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "min" => { - let min = value.clone().into_f32() as f64; - min_result = Some(min); - } - "max" => { - let max = value.clone().into_f32(); - max_result = Some(max as f64); - } - _ => {} - } - } - - // For Clip Opset 11+ , the min and max values are inputs - // Get the min and max values from the input values - if min_result.is_none() && max_result.is_none() { - let min = &node.inputs[1].value; - let max = &node.inputs[2].value; - - if min_result.is_none() && min.is_some() { - let min = min.clone().unwrap().into_scalar(); - min_result = match min { - Data::Float16(min) => Some(f32::from(min) as f64), - Data::Float32(min) => Some(min as f64), - Data::Float64(min) => Some(min), - _ => panic!("Clip: only float min is supported"), - }; - } - - if max_result.is_none() && max.is_some() { - let max = max.clone().unwrap().into_scalar(); - max_result = match max { - Data::Float16(max) => Some(f32::from(max) as f64), - Data::Float32(max) => Some(max as f64), - Data::Float64(max) => Some(max), - _ => panic!("Clip: only float max is supported"), - }; - } - } - - if min_result.is_none() && max_result.is_none() { - panic!("Clip: min and max values must be either attributes or inputs"); - } - - (min_result, max_result) + let mut min_result: Option = None; + let mut max_result: Option = None; + + // For Clip Opset 6+ , the min and max values are attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "min" => { + let min = value.clone().into_f32() as f64; + min_result = Some(min); + } + "max" => { + let max = value.clone().into_f32(); + max_result = Some(max as f64); + } + _ => {} + } + } + + // For Clip Opset 11+ , the min and max values are inputs + // Get the min and max values from the input values + if min_result.is_none() && max_result.is_none() { + let min = &node.inputs[1].value; + let max = &node.inputs[2].value; + + if min_result.is_none() && min.is_some() { + let min = min.clone().unwrap().into_scalar(); + min_result = match min { + Data::Float16(min) => Some(f32::from(min) as f64), + Data::Float32(min) => Some(min as f64), + Data::Float64(min) => Some(min), + _ => panic!("Clip: only float min is supported"), + }; + } + + if max_result.is_none() && max.is_some() { + let max = max.clone().unwrap().into_scalar(); + max_result = match max { + Data::Float16(max) => Some(f32::from(max) as f64), + Data::Float32(max) => Some(max as f64), + Data::Float64(max) => Some(max), + _ => panic!("Clip: only float max is supported"), + }; + } + } + + if min_result.is_none() && max_result.is_none() { + panic!("Clip: min and max values must be either attributes or inputs"); + } + + (min_result, max_result) } /// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. @@ -557,20 +557,20 @@ pub fn clip_config(node: &Node) -> (Option, Option) { /// This function is used when the padding is specified as a list of integers, /// and not used when the padding is specified as a string, e.g. "SAME_UPPER". fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { - let [left, right] = [pads[0], pads[1]]; - - if left < 0 || right < 0 { - panic!("Negative pad values are not supported"); - } else if left != right { - panic!("Asymmetric padding is not supported"); - } else if left == right && right == 0 { - // i.e. [0, 0] - PaddingConfig1d::Valid - } else if left == right { - // i.e. [2, 2] - PaddingConfig1d::Explicit(left as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } + let [left, right] = [pads[0], pads[1]]; + + if left < 0 || right < 0 { + panic!("Negative pad values are not supported"); + } else if left != right { + panic!("Asymmetric padding is not supported"); + } else if left == right && right == 0 { + // i.e. [0, 0] + PaddingConfig1d::Valid + } else if left == right { + // i.e. [2, 2] + PaddingConfig1d::Explicit(left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } } diff --git a/burn-import/src/onnx/proto_conversion.rs b/burn-import/src/onnx/proto_conversion.rs index 1fe5aac54a..7dccdb50ba 100644 --- a/burn-import/src/onnx/proto_conversion.rs +++ b/burn-import/src/onnx/proto_conversion.rs @@ -4,11 +4,11 @@ use crate::onnx::ir::TensorType; use super::ir::Dim; use super::ir::{ - ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor, + ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor, }; use super::protos::{ - attribute_proto::AttributeType, tensor_proto::DataType, tensor_shape_proto::dimension::Value, - type_proto, AttributeProto, NodeProto, TensorProto, TensorShapeProto, ValueInfoProto, + attribute_proto::AttributeType, tensor_proto::DataType, tensor_shape_proto::dimension::Value, + type_proto, AttributeProto, NodeProto, TensorProto, TensorShapeProto, ValueInfoProto, }; use bytemuck::cast_slice; @@ -17,246 +17,244 @@ use protobuf::Enum; /// Error type for parsing ONNX model #[derive(Debug)] pub enum ParseError { - VariantNotFound, + VariantNotFound, } /// Convert a vector of AttributeProto to a HashMap of AttributeValue impl TryFrom for Tensor { - type Error = ParseError; - fn try_from(tensor: TensorProto) -> Result { - let (elem_type, data) = match DataType::from_i32(tensor.data_type).unwrap() { - DataType::FLOAT => ( - ElementType::Float32, - // Convert the raw data to a vector of floats - if !tensor.raw_data.is_empty() { - Data::Float32s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Float32s(tensor.float_data) - }, - ), - DataType::INT16 => { - // TODO : Add support for int16 by converting to int32 - todo!("Add support for int16"); - } - DataType::INT32 => ( - ElementType::Int32, - // Convert the raw data to a vector of ints - if !tensor.raw_data.is_empty() { - Data::Int32s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Int32s(tensor.int32_data) - }, - ), - DataType::INT64 => ( - ElementType::Int64, - // Convert the raw data to a vector of ints - if !tensor.raw_data.is_empty() { - Data::Int64s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Int64s(tensor.int64_data) - }, - ), - DataType::DOUBLE => ( - ElementType::Float64, - // Convert the raw data to a vector of floats - if !tensor.raw_data.is_empty() { - Data::Float64s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Float64s(tensor.double_data) - }, - ), - DataType::BOOL => (ElementType::Bool, { - assert!(!tensor.raw_data.is_empty()); - Data::Bools(tensor.raw_data.iter().map(|x| *x != 0).collect()) - }), - // TODO : Add more types - _ => { - return Err(ParseError::VariantNotFound); - } - }; - let shape = convert_shape(tensor.dims); - - Ok(Tensor { - elem_type, - dim: shape.len(), - shape: Some(shape), - data: Some(data), - }) - } + type Error = ParseError; + fn try_from(tensor: TensorProto) -> Result { + let (elem_type, data) = match DataType::from_i32(tensor.data_type).unwrap() { + DataType::FLOAT => ( + ElementType::Float32, + // Convert the raw data to a vector of floats + if !tensor.raw_data.is_empty() { + Data::Float32s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Float32s(tensor.float_data) + }, + ), + DataType::INT16 => { + // TODO : Add support for int16 by converting to int32 + todo!("Add support for int16"); + } + DataType::INT32 => ( + ElementType::Int32, + // Convert the raw data to a vector of ints + if !tensor.raw_data.is_empty() { + Data::Int32s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Int32s(tensor.int32_data) + }, + ), + DataType::INT64 => ( + ElementType::Int64, + // Convert the raw data to a vector of ints + if !tensor.raw_data.is_empty() { + Data::Int64s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Int64s(tensor.int64_data) + }, + ), + DataType::DOUBLE => ( + ElementType::Float64, + // Convert the raw data to a vector of floats + if !tensor.raw_data.is_empty() { + Data::Float64s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Float64s(tensor.double_data) + }, + ), + DataType::BOOL => (ElementType::Bool, { + assert!(!tensor.raw_data.is_empty()); + Data::Bools(tensor.raw_data.iter().map(|x| *x != 0).collect()) + }), + // TODO : Add more types + _ => { + return Err(ParseError::VariantNotFound); + } + }; + let shape = convert_shape(tensor.dims); + + Ok(Tensor { + elem_type, + dim: shape.len(), + shape: Some(shape), + data: Some(data), + }) + } } impl TryFrom for Vec { - type Error = ParseError; - fn try_from(shape: TensorShapeProto) -> Result, Self::Error> { - let mut result = Vec::new(); - - for dim in shape.dim { - if let Value::DimValue(value) = dim.value.unwrap() { - result.push(value as usize); - } - } + type Error = ParseError; + fn try_from(shape: TensorShapeProto) -> Result, Self::Error> { + let mut result = Vec::new(); - Ok(result) + for dim in shape.dim { + if let Value::DimValue(value) = dim.value.unwrap() { + result.push(value as usize); + } } + + Ok(result) + } } /// Convert a vector of AttributeProto to a HashMap of AttributeValue impl TryFrom<&type_proto::Tensor> for Tensor { - type Error = ParseError; - fn try_from(tensor: &type_proto::Tensor) -> Result { - let elem_type = match DataType::from_i32(tensor.elem_type).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - DataType::BOOL => ElementType::Bool, - - // TODO : Add more types - _ => { - return Err(ParseError::VariantNotFound); - } - }; - - let shape_proto = tensor.shape.clone().unwrap(); - let shape: Vec = shape_proto.try_into().unwrap(); - - Ok(Tensor { - elem_type, - dim: shape.len(), - shape: Some(shape), - data: None, - }) - } + type Error = ParseError; + fn try_from(tensor: &type_proto::Tensor) -> Result { + let elem_type = match DataType::from_i32(tensor.elem_type).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + DataType::BOOL => ElementType::Bool, + + // TODO : Add more types + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + let shape_proto = tensor.shape.clone().unwrap(); + let shape: Vec = shape_proto.try_into().unwrap(); + + Ok(Tensor { + elem_type, + dim: shape.len(), + shape: Some(shape), + data: None, + }) + } } fn convert_vec_tensor_proto(tensors: Vec) -> Result, ParseError> { - let mut result = Vec::new(); - for tensor in tensors { - result.push(Tensor::try_from(tensor)?); - } - Ok(result) + let mut result = Vec::new(); + for tensor in tensors { + result.push(Tensor::try_from(tensor)?); + } + Ok(result) } /// Convert a vector of AttributeProto to a HashMap of AttributeValue impl TryFrom for AttributeValue { - type Error = ParseError; - - fn try_from(attr: AttributeProto) -> Result { - let value = match attr.type_.unwrap() { - AttributeType::FLOAT => AttributeValue::Float32(attr.f), - AttributeType::INT => AttributeValue::Int64(attr.i), - AttributeType::STRING => AttributeValue::String(to_string(attr.s)), - - // warning: tensor can be empty TODO: check if it is empty - AttributeType::TENSOR => AttributeValue::Tensor(Tensor::try_from(attr.t.unwrap())?), - - // Graph is not supported for now - // AttributeType::GRAPH => AttributeValue::Graph(attr.g), - AttributeType::FLOATS => AttributeValue::Float32s(attr.floats), - AttributeType::INTS => AttributeValue::Int64s(attr.ints), - AttributeType::STRINGS => AttributeValue::Strings(to_string_vec(attr.strings)), - AttributeType::TENSORS => { - AttributeValue::Tensors(convert_vec_tensor_proto(attr.tensors)?) - } - // AttributeType::GRAPHS => AttributeValue::Graphs(attr.graphs), - // AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(attr.sparse_tensors), - // AttributeType::SPARSE_TENSOR => AttributeValue::SparseTensor(attr.sparse_tensor), - _ => { - return Err(ParseError::VariantNotFound); - } - }; - - Ok(value) - } + type Error = ParseError; + + fn try_from(attr: AttributeProto) -> Result { + let value = match attr.type_.unwrap() { + AttributeType::FLOAT => AttributeValue::Float32(attr.f), + AttributeType::INT => AttributeValue::Int64(attr.i), + AttributeType::STRING => AttributeValue::String(to_string(attr.s)), + + // warning: tensor can be empty TODO: check if it is empty + AttributeType::TENSOR => AttributeValue::Tensor(Tensor::try_from(attr.t.unwrap())?), + + // Graph is not supported for now + // AttributeType::GRAPH => AttributeValue::Graph(attr.g), + AttributeType::FLOATS => AttributeValue::Float32s(attr.floats), + AttributeType::INTS => AttributeValue::Int64s(attr.ints), + AttributeType::STRINGS => AttributeValue::Strings(to_string_vec(attr.strings)), + AttributeType::TENSORS => AttributeValue::Tensors(convert_vec_tensor_proto(attr.tensors)?), + // AttributeType::GRAPHS => AttributeValue::Graphs(attr.graphs), + // AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(attr.sparse_tensors), + // AttributeType::SPARSE_TENSOR => AttributeValue::SparseTensor(attr.sparse_tensor), + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + Ok(value) + } } /// Convert a vector of AttributeProto to a HashMap of AttributeValue pub fn convert_vec_attrs_proto(attrs: Vec) -> Attributes { - let mut result = Attributes::new(); - for attr in attrs { - result.insert(attr.name.clone(), AttributeValue::try_from(attr).unwrap()); - } - result + let mut result = Attributes::new(); + for attr in attrs { + result.insert(attr.name.clone(), AttributeValue::try_from(attr).unwrap()); + } + result } pub fn convert_node_proto(node: &NodeProto) -> Node { - let name = node.name.clone(); + let name = node.name.clone(); - log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str()); + log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str()); - let inputs = node.input.clone().into_iter().map(Argument::new).collect(); + let inputs = node.input.clone().into_iter().map(Argument::new).collect(); - let outputs = node.output.clone().into_iter().map(Argument::new).collect(); + let outputs = node.output.clone().into_iter().map(Argument::new).collect(); - let attrs = convert_vec_attrs_proto(node.attribute.clone()); + let attrs = convert_vec_attrs_proto(node.attribute.clone()); - let node_type = NodeType::from_str(node.op_type.as_str()).expect("Unknown node type"); + let node_type = NodeType::from_str(node.op_type.as_str()).expect("Unknown node type"); - Node { - node_type, - name, - inputs, - outputs, - attrs, - } + Node { + node_type, + name, + inputs, + outputs, + attrs, + } } fn to_string(bytes: Vec) -> String { - from_utf8(bytes.as_slice()).unwrap().to_string() + from_utf8(bytes.as_slice()).unwrap().to_string() } fn to_string_vec(bytes: Vec>) -> Vec { - bytes.iter().map(|b| to_string(b.clone())).collect() + bytes.iter().map(|b| to_string(b.clone())).collect() } fn convert_shape(shape: Vec) -> Vec { - shape.iter().map(|s| *s as usize).collect() + shape.iter().map(|s| *s as usize).collect() } impl TryFrom for Argument { - type Error = ParseError; - - fn try_from(value: ValueInfoProto) -> Result { - let name = value.name.clone(); - let proto_type = value.type_.unwrap(); - - if !proto_type.has_tensor_type() { - panic!("Unsupported argument type {:?}", proto_type); - } - - let tensor_proto = proto_type.tensor_type(); - - let elem_type = match DataType::from_i32(tensor_proto.elem_type).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - DataType::BOOL => ElementType::Bool, - _ => { - return Err(ParseError::VariantNotFound); - } - }; - - let tensor_type = TensorType { - dim: tensor_proto.shape.dim.len(), - elem_type, - shape: Some( - tensor_proto - .shape - .dim - .iter() - .map(|x| x.dim_value() as Dim) - .collect(), - ), - }; - - let ty = ArgType::Tensor(tensor_type); - - Ok(Argument { - ty, - name, - value: None, - passed: false, - }) + type Error = ParseError; + + fn try_from(value: ValueInfoProto) -> Result { + let name = value.name.clone(); + let proto_type = value.type_.unwrap(); + + if !proto_type.has_tensor_type() { + panic!("Unsupported argument type {:?}", proto_type); } + + let tensor_proto = proto_type.tensor_type(); + + let elem_type = match DataType::from_i32(tensor_proto.elem_type).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + DataType::BOOL => ElementType::Bool, + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + let tensor_type = TensorType { + dim: tensor_proto.shape.dim.len(), + elem_type, + shape: Some( + tensor_proto + .shape + .dim + .iter() + .map(|x| x.dim_value() as Dim) + .collect(), + ), + }; + + let ty = ArgType::Tensor(tensor_type); + + Ok(Argument { + ty, + name, + value: None, + passed: false, + }) + } } diff --git a/burn-import/src/onnx/protos/mod.rs b/burn-import/src/onnx/protos/mod.rs index 328e850e76..b18e3c0908 100644 --- a/burn-import/src/onnx/protos/mod.rs +++ b/burn-import/src/onnx/protos/mod.rs @@ -1,5 +1,5 @@ mod inner { - include!(concat!(env!("OUT_DIR"), "/onnx-protos/mod.rs")); + include!(concat!(env!("OUT_DIR"), "/onnx-protos/mod.rs")); } pub use inner::onnx::*; diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index 4bf912bd44..b2889e3ff9 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -1,56 +1,55 @@ use std::{ - env, - fs::{self, create_dir_all}, - path::{Path, PathBuf}, + env, + fs::{self, create_dir_all}, + path::{Path, PathBuf}, }; use burn::{ - record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings}, - tensor::{DataSerialize, Element}, + record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings}, + tensor::{DataSerialize, Element}, }; use crate::{ - burn::{ - graph::BurnGraph, - node::{ - avg_pool2d::AvgPool2dNode, - batch_norm::BatchNormNode, - binary::BinaryNode, - clip::ClipNode, - concat::ConcatNode, - constant::{ConstantNode, ConstantValue, TensorValue}, - conv1d::Conv1dNode, - conv2d::Conv2dNode, - dropout::DropoutNode, - gather::GatherNode, - global_avg_pool::GlobalAvgPoolNode, - linear::LinearNode, - matmul::MatmulNode, - max_pool2d::MaxPool2dNode, - reshape::ReshapeNode, - unary::UnaryNode, - }, - ScalarKind, ScalarType, TensorKind, TensorType, Type, + burn::{ + graph::BurnGraph, + node::{ + avg_pool2d::AvgPool2dNode, + batch_norm::BatchNormNode, + binary::BinaryNode, + clip::ClipNode, + concat::ConcatNode, + constant::{ConstantNode, ConstantValue, TensorValue}, + conv1d::Conv1dNode, + conv2d::Conv2dNode, + dropout::DropoutNode, + gather::GatherNode, + global_avg_pool::GlobalAvgPoolNode, + linear::LinearNode, + matmul::MatmulNode, + max_pool2d::MaxPool2dNode, + reshape::ReshapeNode, + unary::UnaryNode, }, - format_tokens, - logger::init_log, - onnx::{ - from_onnx::convert_constant_value, - ir::{Node, NodeType}, - op_configuration::{ - batch_norm_config, conv1d_config, conv2d_config, flatten_config, gather_config, - linear_config, log_softmax_config, max_pool2d_config, - }, + ScalarKind, ScalarType, TensorKind, TensorType, Type, + }, + format_tokens, + logger::init_log, + onnx::{ + from_onnx::convert_constant_value, + ir::{Node, NodeType}, + op_configuration::{ + batch_norm_config, conv1d_config, conv2d_config, flatten_config, gather_config, + linear_config, log_softmax_config, max_pool2d_config, }, + }, }; use super::{ - from_onnx::parse_onnx, - ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph}, - op_configuration::{ - avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, - softmax_config, - }, + from_onnx::parse_onnx, + ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph}, + op_configuration::{ + avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, softmax_config, + }, }; pub use crate::burn::graph::RecordType; @@ -58,547 +57,542 @@ pub use crate::burn::graph::RecordType; /// Generate code and states from `.onnx` files and save them to the `out_dir`. #[derive(Debug, Default)] pub struct ModelGen { - out_dir: Option, - /// List of onnx files to generate source code from. - inputs: Vec, - development: bool, - half_precision: bool, - record_type: RecordType, - embed_states: bool, + out_dir: Option, + /// List of onnx files to generate source code from. + inputs: Vec, + development: bool, + half_precision: bool, + record_type: RecordType, + embed_states: bool, } impl ModelGen { - /// Create a new `ModelGen`. - pub fn new() -> Self { - init_log().ok(); // Error when init multiple times are ignored. - Self::default() - } - - /// Set output directory. - pub fn out_dir(&mut self, out_dir: &str) -> &mut Self { - self.out_dir = Some(Path::new(out_dir).into()); - self - } - - /// Add input file. - pub fn input(&mut self, input: &str) -> &mut Self { - self.inputs.push(input.into()); - self - } - - /// Set development mode. - /// - /// If this is set to true, the generated model will be saved as `.graph.txt` files and model - /// states will be saved as `.json` file. - pub fn development(&mut self, development: bool) -> &mut Self { - self.development = development; - self - } - - /// Run code generation. - /// - /// This function is intended to be called from `build.rs` script. - pub fn run_from_script(&self) { - self.run(true); - } - - /// Run code generation. - /// - /// This function is intended to be called from CLI. - pub fn run_from_cli(&self) { - self.run(false); - } - - /// Specify parameter precision to be saved. - /// - /// # Arguments - /// - /// * `half_precision` - If true, half precision is saved. Otherwise, full precision is saved. - pub fn half_precision(&mut self, half_precision: bool) -> &mut Self { - self.half_precision = half_precision; - self - } - - /// Specify the type of the record to be saved. - /// - /// # Arguments - /// - /// * `record_type` - The type of the record to be saved. - pub fn record_type(&mut self, record_type: RecordType) -> &mut Self { - self.record_type = record_type; - self - } - - /// Specify whether to embed states in the generated code. - /// - /// # Arguments - /// - /// * `embed_states` - If true, states are embedded in the generated code. Otherwise, states are - /// saved as a separate file. - pub fn embed_states(&mut self, embed_states: bool) -> &mut Self { - self.embed_states = embed_states; - self - } - - /// Run code generation. - fn run(&self, is_build_script: bool) { - log::info!("Starting to convert ONNX to Burn"); - - // prepend the out_dir to the cargo_out_dir if this is a build script - let out_dir = if is_build_script { - let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set"); - let mut path = PathBuf::from(cargo_out_dir); - - // // Append the out_dir to the cargo_out_dir - path.push(self.out_dir.clone().unwrap()); - path - } else { - self.out_dir.as_ref().expect("out_dir is not set").clone() - }; - - log::debug!("Output directory: {:?}", out_dir); - - create_dir_all(&out_dir).unwrap(); - - for input in self.inputs.iter() { - let file_name = input.file_stem().unwrap(); - let out_file: PathBuf = out_dir.join(file_name); - - log::info!("Converting {:?}", input); - log::debug!("Input file name: {:?}", file_name); - log::debug!("Output file: {:?}", out_file); - - self.generate_model(input, out_file); - } - - log::info!("Finished converting ONNX to Burn"); - } - - /// Generate model source code and model state. - fn generate_model(&self, input: &PathBuf, out_file: PathBuf) { - log::info!("Generating model from {:?}", input); - log::debug!("Development mode: {:?}", self.development); - log::debug!("Output file: {:?}", out_file); - - let graph = parse_onnx(input.as_ref()); - - if self.development { - // export the graph - let debug_graph = format!("{:#?}", graph); - let graph_file = out_file.with_extension("graph.txt"); - log::debug!("Writing debug graph file: {:?}", graph_file); - fs::write(graph_file, debug_graph).unwrap(); - } - - let new_fn = true; - let blank_space = true; - let top_comment = Some(format!("Generated from ONNX {input:?} by burn-import")); - - let code = if self.half_precision { - graph - .into_burn::() - .with_record(out_file.clone(), self.record_type, self.embed_states) - .with_new_fn(new_fn) - .with_blank_space(blank_space) - .with_top_comment(top_comment) - .codegen() - } else { - graph - .into_burn::() - .with_record(out_file.clone(), self.record_type, self.embed_states) - .with_new_fn(new_fn) - .with_blank_space(blank_space) - .with_top_comment(top_comment) - .codegen() - }; - - let code_str = format_tokens(code); - fs::write(out_file.with_extension("rs"), code_str).unwrap(); - - log::info!("Model generated"); - } + /// Create a new `ModelGen`. + pub fn new() -> Self { + init_log().ok(); // Error when init multiple times are ignored. + Self::default() + } + + /// Set output directory. + pub fn out_dir(&mut self, out_dir: &str) -> &mut Self { + self.out_dir = Some(Path::new(out_dir).into()); + self + } + + /// Add input file. + pub fn input(&mut self, input: &str) -> &mut Self { + self.inputs.push(input.into()); + self + } + + /// Set development mode. + /// + /// If this is set to true, the generated model will be saved as `.graph.txt` files and model + /// states will be saved as `.json` file. + pub fn development(&mut self, development: bool) -> &mut Self { + self.development = development; + self + } + + /// Run code generation. + /// + /// This function is intended to be called from `build.rs` script. + pub fn run_from_script(&self) { + self.run(true); + } + + /// Run code generation. + /// + /// This function is intended to be called from CLI. + pub fn run_from_cli(&self) { + self.run(false); + } + + /// Specify parameter precision to be saved. + /// + /// # Arguments + /// + /// * `half_precision` - If true, half precision is saved. Otherwise, full precision is saved. + pub fn half_precision(&mut self, half_precision: bool) -> &mut Self { + self.half_precision = half_precision; + self + } + + /// Specify the type of the record to be saved. + /// + /// # Arguments + /// + /// * `record_type` - The type of the record to be saved. + pub fn record_type(&mut self, record_type: RecordType) -> &mut Self { + self.record_type = record_type; + self + } + + /// Specify whether to embed states in the generated code. + /// + /// # Arguments + /// + /// * `embed_states` - If true, states are embedded in the generated code. Otherwise, states are + /// saved as a separate file. + pub fn embed_states(&mut self, embed_states: bool) -> &mut Self { + self.embed_states = embed_states; + self + } + + /// Run code generation. + fn run(&self, is_build_script: bool) { + log::info!("Starting to convert ONNX to Burn"); + + // prepend the out_dir to the cargo_out_dir if this is a build script + let out_dir = if is_build_script { + let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set"); + let mut path = PathBuf::from(cargo_out_dir); + + // // Append the out_dir to the cargo_out_dir + path.push(self.out_dir.clone().unwrap()); + path + } else { + self.out_dir.as_ref().expect("out_dir is not set").clone() + }; + + log::debug!("Output directory: {:?}", out_dir); + + create_dir_all(&out_dir).unwrap(); + + for input in self.inputs.iter() { + let file_name = input.file_stem().unwrap(); + let out_file: PathBuf = out_dir.join(file_name); + + log::info!("Converting {:?}", input); + log::debug!("Input file name: {:?}", file_name); + log::debug!("Output file: {:?}", out_file); + + self.generate_model(input, out_file); + } + + log::info!("Finished converting ONNX to Burn"); + } + + /// Generate model source code and model state. + fn generate_model(&self, input: &PathBuf, out_file: PathBuf) { + log::info!("Generating model from {:?}", input); + log::debug!("Development mode: {:?}", self.development); + log::debug!("Output file: {:?}", out_file); + + let graph = parse_onnx(input.as_ref()); + + if self.development { + // export the graph + let debug_graph = format!("{:#?}", graph); + let graph_file = out_file.with_extension("graph.txt"); + log::debug!("Writing debug graph file: {:?}", graph_file); + fs::write(graph_file, debug_graph).unwrap(); + } + + let new_fn = true; + let blank_space = true; + let top_comment = Some(format!("Generated from ONNX {input:?} by burn-import")); + + let code = if self.half_precision { + graph + .into_burn::() + .with_record(out_file.clone(), self.record_type, self.embed_states) + .with_new_fn(new_fn) + .with_blank_space(blank_space) + .with_top_comment(top_comment) + .codegen() + } else { + graph + .into_burn::() + .with_record(out_file.clone(), self.record_type, self.embed_states) + .with_new_fn(new_fn) + .with_blank_space(blank_space) + .with_top_comment(top_comment) + .codegen() + }; + + let code_str = format_tokens(code); + fs::write(out_file.with_extension("rs"), code_str).unwrap(); + + log::info!("Model generated"); + } } impl ONNXGraph { - /// Converts ONNX graph to Burn graph. - pub fn into_burn(self) -> BurnGraph { - let mut graph = BurnGraph::::default(); - - for node in self.nodes { - match node.node_type { - NodeType::Add => graph.register(Self::add_conversion(node)), - NodeType::Sub => graph.register(Self::sub_conversion(node)), - NodeType::Mul => graph.register(Self::mul_conversion(node)), - NodeType::Div => graph.register(Self::div_conversion(node)), - NodeType::Equal => graph.register(Self::equal_conversion(node)), - NodeType::Erf => graph.register(Self::erf_conversion(node)), - NodeType::Clip => graph.register(Self::clip_conversion(node)), - NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), - NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), - NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), - NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), - NodeType::MatMul => graph.register(Self::matmul_conversion(node)), - NodeType::Linear => graph.register(Self::linear_conversion::(node)), - NodeType::BatchNormalization => { - graph.register(Self::batch_norm_conversion::(node)) - } - NodeType::Relu => graph.register(Self::relu_conversion(node)), - NodeType::Flatten => graph.register(Self::flatten_conversion(node)), - NodeType::GatherElements => graph.register(Self::gather_conversion(node)), - NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)), - NodeType::Softmax => graph.register(Self::softmax_conversion(node)), - NodeType::Tanh => graph.register(Self::tanh_conversion(node)), - NodeType::Constant => graph.register(Self::constant_conversion::(node)), - NodeType::Reshape => graph.register(Self::reshape_conversion(node)), - NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), - NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), - NodeType::Transpose => graph.register(Self::transpose_conversion(node)), - NodeType::Concat => graph.register(Self::concat_conversion(node)), - NodeType::Cast => graph.register(Self::cast_conversion(node)), - NodeType::Dropout => graph.register(Self::dropout_conversion(node)), - NodeType::GlobalAveragePool => { - graph.register(Self::global_avg_pool_conversion(node)) - } - _ => panic!("Unsupported node conversion {}", node.node_type), - } - } - - // Get input and output names - let input_names = self - .inputs - .iter() - .map(|input| input.name.clone()) - .collect::>(); - let output_names = self - .outputs - .iter() - .map(|output| output.name.clone()) - .collect::>(); - - // Register inputs and outputs with the graph - graph.register_input_output(input_names, output_names); - - graph - } - - fn constant_conversion(node: Node) -> ConstantNode { - let output = node.outputs.get(0).unwrap(); - - let attr = convert_constant_value(&node); - - let const_value = match attr.ty { - ArgType::Tensor(tensor) => { - // Treat tensor with dim 0 as scalar - if tensor.dim == 0 { - panic!("Constant tensor with dim 0 should have been converted to scalar.") - } else { - let kind: TensorKind = tensor.elem_type.clone().into(); - let dim = tensor.dim; - let name = node.name.clone(); - let shape = tensor.shape.clone(); - - let tensor_value = match tensor.elem_type { - // TODO Review how double precision should be supported - ElementType::Float32 | ElementType::Float64 => { - TensorValue::Float(serialize_data::( - attr.value.unwrap(), - tensor.shape.unwrap(), - )) - } - ElementType::Int32 | ElementType::Int64 => { - TensorValue::Int(serialize_data::( - attr.value.unwrap(), - tensor.shape.unwrap(), - )) - } - // TODO support Bool tensor when it is supported by Burn - _ => panic!("Unsupported constant tensor type: {:?} ", tensor.elem_type), - }; - - ConstantValue::Tensor(TensorType::new(name, dim, kind, shape), tensor_value) + /// Converts ONNX graph to Burn graph. + pub fn into_burn(self) -> BurnGraph { + let mut graph = BurnGraph::::default(); + + for node in self.nodes { + match node.node_type { + NodeType::Add => graph.register(Self::add_conversion(node)), + NodeType::Sub => graph.register(Self::sub_conversion(node)), + NodeType::Mul => graph.register(Self::mul_conversion(node)), + NodeType::Div => graph.register(Self::div_conversion(node)), + NodeType::Equal => graph.register(Self::equal_conversion(node)), + NodeType::Erf => graph.register(Self::erf_conversion(node)), + NodeType::Clip => graph.register(Self::clip_conversion(node)), + NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), + NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), + NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), + NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), + NodeType::MatMul => graph.register(Self::matmul_conversion(node)), + NodeType::Linear => graph.register(Self::linear_conversion::(node)), + NodeType::BatchNormalization => graph.register(Self::batch_norm_conversion::(node)), + NodeType::Relu => graph.register(Self::relu_conversion(node)), + NodeType::Flatten => graph.register(Self::flatten_conversion(node)), + NodeType::GatherElements => graph.register(Self::gather_conversion(node)), + NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)), + NodeType::Softmax => graph.register(Self::softmax_conversion(node)), + NodeType::Tanh => graph.register(Self::tanh_conversion(node)), + NodeType::Constant => graph.register(Self::constant_conversion::(node)), + NodeType::Reshape => graph.register(Self::reshape_conversion(node)), + NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), + NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), + NodeType::Transpose => graph.register(Self::transpose_conversion(node)), + NodeType::Concat => graph.register(Self::concat_conversion(node)), + NodeType::Cast => graph.register(Self::cast_conversion(node)), + NodeType::Dropout => graph.register(Self::dropout_conversion(node)), + NodeType::GlobalAveragePool => graph.register(Self::global_avg_pool_conversion(node)), + _ => panic!("Unsupported node conversion {}", node.node_type), + } + } + + // Get input and output names + let input_names = self + .inputs + .iter() + .map(|input| input.name.clone()) + .collect::>(); + let output_names = self + .outputs + .iter() + .map(|output| output.name.clone()) + .collect::>(); + + // Register inputs and outputs with the graph + graph.register_input_output(input_names, output_names); + + graph + } + + fn constant_conversion(node: Node) -> ConstantNode { + let output = node.outputs.get(0).unwrap(); + + let attr = convert_constant_value(&node); + + let const_value = + match attr.ty { + ArgType::Tensor(tensor) => { + // Treat tensor with dim 0 as scalar + if tensor.dim == 0 { + panic!("Constant tensor with dim 0 should have been converted to scalar.") + } else { + let kind: TensorKind = tensor.elem_type.clone().into(); + let dim = tensor.dim; + let name = node.name.clone(); + let shape = tensor.shape.clone(); + + let tensor_value = + match tensor.elem_type { + // TODO Review how double precision should be supported + ElementType::Float32 | ElementType::Float64 => TensorValue::Float( + serialize_data::(attr.value.unwrap(), tensor.shape.unwrap()), + ), + ElementType::Int32 | ElementType::Int64 => { + TensorValue::Int(serialize_data::( + attr.value.unwrap(), + tensor.shape.unwrap(), + )) } - } - ArgType::Scalar(elem_type) => match elem_type { - ElementType::Float64 => ConstantValue::Float64(attr.value.unwrap().into_f64()), - ElementType::Float32 => ConstantValue::Float32(attr.value.unwrap().into_f32()), - ElementType::Int32 => ConstantValue::Int32(attr.value.unwrap().into_i32()), - ElementType::Int64 => ConstantValue::Int64(attr.value.unwrap().into_i64()), - ElementType::Bool => ConstantValue::Bool(attr.value.unwrap().into_bool()), - _ => panic!("Unsupported constant tensor type: {:?} ", elem_type), - }, - ArgType::Shape(_) => panic!("Shape is not supported as constant value."), - }; - - ConstantNode::new(node.name.clone(), const_value, output.to_type()) - } - - fn add_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - BinaryNode::add(lhs, rhs, output) - } - - fn sub_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - BinaryNode::sub(lhs, rhs, output) - } - - fn mul_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - BinaryNode::mul(lhs, rhs, output) - } - - fn div_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - BinaryNode::div(lhs, rhs, output) - } - - fn matmul_conversion(node: Node) -> MatmulNode { - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - - MatmulNode::new(lhs, rhs, output) - } + // TODO support Bool tensor when it is supported by Burn + _ => panic!("Unsupported constant tensor type: {:?} ", tensor.elem_type), + }; - fn equal_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - BinaryNode::equal(lhs, rhs, output) - } - - fn erf_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::erf(input, output) - } - - fn relu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::relu(input, output) - } - - fn flatten_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - let (start_dim, end_dim) = flatten_config(&node); - - UnaryNode::flatten(input, output, start_dim, end_dim) - } - - fn gather_conversion(node: Node) -> GatherNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let index = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let dim = gather_config(&node); - - GatherNode::new(input, index, output, dim) - } - - fn transpose_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::transpose(input, output) - } - - fn cast_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::cast(input, output) - } - - fn reshape_conversion(node: Node) -> ReshapeNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let shape = reshape_config(&node); - - ReshapeNode::new(input, output, shape) - } - - fn clip_conversion(node: Node) -> ClipNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let (min, max) = clip_config(&node); - - ClipNode::new(input, output, min, max) - } - - fn sigmoid_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::sigmoid(input, output) - } - - fn reciprocal_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::reciprocal(input, output) - } - - fn log_softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - let dim = log_softmax_config(&node); + ConstantValue::Tensor(TensorType::new(name, dim, kind, shape), tensor_value) + } + } + ArgType::Scalar(elem_type) => match elem_type { + ElementType::Float64 => ConstantValue::Float64(attr.value.unwrap().into_f64()), + ElementType::Float32 => ConstantValue::Float32(attr.value.unwrap().into_f32()), + ElementType::Int32 => ConstantValue::Int32(attr.value.unwrap().into_i32()), + ElementType::Int64 => ConstantValue::Int64(attr.value.unwrap().into_i64()), + ElementType::Bool => ConstantValue::Bool(attr.value.unwrap().into_bool()), + _ => panic!("Unsupported constant tensor type: {:?} ", elem_type), + }, + ArgType::Shape(_) => panic!("Shape is not supported as constant value."), + }; - UnaryNode::log_softmax(input, output, dim) - } + ConstantNode::new(node.name.clone(), const_value, output.to_type()) + } - fn softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - let dim = softmax_config(&node); + fn add_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - UnaryNode::softmax(input, output, dim) - } + BinaryNode::add(lhs, rhs, output) + } - fn tanh_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + fn sub_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - UnaryNode::tanh(input, output) - } + BinaryNode::sub(lhs, rhs, output) + } - fn concat_conversion(node: Node) -> ConcatNode { - let inputs = node - .inputs - .iter() - .map(|input| input.to_tensor_type()) - .collect(); + fn mul_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let dim = concat_config(&node); + BinaryNode::mul(lhs, rhs, output) + } - ConcatNode::new(inputs, output, dim) - } + fn div_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - fn linear_conversion(node: Node) -> LinearNode { - let name = &node.name; - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = linear_config(&node); + BinaryNode::div(lhs, rhs, output) + } - let weight = extract_data_serialize::(1, &node).expect("Weight is required"); + fn matmul_conversion(node: Node) -> MatmulNode { + let lhs = node.inputs.get(0).unwrap().to_tensor_type(); + let rhs = node.inputs.get(1).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); - let bias = extract_data_serialize::(2, &node); + MatmulNode::new(lhs, rhs, output) + } - LinearNode::new(name, input, output, weight, bias, config) - } + fn equal_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - fn dropout_conversion(node: Node) -> DropoutNode { - let name = &node.name; - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = dropout_config(&node); + BinaryNode::equal(lhs, rhs, output) + } - DropoutNode::new(name, input, output, config) - } + fn erf_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - fn batch_norm_conversion(node: Node) -> BatchNormNode { - let config = batch_norm_config(&node); - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let dim = input.dim - 2; - - let gamma = extract_data_serialize::(1, &node).expect("Gamma is required"); - let beta = extract_data_serialize::(2, &node).expect("Beta is required"); - let running_mean = - extract_data_serialize::(3, &node).expect("Running mean is required"); - let running_var = - extract_data_serialize::(4, &node).expect("Running var is required"); - - let name = &node.name; - - BatchNormNode::new( - dim, - name, - input, - output, - gamma, - beta, - running_mean, - running_var, - config, - ) - } + UnaryNode::erf(input, output) + } - fn conv1d_conversion(node: Node) -> Conv1dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = conv1d_config(&node); + fn relu_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - let bias = node.inputs.len() == 3; - let weight = extract_data_serialize::(1, &node).unwrap(); - let bias = match bias { - true => extract_data_serialize::(2, &node), - false => None, - }; + UnaryNode::relu(input, output) + } - let name = &node.name; - Conv1dNode::::new(name, input, output, weight, bias, config) - } + fn flatten_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + let (start_dim, end_dim) = flatten_config(&node); - fn conv2d_conversion(node: Node) -> Conv2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = conv2d_config(&node); + UnaryNode::flatten(input, output, start_dim, end_dim) + } - let bias = node.inputs.len() == 3; - let weight = extract_data_serialize::(1, &node).unwrap(); - let bias = match bias { - true => extract_data_serialize::(2, &node), - false => None, - }; + fn gather_conversion(node: Node) -> GatherNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let index = node.inputs.get(1).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let dim = gather_config(&node); - let name = &node.name; - Conv2dNode::::new(name, input, output, weight, bias, config) - } + GatherNode::new(input, index, output, dim) + } - fn max_pool2d_conversion(node: Node) -> MaxPool2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = max_pool2d_config(&node); + fn transpose_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - let name = &node.name; - MaxPool2dNode::new(name, input, output, config) - } + UnaryNode::transpose(input, output) + } - fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = avg_pool2d_config(&node); + fn cast_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - let name = &node.name; - AvgPool2dNode::new(name, input, output, config) - } + UnaryNode::cast(input, output) + } - fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + fn reshape_conversion(node: Node) -> ReshapeNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let shape = reshape_config(&node); - let name = &node.name; + ReshapeNode::new(input, output, shape) + } - GlobalAvgPoolNode::new(name, input, output) - } + fn clip_conversion(node: Node) -> ClipNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let (min, max) = clip_config(&node); + + ClipNode::new(input, output, min, max) + } + + fn sigmoid_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::sigmoid(input, output) + } + + fn reciprocal_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::reciprocal(input, output) + } + + fn log_softmax_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + let dim = log_softmax_config(&node); + + UnaryNode::log_softmax(input, output, dim) + } + + fn softmax_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + let dim = softmax_config(&node); + + UnaryNode::softmax(input, output, dim) + } + + fn tanh_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::tanh(input, output) + } + + fn concat_conversion(node: Node) -> ConcatNode { + let inputs = node + .inputs + .iter() + .map(|input| input.to_tensor_type()) + .collect(); + + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let dim = concat_config(&node); + + ConcatNode::new(inputs, output, dim) + } + + fn linear_conversion(node: Node) -> LinearNode { + let name = &node.name; + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = linear_config(&node); + + let weight = extract_data_serialize::(1, &node).expect("Weight is required"); + + let bias = extract_data_serialize::(2, &node); + + LinearNode::new(name, input, output, weight, bias, config) + } + + fn dropout_conversion(node: Node) -> DropoutNode { + let name = &node.name; + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = dropout_config(&node); + + DropoutNode::new(name, input, output, config) + } + + fn batch_norm_conversion(node: Node) -> BatchNormNode { + let config = batch_norm_config(&node); + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let dim = input.dim - 2; + + let gamma = extract_data_serialize::(1, &node).expect("Gamma is required"); + let beta = extract_data_serialize::(2, &node).expect("Beta is required"); + let running_mean = + extract_data_serialize::(3, &node).expect("Running mean is required"); + let running_var = + extract_data_serialize::(4, &node).expect("Running var is required"); + + let name = &node.name; + + BatchNormNode::new( + dim, + name, + input, + output, + gamma, + beta, + running_mean, + running_var, + config, + ) + } + + fn conv1d_conversion(node: Node) -> Conv1dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = conv1d_config(&node); + + let bias = node.inputs.len() == 3; + let weight = extract_data_serialize::(1, &node).unwrap(); + let bias = match bias { + true => extract_data_serialize::(2, &node), + false => None, + }; + + let name = &node.name; + Conv1dNode::::new(name, input, output, weight, bias, config) + } + + fn conv2d_conversion(node: Node) -> Conv2dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = conv2d_config(&node); + + let bias = node.inputs.len() == 3; + let weight = extract_data_serialize::(1, &node).unwrap(); + let bias = match bias { + true => extract_data_serialize::(2, &node), + false => None, + }; + + let name = &node.name; + Conv2dNode::::new(name, input, output, weight, bias, config) + } + + fn max_pool2d_conversion(node: Node) -> MaxPool2dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = max_pool2d_config(&node); + + let name = &node.name; + MaxPool2dNode::new(name, input, output, config) + } + + fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = avg_pool2d_config(&node); + + let name = &node.name; + AvgPool2dNode::new(name, input, output, config) + } + + fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + + let name = &node.name; + + GlobalAvgPoolNode::new(name, input, output) + } } /// Extract data from node states and convert it to `DataSerialize`. @@ -609,108 +603,108 @@ impl ONNXGraph { /// * `node` - The node where value are stored. #[track_caller] fn extract_data_serialize(input_index: usize, node: &Node) -> Option> { - if node.inputs.is_empty() { - return None; - } - - let input = node.inputs.get(input_index); - input?; - let input = input.unwrap(); - input.value.as_ref()?; - let ty = input.ty.clone(); - - match ty { - ArgType::Tensor(tensor_type) => { - let value = input.value.as_ref().expect("Value to be provided.").clone(); - - Some(serialize_data( - value.clone(), - tensor_type.shape.unwrap().clone(), - )) - } - _ => panic!("Unsupported serialization type"), - } + if node.inputs.is_empty() { + return None; + } + + let input = node.inputs.get(input_index); + input?; + let input = input.unwrap(); + input.value.as_ref()?; + let ty = input.ty.clone(); + + match ty { + ArgType::Tensor(tensor_type) => { + let value = input.value.as_ref().expect("Value to be provided.").clone(); + + Some(serialize_data( + value.clone(), + tensor_type.shape.unwrap().clone(), + )) + } + _ => panic!("Unsupported serialization type"), + } } /// Convert data to `DataSerialize`. fn serialize_data(data: Data, shape: Vec) -> DataSerialize { - match data { - Data::Float16s(val) => DataSerialize::new(val, shape).convert(), - Data::Float32s(val) => DataSerialize::new(val, shape).convert(), - Data::Float64s(val) => DataSerialize::new(val, shape).convert(), - Data::Int32s(val) => DataSerialize::new(val, shape).convert(), - Data::Int64s(val) => DataSerialize::new(val, shape).convert(), - // TODO support Bool tensor when it is supported by Burn - _ => panic!("Unsupported tensor element type"), - } + match data { + Data::Float16s(val) => DataSerialize::new(val, shape).convert(), + Data::Float32s(val) => DataSerialize::new(val, shape).convert(), + Data::Float64s(val) => DataSerialize::new(val, shape).convert(), + Data::Int32s(val) => DataSerialize::new(val, shape).convert(), + Data::Int64s(val) => DataSerialize::new(val, shape).convert(), + // TODO support Bool tensor when it is supported by Burn + _ => panic!("Unsupported tensor element type"), + } } impl Argument { - pub fn to_tensor_type(&self) -> TensorType { - match &self.ty { - ArgType::Tensor(ir::TensorType { - elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64, - dim, - .. - }) => TensorType::new_float(self.name.clone(), *dim), - ArgType::Tensor(ir::TensorType { - elem_type: ElementType::Int32 | ElementType::Int64, - dim, - .. - }) => TensorType::new_int(self.name.clone(), *dim), - _ => panic!("Can't transform to tensor."), + pub fn to_tensor_type(&self) -> TensorType { + match &self.ty { + ArgType::Tensor(ir::TensorType { + elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64, + dim, + .. + }) => TensorType::new_float(self.name.clone(), *dim), + ArgType::Tensor(ir::TensorType { + elem_type: ElementType::Int32 | ElementType::Int64, + dim, + .. + }) => TensorType::new_int(self.name.clone(), *dim), + _ => panic!("Can't transform to tensor."), + } + } + + pub fn to_type(&self) -> Type { + match &self.ty { + ArgType::Tensor(tensor) => { + // Treat tensor with dim 0 as scalar + if tensor.dim == 0 { + Type::Scalar(ScalarType::new( + self.name.clone(), + ScalarKind::from(&tensor.elem_type), + )) + } else { + let kind: TensorKind = tensor.elem_type.clone().into(); + let dim = tensor.dim; + let name = self.name.clone(); + let shape = tensor.shape.clone(); + Type::Tensor(TensorType::new(name, dim, kind, shape)) } - } + } - pub fn to_type(&self) -> Type { - match &self.ty { - ArgType::Tensor(tensor) => { - // Treat tensor with dim 0 as scalar - if tensor.dim == 0 { - Type::Scalar(ScalarType::new( - self.name.clone(), - ScalarKind::from(&tensor.elem_type), - )) - } else { - let kind: TensorKind = tensor.elem_type.clone().into(); - let dim = tensor.dim; - let name = self.name.clone(); - let shape = tensor.shape.clone(); - Type::Tensor(TensorType::new(name, dim, kind, shape)) - } - } - - ArgType::Scalar(elem_type) => { - Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into())) - } - ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), - } + ArgType::Scalar(elem_type) => { + Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into())) + } + ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), } + } } impl From<&ElementType> for ScalarKind { - fn from(elem_type: &ElementType) -> Self { - match elem_type { - ElementType::Float32 => ScalarKind::Float32, - ElementType::Float64 => ScalarKind::Float64, - ElementType::Int32 => ScalarKind::Int32, - ElementType::Int64 => ScalarKind::Int64, - ElementType::Bool => ScalarKind::Bool, - ElementType::String => panic!("String tensor unsupported"), - ElementType::Float16 => panic!("Float16 tensor unsupported"), - } - } + fn from(elem_type: &ElementType) -> Self { + match elem_type { + ElementType::Float32 => ScalarKind::Float32, + ElementType::Float64 => ScalarKind::Float64, + ElementType::Int32 => ScalarKind::Int32, + ElementType::Int64 => ScalarKind::Int64, + ElementType::Bool => ScalarKind::Bool, + ElementType::String => panic!("String tensor unsupported"), + ElementType::Float16 => panic!("Float16 tensor unsupported"), + } + } } impl From for TensorKind { - fn from(elem_type: ElementType) -> Self { - match elem_type { - ElementType::Float32 => TensorKind::Float, - ElementType::Float64 => TensorKind::Float, - ElementType::Int32 => TensorKind::Int, - ElementType::Int64 => TensorKind::Int, - ElementType::Bool => TensorKind::Bool, - _ => panic!("Unsupported tensor type"), - } - } + fn from(elem_type: ElementType) -> Self { + match elem_type { + ElementType::Float32 => TensorKind::Float, + ElementType::Float64 => TensorKind::Float, + ElementType::Int32 => TensorKind::Int, + ElementType::Int64 => TensorKind::Int, + ElementType::Bool => TensorKind::Bool, + _ => panic!("Unsupported tensor type"), + } + } } diff --git a/burn-ndarray/build.rs b/burn-ndarray/build.rs index dcb4354ca6..d70cd753ac 100644 --- a/burn-ndarray/build.rs +++ b/burn-ndarray/build.rs @@ -1,6 +1,6 @@ fn main() { - // https://github.com/rust-ndarray/ndarray/issues/1197 - if cfg!(feature = "blas-accelerate") { - println!("cargo:rustc-link-lib=framework=Accelerate"); - } + // https://github.com/rust-ndarray/ndarray/issues/1197 + if cfg!(feature = "blas-accelerate") { + println!("cargo:rustc-link-lib=framework=Accelerate"); + } } diff --git a/burn-ndarray/src/backend.rs b/burn-ndarray/src/backend.rs index 3137ccd893..01e629bde2 100644 --- a/burn-ndarray/src/backend.rs +++ b/burn-ndarray/src/backend.rs @@ -11,14 +11,14 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); /// The device type for the ndarray backend. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NdArrayDevice { - /// The CPU device. - Cpu, + /// The CPU device. + Cpu, } impl Default for NdArrayDevice { - fn default() -> Self { - Self::Cpu - } + fn default() -> Self { + Self::Cpu + } } /// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. @@ -27,33 +27,33 @@ impl Default for NdArrayDevice { /// `wasm`, `arm`, and `x86`. #[derive(Clone, Copy, Default, Debug)] pub struct NdArray { - phantom: PhantomData, + phantom: PhantomData, } impl Backend for NdArray { - type Device = NdArrayDevice; - type FullPrecisionElem = f32; - type FullPrecisionBackend = NdArray; + type Device = NdArrayDevice; + type FullPrecisionElem = f32; + type FullPrecisionBackend = NdArray; - type TensorPrimitive = NdArrayTensor; - type FloatElem = E; + type TensorPrimitive = NdArrayTensor; + type FloatElem = E; - type IntTensorPrimitive = NdArrayTensor; - type IntElem = i64; + type IntTensorPrimitive = NdArrayTensor; + type IntElem = i64; - type BoolTensorPrimitive = NdArrayTensor; + type BoolTensorPrimitive = NdArrayTensor; - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn name() -> String { - String::from("ndarray") - } + fn name() -> String { + String::from("ndarray") + } - fn seed(seed: u64) { - let rng = StdRng::seed_from_u64(seed); - let mut seed = SEED.lock().unwrap(); - *seed = Some(rng); - } + fn seed(seed: u64) { + let rng = StdRng::seed_from_u64(seed); + let mut seed = SEED.lock().unwrap(); + *seed = Some(rng); + } } diff --git a/burn-ndarray/src/element.rs b/burn-ndarray/src/element.rs index be08c76222..35381ae93e 100644 --- a/burn-ndarray/src/element.rs +++ b/burn-ndarray/src/element.rs @@ -6,147 +6,147 @@ use ndarray::LinalgScalar; /// A float element for ndarray backend. pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar where - Self: Sized, + Self: Sized, { } /// A general element for ndarray backend. pub trait NdArrayElement: - Element - + ndarray::LinalgScalar - + ndarray::ScalarOperand - + ExpElement - + num_traits::FromPrimitive - + core::ops::AddAssign - + core::cmp::PartialEq - + core::cmp::PartialOrd + Element + + ndarray::LinalgScalar + + ndarray::ScalarOperand + + ExpElement + + num_traits::FromPrimitive + + core::ops::AddAssign + + core::cmp::PartialEq + + core::cmp::PartialOrd { } /// A element for ndarray backend that supports exp ops. pub trait ExpElement { - fn exp_elem(self) -> Self; - fn log_elem(self) -> Self; - fn log1p_elem(self) -> Self; - fn powf_elem(self, value: f32) -> Self; - fn powi_elem(self, value: i32) -> Self; - fn sqrt_elem(self) -> Self; - fn abs_elem(self) -> Self; - fn int_abs_elem(self) -> Self; + fn exp_elem(self) -> Self; + fn log_elem(self) -> Self; + fn log1p_elem(self) -> Self; + fn powf_elem(self, value: f32) -> Self; + fn powi_elem(self, value: i32) -> Self; + fn sqrt_elem(self) -> Self; + fn abs_elem(self) -> Self; + fn int_abs_elem(self) -> Self; } impl FloatNdArrayElement for f64 {} impl FloatNdArrayElement for f32 {} macro_rules! make_elem { - ( + ( double $ty:ty ) => { - impl NdArrayElement for $ty {} - - impl ExpElement for $ty { - #[inline(always)] - fn exp_elem(self) -> Self { - exp(self as f64) as $ty - } - - #[inline(always)] - fn log_elem(self) -> Self { - log(self as f64) as $ty - } - - #[inline(always)] - fn log1p_elem(self) -> Self { - log1p(self as f64) as $ty - } - - #[inline(always)] - fn powf_elem(self, value: f32) -> Self { - pow(self as f64, value.into()) as $ty - } - - #[inline(always)] - fn powi_elem(self, value: i32) -> Self { - #[cfg(feature = "std")] - let val = f64::powi(self as f64, value) as $ty; - - #[cfg(not(feature = "std"))] - let val = Self::powf_elem(self, value as f32); - - val - } - - #[inline(always)] - fn sqrt_elem(self) -> Self { - sqrt(self as f64) as $ty - } - - #[inline(always)] - fn abs_elem(self) -> Self { - fabs(self as f64) as $ty - } - - #[inline(always)] - fn int_abs_elem(self) -> Self { - (self as i64).abs() as $ty - } - } - }; - ( + impl NdArrayElement for $ty {} + + impl ExpElement for $ty { + #[inline(always)] + fn exp_elem(self) -> Self { + exp(self as f64) as $ty + } + + #[inline(always)] + fn log_elem(self) -> Self { + log(self as f64) as $ty + } + + #[inline(always)] + fn log1p_elem(self) -> Self { + log1p(self as f64) as $ty + } + + #[inline(always)] + fn powf_elem(self, value: f32) -> Self { + pow(self as f64, value.into()) as $ty + } + + #[inline(always)] + fn powi_elem(self, value: i32) -> Self { + #[cfg(feature = "std")] + let val = f64::powi(self as f64, value) as $ty; + + #[cfg(not(feature = "std"))] + let val = Self::powf_elem(self, value as f32); + + val + } + + #[inline(always)] + fn sqrt_elem(self) -> Self { + sqrt(self as f64) as $ty + } + + #[inline(always)] + fn abs_elem(self) -> Self { + fabs(self as f64) as $ty + } + + #[inline(always)] + fn int_abs_elem(self) -> Self { + (self as i64).abs() as $ty + } + } + }; + ( single $ty:ty ) => { - impl NdArrayElement for $ty {} - - impl ExpElement for $ty { - #[inline(always)] - fn exp_elem(self) -> Self { - expf(self as f32) as $ty - } - - #[inline(always)] - fn log_elem(self) -> Self { - logf(self as f32) as $ty - } - - #[inline(always)] - fn log1p_elem(self) -> Self { - log1pf(self as f32) as $ty - } - - #[inline(always)] - fn powf_elem(self, value: f32) -> Self { - powf(self as f32, value.into()) as $ty - } - - #[inline(always)] - fn powi_elem(self, value: i32) -> Self { - #[cfg(feature = "std")] - let val = f32::powi(self as f32, value) as $ty; - - #[cfg(not(feature = "std"))] - let val = Self::powf_elem(self, value as f32); - - val - } - - #[inline(always)] - fn sqrt_elem(self) -> Self { - sqrtf(self as f32) as $ty - } - - #[inline(always)] - fn abs_elem(self) -> Self { - fabsf(self as f32) as $ty - } - - #[inline(always)] - fn int_abs_elem(self) -> Self { - (self as i32).abs() as $ty - } - } - }; + impl NdArrayElement for $ty {} + + impl ExpElement for $ty { + #[inline(always)] + fn exp_elem(self) -> Self { + expf(self as f32) as $ty + } + + #[inline(always)] + fn log_elem(self) -> Self { + logf(self as f32) as $ty + } + + #[inline(always)] + fn log1p_elem(self) -> Self { + log1pf(self as f32) as $ty + } + + #[inline(always)] + fn powf_elem(self, value: f32) -> Self { + powf(self as f32, value.into()) as $ty + } + + #[inline(always)] + fn powi_elem(self, value: i32) -> Self { + #[cfg(feature = "std")] + let val = f32::powi(self as f32, value) as $ty; + + #[cfg(not(feature = "std"))] + let val = Self::powf_elem(self, value as f32); + + val + } + + #[inline(always)] + fn sqrt_elem(self) -> Self { + sqrtf(self as f32) as $ty + } + + #[inline(always)] + fn abs_elem(self) -> Self { + fabsf(self as f32) as $ty + } + + #[inline(always)] + fn int_abs_elem(self) -> Self { + (self as i32).abs() as $ty + } + } + }; } make_elem!(double f64); diff --git a/burn-ndarray/src/lib.rs b/burn-ndarray/src/lib.rs index 0c85644116..1f4b9d3790 100644 --- a/burn-ndarray/src/lib.rs +++ b/burn-ndarray/src/lib.rs @@ -7,9 +7,9 @@ extern crate derive_new; #[cfg(any( - feature = "blas-netlib", - feature = "blas-openblas", - feature = "blas-openblas-system", + feature = "blas-netlib", + feature = "blas-openblas", + feature = "blas-openblas-system", ))] extern crate blas_src; @@ -29,14 +29,14 @@ extern crate alloc; #[cfg(test)] mod tests { - type TestBackend = crate::NdArray; - type TestTensor = burn_tensor::Tensor; - type TestTensorInt = burn_tensor::Tensor; - use alloc::format; - use alloc::vec; + type TestBackend = crate::NdArray; + type TestTensor = burn_tensor::Tensor; + type TestTensorInt = burn_tensor::Tensor; + use alloc::format; + use alloc::vec; - burn_tensor::testgen_all!(); + burn_tensor::testgen_all!(); - #[cfg(feature = "std")] - burn_autodiff::testgen_all!(); + #[cfg(feature = "std")] + burn_autodiff::testgen_all!(); } diff --git a/burn-ndarray/src/ops/activations.rs b/burn-ndarray/src/ops/activations.rs index 40d1e16337..4a6c33eee7 100644 --- a/burn-ndarray/src/ops/activations.rs +++ b/burn-ndarray/src/ops/activations.rs @@ -2,16 +2,16 @@ use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray}; use burn_tensor::{ops::ActivationOps, ElementConversion}; impl ActivationOps for NdArray { - fn relu(tensor: NdArrayTensor) -> NdArrayTensor { - let zero = 0.elem(); - let array = tensor - .array - .mapv_into(|elem| match elem < zero { - true => zero, - false => elem, - }) - .into_shared(); + fn relu(tensor: NdArrayTensor) -> NdArrayTensor { + let zero = 0.elem(); + let array = tensor + .array + .mapv_into(|elem| match elem < zero { + true => zero, + false => elem, + }) + .into_shared(); - NdArrayTensor::new(array) - } + NdArrayTensor::new(array) + } } diff --git a/burn-ndarray/src/ops/adaptive_avgpool.rs b/burn-ndarray/src/ops/adaptive_avgpool.rs index 1e91aa227e..dca187c160 100644 --- a/burn-ndarray/src/ops/adaptive_avgpool.rs +++ b/burn-ndarray/src/ops/adaptive_avgpool.rs @@ -1,103 +1,101 @@ use crate::{ - element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, - tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, + tensor::NdArrayTensor, }; use burn_tensor::ElementConversion; use ndarray::Array4; pub(crate) fn adaptive_avg_pool2d( - x: NdArrayTensor, - output_size: [usize; 2], + x: NdArrayTensor, + output_size: [usize; 2], ) -> NdArrayTensor { - let [batch_size, channels, input_height, input_width] = x.shape().dims; - - let x = x.array; - let mut output = Array4::from_elem( - (batch_size, channels, output_size[0], output_size[1]), - 0.elem(), - ); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - for h in 0..output_size[0] { - for w in 0..output_size[1] { - let ih_start = start_index(h, output_size[0], input_height); - let ih_end = end_index(h, output_size[0], input_height); - let iw_start = start_index(w, output_size[1], input_width); - let iw_end = end_index(w, output_size[1], input_width); - - let mut sum_val: E = 0.elem(); - - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - sum_val += x[[b, c, ih, iw]]; - } - } - - let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); - output[[b, c, h, w]] = sum_val / count.elem(); - } + let [batch_size, channels, input_height, input_width] = x.shape().dims; + + let x = x.array; + let mut output = Array4::from_elem( + (batch_size, channels, output_size[0], output_size[1]), + 0.elem(), + ); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + for h in 0..output_size[0] { + for w in 0..output_size[1] { + let ih_start = start_index(h, output_size[0], input_height); + let ih_end = end_index(h, output_size[0], input_height); + let iw_start = start_index(w, output_size[1], input_width); + let iw_end = end_index(w, output_size[1], input_width); + + let mut sum_val: E = 0.elem(); + + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + sum_val += x[[b, c, ih, iw]]; } - }) - }); + } - NdArrayTensor::new(output.into_dyn().into_shared()) + let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); + output[[b, c, h, w]] = sum_val / count.elem(); + } + } + }) + }); + + NdArrayTensor::new(output.into_dyn().into_shared()) } pub(crate) fn adaptive_avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, + x: NdArrayTensor, + grad: NdArrayTensor, ) -> NdArrayTensor { - let [_, _, input_height, input_width] = x.shape().dims; - let [batch_size, channels, output_height, output_width] = grad.shape().dims; - - let mut output_grad = - Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output_grad = unsafe_shared_out.get(); - for oh in 0..output_height { - for ow in 0..output_width { - let ih_start = start_index(oh, output_height, input_height); - let ih_end = end_index(oh, output_height, input_height); - - let iw_start = start_index(ow, output_width, input_width); - let iw_end = end_index(ow, output_width, input_width); - - let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); - - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - output_grad[[b, c, ih, iw]] += - grad.array[[b, c, oh, ow]] / count.elem(); - } - } - } + let [_, _, input_height, input_width] = x.shape().dims; + let [batch_size, channels, output_height, output_width] = grad.shape().dims; + + let mut output_grad = + Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output_grad = unsafe_shared_out.get(); + for oh in 0..output_height { + for ow in 0..output_width { + let ih_start = start_index(oh, output_height, input_height); + let ih_end = end_index(oh, output_height, input_height); + + let iw_start = start_index(ow, output_width, input_width); + let iw_end = end_index(ow, output_width, input_width); + + let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); + + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + output_grad[[b, c, ih, iw]] += grad.array[[b, c, oh, ow]] / count.elem(); } - }) - }); + } + } + } + }) + }); - NdArrayTensor::new(output_grad.into_dyn().into_shared()) + NdArrayTensor::new(output_grad.into_dyn().into_shared()) } fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { - libm::floorf((output_size_index as f32 * input_size as f32) / output_size as f32) as usize + libm::floorf((output_size_index as f32 * input_size as f32) / output_size as f32) as usize } fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { - let index = - libm::ceilf(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32) - as usize; + let index = + libm::ceilf(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32) as usize; - usize::min(index, input_size) + usize::min(index, input_size) } diff --git a/burn-ndarray/src/ops/avgpool.rs b/burn-ndarray/src/ops/avgpool.rs index 680c4e1175..b0778dcf84 100644 --- a/burn-ndarray/src/ops/avgpool.rs +++ b/burn-ndarray/src/ops/avgpool.rs @@ -1,135 +1,134 @@ use crate::{ - element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, - tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, + tensor::NdArrayTensor, }; use burn_tensor::ElementConversion; use ndarray::Array4; pub(crate) fn avg_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> NdArrayTensor { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - - let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1; - let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1; - - let x = x.array; - - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem()); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut sum_val: E = 0.elem(); - let mut count: E = 0.elem(); - - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let ih = oh * stride_height + kh; - let iw = ow * stride_width + kw; - - if ih >= x_height + padding_height - || iw >= x_width + padding_width - || ih < padding_height - || iw < padding_width - { - continue; - } - - let ih = ih - padding_height; - let iw = iw - padding_width; - - count += 1.elem(); - sum_val += x[[b, c, ih, iw]]; - } - } - - if count_include_pad { - count = ((kernel_height * kernel_width) as i32).elem(); - } - - output[[b, c, oh, ow]] = sum_val / count; - } + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + + let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1; + let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1; + + let x = x.array; + + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem()); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut sum_val: E = 0.elem(); + let mut count: E = 0.elem(); + + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let ih = oh * stride_height + kh; + let iw = ow * stride_width + kw; + + if ih >= x_height + padding_height + || iw >= x_width + padding_width + || ih < padding_height + || iw < padding_width + { + continue; + } + + let ih = ih - padding_height; + let iw = iw - padding_width; + + count += 1.elem(); + sum_val += x[[b, c, ih, iw]]; } - }) - }); + } + + if count_include_pad { + count = ((kernel_height * kernel_width) as i32).elem(); + } + + output[[b, c, oh, ow]] = sum_val / count; + } + } + }) + }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } pub(crate) fn avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: NdArrayTensor, + grad: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> NdArrayTensor { - let [kernel_height, kernel_width] = kernel_size; - let [stride_height, stride_width] = stride; - let [padding_height, padding_width] = padding; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - let [_batch_size, _channels, out_height, out_width] = grad.shape().dims; + let [kernel_height, kernel_width] = kernel_size; + let [stride_height, stride_width] = stride; + let [padding_height, padding_width] = padding; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + let [_batch_size, _channels, out_height, out_width] = grad.shape().dims; - let grad = grad.array; + let grad = grad.array; - let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem()); - let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); + let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem()); + let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; - let output_grad = unsafe_shared_grad.get(); + let output_grad = unsafe_shared_grad.get(); - for oh in 0..out_height { - for ow in 0..out_width { - let ih_start = oh * stride_height; - let iw_start = ow * stride_width; + for oh in 0..out_height { + for ow in 0..out_width { + let ih_start = oh * stride_height; + let iw_start = ow * stride_width; - let ih_end = ih_start + kernel_height; - let iw_end = iw_start + kernel_width; + let ih_end = ih_start + kernel_height; + let iw_end = iw_start + kernel_width; - let ih_start = usize::max(ih_start, padding_height); - let iw_start = usize::max(iw_start, padding_width); + let ih_start = usize::max(ih_start, padding_height); + let iw_start = usize::max(iw_start, padding_width); - let ih_end = usize::min(ih_end, x_height + padding_height); - let iw_end = usize::min(iw_end, x_width + padding_width); + let ih_end = usize::min(ih_end, x_height + padding_height); + let iw_end = usize::min(iw_end, x_width + padding_width); - let count = match count_include_pad { - true => kernel_width * kernel_height, - false => (ih_end - ih_start) * (iw_end - iw_start), - }; + let count = match count_include_pad { + true => kernel_width * kernel_height, + false => (ih_end - ih_start) * (iw_end - iw_start), + }; - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - let ih = ih - padding_height; - let iw = iw - padding_width; + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + let ih = ih - padding_height; + let iw = iw - padding_width; - output_grad[[b, c, ih, iw]] += - grad[[b, c, oh, ow]] / (count as i32).elem(); - } - } - } + output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / (count as i32).elem(); } - }) - }); + } + } + } + }) + }); - NdArrayTensor::new(output_grad.into_dyn().into_shared()) + NdArrayTensor::new(output_grad.into_dyn().into_shared()) } diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index bb73a02d43..cf2b2f5cd2 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -16,475 +16,457 @@ use crate::ops::macros::{keepdim, mean_dim, sum_dim}; use crate::{reshape, tensor::NdArrayTensor}; pub struct NdArrayOps { - e: PhantomData, + e: PhantomData, } pub(crate) struct NdArrayMathOps { - e: PhantomData, + e: PhantomData, } impl NdArrayOps where - E: Copy, + E: Copy, { - pub fn slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - let slices = Self::to_slice_args::(ranges); - let array = tensor.array.slice_move(slices.as_slice()).into_shared(); - - NdArrayTensor { array } - } - - pub fn slice_assign( - tensor: NdArrayTensor, - ranges: [Range; D2], - value: NdArrayTensor, - ) -> NdArrayTensor { - let slices = Self::to_slice_args::(ranges); - let mut array = tensor.array.into_owned(); - array.slice_mut(slices.as_slice()).assign(&value.array); - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - reshape!( - ty E, - shape shape, - array tensor.array, - d D2 - ) - } - - pub fn cat( - tensors: Vec>, - dim: usize, - ) -> NdArrayTensor { - let arrays: Vec> = - tensors.iter().map(|t| t.array.view()).collect(); - let array = ndarray::concatenate(Axis(dim), &arrays) - .unwrap() - .into_shared(); - - NdArrayTensor { array } - } - - fn to_slice_args( - ranges: [Range; D2], - ) -> [SliceInfoElem; D1] { - let mut slices = [SliceInfoElem::NewAxis; D1]; - for i in 0..D1 { - if i >= D2 { - slices[i] = SliceInfoElem::Slice { - start: 0, - end: None, - step: 1, - } - } else { - slices[i] = SliceInfoElem::Slice { - start: ranges[i].start as isize, - end: Some(ranges[i].end as isize), - step: 1, - } - } + pub fn slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + let slices = Self::to_slice_args::(ranges); + let array = tensor.array.slice_move(slices.as_slice()).into_shared(); + + NdArrayTensor { array } + } + + pub fn slice_assign( + tensor: NdArrayTensor, + ranges: [Range; D2], + value: NdArrayTensor, + ) -> NdArrayTensor { + let slices = Self::to_slice_args::(ranges); + let mut array = tensor.array.into_owned(); + array.slice_mut(slices.as_slice()).assign(&value.array); + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + reshape!( + ty E, + shape shape, + array tensor.array, + d D2 + ) + } + + pub fn cat(tensors: Vec>, dim: usize) -> NdArrayTensor { + let arrays: Vec> = + tensors.iter().map(|t| t.array.view()).collect(); + let array = ndarray::concatenate(Axis(dim), &arrays) + .unwrap() + .into_shared(); + + NdArrayTensor { array } + } + + fn to_slice_args( + ranges: [Range; D2], + ) -> [SliceInfoElem; D1] { + let mut slices = [SliceInfoElem::NewAxis; D1]; + for i in 0..D1 { + if i >= D2 { + slices[i] = SliceInfoElem::Slice { + start: 0, + end: None, + step: 1, } - slices + } else { + slices[i] = SliceInfoElem::Slice { + start: ranges[i].start as isize, + end: Some(ranges[i].end as isize), + step: 1, + } + } } + slices + } - pub fn swap_dims( - tensor: NdArrayTensor, - dim1: usize, - dim2: usize, - ) -> NdArrayTensor { - let mut array = tensor.array; - array.swap_axes(dim1, dim2); - - NdArrayTensor::new(array) - } + pub fn swap_dims( + tensor: NdArrayTensor, + dim1: usize, + dim2: usize, + ) -> NdArrayTensor { + let mut array = tensor.array; + array.swap_axes(dim1, dim2); + + NdArrayTensor::new(array) + } } impl NdArrayMathOps where - E: Copy + NdArrayElement, + E: Copy + NdArrayElement, { - pub fn add( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = &lhs.array + &rhs.array; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array + rhs; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn sub( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = lhs.array - rhs.array; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array - rhs; - let array = array.into_shared(); - - NdArrayTensor { array } + pub fn add( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = &lhs.array + &rhs.array; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array + rhs; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn sub( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = lhs.array - rhs.array; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array - rhs; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn mul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = lhs.array * rhs.array; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array * rhs; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn div( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = lhs.array / rhs.array; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array / rhs; + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn recip(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.map(|x| 1.elem::() / *x); + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn mean(tensor: NdArrayTensor) -> NdArrayTensor { + let data = Data::from([tensor.array.mean().unwrap()]); + NdArrayTensor::from_data(data) + } + + pub fn sum(tensor: NdArrayTensor) -> NdArrayTensor { + let data = Data::from([tensor.array.sum()]); + NdArrayTensor::from_data(data) + } + + pub fn mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + match D { + 1 => keepdim!(0, dim, tensor, mean), + 2 => keepdim!(1, dim, tensor, mean), + 3 => keepdim!(2, dim, tensor, mean), + 4 => keepdim!(3, dim, tensor, mean), + 5 => keepdim!(4, dim, tensor, mean), + 6 => keepdim!(5, dim, tensor, mean), + _ => panic!("Dim not supported {D}"), + } + } + + pub fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + match D { + 1 => keepdim!(0, dim, tensor, sum), + 2 => keepdim!(1, dim, tensor, sum), + 3 => keepdim!(2, dim, tensor, sum), + 4 => keepdim!(3, dim, tensor, sum), + 5 => keepdim!(4, dim, tensor, sum), + 6 => keepdim!(5, dim, tensor, sum), + _ => panic!("Dim not supported {D}"), + } + } + + pub fn gather( + dim: usize, + mut tensor: NdArrayTensor, + mut indices: NdArrayTensor, + ) -> NdArrayTensor { + if dim != D - 1 { + tensor.array.swap_axes(D - 1, dim); + indices.array.swap_axes(D - 1, dim); } + let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape()); + let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indices.dims[D - 1]); + let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); - pub fn mul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = lhs.array * rhs.array; - let array = array.into_shared(); + let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; + let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; + let mut output = Array2::zeros((batch_size, size_index)); - NdArrayTensor { array } - } + for b in 0..batch_size { + let indices = indices.slice(s!(b, ..)); - pub fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array * rhs; - let array = array.into_shared(); - - NdArrayTensor { array } + for (i, index) in indices.iter().enumerate() { + output[[b, i]] = tensor[[b, *index as usize]]; + } } - pub fn div( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = lhs.array / rhs.array; - let array = array.into_shared(); + let mut output = NdArrayOps::reshape( + NdArrayTensor::::new(output.into_shared().into_dyn()), + shape_indices, + ); - NdArrayTensor { array } + if dim != D - 1 { + output.array.swap_axes(D - 1, dim); } - pub fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array / rhs; - let array = array.into_shared(); + output + } - NdArrayTensor { array } + pub fn scatter( + dim: usize, + mut tensor: NdArrayTensor, + mut indices: NdArrayTensor, + mut value: NdArrayTensor, + ) -> NdArrayTensor { + if dim != D - 1 { + tensor.array.swap_axes(D - 1, dim); + indices.array.swap_axes(D - 1, dim); + value.array.swap_axes(D - 1, dim); } - pub fn recip(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.map(|x| 1.elem::() / *x); - let array = array.into_shared(); + let (shape_tensor, shape_indices, shape_value) = + (tensor.shape(), indices.shape(), value.shape()); + let (size_tensor, size_index, size_value) = ( + shape_tensor.dims[D - 1], + shape_indices.dims[D - 1], + shape_value.dims[D - 1], + ); + let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); - NdArrayTensor { array } + if shape_value != shape_indices { + panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indices.dims, shape_value.dims); } - pub fn mean(tensor: NdArrayTensor) -> NdArrayTensor { - let data = Data::from([tensor.array.mean().unwrap()]); - NdArrayTensor::from_data(data) - } - - pub fn sum(tensor: NdArrayTensor) -> NdArrayTensor { - let data = Data::from([tensor.array.sum()]); - NdArrayTensor::from_data(data) - } + let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; + let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array; + let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; - pub fn mean_dim( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - match D { - 1 => keepdim!(0, dim, tensor, mean), - 2 => keepdim!(1, dim, tensor, mean), - 3 => keepdim!(2, dim, tensor, mean), - 4 => keepdim!(3, dim, tensor, mean), - 5 => keepdim!(4, dim, tensor, mean), - 6 => keepdim!(5, dim, tensor, mean), - _ => panic!("Dim not supported {D}"), - } - } + for b in 0..batch_size { + let indices = indices.slice(s!(b, ..)); - pub fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - match D { - 1 => keepdim!(0, dim, tensor, sum), - 2 => keepdim!(1, dim, tensor, sum), - 3 => keepdim!(2, dim, tensor, sum), - 4 => keepdim!(3, dim, tensor, sum), - 5 => keepdim!(4, dim, tensor, sum), - 6 => keepdim!(5, dim, tensor, sum), - _ => panic!("Dim not supported {D}"), - } + for (i, index) in indices.iter().enumerate() { + let index = *index as usize; + tensor[[b, index]] += value[[b, i]]; + } } - pub fn gather( - dim: usize, - mut tensor: NdArrayTensor, - mut indices: NdArrayTensor, - ) -> NdArrayTensor { - if dim != D - 1 { - tensor.array.swap_axes(D - 1, dim); - indices.array.swap_axes(D - 1, dim); - } - let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape()); - let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indices.dims[D - 1]); - let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); - - let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; - let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; - let mut output = Array2::zeros((batch_size, size_index)); - - for b in 0..batch_size { - let indices = indices.slice(s!(b, ..)); - - for (i, index) in indices.iter().enumerate() { - output[[b, i]] = tensor[[b, *index as usize]]; - } - } - - let mut output = NdArrayOps::reshape( - NdArrayTensor::::new(output.into_shared().into_dyn()), - shape_indices, - ); - - if dim != D - 1 { - output.array.swap_axes(D - 1, dim); - } - - output + let mut output = NdArrayOps::reshape( + NdArrayTensor::::new(tensor.into_shared().into_dyn()), + shape_tensor, + ); + if dim != D - 1 { + output.array.swap_axes(D - 1, dim); } + output + } - pub fn scatter( - dim: usize, - mut tensor: NdArrayTensor, - mut indices: NdArrayTensor, - mut value: NdArrayTensor, - ) -> NdArrayTensor { - if dim != D - 1 { - tensor.array.swap_axes(D - 1, dim); - indices.array.swap_axes(D - 1, dim); - value.array.swap_axes(D - 1, dim); - } - - let (shape_tensor, shape_indices, shape_value) = - (tensor.shape(), indices.shape(), value.shape()); - let (size_tensor, size_index, size_value) = ( - shape_tensor.dims[D - 1], - shape_indices.dims[D - 1], - shape_value.dims[D - 1], - ); - let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); + pub fn mask_where( + tensor: NdArrayTensor, + mask: NdArrayTensor, + source: NdArrayTensor, + ) -> NdArrayTensor { + let mask_mul_4tensor = mask.array.mapv(|x| match x { + true => 0.elem(), + false => 1.elem(), + }); + let mask_mul_4source = mask.array.mapv(|x| match x { + true => 1.elem(), + false => 0.elem(), + }); + let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source); - if shape_value != shape_indices { - panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indices.dims, shape_value.dims); - } + NdArrayTensor::new(array) + } - let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; - let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array; - let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; + pub fn mask_fill( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: E, + ) -> NdArrayTensor { + let mask_mul = mask.array.mapv(|x| match x { + true => 0.elem(), + false => 1.elem(), + }); + let mask_add = mask.array.mapv(|x| match x { + true => value, + false => 0.elem(), + }); + let array = (tensor.array * mask_mul) + mask_add; - for b in 0..batch_size { - let indices = indices.slice(s!(b, ..)); + NdArrayTensor::new(array) + } - for (i, index) in indices.iter().enumerate() { - let index = *index as usize; - tensor[[b, index]] += value[[b, i]]; - } - } + fn gather_batch_size(shape_tensor: &Shape, shape_indices: &Shape) -> usize { + let mut batch_size = 1; - let mut output = NdArrayOps::reshape( - NdArrayTensor::::new(tensor.into_shared().into_dyn()), - shape_tensor, + for i in 0..D - 1 { + if shape_tensor.dims[i] != shape_indices.dims[i] { + panic!( + "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", + shape_tensor.dims, shape_indices.dims ); - if dim != D - 1 { - output.array.swap_axes(D - 1, dim); - } - output - } - - pub fn mask_where( - tensor: NdArrayTensor, - mask: NdArrayTensor, - source: NdArrayTensor, - ) -> NdArrayTensor { - let mask_mul_4tensor = mask.array.mapv(|x| match x { - true => 0.elem(), - false => 1.elem(), - }); - let mask_mul_4source = mask.array.mapv(|x| match x { - true => 1.elem(), - false => 0.elem(), - }); - let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source); - - NdArrayTensor::new(array) + } + batch_size *= shape_indices.dims[i]; } - pub fn mask_fill( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: E, - ) -> NdArrayTensor { - let mask_mul = mask.array.mapv(|x| match x { - true => 0.elem(), - false => 1.elem(), - }); - let mask_add = mask.array.mapv(|x| match x { - true => value, - false => 0.elem(), - }); - let array = (tensor.array * mask_mul) + mask_add; - - NdArrayTensor::new(array) - } - - fn gather_batch_size( - shape_tensor: &Shape, - shape_indices: &Shape, - ) -> usize { - let mut batch_size = 1; - - for i in 0..D - 1 { - if shape_tensor.dims[i] != shape_indices.dims[i] { - panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims); - } - batch_size *= shape_indices.dims[i]; - } + batch_size + } - batch_size - } + pub fn select( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + ) -> NdArrayTensor { + let array = tensor.array.select( + Axis(dim), + &indices + .array + .into_iter() + .map(|i| i as usize) + .collect::>(), + ); + + NdArrayTensor::new(array.into_shared()) + } + + pub fn select_assign( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + let mut output_array = tensor.array.into_owned(); - pub fn select( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - ) -> NdArrayTensor { - let array = tensor.array.select( - Axis(dim), - &indices - .array - .into_iter() - .map(|i| i as usize) - .collect::>(), - ); + for (index_value, index) in indices.array.into_iter().enumerate() { + let mut view = output_array.index_axis_mut(Axis(dim), index as usize); + let value = value.array.index_axis(Axis(dim), index_value); - NdArrayTensor::new(array.into_shared()) + view.zip_mut_with(&value, |a, b| *a += *b); } - pub fn select_assign( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - let mut output_array = tensor.array.into_owned(); - - for (index_value, index) in indices.array.into_iter().enumerate() { - let mut view = output_array.index_axis_mut(Axis(dim), index as usize); - let value = value.array.index_axis(Axis(dim), index_value); - - view.zip_mut_with(&value, |a, b| *a += *b); - } + NdArrayTensor::new(output_array.into_shared()) + } + pub fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + arg(tensor, dim, CmpType::Max) + } - NdArrayTensor::new(output_array.into_shared()) - } - pub fn argmax( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - arg(tensor, dim, CmpType::Max) - } - - pub fn argmin( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - arg(tensor, dim, CmpType::Min) - } + pub fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + arg(tensor, dim, CmpType::Min) + } - pub fn clamp_min( - mut tensor: NdArrayTensor, - min: E, - ) -> NdArrayTensor { - tensor.array.mapv_inplace(|x| match x < min { - true => min, - false => x, - }); + pub fn clamp_min(mut tensor: NdArrayTensor, min: E) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x < min { + true => min, + false => x, + }); - tensor - } + tensor + } - pub fn clamp_max( - mut tensor: NdArrayTensor, - max: E, - ) -> NdArrayTensor { - tensor.array.mapv_inplace(|x| match x > max { - true => max, - false => x, - }); + pub fn clamp_max(mut tensor: NdArrayTensor, max: E) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x > max { + true => max, + false => x, + }); - tensor - } + tensor + } + + pub fn clamp( + mut tensor: NdArrayTensor, + min: E, + max: E, + ) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x < min { + true => min, + false => match x > max { + true => max, + false => x, + }, + }); - pub fn clamp( - mut tensor: NdArrayTensor, - min: E, - max: E, - ) -> NdArrayTensor { - tensor.array.mapv_inplace(|x| match x < min { - true => min, - false => match x > max { - true => max, - false => x, - }, - }); - - tensor - } + tensor + } } enum CmpType { - Min, - Max, + Min, + Max, } fn arg( - tensor: NdArrayTensor, - dim: usize, - cmp: CmpType, + tensor: NdArrayTensor, + dim: usize, + cmp: CmpType, ) -> NdArrayTensor { - let mut reshape = tensor.array.shape().to_vec(); - reshape[dim] = 1; - - let output = tensor.array.map_axis(Axis(dim), |arr| { - // Find the min/max value in the array, and return its index. - let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { - let cmp = match cmp { - CmpType::Min => e < &acc.0, - CmpType::Max => e > &acc.0, - }; - - if cmp { - (*e, idx) - } else { - acc - } - }); - - idx as i64 + let mut reshape = tensor.array.shape().to_vec(); + reshape[dim] = 1; + + let output = tensor.array.map_axis(Axis(dim), |arr| { + // Find the min/max value in the array, and return its index. + let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { + let cmp = match cmp { + CmpType::Min => e < &acc.0, + CmpType::Max => e > &acc.0, + }; + + if cmp { + (*e, idx) + } else { + acc + } }); - let output = output.into_shape(Dim(reshape.as_slice())).unwrap(); + idx as i64 + }); - NdArrayTensor { - array: output.into_shared(), - } + let output = output.into_shape(Dim(reshape.as_slice())).unwrap(); + + NdArrayTensor { + array: output.into_shared(), + } } diff --git a/burn-ndarray/src/ops/bool_tensor.rs b/burn-ndarray/src/ops/bool_tensor.rs index adddb12621..8d2e163674 100644 --- a/burn-ndarray/src/ops/bool_tensor.rs +++ b/burn-ndarray/src/ops/bool_tensor.rs @@ -16,116 +16,116 @@ use burn_tensor::{backend::Backend, Data, Shape}; use super::NdArrayOps; impl BoolTensorOps for NdArray { - fn bool_from_data( - data: Data, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - NdArrayTensor::from_data(data) - } - - fn bool_shape( - tensor: & as Backend>::BoolTensorPrimitive, - ) -> Shape { - tensor.shape() - } - - fn bool_into_data( - tensor: as Backend>::BoolTensorPrimitive, - ) -> Reader> { - let shape = tensor.shape(); - let values = tensor.array.into_iter().collect(); - - Reader::Concrete(Data::new(values, shape)) - } - - fn bool_to_device( - tensor: NdArrayTensor, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - tensor - } - - fn bool_reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - NdArrayOps::reshape(tensor, shape) - } - - fn bool_slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - NdArrayOps::slice(tensor, ranges) - } - - fn bool_into_int( - tensor: as Backend>::BoolTensorPrimitive, - ) -> NdArrayTensor { - let data = Self::bool_into_data(tensor) - .read_sync() - .expect("Always sync with ndarray"); - NdArray::::int_from_data(data.convert(), &NdArrayDevice::Cpu) - } - - fn bool_device( - _tensor: & as Backend>::BoolTensorPrimitive, - ) -> as Backend>::Device { - NdArrayDevice::Cpu - } - - fn bool_empty( - shape: Shape, - _device: & as Backend>::Device, - ) -> as Backend>::BoolTensorPrimitive { - let values = vec![false; shape.num_elements()]; - NdArrayTensor::from_data(Data::new(values, shape)) - } - - fn bool_slice_assign( - tensor: as Backend>::BoolTensorPrimitive, - ranges: [Range; D2], - value: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::BoolTensorPrimitive { - NdArrayOps::slice_assign(tensor, ranges, value) - } - - fn bool_cat( - tensors: Vec< as Backend>::BoolTensorPrimitive>, - dim: usize, - ) -> as Backend>::BoolTensorPrimitive { - NdArrayOps::cat(tensors, dim) - } - - fn bool_equal( - lhs: as Backend>::BoolTensorPrimitive, - rhs: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::BoolTensorPrimitive { - let mut array = lhs.array; - array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b); - - NdArrayTensor { array } - } - - fn bool_not( - tensor: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::BoolTensorPrimitive { - let array = tensor.array.mapv(|a| !a).into_shared(); - NdArrayTensor { array } - } - - fn bool_into_float( - tensor: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::TensorPrimitive { - let array = tensor.array.mapv(|a| (a as i32).elem()).into_shared(); - NdArrayTensor { array } - } - - fn bool_swap_dims( - tensor: as Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::BoolTensorPrimitive { - NdArrayOps::swap_dims(tensor, dim1, dim2) - } + fn bool_from_data( + data: Data, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn bool_shape( + tensor: & as Backend>::BoolTensorPrimitive, + ) -> Shape { + tensor.shape() + } + + fn bool_into_data( + tensor: as Backend>::BoolTensorPrimitive, + ) -> Reader> { + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + + Reader::Concrete(Data::new(values, shape)) + } + + fn bool_to_device( + tensor: NdArrayTensor, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + tensor + } + + fn bool_reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + NdArrayOps::reshape(tensor, shape) + } + + fn bool_slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + NdArrayOps::slice(tensor, ranges) + } + + fn bool_into_int( + tensor: as Backend>::BoolTensorPrimitive, + ) -> NdArrayTensor { + let data = Self::bool_into_data(tensor) + .read_sync() + .expect("Always sync with ndarray"); + NdArray::::int_from_data(data.convert(), &NdArrayDevice::Cpu) + } + + fn bool_device( + _tensor: & as Backend>::BoolTensorPrimitive, + ) -> as Backend>::Device { + NdArrayDevice::Cpu + } + + fn bool_empty( + shape: Shape, + _device: & as Backend>::Device, + ) -> as Backend>::BoolTensorPrimitive { + let values = vec![false; shape.num_elements()]; + NdArrayTensor::from_data(Data::new(values, shape)) + } + + fn bool_slice_assign( + tensor: as Backend>::BoolTensorPrimitive, + ranges: [Range; D2], + value: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + NdArrayOps::slice_assign(tensor, ranges, value) + } + + fn bool_cat( + tensors: Vec< as Backend>::BoolTensorPrimitive>, + dim: usize, + ) -> as Backend>::BoolTensorPrimitive { + NdArrayOps::cat(tensors, dim) + } + + fn bool_equal( + lhs: as Backend>::BoolTensorPrimitive, + rhs: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + let mut array = lhs.array; + array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b); + + NdArrayTensor { array } + } + + fn bool_not( + tensor: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + let array = tensor.array.mapv(|a| !a).into_shared(); + NdArrayTensor { array } + } + + fn bool_into_float( + tensor: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::TensorPrimitive { + let array = tensor.array.mapv(|a| (a as i32).elem()).into_shared(); + NdArrayTensor { array } + } + + fn bool_swap_dims( + tensor: as Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::BoolTensorPrimitive { + NdArrayOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-ndarray/src/ops/conv.rs b/burn-ndarray/src/ops/conv.rs index 1d4fcc259c..bd300bc5a2 100644 --- a/burn-ndarray/src/ops/conv.rs +++ b/burn-ndarray/src/ops/conv.rs @@ -1,252 +1,243 @@ use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions}, - ElementConversion, + ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions}, + ElementConversion, }; use ndarray::{s, Array3, Array4, ArrayView2, ArrayViewMut2, Axis, Dim}; use crate::{ - element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d, - run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d, run_par, + sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; #[inline(always)] fn conv2d_mad_inner( - mut output: ArrayViewMut2, - x: ArrayView2, - k: E, - k_xy: (usize, usize), - out_xy: (usize, usize), - stride: (usize, usize), - dilation: (usize, usize), + mut output: ArrayViewMut2, + x: ArrayView2, + k: E, + k_xy: (usize, usize), + out_xy: (usize, usize), + stride: (usize, usize), + dilation: (usize, usize), ) { - let (kh, kw) = k_xy; - let (out_width, out_height) = out_xy; - let (stride_width, stride_height) = stride; - let (dilation_width, dilation_height) = dilation; - - for oh in 0..out_height { - // Construct a sub-slice view of the input row. - // This is done upfront so that rustc does not have to emit bounds checks - // in the hot loop below. - let ir = x - .row(oh * stride_height + kh * dilation_height) - .to_slice() - .unwrap(); - - // Ditto. Construct a sub-slice view of the output row, and explicitly specify - // the bounds upfront as 0..out_width so that rustc can make the assumption - // that all accesses are in-bounds in the below loop. - let mut or = output.row_mut(oh); - let or = &mut or.as_slice_mut().unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - let iw = (ow * stride_width) + (kw * dilation_width); - or[ow] += ir[iw] * k; - } + let (kh, kw) = k_xy; + let (out_width, out_height) = out_xy; + let (stride_width, stride_height) = stride; + let (dilation_width, dilation_height) = dilation; + + for oh in 0..out_height { + // Construct a sub-slice view of the input row. + // This is done upfront so that rustc does not have to emit bounds checks + // in the hot loop below. + let ir = x + .row(oh * stride_height + kh * dilation_height) + .to_slice() + .unwrap(); + + // Ditto. Construct a sub-slice view of the output row, and explicitly specify + // the bounds upfront as 0..out_width so that rustc can make the assumption + // that all accesses are in-bounds in the below loop. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + let iw = (ow * stride_width) + (kw * dilation_width); + or[ow] += ir[iw] * k; } + } } pub(crate) fn conv2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvOptions<2>, + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvOptions<2>, ) -> NdArrayTensor { - let [dilation_height, dilation_width] = options.dilation; - let [padding_height, padding_width] = options.padding; - let [stride_height, stride_width] = options.stride; - let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; - let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims; - - let out_height = calculate_conv_output_size( - kernel_height, - stride_height, - padding_height, - dilation_height, - in_height, - ); - let out_width = calculate_conv_output_size( - kernel_width, - stride_width, - padding_width, - dilation_width, - in_width, - ); - - let x = apply_padding_4d(x, options.padding, 0i32.elem()).array; - - // Convert inputs from dynamic indexes to static to improve perf. - let x = x.into_dimensionality::().unwrap(); - let weights = weight.array.into_dimensionality::().unwrap(); - - let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); - - run_par!(|| { - iter_par!(output.axis_iter_mut(Axis(0))) - .enumerate() - .for_each( - #[inline(never)] - |(k, mut output)| { - let b = k / out_channels; - let oc = k % out_channels; - let g = k % options.groups; - - for ic in (in_channels * g)..(in_channels * (g + 1)) { - let weight_ic = ic - (g * in_channels); - - let x = x.slice(s![b, ic, .., ..]); - let k = weights.slice(s![oc, weight_ic, .., ..]); - - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let k = k[[kh, kw]]; - - // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization - // in the case that the stride/dilation is 1. - #[allow(clippy::if_same_then_else)] - if (1, 1, 1, 1) - == ( - stride_width, - stride_height, - dilation_width, - dilation_height, - ) - { - conv2d_mad_inner( - output.view_mut(), - x.view(), - k, - (kh, kw), - (out_width, out_height), - (stride_width, stride_height), - (dilation_width, dilation_height), - ); - } else { - conv2d_mad_inner( - output.view_mut(), - x.view(), - k, - (kh, kw), - (out_width, out_height), - (stride_width, stride_height), - (dilation_width, dilation_height), - ); - } - } - } - } - - if let Some(bias) = &bias { - let bias = bias.array[oc]; - - for oh in 0..out_height { - // Get a mutable slice reference to the row we're looping over. - // We explicitly define the bounds to 0..out_width so that rustc can make - // the assumption that all accesses are in-bounds. - let mut or = output.row_mut(oh); - let or = &mut or.as_slice_mut().unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - or[ow] += bias; - } - } - } - }, - ); - }); - - let output = output - .into_shape([batch_size, out_channels, out_height, out_width]) - .unwrap() - .into_dyn() - .into_shared(); - - NdArrayTensor::new(output) + let [dilation_height, dilation_width] = options.dilation; + let [padding_height, padding_width] = options.padding; + let [stride_height, stride_width] = options.stride; + let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; + let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims; + + let out_height = calculate_conv_output_size( + kernel_height, + stride_height, + padding_height, + dilation_height, + in_height, + ); + let out_width = calculate_conv_output_size( + kernel_width, + stride_width, + padding_width, + dilation_width, + in_width, + ); + + let x = apply_padding_4d(x, options.padding, 0i32.elem()).array; + + // Convert inputs from dynamic indexes to static to improve perf. + let x = x.into_dimensionality::().unwrap(); + let weights = weight.array.into_dimensionality::().unwrap(); + + let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); + + run_par!(|| { + iter_par!(output.axis_iter_mut(Axis(0))) + .enumerate() + .for_each( + #[inline(never)] + |(k, mut output)| { + let b = k / out_channels; + let oc = k % out_channels; + let g = k % options.groups; + + for ic in (in_channels * g)..(in_channels * (g + 1)) { + let weight_ic = ic - (g * in_channels); + + let x = x.slice(s![b, ic, .., ..]); + let k = weights.slice(s![oc, weight_ic, .., ..]); + + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let k = k[[kh, kw]]; + + // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization + // in the case that the stride/dilation is 1. + #[allow(clippy::if_same_then_else)] + if (1, 1, 1, 1) == (stride_width, stride_height, dilation_width, dilation_height) { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } else { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } + } + } + } + + if let Some(bias) = &bias { + let bias = bias.array[oc]; + + for oh in 0..out_height { + // Get a mutable slice reference to the row we're looping over. + // We explicitly define the bounds to 0..out_width so that rustc can make + // the assumption that all accesses are in-bounds. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + or[ow] += bias; + } + } + } + }, + ); + }); + + let output = output + .into_shape([batch_size, out_channels, out_height, out_width]) + .unwrap() + .into_dyn() + .into_shared(); + + NdArrayTensor::new(output) } pub(crate) fn conv_transpose2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvTransposeOptions<2>, + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvTransposeOptions<2>, ) -> NdArrayTensor { - let [dilation_height, dilation_width] = options.dilation; - let [padding_height, padding_width] = options.padding; - let [stride_height, stride_width] = options.stride; - let [out_padding_height, out_padding_width] = options.padding_out; - let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; - let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims; - - let out_height = (in_height - 1) * stride_height - + dilation_height * (kernel_height - 1) - + out_padding_height - - 2 * padding_height - + 1; - let out_width = - (in_width - 1) * stride_width + dilation_width * (kernel_width - 1) + out_padding_width - - 2 * padding_width - + 1; - - let x = x.array; - let mut output = Array4::zeros(Dim([ - batch_size, - out_channels * options.groups, - out_height, - out_width, - ])); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { - let b = k / (out_channels * options.groups); - let oc = k % out_channels; - let g = k % options.groups; - - let output = unsafe_shared_out.get(); - - let oc_out = oc + (out_channels * g); - let ic_start = g * (in_channels / options.groups); - let ic_end = ic_start + in_channels / options.groups; - - for ic in ic_start..ic_end { - for ih in 0..in_height { - for iw in 0..in_width { - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let oh = ih * stride_height + kh * dilation_height; - let ow = iw * stride_width + kw * dilation_width; - - if oh >= out_height + padding_height - || ow >= out_width + padding_width - || oh < padding_height - || ow < padding_width - { - continue; - } - - let oh = oh - padding_height; - let ow = ow - padding_width; - - output[[b, oc_out, oh, ow]] += - x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]]; - } - } - } + let [dilation_height, dilation_width] = options.dilation; + let [padding_height, padding_width] = options.padding; + let [stride_height, stride_width] = options.stride; + let [out_padding_height, out_padding_width] = options.padding_out; + let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; + let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims; + + let out_height = + (in_height - 1) * stride_height + dilation_height * (kernel_height - 1) + out_padding_height + - 2 * padding_height + + 1; + let out_width = + (in_width - 1) * stride_width + dilation_width * (kernel_width - 1) + out_padding_width + - 2 * padding_width + + 1; + + let x = x.array; + let mut output = Array4::zeros(Dim([ + batch_size, + out_channels * options.groups, + out_height, + out_width, + ])); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { + let b = k / (out_channels * options.groups); + let oc = k % out_channels; + let g = k % options.groups; + + let output = unsafe_shared_out.get(); + + let oc_out = oc + (out_channels * g); + let ic_start = g * (in_channels / options.groups); + let ic_end = ic_start + in_channels / options.groups; + + for ic in ic_start..ic_end { + for ih in 0..in_height { + for iw in 0..in_width { + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let oh = ih * stride_height + kh * dilation_height; + let ow = iw * stride_width + kw * dilation_width; + + if oh >= out_height + padding_height + || ow >= out_width + padding_width + || oh < padding_height + || ow < padding_width + { + continue; } - } - if let Some(bias) = &bias { - for oh in 0..out_height { - for ow in 0..out_width { - output[[b, oc_out, oh, ow]] += bias.array[oc_out]; - } - } + let oh = oh - padding_height; + let ow = ow - padding_width; + + output[[b, oc_out, oh, ow]] += x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]]; + } } - }); + } + } + } + + if let Some(bias) = &bias { + for oh in 0..out_height { + for ow in 0..out_width { + output[[b, oc_out, oh, ow]] += bias.array[oc_out]; + } + } + } }); + }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } diff --git a/burn-ndarray/src/ops/int_tensor.rs b/burn-ndarray/src/ops/int_tensor.rs index fb6adb5517..88b9989315 100644 --- a/burn-ndarray/src/ops/int_tensor.rs +++ b/burn-ndarray/src/ops/int_tensor.rs @@ -19,362 +19,350 @@ use burn_tensor::{backend::Backend, Data, Shape}; use super::{NdArrayMathOps, NdArrayOps}; impl IntTensorOps for NdArray { - fn int_from_data( - data: Data, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - NdArrayTensor::from_data(data) - } - - fn int_shape(tensor: &NdArrayTensor) -> Shape { - tensor.shape() - } - - fn int_into_data(tensor: NdArrayTensor) -> Reader> { - let shape = tensor.shape(); - let values = tensor.array.into_iter().collect(); - - Reader::Concrete(Data::new(values, shape)) - } - - fn int_to_device( - tensor: NdArrayTensor, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - tensor - } - - fn int_reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - NdArrayOps::reshape(tensor, shape) - } - - fn int_slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - NdArrayOps::slice(tensor, ranges) - } - - fn int_device( - _tensor: &NdArrayTensor, - ) -> as Backend>::Device { - NdArrayDevice::Cpu - } - - fn int_empty( - shape: Shape, - _device: & as Backend>::Device, - ) -> NdArrayTensor { - let values = vec![0; shape.num_elements()]; - NdArrayTensor::from_data(Data::new(values, shape)) - } - - fn int_mask_where( - tensor: NdArrayTensor, - mask: NdArrayTensor, - source: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mask_where(tensor, mask, source) - } - - fn int_mask_fill( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: i64, - ) -> NdArrayTensor { - NdArrayMathOps::mask_fill(tensor, mask, value) - } - - fn int_slice_assign( - tensor: NdArrayTensor, - ranges: [Range; D2], - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayOps::slice_assign(tensor, ranges, value) - } - - fn int_cat( - tensors: Vec>, - dim: usize, - ) -> NdArrayTensor { - NdArrayOps::cat(tensors, dim) - } - - fn int_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - - Self::int_equal_elem(tensor, 0) - } - - fn int_equal_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a == rhs).into_shared(); - NdArrayTensor { array } - } - - fn int_greater( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_greater_elem(tensor, 0) - } - - fn int_greater_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a > rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_greater_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_greater_equal_elem(tensor, 0) - } - - fn int_greater_equal_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a >= rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_lower( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_lower_elem(tensor, 0) - } - - fn int_lower_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a < rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_lower_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_lower_equal_elem(tensor, 0) - } - - fn int_lower_equal_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a <= rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_add( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::add(lhs, rhs) - } - - fn int_add_scalar( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - NdArrayMathOps::add_scalar(lhs, rhs) - } - - fn int_sub( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::sub(lhs, rhs) - } - - fn int_sub_scalar( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - NdArrayMathOps::sub_scalar(lhs, rhs) - } - - fn int_mul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mul(lhs, rhs) - } - - fn int_mul_scalar( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - NdArrayMathOps::mul_scalar(lhs, rhs) - } - - fn int_div( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::div(lhs, rhs) - } - - fn int_div_scalar( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - NdArrayMathOps::div_scalar(lhs, rhs) - } - - fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor { - Self::int_mul_scalar(tensor, -1) - } - - fn int_zeros( - shape: Shape, - device: & as Backend>::Device, - ) -> NdArrayTensor { - Self::int_from_data(Data::zeros(shape), device) - } - - fn int_ones( - shape: Shape, - device: & as Backend>::Device, - ) -> NdArrayTensor { - Self::int_from_data(Data::ones(shape), device) - } - - fn int_full( - shape: Shape, - fill_value: i64, - device: & as Backend>::Device, - ) -> NdArrayTensor { - Self::int_from_data(Data::full(shape, fill_value), device) - } - - fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::sum(tensor) - } - - fn int_sum_dim( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::sum_dim(tensor, dim) - } - - fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::mean(tensor) - } - - fn int_mean_dim( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::mean_dim(tensor, dim) - } - - fn int_gather( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select(tensor, dim, indices) - } - - fn int_select_assign( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select_assign(tensor, dim, indices, value) - } - fn int_argmax( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::argmax(tensor, dim) - } - - fn int_argmin( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::argmin(tensor, dim) - } - - fn int_clamp_min( - tensor: NdArrayTensor, - min: i64, - ) -> NdArrayTensor { - NdArrayMathOps::clamp_min(tensor, min) - } - - fn int_clamp_max( - tensor: NdArrayTensor, - max: i64, - ) -> NdArrayTensor { - NdArrayMathOps::clamp_max(tensor, max) - } - - fn int_clamp( - tensor: NdArrayTensor, - min: i64, - max: i64, - ) -> NdArrayTensor { - NdArrayMathOps::clamp(tensor, min, max) - } - - fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn int_into_float( - tensor: as Backend>::IntTensorPrimitive, - ) -> as Backend>::TensorPrimitive { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - NdArrayTensor { array } - } - - fn int_swap_dims( - tensor: as Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::IntTensorPrimitive { - NdArrayOps::swap_dims(tensor, dim1, dim2) - } + fn int_from_data( + data: Data, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn int_shape(tensor: &NdArrayTensor) -> Shape { + tensor.shape() + } + + fn int_into_data(tensor: NdArrayTensor) -> Reader> { + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + + Reader::Concrete(Data::new(values, shape)) + } + + fn int_to_device( + tensor: NdArrayTensor, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + tensor + } + + fn int_reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + NdArrayOps::reshape(tensor, shape) + } + + fn int_slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + NdArrayOps::slice(tensor, ranges) + } + + fn int_device( + _tensor: &NdArrayTensor, + ) -> as Backend>::Device { + NdArrayDevice::Cpu + } + + fn int_empty( + shape: Shape, + _device: & as Backend>::Device, + ) -> NdArrayTensor { + let values = vec![0; shape.num_elements()]; + NdArrayTensor::from_data(Data::new(values, shape)) + } + + fn int_mask_where( + tensor: NdArrayTensor, + mask: NdArrayTensor, + source: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mask_where(tensor, mask, source) + } + + fn int_mask_fill( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: i64, + ) -> NdArrayTensor { + NdArrayMathOps::mask_fill(tensor, mask, value) + } + + fn int_slice_assign( + tensor: NdArrayTensor, + ranges: [Range; D2], + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayOps::slice_assign(tensor, ranges, value) + } + + fn int_cat( + tensors: Vec>, + dim: usize, + ) -> NdArrayTensor { + NdArrayOps::cat(tensors, dim) + } + + fn int_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + + Self::int_equal_elem(tensor, 0) + } + + fn int_equal_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a == rhs).into_shared(); + NdArrayTensor { array } + } + + fn int_greater( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_greater_elem(tensor, 0) + } + + fn int_greater_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a > rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_greater_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_greater_equal_elem(tensor, 0) + } + + fn int_greater_equal_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a >= rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_lower( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_lower_elem(tensor, 0) + } + + fn int_lower_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a < rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_lower_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_lower_equal_elem(tensor, 0) + } + + fn int_lower_equal_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a <= rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_add( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::add(lhs, rhs) + } + + fn int_add_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + NdArrayMathOps::add_scalar(lhs, rhs) + } + + fn int_sub( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::sub(lhs, rhs) + } + + fn int_sub_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + NdArrayMathOps::sub_scalar(lhs, rhs) + } + + fn int_mul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mul(lhs, rhs) + } + + fn int_mul_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + NdArrayMathOps::mul_scalar(lhs, rhs) + } + + fn int_div( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::div(lhs, rhs) + } + + fn int_div_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + NdArrayMathOps::div_scalar(lhs, rhs) + } + + fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor { + Self::int_mul_scalar(tensor, -1) + } + + fn int_zeros( + shape: Shape, + device: & as Backend>::Device, + ) -> NdArrayTensor { + Self::int_from_data(Data::zeros(shape), device) + } + + fn int_ones( + shape: Shape, + device: & as Backend>::Device, + ) -> NdArrayTensor { + Self::int_from_data(Data::ones(shape), device) + } + + fn int_full( + shape: Shape, + fill_value: i64, + device: & as Backend>::Device, + ) -> NdArrayTensor { + Self::int_from_data(Data::full(shape, fill_value), device) + } + + fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::sum(tensor) + } + + fn int_sum_dim( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::sum_dim(tensor, dim) + } + + fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::mean(tensor) + } + + fn int_mean_dim( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::mean_dim(tensor, dim) + } + + fn int_gather( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select(tensor, dim, indices) + } + + fn int_select_assign( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select_assign(tensor, dim, indices, value) + } + fn int_argmax( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::argmax(tensor, dim) + } + + fn int_argmin( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::argmin(tensor, dim) + } + + fn int_clamp_min( + tensor: NdArrayTensor, + min: i64, + ) -> NdArrayTensor { + NdArrayMathOps::clamp_min(tensor, min) + } + + fn int_clamp_max( + tensor: NdArrayTensor, + max: i64, + ) -> NdArrayTensor { + NdArrayMathOps::clamp_max(tensor, max) + } + + fn int_clamp( + tensor: NdArrayTensor, + min: i64, + max: i64, + ) -> NdArrayTensor { + NdArrayMathOps::clamp(tensor, min, max) + } + + fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn int_into_float( + tensor: as Backend>::IntTensorPrimitive, + ) -> as Backend>::TensorPrimitive { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + NdArrayTensor { array } + } + + fn int_swap_dims( + tensor: as Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::IntTensorPrimitive { + NdArrayOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-ndarray/src/ops/macros.rs b/burn-ndarray/src/ops/macros.rs index b92b37b82d..6895d8fc5d 100644 --- a/burn-ndarray/src/ops/macros.rs +++ b/burn-ndarray/src/ops/macros.rs @@ -1,26 +1,26 @@ macro_rules! keepdim { - ( + ( $D:expr, $dim:expr, $self:expr, mean ) => {{ - let tensor: NdArrayTensor = mean_dim($self.clone(), $dim); - let mut shape = $self.shape(); - shape.dims[$dim] = 1; - NdArrayOps::reshape(tensor.clone(), shape) - }}; - ( + let tensor: NdArrayTensor = mean_dim($self.clone(), $dim); + let mut shape = $self.shape(); + shape.dims[$dim] = 1; + NdArrayOps::reshape(tensor.clone(), shape) + }}; + ( $D:expr, $dim:expr, $self:expr, sum ) => {{ - let tensor: NdArrayTensor = sum_dim($self.clone(), $dim); - let mut shape = $self.shape(); - shape.dims[$dim] = 1; - NdArrayOps::reshape(tensor, shape) - }}; + let tensor: NdArrayTensor = sum_dim($self.clone(), $dim); + let mut shape = $self.shape(); + shape.dims[$dim] = 1; + NdArrayOps::reshape(tensor, shape) + }}; } pub(crate) use keepdim; @@ -29,19 +29,19 @@ use ndarray::Axis; use crate::{element::NdArrayElement, tensor::NdArrayTensor}; pub(crate) fn mean_dim( - tensor: NdArrayTensor, - dim: usize, + tensor: NdArrayTensor, + dim: usize, ) -> NdArrayTensor { - let array = tensor.array.mean_axis(Axis(dim)).unwrap().into_shared(); + let array = tensor.array.mean_axis(Axis(dim)).unwrap().into_shared(); - NdArrayTensor { array } + NdArrayTensor { array } } pub(crate) fn sum_dim( - tensor: NdArrayTensor, - dim: usize, + tensor: NdArrayTensor, + dim: usize, ) -> NdArrayTensor { - let array = tensor.array.sum_axis(Axis(dim)).into_shared(); + let array = tensor.array.sum_axis(Axis(dim)).into_shared(); - NdArrayTensor { array } + NdArrayTensor { array } } diff --git a/burn-ndarray/src/ops/matmul.rs b/burn-ndarray/src/ops/matmul.rs index 185a7567dd..4db10005ed 100644 --- a/burn-ndarray/src/ops/matmul.rs +++ b/burn-ndarray/src/ops/matmul.rs @@ -5,107 +5,101 @@ use burn_tensor::{ops::TensorOps, Shape}; use ndarray::s; pub(crate) fn matmul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, + lhs: NdArrayTensor, + rhs: NdArrayTensor, ) -> NdArrayTensor where - E: FloatNdArrayElement, + E: FloatNdArrayElement, { - let shape_ori_lhs = lhs.shape(); - let shape_ori_rhs = rhs.shape(); + let shape_ori_lhs = lhs.shape(); + let shape_ori_rhs = rhs.shape(); - let lhs = reshape(lhs); - let rhs = reshape(rhs); + let lhs = reshape(lhs); + let rhs = reshape(rhs); - let [batch_size_lhs, m, _] = lhs.shape().dims; - let [batch_size_rhs, _, n] = rhs.shape().dims; + let [batch_size_lhs, m, _] = lhs.shape().dims; + let [batch_size_rhs, _, n] = rhs.shape().dims; - let mut shape_out = match batch_size_lhs > batch_size_rhs { - true => shape_ori_lhs, - false => shape_ori_rhs, - }; - shape_out.dims[D - 2] = m; - shape_out.dims[D - 1] = n; + let mut shape_out = match batch_size_lhs > batch_size_rhs { + true => shape_ori_lhs, + false => shape_ori_rhs, + }; + shape_out.dims[D - 2] = m; + shape_out.dims[D - 1] = n; - let out = general_matmul(lhs, rhs); + let out = general_matmul(lhs, rhs); - NdArray::::reshape(out, shape_out) + NdArray::::reshape(out, shape_out) } fn general_matmul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, + lhs: NdArrayTensor, + rhs: NdArrayTensor, ) -> NdArrayTensor { - run_par!(|| { - let [batch_size_lhs, m, _] = lhs.shape().dims; - let [batch_size_rhs, k, n] = rhs.shape().dims; - let batch_size = usize::max(batch_size_rhs, batch_size_lhs); - - if batch_size_lhs > batch_size && batch_size_lhs != 1 { - panic!("Broadcast on multiple dimensions is not yet supported"); - } - - if batch_size_rhs > batch_size && batch_size_rhs != 1 { - panic!("Broadcast on multiple dimensions is not yet supported"); - } - - let alpha: E = 1.0.elem(); - let beta: E = 0.0.elem(); - - let mut out_array = ndarray::Array3::::zeros((batch_size, m, n)); - let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array); - - let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap(); - let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap(); - - iter_range_par!(0, batch_size).for_each(|b| { - let lhs_slice = match batch_size_lhs == 1 { - true => lhs_array.slice(s!(0, .., ..)), - false => lhs_array.slice(s!(b, .., ..)), - }; - let rhs_slice = match batch_size_rhs == 1 { - true => rhs_array.slice(s!(0, .., ..)), - false => rhs_array.slice(s!(b, .., ..)), - }; - - unsafe { - let mut out_slice = unsafe_shared_out_array.get().slice_mut(s!(b, .., ..)); - - ndarray::linalg::general_mat_mul( - alpha, - &lhs_slice, - &rhs_slice, - beta, - &mut out_slice, - ); - } - }); - - NdArrayTensor::new(out_array.into_shared().into_dyn()) - }) + run_par!(|| { + let [batch_size_lhs, m, _] = lhs.shape().dims; + let [batch_size_rhs, k, n] = rhs.shape().dims; + let batch_size = usize::max(batch_size_rhs, batch_size_lhs); + + if batch_size_lhs > batch_size && batch_size_lhs != 1 { + panic!("Broadcast on multiple dimensions is not yet supported"); + } + + if batch_size_rhs > batch_size && batch_size_rhs != 1 { + panic!("Broadcast on multiple dimensions is not yet supported"); + } + + let alpha: E = 1.0.elem(); + let beta: E = 0.0.elem(); + + let mut out_array = ndarray::Array3::::zeros((batch_size, m, n)); + let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array); + + let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap(); + let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap(); + + iter_range_par!(0, batch_size).for_each(|b| { + let lhs_slice = match batch_size_lhs == 1 { + true => lhs_array.slice(s!(0, .., ..)), + false => lhs_array.slice(s!(b, .., ..)), + }; + let rhs_slice = match batch_size_rhs == 1 { + true => rhs_array.slice(s!(0, .., ..)), + false => rhs_array.slice(s!(b, .., ..)), + }; + + unsafe { + let mut out_slice = unsafe_shared_out_array.get().slice_mut(s!(b, .., ..)); + + ndarray::linalg::general_mat_mul(alpha, &lhs_slice, &rhs_slice, beta, &mut out_slice); + } + }); + + NdArrayTensor::new(out_array.into_shared().into_dyn()) + }) } fn reshape( - tensor: NdArrayTensor, + tensor: NdArrayTensor, ) -> NdArrayTensor { - let shape = tensor.shape(); + let shape = tensor.shape(); - if D < 2 { - NdArray::::reshape(tensor, Shape::new([1, 1, shape.dims[0]])) - } else { - let batch_size = batch_size(&shape); - let size0 = shape.dims[D - 2]; - let size1 = shape.dims[D - 1]; + if D < 2 { + NdArray::::reshape(tensor, Shape::new([1, 1, shape.dims[0]])) + } else { + let batch_size = batch_size(&shape); + let size0 = shape.dims[D - 2]; + let size1 = shape.dims[D - 1]; - NdArray::::reshape(tensor, Shape::new([batch_size, size0, size1])) - } + NdArray::::reshape(tensor, Shape::new([batch_size, size0, size1])) + } } fn batch_size(shape: &Shape) -> usize { - let mut num_batch = 1; - for i in 0..D - 2 { - num_batch *= shape.dims[i]; - } + let mut num_batch = 1; + for i in 0..D - 2 { + num_batch *= shape.dims[i]; + } - num_batch + num_batch } diff --git a/burn-ndarray/src/ops/maxpool.rs b/burn-ndarray/src/ops/maxpool.rs index 948c942932..7546271205 100644 --- a/burn-ndarray/src/ops/maxpool.rs +++ b/burn-ndarray/src/ops/maxpool.rs @@ -1,183 +1,181 @@ use crate::{ - element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par, - sharing::UnsafeSharedRef, tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par, + sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; use burn_tensor::ElementConversion; use ndarray::Array4; pub(crate) fn max_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> NdArrayTensor { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - let inf = (-f32::INFINITY).elem::(); + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + let inf = (-f32::INFINITY).elem::(); - let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) - / stride_height) - + 1; - let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) - / stride_width) - + 1; + let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) + / stride_height) + + 1; + let out_width = + ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; - let x = apply_padding_4d(x, padding, inf).array; + let x = apply_padding_4d(x, padding, inf).array; - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; - let output = unsafe_shared_out.get(); + let output = unsafe_shared_out.get(); - for oh in 0..out_height { - for ow in 0..out_width { - let mut max_val = inf; + for oh in 0..out_height { + for ow in 0..out_width { + let mut max_val = inf; - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height; + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height; - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width; + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width; - let val = x[[b, c, ih, iw]]; + let val = x[[b, c, ih, iw]]; - if val > max_val { - max_val = val; - } - } - } - - output[[b, c, oh, ow]] = max_val; - } + if val > max_val { + max_val = val; + } } - }) - }); + } + + output[[b, c, oh, ow]] = max_val; + } + } + }) + }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } pub(crate) fn max_pool2d_with_indices( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (NdArrayTensor, NdArrayTensor) { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - let inf = (-f32::INFINITY).elem::(); - - let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) - / stride_height) - + 1; - let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) - / stride_width) - + 1; - - let x = apply_padding_4d(x, padding, inf).array; - - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); - let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - let indices = unsafe_shared_indices.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut max_val = inf; - let mut index = 0; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width; - let val = x[[b, c, ih, iw]]; - - if val > max_val { - max_val = val; - - let ih = ih as i64 - padding_height as i64; - let iw = iw as i64 - padding_width as i64; - - index = ih * x_height as i64 + iw; - } - } - } - - output[[b, c, oh, ow]] = max_val; - indices[[b, c, oh, ow]] = index; - } + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + let inf = (-f32::INFINITY).elem::(); + + let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) + / stride_height) + + 1; + let out_width = + ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; + + let x = apply_padding_4d(x, padding, inf).array; + + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); + let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + let indices = unsafe_shared_indices.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut max_val = inf; + let mut index = 0; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width; + let val = x[[b, c, ih, iw]]; + + if val > max_val { + max_val = val; + + let ih = ih as i64 - padding_height as i64; + let iw = iw as i64 - padding_width as i64; + + index = ih * x_height as i64 + iw; + } } - }) - }); + } - let output = NdArrayTensor::new(output.into_dyn().into_shared()); - let indices = NdArrayTensor::new(indices.into_dyn().into_shared()); + output[[b, c, oh, ow]] = max_val; + indices[[b, c, oh, ow]] = index; + } + } + }) + }); - (output, indices) + let output = NdArrayTensor::new(output.into_dyn().into_shared()); + let indices = NdArrayTensor::new(indices.into_dyn().into_shared()); + + (output, indices) } pub(crate) fn max_pool2d_backward( - x: NdArrayTensor, - _kernel_size: [usize; 2], - _stride: [usize; 2], - _padding: [usize; 2], - _dilation: [usize; 2], - output_grad: NdArrayTensor, - indices: NdArrayTensor, + x: NdArrayTensor, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _dilation: [usize; 2], + output_grad: NdArrayTensor, + indices: NdArrayTensor, ) -> NdArrayTensor { - let [_batch_size, _channels, height, width] = output_grad.shape().dims; - let [batch_size, channels, height_x, width_x] = x.shape().dims; + let [_batch_size, _channels, height, width] = output_grad.shape().dims; + let [batch_size, channels, height_x, width_x] = x.shape().dims; - let output_grad = output_grad.array; - let indices = indices.array; + let output_grad = output_grad.array; + let indices = indices.array; - let mut output = Array4::zeros((batch_size, channels, height_x, width_x)); + let mut output = Array4::zeros((batch_size, channels, height_x, width_x)); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; - let output = unsafe_shared_out.get(); + let output = unsafe_shared_out.get(); - for h in 0..height { - for w in 0..width { - let index = indices[[b, c, h, w]]; - let grad = output_grad[[b, c, h, w]]; + for h in 0..height { + for w in 0..width { + let index = indices[[b, c, h, w]]; + let grad = output_grad[[b, c, h, w]]; - let index_h = index as usize / width_x; - let index_w = index as usize % width_x; + let index_h = index as usize / width_x; + let index_w = index as usize % width_x; - output[[b, c, index_h, index_w]] += grad; - } - } - }); + output[[b, c, index_h, index_w]] += grad; + } + } }); + }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } diff --git a/burn-ndarray/src/ops/module.rs b/burn-ndarray/src/ops/module.rs index 119f13657a..5d7f63378e 100644 --- a/burn-ndarray/src/ops/module.rs +++ b/burn-ndarray/src/ops/module.rs @@ -1,102 +1,102 @@ use super::{ - adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, - avgpool::{avg_pool2d, avg_pool2d_backward}, - conv::{conv2d, conv_transpose2d}, - maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, + adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, + avgpool::{avg_pool2d, avg_pool2d_backward}, + conv::{conv2d, conv_transpose2d}, + maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, }; use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray}; use burn_tensor::ops::*; impl ModuleOps for NdArray { - fn conv2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> NdArrayTensor { - conv2d(x, weight, bias, options) - } + fn conv2d( + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> NdArrayTensor { + conv2d(x, weight, bias, options) + } - fn conv_transpose2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> NdArrayTensor { - conv_transpose2d(x, weight, bias, options) - } + fn conv_transpose2d( + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> NdArrayTensor { + conv_transpose2d(x, weight, bias, options) + } - fn avg_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> NdArrayTensor { - avg_pool2d(x, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> NdArrayTensor { + avg_pool2d(x, kernel_size, stride, padding, count_include_pad) + } - fn avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> NdArrayTensor { - avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d_backward( + x: NdArrayTensor, + grad: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> NdArrayTensor { + avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) + } - fn max_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> NdArrayTensor { - max_pool2d(x, kernel_size, stride, padding, dilation) - } + fn max_pool2d( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> NdArrayTensor { + max_pool2d(x, kernel_size, stride, padding, dilation) + } - fn max_pool2d_with_indices( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); + fn max_pool2d_with_indices( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); - MaxPool2dWithIndices::new(output, indices) - } + MaxPool2dWithIndices::new(output, indices) + } - fn max_pool2d_with_indices_backward( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: NdArrayTensor, - indices: NdArrayTensor, - ) -> MaxPool2dBackward> { - MaxPool2dBackward::new(max_pool2d_backward( - x, - kernel_size, - stride, - padding, - dilation, - output_grad, - indices, - )) - } + fn max_pool2d_with_indices_backward( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: NdArrayTensor, + indices: NdArrayTensor, + ) -> MaxPool2dBackward> { + MaxPool2dBackward::new(max_pool2d_backward( + x, + kernel_size, + stride, + padding, + dilation, + output_grad, + indices, + )) + } - fn adaptive_avg_pool2d(x: NdArrayTensor, output_size: [usize; 2]) -> NdArrayTensor { - adaptive_avg_pool2d(x, output_size) - } + fn adaptive_avg_pool2d(x: NdArrayTensor, output_size: [usize; 2]) -> NdArrayTensor { + adaptive_avg_pool2d(x, output_size) + } - fn adaptive_avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, - ) -> NdArrayTensor { - adaptive_avg_pool2d_backward(x, grad) - } + fn adaptive_avg_pool2d_backward( + x: NdArrayTensor, + grad: NdArrayTensor, + ) -> NdArrayTensor { + adaptive_avg_pool2d_backward(x, grad) + } } diff --git a/burn-ndarray/src/ops/padding.rs b/burn-ndarray/src/ops/padding.rs index 790a291053..a072f3e77b 100644 --- a/burn-ndarray/src/ops/padding.rs +++ b/burn-ndarray/src/ops/padding.rs @@ -3,31 +3,31 @@ use burn_tensor::ops::TensorOps; use ndarray::Array4; pub(crate) fn apply_padding_4d( - x: NdArrayTensor, - padding: [usize; 2], - elem: E, + x: NdArrayTensor, + padding: [usize; 2], + elem: E, ) -> NdArrayTensor { - let [batch_size, input_channels, height, width] = x.shape().dims; - let [padding_height, padding_width] = padding; - let padded_height = height + 2 * padding_height; - let padded_width = width + 2 * padding_width; + let [batch_size, input_channels, height, width] = x.shape().dims; + let [padding_height, padding_width] = padding; + let padded_height = height + 2 * padding_height; + let padded_width = width + 2 * padding_width; - let x_new = Array4::from_elem( - (batch_size, input_channels, padded_height, padded_width), - elem, - ); - let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn()); + let x_new = Array4::from_elem( + (batch_size, input_channels, padded_height, padded_width), + elem, + ); + let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn()); - x_new = NdArray::slice_assign( - x_new, - [ - 0..batch_size, - 0..input_channels, - padding_height..(height + padding_height), - padding_width..width + padding_width, - ], - x, - ); + x_new = NdArray::slice_assign( + x_new, + [ + 0..batch_size, + 0..input_channels, + padding_height..(height + padding_height), + padding_width..width + padding_width, + ], + x, + ); - x_new + x_new } diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index 0ac3f20226..6367aa7ae3 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -21,422 +21,419 @@ use libm::{cos, erf, sin, tanh}; use num_traits::Float; impl TensorOps for NdArray { - fn from_data(data: Data, _device: &NdArrayDevice) -> NdArrayTensor { - NdArrayTensor::from_data(data) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: &NdArrayDevice, - ) -> NdArrayTensor { - let mut seed = SEED.lock().unwrap(); - let mut rng = if let Some(rng_seeded) = seed.as_ref() { - rng_seeded.clone() - } else { - get_seeded_rng() - }; - let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device); - *seed = Some(rng); - tensor - } - - fn shape(tensor: &NdArrayTensor) -> Shape { - tensor.shape() - } - - fn into_data( - tensor: NdArrayTensor, - ) -> Reader as Backend>::FloatElem, D>> { - let shape = tensor.shape(); - let values = tensor.array.into_iter().collect(); - - Reader::Concrete(Data::new(values, shape)) - } - - fn device(_tensor: &NdArrayTensor) -> NdArrayDevice { - NdArrayDevice::Cpu - } - - fn to_device( - tensor: NdArrayTensor, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - tensor - } - - fn empty( - shape: Shape, - device: & as Backend>::Device, - ) -> NdArrayTensor { - NdArray::::zeros(shape, device) - } - - fn add( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::add(lhs, rhs) - } - - fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::add_scalar(lhs, rhs) - } - - fn sub( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::sub(lhs, rhs) - } - - fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::sub_scalar(lhs, rhs) - } - - fn mul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mul(lhs, rhs) - } - - fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::mul_scalar(lhs, rhs) - } - - fn div( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::div(lhs, rhs) - } - - fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::div_scalar(lhs, rhs) - } - - fn matmul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - matmul(lhs, rhs) - } - - fn neg(tensor: NdArrayTensor) -> NdArrayTensor { - Self::mul_scalar(tensor, (-1f32).elem::()) - } - - fn recip(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::recip(tensor) - } - - fn swap_dims( - tensor: NdArrayTensor, - dim1: usize, - dim2: usize, - ) -> NdArrayTensor { - NdArrayOps::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - NdArrayOps::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::gather(dim, tensor, indices) - } - - fn scatter( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::scatter(dim, tensor, indices, value) - } - - fn select( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select(tensor, dim, indices) - } - - fn select_assign( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select_assign(tensor, dim, indices, value) - } - - fn slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - NdArrayOps::slice(tensor, ranges) - } - - fn slice_assign( - tensor: NdArrayTensor, - ranges: [Range; D2], - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayOps::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mask_where(tensor, mask, value) - } - - fn mask_fill( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: E, - ) -> NdArrayTensor { - NdArrayMathOps::mask_fill(tensor, mask, value) - } - - fn equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - - Self::equal_elem(tensor, zero) - } - - fn equal_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a == rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn greater( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::greater_elem(tensor, zero) - } - - fn greater_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a > rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn greater_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::greater_equal_elem(tensor, zero) - } - - fn greater_equal_elem( - lhs: NdArrayTensor, - rhs: E, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a >= rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn lower( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::lower_elem(tensor, zero) - } - - fn lower_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a < rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn lower_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::lower_equal_elem(tensor, zero) - } - - fn lower_equal_elem( - lhs: NdArrayTensor, - rhs: E, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a <= rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn detach(tensor: NdArrayTensor) -> NdArrayTensor { - tensor - } - - fn mean(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::mean(tensor) - } - - fn sum(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::sum(tensor) - } - - fn mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::mean_dim(tensor, dim) - } - - fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::sum_dim(tensor, dim) - } - - fn to_full_precision(tensor: &NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn from_full_precision(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::argmax(tensor, dim) - } - - fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::argmin(tensor, dim) - } - - fn exp(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn log(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn log1p(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn powf(tensor: NdArrayTensor, value: f32) -> NdArrayTensor { - let array = if value == 2.0 { - // Happens often and is faster. - tensor.array.mapv_into(|a| a * a).into_shared() - } else if value.floor() == value { - // Is faster then powf - tensor - .array - .mapv_into(|a| a.powi_elem(value as i32)) - .into_shared() - } else { - // Default - tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared() - }; - - NdArrayTensor::new(array) - } - - fn sqrt(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn abs(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn cos(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| cos(a.to_f64().unwrap()).elem()) - .into_shared(); - - NdArrayTensor::new(array) - } + fn from_data(data: Data, _device: &NdArrayDevice) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: &NdArrayDevice, + ) -> NdArrayTensor { + let mut seed = SEED.lock().unwrap(); + let mut rng = if let Some(rng_seeded) = seed.as_ref() { + rng_seeded.clone() + } else { + get_seeded_rng() + }; + let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device); + *seed = Some(rng); + tensor + } + + fn shape(tensor: &NdArrayTensor) -> Shape { + tensor.shape() + } + + fn into_data( + tensor: NdArrayTensor, + ) -> Reader as Backend>::FloatElem, D>> { + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + + Reader::Concrete(Data::new(values, shape)) + } + + fn device(_tensor: &NdArrayTensor) -> NdArrayDevice { + NdArrayDevice::Cpu + } + + fn to_device( + tensor: NdArrayTensor, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + tensor + } + + fn empty( + shape: Shape, + device: & as Backend>::Device, + ) -> NdArrayTensor { + NdArray::::zeros(shape, device) + } + + fn add( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::add(lhs, rhs) + } + + fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::add_scalar(lhs, rhs) + } + + fn sub( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::sub(lhs, rhs) + } + + fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::sub_scalar(lhs, rhs) + } + + fn mul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mul(lhs, rhs) + } + + fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::mul_scalar(lhs, rhs) + } + + fn div( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::div(lhs, rhs) + } + + fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::div_scalar(lhs, rhs) + } + + fn matmul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + matmul(lhs, rhs) + } + + fn neg(tensor: NdArrayTensor) -> NdArrayTensor { + Self::mul_scalar(tensor, (-1f32).elem::()) + } + + fn recip(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::recip(tensor) + } + + fn swap_dims( + tensor: NdArrayTensor, + dim1: usize, + dim2: usize, + ) -> NdArrayTensor { + NdArrayOps::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + NdArrayOps::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::gather(dim, tensor, indices) + } + + fn scatter( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::scatter(dim, tensor, indices, value) + } + + fn select( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select(tensor, dim, indices) + } + + fn select_assign( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select_assign(tensor, dim, indices, value) + } + + fn slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + NdArrayOps::slice(tensor, ranges) + } + + fn slice_assign( + tensor: NdArrayTensor, + ranges: [Range; D2], + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayOps::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mask_where(tensor, mask, value) + } + + fn mask_fill( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: E, + ) -> NdArrayTensor { + NdArrayMathOps::mask_fill(tensor, mask, value) + } + + fn equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + + Self::equal_elem(tensor, zero) + } + + fn equal_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a == rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn greater( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::greater_elem(tensor, zero) + } + + fn greater_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a > rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn greater_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::greater_equal_elem(tensor, zero) + } + + fn greater_equal_elem( + lhs: NdArrayTensor, + rhs: E, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a >= rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn lower( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::lower_elem(tensor, zero) + } + + fn lower_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a < rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn lower_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::lower_equal_elem(tensor, zero) + } + + fn lower_equal_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a <= rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn detach(tensor: NdArrayTensor) -> NdArrayTensor { + tensor + } + + fn mean(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::mean(tensor) + } + + fn sum(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::sum(tensor) + } + + fn mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::mean_dim(tensor, dim) + } + + fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::sum_dim(tensor, dim) + } + + fn to_full_precision(tensor: &NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn from_full_precision(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::argmax(tensor, dim) + } + + fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::argmin(tensor, dim) + } + + fn exp(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn log(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn log1p(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn powf(tensor: NdArrayTensor, value: f32) -> NdArrayTensor { + let array = if value == 2.0 { + // Happens often and is faster. + tensor.array.mapv_into(|a| a * a).into_shared() + } else if value.floor() == value { + // Is faster then powf + tensor + .array + .mapv_into(|a| a.powi_elem(value as i32)) + .into_shared() + } else { + // Default + tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared() + }; + + NdArrayTensor::new(array) + } + + fn sqrt(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn abs(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn cos(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| cos(a.to_f64().unwrap()).elem()) + .into_shared(); - fn sin(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| sin(a.to_f64().unwrap()).elem()) - .into_shared(); - - NdArrayTensor::new(array) - } - - fn tanh(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| tanh(a.to_f64().unwrap()).elem()) - .into_shared(); + NdArrayTensor::new(array) + } - NdArrayTensor::new(array) - } - - fn erf(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| erf(a.to_f64().unwrap()).elem()) - .into_shared(); - - NdArrayTensor::new(array) - } - - fn cat(tensors: Vec>, dim: usize) -> NdArrayTensor { - NdArrayOps::cat(tensors, dim) - } - - fn clamp_min(tensor: NdArrayTensor, min: E) -> NdArrayTensor { - NdArrayMathOps::clamp_min(tensor, min) - } - - fn clamp_max(tensor: NdArrayTensor, max: E) -> NdArrayTensor { - NdArrayMathOps::clamp_max(tensor, max) - } - - fn clamp(tensor: NdArrayTensor, min: E, max: E) -> NdArrayTensor { - NdArrayMathOps::clamp(tensor, min, max) - } - - fn into_int( - tensor: as Backend>::TensorPrimitive, - ) -> as Backend>::IntTensorPrimitive { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - NdArrayTensor { array } - } + fn sin(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| sin(a.to_f64().unwrap()).elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + + fn tanh(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| tanh(a.to_f64().unwrap()).elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + + fn erf(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| erf(a.to_f64().unwrap()).elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + + fn cat(tensors: Vec>, dim: usize) -> NdArrayTensor { + NdArrayOps::cat(tensors, dim) + } + + fn clamp_min(tensor: NdArrayTensor, min: E) -> NdArrayTensor { + NdArrayMathOps::clamp_min(tensor, min) + } + + fn clamp_max(tensor: NdArrayTensor, max: E) -> NdArrayTensor { + NdArrayMathOps::clamp_max(tensor, max) + } + + fn clamp(tensor: NdArrayTensor, min: E, max: E) -> NdArrayTensor { + NdArrayMathOps::clamp(tensor, min, max) + } + + fn into_int( + tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + NdArrayTensor { array } + } } diff --git a/burn-ndarray/src/parallel.rs b/burn-ndarray/src/parallel.rs index 8229bfb396..5debb0f1b5 100644 --- a/burn-ndarray/src/parallel.rs +++ b/burn-ndarray/src/parallel.rs @@ -1,51 +1,51 @@ /// Macro for running a function in parallel. #[macro_export(local_inner_macros)] macro_rules! run_par { - ( + ( $func:expr ) => {{ - #[cfg(feature = "std")] - use rayon::prelude::*; + #[cfg(feature = "std")] + use rayon::prelude::*; - #[cfg(feature = "std")] - #[allow(clippy::redundant_closure_call)] - let output = rayon::scope(|_| $func()); + #[cfg(feature = "std")] + #[allow(clippy::redundant_closure_call)] + let output = rayon::scope(|_| $func()); - #[cfg(not(feature = "std"))] - let output = $func(); + #[cfg(not(feature = "std"))] + let output = $func(); - output - }}; + output + }}; } /// Macro for iterating in parallel. #[macro_export(local_inner_macros)] macro_rules! iter_par { - ( + ( $iter:expr ) => {{ - #[cfg(feature = "std")] - let output = $iter.into_par_iter(); + #[cfg(feature = "std")] + let output = $iter.into_par_iter(); - #[cfg(not(feature = "std"))] - let output = $iter; + #[cfg(not(feature = "std"))] + let output = $iter; - output - }}; + output + }}; } /// Macro for iterating over a range in parallel. #[macro_export(local_inner_macros)] macro_rules! iter_range_par { - ( + ( $start:expr, $end:expr ) => {{ - #[cfg(feature = "std")] - let output = ($start..$end).into_par_iter(); + #[cfg(feature = "std")] + let output = ($start..$end).into_par_iter(); - #[cfg(not(feature = "std"))] - let output = ($start..$end); + #[cfg(not(feature = "std"))] + let output = ($start..$end); - output - }}; + output + }}; } diff --git a/burn-ndarray/src/sharing.rs b/burn-ndarray/src/sharing.rs index 073f3d3872..be99c0999e 100644 --- a/burn-ndarray/src/sharing.rs +++ b/burn-ndarray/src/sharing.rs @@ -2,18 +2,18 @@ use core::cell::UnsafeCell; /// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439). pub(crate) struct UnsafeSharedRef<'a, T> { - cell: UnsafeCell<&'a mut T>, + cell: UnsafeCell<&'a mut T>, } unsafe impl<'a, T> Sync for UnsafeSharedRef<'a, T> {} impl<'a, T> UnsafeSharedRef<'a, T> { - pub fn new(data: &'a mut T) -> Self { - Self { - cell: UnsafeCell::new(data), - } - } - pub unsafe fn get(&self) -> &'a mut T { - unsafe { core::ptr::read(self.cell.get()) } + pub fn new(data: &'a mut T) -> Self { + Self { + cell: UnsafeCell::new(data), } + } + pub unsafe fn get(&self) -> &'a mut T { + unsafe { core::ptr::read(self.cell.get()) } + } } diff --git a/burn-ndarray/src/tensor.rs b/burn-ndarray/src/tensor.rs index db99bc87e6..ae05aa9e47 100644 --- a/burn-ndarray/src/tensor.rs +++ b/burn-ndarray/src/tensor.rs @@ -5,52 +5,52 @@ use ndarray::{ArcArray, Array, Dim, IxDyn}; /// Tensor primitive used by the [ndarray backend](crate::NdArray). #[derive(new, Debug, Clone)] pub struct NdArrayTensor { - /// Dynamic array that contains the data of type E. - pub array: ArcArray, + /// Dynamic array that contains the data of type E. + pub array: ArcArray, } impl NdArrayTensor { - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.array.shape().to_vec()) - } + pub(crate) fn shape(&self) -> Shape { + Shape::from(self.array.shape().to_vec()) + } } #[cfg(test)] mod utils { - use super::*; - use crate::element::FloatNdArrayElement; + use super::*; + use crate::element::FloatNdArrayElement; - impl NdArrayTensor + impl NdArrayTensor + where + E: Default + Clone, + { + pub(crate) fn into_data(self) -> Data where - E: Default + Clone, + E: FloatNdArrayElement, { - pub(crate) fn into_data(self) -> Data - where - E: FloatNdArrayElement, - { - let shape = self.shape(); - let values = self.array.into_iter().collect(); - - Data::new(values, shape) - } + let shape = self.shape(); + let values = self.array.into_iter().collect(); + + Data::new(values, shape) } + } } /// Converts a slice of usize to a typed dimension. #[macro_export(local_inner_macros)] macro_rules! to_typed_dims { - ( + ( $n:expr, $dims:expr, justdim ) => {{ - let mut dims = [0; $n]; - for i in 0..$n { - dims[i] = $dims[i]; - } - let dim: Dim<[usize; $n]> = Dim(dims); - dim - }}; + let mut dims = [0; $n]; + for i in 0..$n { + dims[i] = $dims[i]; + } + let dim: Dim<[usize; $n]> = Dim(dims); + dim + }}; } /// Reshapes an array into a tensor. @@ -101,82 +101,82 @@ macro_rules! reshape { impl NdArrayTensor where - E: Default + Clone, + E: Default + Clone, { - /// Create a new [ndarray tensor](NdArrayTensor) from [data](Data). - pub fn from_data(data: Data) -> NdArrayTensor { - let shape = data.shape.clone(); - let to_array = |data: Data| Array::from_iter(data.value).into_shared(); - let array = to_array(data); - - reshape!( - ty E, - shape shape, - array array, - d D - ) - } + /// Create a new [ndarray tensor](NdArrayTensor) from [data](Data). + pub fn from_data(data: Data) -> NdArrayTensor { + let shape = data.shape.clone(); + let to_array = |data: Data| Array::from_iter(data.value).into_shared(); + let array = to_array(data); + + reshape!( + ty E, + shape shape, + array array, + d D + ) + } } #[cfg(test)] mod tests { - use super::*; - use burn_common::rand::get_seeded_rng; - use burn_tensor::Distribution; - - #[test] - fn should_support_into_and_from_data_1d() { - let data_expected = Data::::random( - Shape::new([3]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_2d() { - let data_expected = Data::::random( - Shape::new([2, 3]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_3d() { - let data_expected = Data::::random( - Shape::new([2, 3, 4]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_4d() { - let data_expected = Data::::random( - Shape::new([2, 3, 4, 2]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_common::rand::get_seeded_rng; + use burn_tensor::Distribution; + + #[test] + fn should_support_into_and_from_data_1d() { + let data_expected = Data::::random( + Shape::new([3]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_2d() { + let data_expected = Data::::random( + Shape::new([2, 3]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_3d() { + let data_expected = Data::::random( + Shape::new([2, 3, 4]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_4d() { + let data_expected = Data::::random( + Shape::new([2, 3, 4, 2]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-no-std-tests/src/conv.rs b/burn-no-std-tests/src/conv.rs index 6f17949409..21d4cc67d8 100644 --- a/burn-no-std-tests/src/conv.rs +++ b/burn-no-std-tests/src/conv.rs @@ -1,48 +1,48 @@ // Originally copied from the burn/examples/mnist package use burn::{ - config::Config, - module::Module, - nn, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, }; #[derive(Module, Debug)] pub struct ConvBlock { - conv: nn::conv::Conv2d, - pool: nn::pool::MaxPool2d, - activation: nn::GELU, + conv: nn::conv::Conv2d, + pool: nn::pool::MaxPool2d, + activation: nn::GELU, } #[derive(Config)] pub struct ConvBlockConfig { - channels: [usize; 2], - #[config(default = "[3, 3]")] - kernel_size: [usize; 2], + channels: [usize; 2], + #[config(default = "[3, 3]")] + kernel_size: [usize; 2], } impl ConvBlock { - pub fn new(config: &ConvBlockConfig) -> Self { - let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size) - .with_padding(nn::PaddingConfig2d::Same) - .init(); - let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size) - .with_padding(nn::PaddingConfig2d::Same) - .init(); - let activation = nn::GELU::new(); + pub fn new(config: &ConvBlockConfig) -> Self { + let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size) + .with_padding(nn::PaddingConfig2d::Same) + .init(); + let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size) + .with_padding(nn::PaddingConfig2d::Same) + .init(); + let activation = nn::GELU::new(); - Self { - conv, - pool, - activation, - } + Self { + conv, + pool, + activation, } + } - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.conv.forward(input.clone()); - let x = self.pool.forward(x); - let x = self.activation.forward(x); + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.conv.forward(input.clone()); + let x = self.pool.forward(x); + let x = self.activation.forward(x); - (x + input) / 2.0 - } + (x + input) / 2.0 + } } diff --git a/burn-no-std-tests/src/mlp.rs b/burn-no-std-tests/src/mlp.rs index ec8f189718..1ef28fd496 100644 --- a/burn-no-std-tests/src/mlp.rs +++ b/burn-no-std-tests/src/mlp.rs @@ -3,65 +3,65 @@ use alloc::vec::Vec; use burn::{ - config::Config, - module::Module, - nn, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [Multilayer Perceptron](Mlp) layer. #[derive(Config)] pub struct MlpConfig { - /// The number of layers. - #[config(default = 3)] - pub num_layers: usize, - /// The dropout rate. - #[config(default = 0.5)] - pub dropout: f64, - /// The size of each layer. - #[config(default = 256)] - pub d_model: usize, + /// The number of layers. + #[config(default = 3)] + pub num_layers: usize, + /// The dropout rate. + #[config(default = 0.5)] + pub dropout: f64, + /// The size of each layer. + #[config(default = 256)] + pub d_model: usize, } /// Multilayer Perceptron module. #[derive(Module, Debug)] pub struct Mlp { - linears: Vec>, - dropout: nn::Dropout, - activation: nn::ReLU, + linears: Vec>, + dropout: nn::Dropout, + activation: nn::ReLU, } impl Mlp { - /// Create the module from the given configuration. - pub fn new(config: &MlpConfig) -> Self { - let mut linears = Vec::with_capacity(config.num_layers); + /// Create the module from the given configuration. + pub fn new(config: &MlpConfig) -> Self { + let mut linears = Vec::with_capacity(config.num_layers); - for _ in 0..config.num_layers { - linears.push(nn::LinearConfig::new(config.d_model, config.d_model).init()); - } - - Self { - linears, - dropout: nn::DropoutConfig::new(0.3).init(), - activation: nn::ReLU::new(), - } + for _ in 0..config.num_layers { + linears.push(nn::LinearConfig::new(config.d_model, config.d_model).init()); } - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[batch_size, d_model]` - /// - output: `[batch_size, d_model]` - pub fn forward(&self, input: Tensor) -> Tensor { - let mut x = input; + Self { + linears, + dropout: nn::DropoutConfig::new(0.3).init(), + activation: nn::ReLU::new(), + } + } - for linear in self.linears.iter() { - x = linear.forward(x); - x = self.dropout.forward(x); - x = self.activation.forward(x); - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[batch_size, d_model]` + /// - output: `[batch_size, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let mut x = input; - x + for linear in self.linears.iter() { + x = linear.forward(x); + x = self.dropout.forward(x); + x = self.activation.forward(x); } + + x + } } diff --git a/burn-no-std-tests/src/model.rs b/burn-no-std-tests/src/model.rs index 7028bdb6b3..6363a4ae21 100644 --- a/burn-no-std-tests/src/model.rs +++ b/burn-no-std-tests/src/model.rs @@ -1,66 +1,66 @@ // Originally copied from the burn/examples/mnist package use crate::{ - conv::{ConvBlock, ConvBlockConfig}, - mlp::{Mlp, MlpConfig}, + conv::{ConvBlock, ConvBlockConfig}, + mlp::{Mlp, MlpConfig}, }; use burn::{ - config::Config, - module::Module, - nn, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, }; #[derive(Config)] pub struct MnistConfig { - #[config(default = 42)] - pub seed: u64, + #[config(default = 42)] + pub seed: u64, - pub mlp: MlpConfig, + pub mlp: MlpConfig, - #[config(default = 784)] - pub input_size: usize, + #[config(default = 784)] + pub input_size: usize, - #[config(default = 10)] - pub output_size: usize, + #[config(default = 10)] + pub output_size: usize, } #[derive(Module, Debug)] pub struct Model { - mlp: Mlp, - conv: ConvBlock, - input: nn::Linear, - output: nn::Linear, - num_classes: usize, + mlp: Mlp, + conv: ConvBlock, + input: nn::Linear, + output: nn::Linear, + num_classes: usize, } impl Model { - pub fn new(config: &MnistConfig) -> Self { - let mlp = Mlp::new(&config.mlp); - let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init(); - let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init(); - let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1])); + pub fn new(config: &MnistConfig) -> Self { + let mlp = Mlp::new(&config.mlp); + let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init(); + let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init(); + let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1])); - Self { - mlp, - conv, - output, - input, - num_classes: config.output_size, - } + Self { + mlp, + conv, + output, + input, + num_classes: config.output_size, } + } - pub fn forward(&self, input: Tensor) -> Tensor { - let [batch_size, height, width] = input.dims(); + pub fn forward(&self, input: Tensor) -> Tensor { + let [batch_size, height, width] = input.dims(); - let x = input.reshape([batch_size, 1, height, width]).detach(); - let x = self.conv.forward(x); - let x = x.reshape([batch_size, height * width]); + let x = input.reshape([batch_size, 1, height, width]).detach(); + let x = self.conv.forward(x); + let x = x.reshape([batch_size, height * width]); - let x = self.input.forward(x); - let x = self.mlp.forward(x); + let x = self.input.forward(x); + let x = self.mlp.forward(x); - self.output.forward(x) - } + self.output.forward(x) + } } diff --git a/burn-no-std-tests/tests/integration_test.rs b/burn-no-std-tests/tests/integration_test.rs index 6f6558cea9..2907909cc2 100644 --- a/burn-no-std-tests/tests/integration_test.rs +++ b/burn-no-std-tests/tests/integration_test.rs @@ -8,23 +8,23 @@ use burn_ndarray::NdArray; #[test] fn test_mnist_model_with_random_input() { - type Backend = NdArray; + type Backend = NdArray; - // Model configurations - let mlp_config = MlpConfig::new(); - let mnist_config = MnistConfig::new(mlp_config); - let mnist_model: Model = Model::new(&mnist_config); + // Model configurations + let mlp_config = MlpConfig::new(); + let mnist_config = MnistConfig::new(mlp_config); + let mnist_model: Model = Model::new(&mnist_config); - // Pass a fixed seed for random, otherwise a build generated random seed is used - Backend::seed(mnist_config.seed); + // Pass a fixed seed for random, otherwise a build generated random seed is used + Backend::seed(mnist_config.seed); - // Some random input - let input_shape = [1, 28, 28]; - let input = Tensor::::random(input_shape, Default); + // Some random input + let input_shape = [1, 28, 28]; + let input = Tensor::::random(input_shape, Default); - // Run through the model - let output = mnist_model.forward(input); + // Run through the model + let output = mnist_model.forward(input); - assert_eq!(output.shape().dims, [1, 10]); - assert!(output.to_data().value.into_iter().all(|x| x <= 1.0)); + assert_eq!(output.shape().dims, [1, 10]); + assert!(output.to_data().value.into_iter().all(|x| x <= 1.0)); } diff --git a/burn-tch/src/backend.rs b/burn-tch/src/backend.rs index 3e62d40305..6a70c96914 100644 --- a/burn-tch/src/backend.rs +++ b/burn-tch/src/backend.rs @@ -19,46 +19,46 @@ use burn_tensor::backend::Backend; /// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan /// ``` pub enum LibTorchDevice { - /// CPU device. - Cpu, + /// CPU device. + Cpu, - /// Cuda device with the given index. The index is the index of the Cuda device in the list of - /// all Cuda devices found on the system. - Cuda(usize), + /// Cuda device with the given index. The index is the index of the Cuda device in the list of + /// all Cuda devices found on the system. + Cuda(usize), - /// Metal Performance Shaders device. - Mps, + /// Metal Performance Shaders device. + Mps, - /// Vulkan device. - Vulkan, + /// Vulkan device. + Vulkan, } impl From for tch::Device { - fn from(device: LibTorchDevice) -> Self { - match device { - LibTorchDevice::Cpu => tch::Device::Cpu, - LibTorchDevice::Cuda(num) => tch::Device::Cuda(num), - LibTorchDevice::Mps => tch::Device::Mps, - LibTorchDevice::Vulkan => tch::Device::Vulkan, - } + fn from(device: LibTorchDevice) -> Self { + match device { + LibTorchDevice::Cpu => tch::Device::Cpu, + LibTorchDevice::Cuda(num) => tch::Device::Cuda(num), + LibTorchDevice::Mps => tch::Device::Mps, + LibTorchDevice::Vulkan => tch::Device::Vulkan, } + } } impl From for LibTorchDevice { - fn from(device: tch::Device) -> Self { - match device { - tch::Device::Cpu => LibTorchDevice::Cpu, - tch::Device::Cuda(num) => LibTorchDevice::Cuda(num), - tch::Device::Mps => LibTorchDevice::Mps, - tch::Device::Vulkan => LibTorchDevice::Vulkan, - } + fn from(device: tch::Device) -> Self { + match device { + tch::Device::Cpu => LibTorchDevice::Cpu, + tch::Device::Cuda(num) => LibTorchDevice::Cuda(num), + tch::Device::Mps => LibTorchDevice::Mps, + tch::Device::Vulkan => LibTorchDevice::Vulkan, } + } } impl Default for LibTorchDevice { - fn default() -> Self { - Self::Cpu - } + fn default() -> Self { + Self::Cpu + } } /// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations. @@ -70,39 +70,39 @@ impl Default for LibTorchDevice { /// Refer to the [tch] crate for more information. #[derive(Clone, Copy, Default, Debug)] pub struct LibTorch { - _e: E, + _e: E, } impl Backend for LibTorch { - type Device = LibTorchDevice; - type FullPrecisionElem = f32; - type FullPrecisionBackend = LibTorch; + type Device = LibTorchDevice; + type FullPrecisionElem = f32; + type FullPrecisionBackend = LibTorch; - type TensorPrimitive = TchTensor; - type FloatElem = E; + type TensorPrimitive = TchTensor; + type FloatElem = E; - type IntTensorPrimitive = TchTensor; - type IntElem = i64; + type IntTensorPrimitive = TchTensor; + type IntElem = i64; - type BoolTensorPrimitive = TchTensor; + type BoolTensorPrimitive = TchTensor; - fn seed(seed: u64) { - tch::manual_seed(seed as i64); - } + fn seed(seed: u64) { + tch::manual_seed(seed as i64); + } - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn name() -> String { - "tch".to_string() - } + fn name() -> String { + "tch".to_string() + } - fn sync(device: &Self::Device) { - if let LibTorchDevice::Cuda(index) = device { - tch::Cuda::synchronize(*index as i64); - } else if let LibTorchDevice::Mps = device { - panic!("Can't sync MPS device") - } + fn sync(device: &Self::Device) { + if let LibTorchDevice::Cuda(index) = device { + tch::Cuda::synchronize(*index as i64); + } else if let LibTorchDevice::Mps = device { + panic!("Can't sync MPS device") } + } } diff --git a/burn-tch/src/lib.rs b/burn-tch/src/lib.rs index f7580bf641..0c1e130b39 100644 --- a/burn-tch/src/lib.rs +++ b/burn-tch/src/lib.rs @@ -14,12 +14,12 @@ pub use tensor::*; #[cfg(test)] mod tests { - extern crate alloc; + extern crate alloc; - type TestBackend = crate::LibTorch; - type TestTensor = burn_tensor::Tensor; - type TestTensorInt = burn_tensor::Tensor; + type TestBackend = crate::LibTorch; + type TestTensor = burn_tensor::Tensor; + type TestTensorInt = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!(); + burn_autodiff::testgen_all!(); } diff --git a/burn-tch/src/ops/activation.rs b/burn-tch/src/ops/activation.rs index 386e87c493..cf834a2368 100644 --- a/burn-tch/src/ops/activation.rs +++ b/burn-tch/src/ops/activation.rs @@ -2,24 +2,24 @@ use crate::{element::TchElement, LibTorch, TchTensor}; use burn_tensor::ops::ActivationOps; impl ActivationOps for LibTorch { - fn relu(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()) - } + fn relu(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()) + } - fn gelu(tensor: TchTensor) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.gelu_("none"), - |tensor| tensor.gelu("none"), - ) - } + fn gelu(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.gelu_("none"), + |tensor| tensor.gelu("none"), + ) + } - fn gelu_backward( - tensor: TchTensor, - grad: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none"); + fn gelu_backward( + tensor: TchTensor, + grad: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none"); - TchTensor::from_existing(tensor, storage) - } + TchTensor::from_existing(tensor, storage) + } } diff --git a/burn-tch/src/ops/base.rs b/burn-tch/src/ops/base.rs index d2cd0d4dc3..215d4c8b45 100644 --- a/burn-tch/src/ops/base.rs +++ b/burn-tch/src/ops/base.rs @@ -5,408 +5,405 @@ use crate::{TchShape, TchTensor}; use std::{marker::PhantomData, ops::Range}; pub struct TchOps { - e: PhantomData, + e: PhantomData, } impl TchOps { - pub fn reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - let shape_tch: TchShape = shape.into(); - - TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage) - } - - pub fn repeat( - tensor: TchTensor, - dim: usize, - times: usize, - ) -> TchTensor { - let mut dims = [1; D]; - dims[dim] = times as i64; - let tensor = tch::Tensor::repeat(&tensor.tensor, dims); - TchTensor::new(tensor) - } - - pub fn slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - let storage = tensor.storage.clone(); - let mut tensor = tensor.tensor.shallow_clone(); - - for (i, index) in ranges.iter().enumerate().take(D2) { - let start = index.start as i64; - let length = (index.end - index.start) as i64; - tensor = tensor.narrow(i as i64, start, length); - } - - TchTensor::from_existing(tensor, storage) - } - - pub fn slice_assign( - tensor: TchTensor, - ranges: [Range; D2], - value: TchTensor, - ) -> TchTensor { - let tensor_original = tensor.tensor.copy(); - let tch_shape = TchShape::from(tensor.shape()); - - let mut tensor = tensor_original.view_(tch_shape.dims); - - for (i, index) in ranges.into_iter().enumerate().take(D2) { - let start = index.start as i64; - let length = (index.end - index.start) as i64; - - tensor = tensor.narrow(i as i64, start, length); - } - - tensor.copy_(&value.tensor); - - TchTensor::new(tensor_original) - } - - pub fn gather( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false); - - TchTensor::from_existing(tensor, storage) - } - - pub fn scatter( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor - .tensor - .scatter_add(dim as i64, &indices.tensor, &value.tensor); - - TchTensor::from_existing(tensor, storage) - } - - pub fn index_select_dim( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor); - - TchTensor::from_existing(tensor, storage) - } - - pub fn select_assign( - tensor: TchTensor, - dim: usize, - indices_tensor: TchTensor, - value: TchTensor, - ) -> TchTensor { - let mut indices = Vec::with_capacity(D); - for _ in 0..D { - indices.push(None); - } - indices[dim] = Some(indices_tensor.tensor); - - tensor.unary_ops( - |mut tensor| tensor.index_put_(&indices, &value.tensor, true), - |tensor| tensor.index_put(&indices, &value.tensor, true), - ) - } - - pub fn cat(tensors: Vec>, dim: usize) -> TchTensor { - let tensors: Vec = tensors - .into_iter() - .map(|t| t.tensor.shallow_clone()) - .collect(); - let tensor = tch::Tensor::cat(&tensors, dim as i64); - - TchTensor::new(tensor) - } - - pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.eq_tensor(rhs), - ) - } - - pub fn equal_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool), - |tensor| tensor.eq(rhs.clone().into()), - ) - } - - pub fn greater( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.greater_tensor(rhs), - ) - } - - pub fn greater_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool), - |tensor| tensor.greater(rhs.clone().into()), - ) - } - - pub fn greater_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.greater_equal_tensor(rhs), - ) - } - - pub fn greater_equal_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| { - tensor - .greater_equal_(rhs.clone().into()) - .to_kind(tch::Kind::Bool) - }, - |tensor| tensor.greater_equal(rhs.clone().into()), - ) - } - - pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.less_tensor(rhs), - ) - } - - pub fn lower_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool), - |tensor| tensor.less(rhs.clone().into()), - ) - } - - pub fn lower_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.less_equal_tensor(rhs), - ) - } - - pub fn lower_equal_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| { - tensor - .less_equal_(rhs.clone().into()) - .to_kind(tch::Kind::Bool) - }, - |tensor| tensor.less_equal(rhs.clone().into()), - ) - } - - pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_add_(rhs).unwrap(), - |lhs, rhs| rhs.f_add_(lhs).unwrap(), - |lhs, rhs| lhs.f_add(rhs).unwrap(), - ) - } - - pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_sub_(rhs).unwrap(), - |lhs, rhs| lhs.f_sub(rhs).unwrap(), - |lhs, rhs| lhs.f_sub(rhs).unwrap(), - ) - } - - pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_mul_(rhs).unwrap(), - |lhs, rhs| rhs.f_mul_(lhs).unwrap(), - |lhs, rhs| lhs.f_mul(rhs).unwrap(), - ) - } - - pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_div_(rhs).unwrap(), - |lhs, rhs| lhs.f_div(rhs).unwrap(), - |lhs, rhs| lhs.f_div(rhs).unwrap(), - ) - } - - pub fn mean(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.mean(E::KIND); - TchTensor::new(tensor) - } - - pub fn sum(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.sum(E::KIND); - TchTensor::new(tensor) - } - - pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchTensor::from_existing( - tensor - .tensor - .mean_dim(Some([dim as i64].as_slice()), true, E::KIND), - tensor.storage, - ) - } - - pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchTensor::from_existing( - tensor - .tensor - .sum_dim_intlist(Some([dim as i64].as_slice()), true, E::KIND), - tensor.storage, - ) - } - - pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.argmax(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.argmin(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn max_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - let storage = tensor.storage.clone(); - let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true); - - let tensor = TchTensor::from_existing(tensor, storage); - let indices = TchTensor::new(indices); - - (tensor, indices) - } - - pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn min_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - let storage = tensor.storage.clone(); - let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true); - - let tensor = TchTensor::from_existing(tensor, storage); - let indices = TchTensor::new(indices); - - (tensor, indices) - } - - pub fn clamp_min + Clone + Copy>( - tensor: TchTensor, - min: S, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.clamp_min_(min), - |tensor| tensor.clamp_min(min), - ) - } - - pub fn clamp_max + Clone + Copy>( - tensor: TchTensor, - max: S, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.clamp_max_(max), - |tensor| tensor.clamp_max(max), - ) - } - - pub fn clamp + Clone + Copy>( - tensor: TchTensor, - min: S, - max: S, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.clamp_(min, max), - |tensor| tensor.clamp(min, max), - ) - } - - pub fn swap_dims( - tensor: TchTensor, - dim1: usize, - dim2: usize, - ) -> TchTensor { - let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64); - TchTensor::new(tensor) - } + pub fn reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + let shape_tch: TchShape = shape.into(); + + TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage) + } + + pub fn repeat( + tensor: TchTensor, + dim: usize, + times: usize, + ) -> TchTensor { + let mut dims = [1; D]; + dims[dim] = times as i64; + let tensor = tch::Tensor::repeat(&tensor.tensor, dims); + TchTensor::new(tensor) + } + + pub fn slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + let storage = tensor.storage.clone(); + let mut tensor = tensor.tensor.shallow_clone(); + + for (i, index) in ranges.iter().enumerate().take(D2) { + let start = index.start as i64; + let length = (index.end - index.start) as i64; + tensor = tensor.narrow(i as i64, start, length); + } + + TchTensor::from_existing(tensor, storage) + } + + pub fn slice_assign( + tensor: TchTensor, + ranges: [Range; D2], + value: TchTensor, + ) -> TchTensor { + let tensor_original = tensor.tensor.copy(); + let tch_shape = TchShape::from(tensor.shape()); + + let mut tensor = tensor_original.view_(tch_shape.dims); + + for (i, index) in ranges.into_iter().enumerate().take(D2) { + let start = index.start as i64; + let length = (index.end - index.start) as i64; + + tensor = tensor.narrow(i as i64, start, length); + } + + tensor.copy_(&value.tensor); + + TchTensor::new(tensor_original) + } + + pub fn gather( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false); + + TchTensor::from_existing(tensor, storage) + } + + pub fn scatter( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor + .tensor + .scatter_add(dim as i64, &indices.tensor, &value.tensor); + + TchTensor::from_existing(tensor, storage) + } + + pub fn index_select_dim( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor); + + TchTensor::from_existing(tensor, storage) + } + + pub fn select_assign( + tensor: TchTensor, + dim: usize, + indices_tensor: TchTensor, + value: TchTensor, + ) -> TchTensor { + let mut indices = Vec::with_capacity(D); + for _ in 0..D { + indices.push(None); + } + indices[dim] = Some(indices_tensor.tensor); + + tensor.unary_ops( + |mut tensor| tensor.index_put_(&indices, &value.tensor, true), + |tensor| tensor.index_put(&indices, &value.tensor, true), + ) + } + + pub fn cat(tensors: Vec>, dim: usize) -> TchTensor { + let tensors: Vec = tensors + .into_iter() + .map(|t| t.tensor.shallow_clone()) + .collect(); + let tensor = tch::Tensor::cat(&tensors, dim as i64); + + TchTensor::new(tensor) + } + + pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.eq_tensor(rhs), + ) + } + + pub fn equal_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool), + |tensor| tensor.eq(rhs.clone().into()), + ) + } + + pub fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.greater_tensor(rhs), + ) + } + + pub fn greater_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool), + |tensor| tensor.greater(rhs.clone().into()), + ) + } + + pub fn greater_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.greater_equal_tensor(rhs), + ) + } + + pub fn greater_equal_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| { + tensor + .greater_equal_(rhs.clone().into()) + .to_kind(tch::Kind::Bool) + }, + |tensor| tensor.greater_equal(rhs.clone().into()), + ) + } + + pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.less_tensor(rhs), + ) + } + + pub fn lower_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool), + |tensor| tensor.less(rhs.clone().into()), + ) + } + + pub fn lower_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.less_equal_tensor(rhs), + ) + } + + pub fn lower_equal_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| { + tensor + .less_equal_(rhs.clone().into()) + .to_kind(tch::Kind::Bool) + }, + |tensor| tensor.less_equal(rhs.clone().into()), + ) + } + + pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_add_(rhs).unwrap(), + |lhs, rhs| rhs.f_add_(lhs).unwrap(), + |lhs, rhs| lhs.f_add(rhs).unwrap(), + ) + } + + pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_sub_(rhs).unwrap(), + |lhs, rhs| lhs.f_sub(rhs).unwrap(), + |lhs, rhs| lhs.f_sub(rhs).unwrap(), + ) + } + + pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_mul_(rhs).unwrap(), + |lhs, rhs| rhs.f_mul_(lhs).unwrap(), + |lhs, rhs| lhs.f_mul(rhs).unwrap(), + ) + } + + pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_div_(rhs).unwrap(), + |lhs, rhs| lhs.f_div(rhs).unwrap(), + |lhs, rhs| lhs.f_div(rhs).unwrap(), + ) + } + + pub fn mean(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.mean(E::KIND); + TchTensor::new(tensor) + } + + pub fn sum(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.sum(E::KIND); + TchTensor::new(tensor) + } + + pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchTensor::from_existing( + tensor + .tensor + .mean_dim(Some([dim as i64].as_slice()), true, E::KIND), + tensor.storage, + ) + } + + pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchTensor::from_existing( + tensor + .tensor + .sum_dim_intlist(Some([dim as i64].as_slice()), true, E::KIND), + tensor.storage, + ) + } + + pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.argmax(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.argmin(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn max_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + let storage = tensor.storage.clone(); + let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true); + + let tensor = TchTensor::from_existing(tensor, storage); + let indices = TchTensor::new(indices); + + (tensor, indices) + } + + pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn min_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + let storage = tensor.storage.clone(); + let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true); + + let tensor = TchTensor::from_existing(tensor, storage); + let indices = TchTensor::new(indices); + + (tensor, indices) + } + + pub fn clamp_min + Clone + Copy>( + tensor: TchTensor, + min: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.clamp_min_(min), + |tensor| tensor.clamp_min(min), + ) + } + + pub fn clamp_max + Clone + Copy>( + tensor: TchTensor, + max: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.clamp_max_(max), + |tensor| tensor.clamp_max(max), + ) + } + + pub fn clamp + Clone + Copy>( + tensor: TchTensor, + min: S, + max: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.clamp_(min, max), + |tensor| tensor.clamp(min, max), + ) + } + + pub fn swap_dims( + tensor: TchTensor, + dim1: usize, + dim2: usize, + ) -> TchTensor { + let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64); + TchTensor::new(tensor) + } } diff --git a/burn-tch/src/ops/bool_tensor.rs b/burn-tch/src/ops/bool_tensor.rs index 76abdfadc4..c919ed5ffc 100644 --- a/burn-tch/src/ops/bool_tensor.rs +++ b/burn-tch/src/ops/bool_tensor.rs @@ -4,114 +4,111 @@ use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Reader, Shape}; use std::ops::Range; impl BoolTensorOps for LibTorch { - fn bool_from_data( - data: Data, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::from_data(data, (*device).into()) - } - - fn bool_shape(tensor: &TchTensor) -> Shape { - tensor.shape() - } - - fn bool_repeat( - tensor: TchTensor, - dim: usize, - times: usize, - ) -> TchTensor { - TchOps::repeat(tensor, dim, times) - } - - fn bool_into_data(tensor: TchTensor) -> Reader> { - let shape = Self::bool_shape(&tensor); - let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); - let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); - - Reader::Concrete(Data::new(values.unwrap(), shape)) - } - - fn bool_to_device( - tensor: TchTensor, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) - } - - fn bool_reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - TchOps::reshape(tensor, shape) - } - - fn bool_device(tensor: &TchTensor) -> LibTorchDevice { - tensor.tensor.device().into() - } - - fn bool_empty( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let tensor = tch::Tensor::empty( - shape.dims.map(|a| a as i64), - (tch::Kind::Bool, (*device).into()), - ); - - TchTensor::new(tensor) - } - - fn bool_slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - TchOps::slice(tensor, ranges) - } - - fn bool_slice_assign( - tensor: TchTensor, - ranges: [std::ops::Range; D2], - value: TchTensor, - ) -> TchTensor { - TchOps::slice_assign(tensor, ranges, value) - } - - fn bool_cat( - tensors: Vec>, - dim: usize, - ) -> TchTensor { - TchOps::cat(tensors, dim) - } - - fn bool_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::equal(lhs, rhs) - } - - fn bool_not(tensor: TchTensor) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool), - |tensor| tensor.eq(0), - ) - } - - fn bool_into_int(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(tch::Kind::Int64); - TchTensor::new(tensor) - } - - fn bool_into_float(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(E::KIND); - TchTensor::new(tensor) - } - - fn bool_swap_dims( - tensor: as Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::BoolTensorPrimitive { - TchOps::swap_dims(tensor, dim1, dim2) - } + fn bool_from_data( + data: Data, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::from_data(data, (*device).into()) + } + + fn bool_shape(tensor: &TchTensor) -> Shape { + tensor.shape() + } + + fn bool_repeat( + tensor: TchTensor, + dim: usize, + times: usize, + ) -> TchTensor { + TchOps::repeat(tensor, dim, times) + } + + fn bool_into_data(tensor: TchTensor) -> Reader> { + let shape = Self::bool_shape(&tensor); + let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); + + Reader::Concrete(Data::new(values.unwrap(), shape)) + } + + fn bool_to_device( + tensor: TchTensor, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::new(tensor.tensor.to((*device).into())) + } + + fn bool_reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + TchOps::reshape(tensor, shape) + } + + fn bool_device(tensor: &TchTensor) -> LibTorchDevice { + tensor.tensor.device().into() + } + + fn bool_empty( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let tensor = tch::Tensor::empty( + shape.dims.map(|a| a as i64), + (tch::Kind::Bool, (*device).into()), + ); + + TchTensor::new(tensor) + } + + fn bool_slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + TchOps::slice(tensor, ranges) + } + + fn bool_slice_assign( + tensor: TchTensor, + ranges: [std::ops::Range; D2], + value: TchTensor, + ) -> TchTensor { + TchOps::slice_assign(tensor, ranges, value) + } + + fn bool_cat(tensors: Vec>, dim: usize) -> TchTensor { + TchOps::cat(tensors, dim) + } + + fn bool_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::equal(lhs, rhs) + } + + fn bool_not(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool), + |tensor| tensor.eq(0), + ) + } + + fn bool_into_int(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(tch::Kind::Int64); + TchTensor::new(tensor) + } + + fn bool_into_float(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(E::KIND); + TchTensor::new(tensor) + } + + fn bool_swap_dims( + tensor: as Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::BoolTensorPrimitive { + TchOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs index 16eb9af05c..bb6448f5d6 100644 --- a/burn-tch/src/ops/int_tensor.rs +++ b/burn-tch/src/ops/int_tensor.rs @@ -7,385 +7,365 @@ use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor}; use super::TchOps; impl IntTensorOps for LibTorch { - fn int_from_data( - data: Data, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::from_data(data, (*device).into()) - } - - fn int_shape(tensor: &TchTensor) -> Shape { - tensor.shape() - } - - fn int_repeat( - tensor: TchTensor, - dim: usize, - times: usize, - ) -> TchTensor { - TchOps::repeat(tensor, dim, times) - } - - fn int_into_data(tensor: TchTensor) -> Reader> { - let shape = Self::int_shape(&tensor); - let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); - let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); - - Reader::Concrete(Data::new(values.unwrap(), shape)) - } - - fn int_to_device( - tensor: TchTensor, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) - } - - fn int_reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - TchOps::reshape(tensor, shape) - } - - fn int_device(tensor: &TchTensor) -> LibTorchDevice { - tensor.tensor.device().into() - } - - fn int_empty( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let tensor = tch::Tensor::empty( - shape.dims.map(|a| a as i64), - (tch::Kind::Int64, (*device).into()), - ); - - TchTensor::new(tensor) - } - - fn int_slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - TchOps::slice(tensor, ranges) - } - - fn int_slice_assign( - tensor: TchTensor, - ranges: [std::ops::Range; D2], - value: TchTensor, - ) -> TchTensor { - TchOps::slice_assign(tensor, ranges, value) - } - - fn int_cat(tensors: Vec>, dim: usize) -> TchTensor { - TchOps::cat(tensors, dim) - } - - fn int_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::equal(lhs, rhs) - } - - fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor { - TchOps::equal_elem(lhs, rhs) - } - - fn int_greater( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::greater(lhs, rhs) - } - - fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor { - TchOps::greater_elem(lhs, rhs) - } - - fn int_greater_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::greater_equal(lhs, rhs) - } - - fn int_greater_equal_elem( - lhs: TchTensor, - rhs: i64, - ) -> TchTensor { - TchOps::greater_equal_elem(lhs, rhs) - } - - fn int_lower( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::lower(lhs, rhs) - } - - fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor { - TchOps::lower_elem(lhs, rhs) - } - - fn int_lower_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::lower_equal(lhs, rhs) - } - - fn int_lower_equal_elem( - lhs: TchTensor, - rhs: i64, - ) -> TchTensor { - TchOps::lower_equal_elem(lhs, rhs) - } - - fn int_add( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::add(lhs, rhs) - } - - fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), - |tensor| tensor.f_add_scalar(rhs).unwrap(), - ) - } - - fn int_sub( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::sub(lhs, rhs) - } - - fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), - |tensor| tensor.f_sub_scalar(rhs).unwrap(), - ) - } - - fn int_mul( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::mul(lhs, rhs) - } - - fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), - |tensor| tensor.f_mul_scalar(rhs).unwrap(), - ) - } - - fn int_div( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::div(lhs, rhs) - } - - fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - let lhs: TchTensor = - TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false)); - let output: TchTensor = lhs.unary_ops( - |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), - |tensor| tensor.f_div_scalar(rhs).unwrap(), - ); - TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) - } - - fn int_neg(tensor: TchTensor) -> TchTensor { - Self::int_mul_scalar(tensor, -1) - } - - fn int_zeros( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device))) - } - - fn int_ones( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device))) - } - - fn int_full( - shape: Shape, - fill_value: i64, - device: & as Backend>::Device, - ) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::full( - shape.dims, - fill_value, - (tch::Kind::Int64, device), - )) - } - - fn int_sum(tensor: TchTensor) -> TchTensor { - TchOps::sum(tensor) - } - - fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::sum_dim(tensor, dim) - } - - fn int_mean(tensor: TchTensor) -> TchTensor { - let tensor: TchTensor = - TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); - let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor); - - TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) - } - - fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { - let tensor: TchTensor = - TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); - - let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor); - - TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) - } - - fn int_gather( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - ) -> TchTensor { - TchOps::gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - ) -> TchTensor { - TchOps::index_select_dim(tensor, dim, indices) - } - - fn int_select_assign( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::select_assign(tensor, dim, indices, value) - } - - fn int_mask_where( - tensor: TchTensor, - mask: TchTensor, - source: TchTensor, - ) -> TchTensor { - TchTensor::binary_ops_tensor( - tensor, - source, - |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), - |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), - |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), - ) - } - - fn int_mask_fill( - tensor: TchTensor, - mask: TchTensor, - value: i64, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), - |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), - ) - } - - fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmax(tensor, dim) - } - - fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmin(tensor, dim) - } - - fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::max_dim(tensor, dim) - } - - fn int_max_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::max_dim_with_indices(tensor, dim) - } - - fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::min_dim(tensor, dim) - } - - fn int_min_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::min_dim_with_indices(tensor, dim) - } - - fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor { - TchOps::clamp_min(tensor, min) - } - - fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor { - TchOps::clamp_max(tensor, max) - } - - fn int_clamp( - tensor: TchTensor, - min: i64, - max: i64, - ) -> TchTensor { - TchOps::clamp(tensor, min, max) - } - - fn int_abs(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) - } - - fn int_into_float(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(E::KIND); - TchTensor::new(tensor) - } - - fn int_swap_dims( - tensor: as Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::IntTensorPrimitive { - TchOps::swap_dims(tensor, dim1, dim2) - } + fn int_from_data( + data: Data, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::from_data(data, (*device).into()) + } + + fn int_shape(tensor: &TchTensor) -> Shape { + tensor.shape() + } + + fn int_repeat( + tensor: TchTensor, + dim: usize, + times: usize, + ) -> TchTensor { + TchOps::repeat(tensor, dim, times) + } + + fn int_into_data(tensor: TchTensor) -> Reader> { + let shape = Self::int_shape(&tensor); + let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); + + Reader::Concrete(Data::new(values.unwrap(), shape)) + } + + fn int_to_device( + tensor: TchTensor, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::new(tensor.tensor.to((*device).into())) + } + + fn int_reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + TchOps::reshape(tensor, shape) + } + + fn int_device(tensor: &TchTensor) -> LibTorchDevice { + tensor.tensor.device().into() + } + + fn int_empty( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let tensor = tch::Tensor::empty( + shape.dims.map(|a| a as i64), + (tch::Kind::Int64, (*device).into()), + ); + + TchTensor::new(tensor) + } + + fn int_slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + TchOps::slice(tensor, ranges) + } + + fn int_slice_assign( + tensor: TchTensor, + ranges: [std::ops::Range; D2], + value: TchTensor, + ) -> TchTensor { + TchOps::slice_assign(tensor, ranges, value) + } + + fn int_cat(tensors: Vec>, dim: usize) -> TchTensor { + TchOps::cat(tensors, dim) + } + + fn int_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::equal(lhs, rhs) + } + + fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor { + TchOps::equal_elem(lhs, rhs) + } + + fn int_greater( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::greater(lhs, rhs) + } + + fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor { + TchOps::greater_elem(lhs, rhs) + } + + fn int_greater_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::greater_equal(lhs, rhs) + } + + fn int_greater_equal_elem( + lhs: TchTensor, + rhs: i64, + ) -> TchTensor { + TchOps::greater_equal_elem(lhs, rhs) + } + + fn int_lower( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::lower(lhs, rhs) + } + + fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor { + TchOps::lower_elem(lhs, rhs) + } + + fn int_lower_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::lower_equal(lhs, rhs) + } + + fn int_lower_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor { + TchOps::lower_equal_elem(lhs, rhs) + } + + fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::add(lhs, rhs) + } + + fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), + |tensor| tensor.f_add_scalar(rhs).unwrap(), + ) + } + + fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::sub(lhs, rhs) + } + + fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), + |tensor| tensor.f_sub_scalar(rhs).unwrap(), + ) + } + + fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::mul(lhs, rhs) + } + + fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), + |tensor| tensor.f_mul_scalar(rhs).unwrap(), + ) + } + + fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::div(lhs, rhs) + } + + fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + let lhs: TchTensor = TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false)); + let output: TchTensor = lhs.unary_ops( + |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), + |tensor| tensor.f_div_scalar(rhs).unwrap(), + ); + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) + } + + fn int_neg(tensor: TchTensor) -> TchTensor { + Self::int_mul_scalar(tensor, -1) + } + + fn int_zeros( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device))) + } + + fn int_ones( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device))) + } + + fn int_full( + shape: Shape, + fill_value: i64, + device: & as Backend>::Device, + ) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::full( + shape.dims, + fill_value, + (tch::Kind::Int64, device), + )) + } + + fn int_sum(tensor: TchTensor) -> TchTensor { + TchOps::sum(tensor) + } + + fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::sum_dim(tensor, dim) + } + + fn int_mean(tensor: TchTensor) -> TchTensor { + let tensor: TchTensor = + TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); + let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor); + + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) + } + + fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { + let tensor: TchTensor = + TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); + + let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor); + + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) + } + + fn int_gather( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + ) -> TchTensor { + TchOps::gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + ) -> TchTensor { + TchOps::index_select_dim(tensor, dim, indices) + } + + fn int_select_assign( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::select_assign(tensor, dim, indices, value) + } + + fn int_mask_where( + tensor: TchTensor, + mask: TchTensor, + source: TchTensor, + ) -> TchTensor { + TchTensor::binary_ops_tensor( + tensor, + source, + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + ) + } + + fn int_mask_fill( + tensor: TchTensor, + mask: TchTensor, + value: i64, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), + |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), + ) + } + + fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmax(tensor, dim) + } + + fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmin(tensor, dim) + } + + fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::max_dim(tensor, dim) + } + + fn int_max_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::max_dim_with_indices(tensor, dim) + } + + fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::min_dim(tensor, dim) + } + + fn int_min_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::min_dim_with_indices(tensor, dim) + } + + fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor { + TchOps::clamp_min(tensor, min) + } + + fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor { + TchOps::clamp_max(tensor, max) + } + + fn int_clamp(tensor: TchTensor, min: i64, max: i64) -> TchTensor { + TchOps::clamp(tensor, min, max) + } + + fn int_abs(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) + } + + fn int_into_float(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(E::KIND); + TchTensor::new(tensor) + } + + fn int_swap_dims( + tensor: as Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::IntTensorPrimitive { + TchOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-tch/src/ops/module.rs b/burn-tch/src/ops/module.rs index c4b0e156ca..db522187d3 100644 --- a/burn-tch/src/ops/module.rs +++ b/burn-tch/src/ops/module.rs @@ -1,286 +1,286 @@ use crate::{element::TchElement, LibTorch, TchTensor}; use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, MaxPool1dWithIndices, MaxPool2dBackward, - MaxPool2dWithIndices, ModuleOps, + ConvOptions, ConvTransposeOptions, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, + ModuleOps, }; impl ModuleOps for LibTorch { - fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor { - let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false); - - TchTensor::new(tensor) - } - - fn embedding_backward( - weights: TchTensor, - output: TchTensor, - indices: TchTensor, - ) -> TchTensor { - let [n_embedding, _d_model] = weights.shape().dims; - let tensor = tch::Tensor::embedding_backward( - &output.tensor, - &indices.tensor, - n_embedding as i64, - -1, - false, - false, - ); - - TchTensor::new(tensor) - } - - fn conv1d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> TchTensor { - let tensor = tch::Tensor::conv1d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.dilation.map(|i| i as i64), - options.groups as i64, - ); - - TchTensor::new(tensor) - } - - fn conv2d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> TchTensor { - let tensor = tch::Tensor::conv2d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.dilation.map(|i| i as i64), - options.groups as i64, - ); - - TchTensor::new(tensor) - } - - fn conv_transpose2d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> TchTensor { - let tensor = tch::Tensor::conv_transpose2d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.padding_out.map(|i| i as i64), - options.groups as i64, - options.dilation.map(|i| i as i64), - ); - - TchTensor::new(tensor) - } - - fn conv_transpose1d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> TchTensor { - let tensor = tch::Tensor::conv_transpose1d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.padding_out.map(|i| i as i64), - options.groups as i64, - options.dilation.map(|i| i as i64), - ); - - TchTensor::new(tensor) - } - - fn avg_pool1d( - x: TchTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> TchTensor { - let tensor = tch::Tensor::avg_pool1d( - &x.tensor, - [kernel_size as i64], - [stride as i64], - [padding as i64], - false, - count_include_pad, - ); - - TchTensor::new(tensor) - } - fn avg_pool2d( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> TchTensor { - let tensor = tch::Tensor::avg_pool2d( - &x.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - false, - count_include_pad, - None, - ); - - TchTensor::new(tensor) - } - - fn avg_pool2d_backward( - x: TchTensor, - grad: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> TchTensor { - let tensor = tch::Tensor::avg_pool2d_backward( - &x.tensor, - &grad.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - false, - count_include_pad, - None, - ); - - TchTensor::new(tensor) - } - - fn max_pool1d( - x: TchTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> TchTensor { - let tensor = tch::Tensor::max_pool1d( - &x.tensor, - kernel_size as i64, - stride as i64, - padding as i64, - dilation as i64, - false, - ); - - TchTensor::new(tensor) - } - - fn max_pool1d_with_indices( - x: TchTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices> { - let (tensor, indices) = tch::Tensor::max_pool1d_with_indices( - &x.tensor, - kernel_size as i64, - stride as i64, - padding as i64, - dilation as i64, - false, - ); - - MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) - } - - fn max_pool2d( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> TchTensor { - let tensor = tch::Tensor::max_pool2d( - &x.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - [dilation[0] as i64, dilation[1] as i64], - false, - ); - - TchTensor::new(tensor) - } - - fn max_pool2d_with_indices( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - let (tensor, indices) = tch::Tensor::max_pool2d_with_indices( - &x.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - [dilation[0] as i64, dilation[1] as i64], - false, - ); - - MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) - } - - fn max_pool2d_with_indices_backward( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: TchTensor, - indices: TchTensor, - ) -> MaxPool2dBackward> { - let grad = tch::Tensor::max_pool2d_with_indices_backward( - &x.tensor, - &output_grad.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - [dilation[0] as i64, dilation[1] as i64], - false, - &indices.tensor, - ); - - MaxPool2dBackward::new(TchTensor::new(grad)) - } - - fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor { - let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64)); - - TchTensor::new(tensor) - } - - fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor { - let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor); - - TchTensor::new(tensor) - } - - fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor { - let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64); - - TchTensor::new(tensor) - } + fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor { + let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false); + + TchTensor::new(tensor) + } + + fn embedding_backward( + weights: TchTensor, + output: TchTensor, + indices: TchTensor, + ) -> TchTensor { + let [n_embedding, _d_model] = weights.shape().dims; + let tensor = tch::Tensor::embedding_backward( + &output.tensor, + &indices.tensor, + n_embedding as i64, + -1, + false, + false, + ); + + TchTensor::new(tensor) + } + + fn conv1d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> TchTensor { + let tensor = tch::Tensor::conv1d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.dilation.map(|i| i as i64), + options.groups as i64, + ); + + TchTensor::new(tensor) + } + + fn conv2d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> TchTensor { + let tensor = tch::Tensor::conv2d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.dilation.map(|i| i as i64), + options.groups as i64, + ); + + TchTensor::new(tensor) + } + + fn conv_transpose2d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> TchTensor { + let tensor = tch::Tensor::conv_transpose2d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.padding_out.map(|i| i as i64), + options.groups as i64, + options.dilation.map(|i| i as i64), + ); + + TchTensor::new(tensor) + } + + fn conv_transpose1d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> TchTensor { + let tensor = tch::Tensor::conv_transpose1d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.padding_out.map(|i| i as i64), + options.groups as i64, + options.dilation.map(|i| i as i64), + ); + + TchTensor::new(tensor) + } + + fn avg_pool1d( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool1d( + &x.tensor, + [kernel_size as i64], + [stride as i64], + [padding as i64], + false, + count_include_pad, + ); + + TchTensor::new(tensor) + } + fn avg_pool2d( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool2d( + &x.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + false, + count_include_pad, + None, + ); + + TchTensor::new(tensor) + } + + fn avg_pool2d_backward( + x: TchTensor, + grad: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool2d_backward( + &x.tensor, + &grad.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + false, + count_include_pad, + None, + ); + + TchTensor::new(tensor) + } + + fn max_pool1d( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> TchTensor { + let tensor = tch::Tensor::max_pool1d( + &x.tensor, + kernel_size as i64, + stride as i64, + padding as i64, + dilation as i64, + false, + ); + + TchTensor::new(tensor) + } + + fn max_pool1d_with_indices( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices> { + let (tensor, indices) = tch::Tensor::max_pool1d_with_indices( + &x.tensor, + kernel_size as i64, + stride as i64, + padding as i64, + dilation as i64, + false, + ); + + MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) + } + + fn max_pool2d( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> TchTensor { + let tensor = tch::Tensor::max_pool2d( + &x.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [dilation[0] as i64, dilation[1] as i64], + false, + ); + + TchTensor::new(tensor) + } + + fn max_pool2d_with_indices( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + let (tensor, indices) = tch::Tensor::max_pool2d_with_indices( + &x.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [dilation[0] as i64, dilation[1] as i64], + false, + ); + + MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) + } + + fn max_pool2d_with_indices_backward( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: TchTensor, + indices: TchTensor, + ) -> MaxPool2dBackward> { + let grad = tch::Tensor::max_pool2d_with_indices_backward( + &x.tensor, + &output_grad.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [dilation[0] as i64, dilation[1] as i64], + false, + &indices.tensor, + ); + + MaxPool2dBackward::new(TchTensor::new(grad)) + } + + fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor { + let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64)); + + TchTensor::new(tensor) + } + + fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor { + let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor); + + TchTensor::new(tensor) + } + + fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor { + let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64); + + TchTensor::new(tensor) + } } diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index 2326b8a76f..25c0a3ac17 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -1,445 +1,438 @@ use super::TchOps; use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor}; use burn_tensor::{ - backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape, + backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape, }; use std::ops::Range; impl TensorOps for LibTorch { - fn from_data(data: Data, device: &LibTorchDevice) -> TchTensor { - TchTensor::from_data(data, (*device).into()) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: &LibTorchDevice, - ) -> TchTensor { - match distribution { - Distribution::Default => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor - .mut_ops(|tensor| tensor.rand_like_out(tensor)) - .unwrap() - } - Distribution::Bernoulli(prob) => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor - .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap()) - .unwrap() - } - Distribution::Uniform(from, to) => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor - .mut_ops(|tensor| tensor.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap())) - .unwrap() - } - Distribution::Normal(mean, std) => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap() - } - } - } - - fn arange(range: Range, device: &LibTorchDevice) -> TchTensor { - let device: tch::Device = (*device).into(); - let mut tensor = tch::Tensor::arange( - range.end as i64 - range.start as i64, - (tch::Kind::Int64, device), - ); - - if range.start != 0 { - tensor = tensor.f_add_scalar_(range.start as i64).unwrap(); - } - - TchTensor::new(tensor) - } - - fn repeat( - tensor: TchTensor, - dim: usize, - times: usize, - ) -> TchTensor { - TchOps::repeat(tensor, dim, times) - } - - fn zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device))) - } - - fn ones(shape: Shape, device: &LibTorchDevice) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device))) - } - - fn shape(tensor: & as Backend>::TensorPrimitive) -> Shape { - tensor.shape() - } - - fn into_data( - tensor: as Backend>::TensorPrimitive, - ) -> Reader as Backend>::FloatElem, D>> { - let shape = Self::shape(&tensor); - let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()])); - let values: Result, tch::TchError> = tensor.tensor.try_into(); - - Reader::Concrete(Data::new(values.unwrap(), shape)) - } - - fn device(tensor: &TchTensor) -> LibTorchDevice { - tensor.tensor.device().into() - } - - fn to_device( - tensor: TchTensor, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) - } - - fn empty( - shape: Shape, - device: & as Backend>::Device, - ) -> as Backend>::TensorPrimitive { - let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into())); - - TchTensor::new(tensor) - } - - fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::add(lhs, rhs) - } - - fn add_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), - |tensor| tensor.f_add_scalar(rhs).unwrap(), - ) - } - - fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::sub(lhs, rhs) - } - - fn sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), - |tensor| tensor.f_sub_scalar(rhs).unwrap(), - ) - } - - fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::mul(lhs, rhs) - } - - fn mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), - |tensor| tensor.f_mul_scalar(rhs).unwrap(), - ) - } - - fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::div(lhs, rhs) - } - - fn div_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), - |tensor| tensor.f_div_scalar(rhs).unwrap(), - ) - } - - fn matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - let tensor = lhs.tensor.matmul(&rhs.tensor); - TchTensor::new(tensor) - } - - fn neg(tensor: TchTensor) -> TchTensor { - Self::mul_scalar(tensor, (-1f32).elem::()) - } - - fn recip(tensor: TchTensor) -> TchTensor { - TchTensor::new(tensor.tensor.reciprocal()) - } - - fn swap_dims( - tensor: TchTensor, - dim1: usize, - dim2: usize, - ) -> TchTensor { - TchOps::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - TchOps::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - ) -> TchTensor { - TchOps::gather(dim, tensor, indices) - } - - fn scatter( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::scatter(dim, tensor, indices, value) - } - - fn select( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - ) -> TchTensor { - TchOps::index_select_dim(tensor, dim, indices) - } - - fn select_assign( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::select_assign(tensor, dim, indices, value) - } - - fn slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - TchOps::slice(tensor, ranges) - } - - fn slice_assign( - tensor: TchTensor, - ranges: [Range; D2], - value: TchTensor, - ) -> as Backend>::TensorPrimitive { - TchOps::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: TchTensor, - mask: TchTensor, - value: TchTensor, - ) -> TchTensor { - let output = value.tensor.where_self(&mask.tensor, &tensor.tensor); - - TchTensor::new(output) - } - - fn mask_fill( - tensor: TchTensor, - mask: TchTensor, - value: E, - ) -> TchTensor { - let value: f64 = value.elem(); - - tensor.unary_ops( - |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), - |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), - ) - } - - fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::equal(lhs, rhs) - } - - fn equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::equal_elem(lhs, rhs.elem::()) - } - - fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::greater(lhs, rhs) - } - - fn greater_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::greater_elem(lhs, rhs.elem::()) - } - - fn greater_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::greater_equal(lhs, rhs) - } - - fn greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::greater_equal_elem(lhs, rhs.elem::()) - } - - fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::lower(lhs, rhs) - } - - fn lower_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::lower_elem(lhs, rhs.elem::()) - } - - fn lower_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::lower_equal(lhs, rhs) - } - - fn lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::lower_equal_elem(lhs, rhs.elem::()) - } - - fn mean(tensor: TchTensor) -> TchTensor { - TchOps::mean(tensor) - } - - fn sum(tensor: TchTensor) -> TchTensor { - TchOps::sum(tensor) - } - - fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::mean_dim(tensor, dim) - } - - fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::sum_dim(tensor, dim) - } - - fn to_full_precision(tensor: &TchTensor) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.to_kind(tch::Kind::Float); - - TchTensor::from_existing(tensor, storage) - } - - fn from_full_precision(tensor: TchTensor) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.to_kind(E::KIND); - - TchTensor::from_existing(tensor, storage) - } - - fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmax(tensor, dim) - } - - fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmin(tensor, dim) - } - - fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::max_dim(tensor, dim) - } - - fn max_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::max_dim_with_indices(tensor, dim) - } - - fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::min_dim(tensor, dim) - } - - fn min_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::min_dim_with_indices(tensor, dim) - } - - fn exp(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp()) - } - - fn log(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log()) - } - - fn log1p(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p()) - } - - fn powf(tensor: TchTensor, value: f32) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.f_pow_(value as f64).unwrap(), - |tensor| tensor.pow_tensor_scalar(value as f64), - ) - } - - fn sqrt(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt()) - } - - fn abs(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) - } - - fn cos(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos()) - } - - fn sin(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin()) - } - - fn tanh(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) - } - - fn erf(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) - } - - fn cat(tensors: Vec>, dim: usize) -> TchTensor { - TchOps::cat(tensors, dim) - } - - fn clamp_min( - tensor: TchTensor, - min: E, - ) -> as Backend>::TensorPrimitive { - TchOps::clamp_min(tensor, min.elem::()) - } - - fn clamp_max( - tensor: as Backend>::TensorPrimitive, - max: as Backend>::FloatElem, - ) -> as Backend>::TensorPrimitive { - TchOps::clamp_max(tensor, max.elem::()) - } - - fn clamp( - tensor: as Backend>::TensorPrimitive, - min: as Backend>::FloatElem, - max: as Backend>::FloatElem, - ) -> as Backend>::TensorPrimitive { - TchOps::clamp(tensor, min.elem::(), max.elem::()) - } - - fn into_int(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(tch::Kind::Int64); - TchTensor::new(tensor) - } + fn from_data(data: Data, device: &LibTorchDevice) -> TchTensor { + TchTensor::from_data(data, (*device).into()) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: &LibTorchDevice, + ) -> TchTensor { + match distribution { + Distribution::Default => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor + .mut_ops(|tensor| tensor.rand_like_out(tensor)) + .unwrap() + } + Distribution::Bernoulli(prob) => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor + .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap()) + .unwrap() + } + Distribution::Uniform(from, to) => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor + .mut_ops(|tensor| tensor.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap())) + .unwrap() + } + Distribution::Normal(mean, std) => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap() + } + } + } + + fn arange(range: Range, device: &LibTorchDevice) -> TchTensor { + let device: tch::Device = (*device).into(); + let mut tensor = tch::Tensor::arange( + range.end as i64 - range.start as i64, + (tch::Kind::Int64, device), + ); + + if range.start != 0 { + tensor = tensor.f_add_scalar_(range.start as i64).unwrap(); + } + + TchTensor::new(tensor) + } + + fn repeat(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { + TchOps::repeat(tensor, dim, times) + } + + fn zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device))) + } + + fn ones(shape: Shape, device: &LibTorchDevice) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device))) + } + + fn shape(tensor: & as Backend>::TensorPrimitive) -> Shape { + tensor.shape() + } + + fn into_data( + tensor: as Backend>::TensorPrimitive, + ) -> Reader as Backend>::FloatElem, D>> { + let shape = Self::shape(&tensor); + let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.try_into(); + + Reader::Concrete(Data::new(values.unwrap(), shape)) + } + + fn device(tensor: &TchTensor) -> LibTorchDevice { + tensor.tensor.device().into() + } + + fn to_device( + tensor: TchTensor, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::new(tensor.tensor.to((*device).into())) + } + + fn empty( + shape: Shape, + device: & as Backend>::Device, + ) -> as Backend>::TensorPrimitive { + let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into())); + + TchTensor::new(tensor) + } + + fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::add(lhs, rhs) + } + + fn add_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), + |tensor| tensor.f_add_scalar(rhs).unwrap(), + ) + } + + fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::sub(lhs, rhs) + } + + fn sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), + |tensor| tensor.f_sub_scalar(rhs).unwrap(), + ) + } + + fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::mul(lhs, rhs) + } + + fn mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), + |tensor| tensor.f_mul_scalar(rhs).unwrap(), + ) + } + + fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::div(lhs, rhs) + } + + fn div_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), + |tensor| tensor.f_div_scalar(rhs).unwrap(), + ) + } + + fn matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + let tensor = lhs.tensor.matmul(&rhs.tensor); + TchTensor::new(tensor) + } + + fn neg(tensor: TchTensor) -> TchTensor { + Self::mul_scalar(tensor, (-1f32).elem::()) + } + + fn recip(tensor: TchTensor) -> TchTensor { + TchTensor::new(tensor.tensor.reciprocal()) + } + + fn swap_dims( + tensor: TchTensor, + dim1: usize, + dim2: usize, + ) -> TchTensor { + TchOps::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + TchOps::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + ) -> TchTensor { + TchOps::gather(dim, tensor, indices) + } + + fn scatter( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::scatter(dim, tensor, indices, value) + } + + fn select( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + ) -> TchTensor { + TchOps::index_select_dim(tensor, dim, indices) + } + + fn select_assign( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::select_assign(tensor, dim, indices, value) + } + + fn slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + TchOps::slice(tensor, ranges) + } + + fn slice_assign( + tensor: TchTensor, + ranges: [Range; D2], + value: TchTensor, + ) -> as Backend>::TensorPrimitive { + TchOps::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: TchTensor, + mask: TchTensor, + value: TchTensor, + ) -> TchTensor { + let output = value.tensor.where_self(&mask.tensor, &tensor.tensor); + + TchTensor::new(output) + } + + fn mask_fill( + tensor: TchTensor, + mask: TchTensor, + value: E, + ) -> TchTensor { + let value: f64 = value.elem(); + + tensor.unary_ops( + |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), + |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), + ) + } + + fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::equal(lhs, rhs) + } + + fn equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::equal_elem(lhs, rhs.elem::()) + } + + fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::greater(lhs, rhs) + } + + fn greater_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::greater_elem(lhs, rhs.elem::()) + } + + fn greater_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::greater_equal(lhs, rhs) + } + + fn greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::greater_equal_elem(lhs, rhs.elem::()) + } + + fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::lower(lhs, rhs) + } + + fn lower_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::lower_elem(lhs, rhs.elem::()) + } + + fn lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::lower_equal(lhs, rhs) + } + + fn lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::lower_equal_elem(lhs, rhs.elem::()) + } + + fn mean(tensor: TchTensor) -> TchTensor { + TchOps::mean(tensor) + } + + fn sum(tensor: TchTensor) -> TchTensor { + TchOps::sum(tensor) + } + + fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::mean_dim(tensor, dim) + } + + fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::sum_dim(tensor, dim) + } + + fn to_full_precision(tensor: &TchTensor) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.to_kind(tch::Kind::Float); + + TchTensor::from_existing(tensor, storage) + } + + fn from_full_precision(tensor: TchTensor) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.to_kind(E::KIND); + + TchTensor::from_existing(tensor, storage) + } + + fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmax(tensor, dim) + } + + fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmin(tensor, dim) + } + + fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::max_dim(tensor, dim) + } + + fn max_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::max_dim_with_indices(tensor, dim) + } + + fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::min_dim(tensor, dim) + } + + fn min_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::min_dim_with_indices(tensor, dim) + } + + fn exp(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp()) + } + + fn log(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log()) + } + + fn log1p(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p()) + } + + fn powf(tensor: TchTensor, value: f32) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_pow_(value as f64).unwrap(), + |tensor| tensor.pow_tensor_scalar(value as f64), + ) + } + + fn sqrt(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt()) + } + + fn abs(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) + } + + fn cos(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos()) + } + + fn sin(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin()) + } + + fn tanh(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) + } + + fn erf(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) + } + + fn cat(tensors: Vec>, dim: usize) -> TchTensor { + TchOps::cat(tensors, dim) + } + + fn clamp_min( + tensor: TchTensor, + min: E, + ) -> as Backend>::TensorPrimitive { + TchOps::clamp_min(tensor, min.elem::()) + } + + fn clamp_max( + tensor: as Backend>::TensorPrimitive, + max: as Backend>::FloatElem, + ) -> as Backend>::TensorPrimitive { + TchOps::clamp_max(tensor, max.elem::()) + } + + fn clamp( + tensor: as Backend>::TensorPrimitive, + min: as Backend>::FloatElem, + max: as Backend>::FloatElem, + ) -> as Backend>::TensorPrimitive { + TchOps::clamp(tensor, min.elem::(), max.elem::()) + } + + fn into_int(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(tch::Kind::Int64); + TchTensor::new(tensor) + } } diff --git a/burn-tch/src/tensor.rs b/burn-tch/src/tensor.rs index 1ca6f44f36..1f3e73b9c0 100644 --- a/burn-tch/src/tensor.rs +++ b/burn-tch/src/tensor.rs @@ -9,61 +9,61 @@ pub type StorageRef = Rc<*mut c_void>; /// A tensor that uses the tch backend. #[derive(Debug, PartialEq)] pub struct TchTensor { - /// Handle to the tensor. Call methods on this field. - pub tensor: tch::Tensor, - /// The tensor's storage - pub storage: StorageRef, - phantom: PhantomData, + /// Handle to the tensor. Call methods on this field. + pub tensor: tch::Tensor, + /// The tensor's storage + pub storage: StorageRef, + phantom: PhantomData, } impl TchTensor { - /// Create a new tensor. - /// - /// Note that if the tensor was created from an operation that may reuse the same tensor - /// storage as the parent, you should use [from_existing](TchTensor::from_existing) - /// instead. - pub fn new(tensor: tch::Tensor) -> Self { - let data = Rc::new(tensor.data_ptr()); - - Self { - tensor, - phantom: PhantomData, - storage: data, - } + /// Create a new tensor. + /// + /// Note that if the tensor was created from an operation that may reuse the same tensor + /// storage as the parent, you should use [from_existing](TchTensor::from_existing) + /// instead. + pub fn new(tensor: tch::Tensor) -> Self { + let data = Rc::new(tensor.data_ptr()); + + Self { + tensor, + phantom: PhantomData, + storage: data, } - - /// Create a tensor that was created from an operation executed on a parent tensor. - /// - /// If the child tensor shared the same storage as its parent, it will be cloned, effectively - /// tracking how much tensors point to the same memory space. - pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self { - let storage_child = tensor.data_ptr(); - - let storage = match storage_child == *storage_parent { - true => storage_parent.clone(), - false => Rc::new(storage_child), - }; - - Self { - tensor, - storage, - phantom: PhantomData, - } + } + + /// Create a tensor that was created from an operation executed on a parent tensor. + /// + /// If the child tensor shared the same storage as its parent, it will be cloned, effectively + /// tracking how much tensors point to the same memory space. + pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self { + let storage_child = tensor.data_ptr(); + + let storage = match storage_child == *storage_parent { + true => storage_parent.clone(), + false => Rc::new(storage_child), + }; + + Self { + tensor, + storage, + phantom: PhantomData, } + } } impl std::ops::Add for TchTensor { - type Output = Self; + type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - LibTorch::add(self, rhs) - } + fn add(self, rhs: Self) -> Self::Output { + LibTorch::add(self, rhs) + } } impl TchTensor { - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.tensor.size()) - } + pub(crate) fn shape(&self) -> Shape { + Shape::from(self.tensor.size()) + } } // This is safe since we don't use autodiff from LibTorch. @@ -73,209 +73,209 @@ unsafe impl Send for TchTensor {} unsafe impl Sync for TchTensor {} impl TchTensor { - /// Execute an operation on a tensor if the data can be reused. - pub fn mut_ops< - F: Fn(&mut tch::Tensor) -> tch::Tensor, - EOut: tch::kind::Element, - const D_OUT: usize, - >( - &mut self, - func: F, - ) -> Option> { - if Rc::strong_count(&self.storage) > 1 { - return None; - } - - let data = self.storage.clone(); - Some(TchTensor::from_existing(func(&mut self.tensor), data)) + /// Execute an operation on a tensor if the data can be reused. + pub fn mut_ops< + F: Fn(&mut tch::Tensor) -> tch::Tensor, + EOut: tch::kind::Element, + const D_OUT: usize, + >( + &mut self, + func: F, + ) -> Option> { + if Rc::strong_count(&self.storage) > 1 { + return None; } - /// Execute a unary ops reusing the tensor data if possible. - pub fn unary_ops( - self, - fown: FOwn, - fref: FRef, - ) -> TchTensor - where - FOwn: Fn(tch::Tensor) -> tch::Tensor, - FRef: Fn(&tch::Tensor) -> tch::Tensor, - { - if Rc::strong_count(&self.storage) > 1 { - return TchTensor::from_existing(fref(&self.tensor), self.storage); - } - TchTensor::from_existing(fown(self.tensor), self.storage) + let data = self.storage.clone(); + Some(TchTensor::from_existing(func(&mut self.tensor), data)) + } + /// Execute a unary ops reusing the tensor data if possible. + pub fn unary_ops( + self, + fown: FOwn, + fref: FRef, + ) -> TchTensor + where + FOwn: Fn(tch::Tensor) -> tch::Tensor, + FRef: Fn(&tch::Tensor) -> tch::Tensor, + { + if Rc::strong_count(&self.storage) > 1 { + return TchTensor::from_existing(fref(&self.tensor), self.storage); } - /// Execute a binary ops reusing the tensor data if possible. - pub fn binary_ops_tensor( - mut lhs: Self, - mut rhs: Self, - flmut: FLMut, - frmut: FRMut, - fref: FRef, - ) -> TchTensor - where - FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor, - FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor, - FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor, - { - let lhs_num_elems = lhs.shape().num_elements(); - let rhs_num_elems = rhs.shape().num_elements(); - - let safe_mut_lhs = lhs_num_elems > rhs_num_elems; - let safe_mut_rhs = rhs_num_elems > lhs_num_elems; - - if safe_mut_lhs { - if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) { - return output; - } - } + TchTensor::from_existing(fown(self.tensor), self.storage) + } + + /// Execute a binary ops reusing the tensor data if possible. + pub fn binary_ops_tensor( + mut lhs: Self, + mut rhs: Self, + flmut: FLMut, + frmut: FRMut, + fref: FRef, + ) -> TchTensor + where + FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor, + FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor, + FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor, + { + let lhs_num_elems = lhs.shape().num_elements(); + let rhs_num_elems = rhs.shape().num_elements(); + + let safe_mut_lhs = lhs_num_elems > rhs_num_elems; + let safe_mut_rhs = rhs_num_elems > lhs_num_elems; + + if safe_mut_lhs { + if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) { + return output; + } + } - if safe_mut_rhs { - if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) { - return output; - } - } + if safe_mut_rhs { + if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) { + return output; + } + } - let storage = lhs.storage; - let tensor = fref(&lhs.tensor, &rhs.tensor); + let storage = lhs.storage; + let tensor = fref(&lhs.tensor, &rhs.tensor); - TchTensor::from_existing(tensor, storage) - } + TchTensor::from_existing(tensor, storage) + } } impl Clone for TchTensor { - fn clone(&self) -> Self { - Self { - tensor: self.tensor.shallow_clone(), - phantom: PhantomData, - storage: self.storage.clone(), - } + fn clone(&self) -> Self { + Self { + tensor: self.tensor.shallow_clone(), + phantom: PhantomData, + storage: self.storage.clone(), } + } } /// A shape that can be used by LibTorch. pub struct TchShape { - /// The shape's dimensions. - pub dims: [i64; D], + /// The shape's dimensions. + pub dims: [i64; D], } impl From> for TchShape { - fn from(shape: Shape) -> Self { - let mut dims = [0; D]; - for (i, dim) in dims.iter_mut().enumerate().take(D) { - *dim = shape.dims[i] as i64; - } - TchShape { dims } + fn from(shape: Shape) -> Self { + let mut dims = [0; D]; + for (i, dim) in dims.iter_mut().enumerate().take(D) { + *dim = shape.dims[i] as i64; } + TchShape { dims } + } } impl TchTensor { - /// Creates a new tensor from a shape and a device. - /// - /// # Arguments - /// - /// * `data` - The tensor's data. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// A new tensor. - pub fn from_data(data: Data, device: tch::Device) -> Self { - let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device); - let shape_tch = TchShape::from(data.shape); - let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND); - - Self::new(tensor) - } + /// Creates a new tensor from a shape and a device. + /// + /// # Arguments + /// + /// * `data` - The tensor's data. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// A new tensor. + pub fn from_data(data: Data, device: tch::Device) -> Self { + let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device); + let shape_tch = TchShape::from(data.shape); + let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND); + + Self::new(tensor) + } } #[cfg(test)] mod utils { - use super::*; - use crate::{backend::LibTorch, element::TchElement}; - - impl TchTensor { - pub(crate) fn into_data(self) -> Data - where - P: tch::kind::Element, - { - as TensorOps>>::into_data(self).read() - } + use super::*; + use crate::{backend::LibTorch, element::TchElement}; + + impl TchTensor { + pub(crate) fn into_data(self) -> Data + where + P: tch::kind::Element, + { + as TensorOps>>::into_data(self).read() } + } } impl TchTensor { - /// Creates an empty tensor from a shape and a device. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// A new empty tensor. - pub fn empty(shape: Shape, device: LibTorchDevice) -> Self { - let shape_tch = TchShape::from(shape); - let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into())); - - Self::new(tensor) - } + /// Creates an empty tensor from a shape and a device. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// A new empty tensor. + pub fn empty(shape: Shape, device: LibTorchDevice) -> Self { + let shape_tch = TchShape::from(shape); + let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into())); + + Self::new(tensor) + } } #[cfg(test)] mod tests { - use super::*; - use burn_tensor::{Distribution, Tensor}; - use rand::prelude::StdRng; - use rand::SeedableRng; - - #[test] - fn should_support_into_and_from_data_1d() { - let data_expected = Data::::random( - Shape::new([3]), - Distribution::Default, - &mut StdRng::from_entropy(), - ); - let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_2d() { - let data_expected = Data::::random( - Shape::new([2, 3]), - Distribution::Default, - &mut StdRng::from_entropy(), - ); - let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_not_update_inplace_after_reshape() { - let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); - let tensor_2 = tensor_1.clone(); - - let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0); - - assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); - } - - #[test] - fn should_not_update_inplace_after_slice() { - let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); - let tensor_2 = tensor_1.clone(); - - let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0); - - assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); - } + use super::*; + use burn_tensor::{Distribution, Tensor}; + use rand::prelude::StdRng; + use rand::SeedableRng; + + #[test] + fn should_support_into_and_from_data_1d() { + let data_expected = Data::::random( + Shape::new([3]), + Distribution::Default, + &mut StdRng::from_entropy(), + ); + let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_2d() { + let data_expected = Data::::random( + Shape::new([2, 3]), + Distribution::Default, + &mut StdRng::from_entropy(), + ); + let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_not_update_inplace_after_reshape() { + let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); + let tensor_2 = tensor_1.clone(); + + let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0); + + assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); + } + + #[test] + fn should_not_update_inplace_after_slice() { + let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); + let tensor_2 = tensor_1.clone(); + + let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0); + + assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); + } } diff --git a/burn-tensor-testgen/src/lib.rs b/burn-tensor-testgen/src/lib.rs index d67d13114a..5ddbcc2a8e 100644 --- a/burn-tensor-testgen/src/lib.rs +++ b/burn-tensor-testgen/src/lib.rs @@ -4,22 +4,22 @@ use quote::{format_ident, quote}; #[allow(missing_docs)] #[proc_macro_attribute] pub fn testgen(attr: TokenStream, item: TokenStream) -> TokenStream { - let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item); - let attr: proc_macro2::TokenStream = proc_macro2::TokenStream::from(attr); - let macro_ident = format_ident!("testgen_{}", attr.to_string()); + let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item); + let attr: proc_macro2::TokenStream = proc_macro2::TokenStream::from(attr); + let macro_ident = format_ident!("testgen_{}", attr.to_string()); - let macro_gen = quote! { - #[macro_export] - macro_rules! #macro_ident { - () => { - mod #attr { - use super::*; + let macro_gen = quote! { + #[macro_export] + macro_rules! #macro_ident { + () => { + mod #attr { + use super::*; - #item - } - }; - } - }; + #item + } + }; + } + }; - macro_gen.into() + macro_gen.into() } diff --git a/burn-tensor/src/tensor/activation/base.rs b/burn-tensor/src/tensor/activation/base.rs index 5fa502422c..bcc3f2593e 100644 --- a/burn-tensor/src/tensor/activation/base.rs +++ b/burn-tensor/src/tensor/activation/base.rs @@ -5,12 +5,12 @@ use crate::{ElementPrecision, Precision}; /// Applies the rectified linear unit function. pub fn relu(tensor: Tensor) -> Tensor { - tensor.relu() + tensor.relu() } /// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf). pub fn gelu(tensor: Tensor) -> Tensor { - Tensor::from_primitive(B::gelu(tensor.primitive)) + Tensor::from_primitive(B::gelu(tensor.primitive)) } /// Applies the softmax function on the input tensor along the given dimension. @@ -22,13 +22,33 @@ pub fn gelu(tensor: Tensor) -> Tensor { /// The dimension argument `dim` specifies the dimension along which the function will be computed. /// It must in the range of `0` and `D-1`. pub fn softmax(tensor: Tensor, dim: usize) -> Tensor { - check!(TensorCheck::dim_ops::("softmax", dim)); + check!(TensorCheck::dim_ops::("softmax", dim)); - let tensor = tensor.clone() - tensor.detach().max_dim(dim); - let tensor = tensor.exp(); - let tensor_tmp = tensor.clone().sum_dim(dim); + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor = tensor.exp(); + let tensor_tmp = tensor.clone().sum_dim(dim); - tensor.div(tensor_tmp) + tensor.div(tensor_tmp) +} + +/// Applies the "quiet softmax" function on the input tensor along the given dimension. +/// This function is similar to the softmax function, but it allows for "no selection", e.g., +/// all outputs can tend to zero. +/// +/// `softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]` +/// +/// # Notes +/// +/// The dimension argument `dim` specifies the dimension along which the function will be computed. +/// It must in the range of `0` and `D-1`. +pub fn quiet_softmax(tensor: Tensor, dim: usize) -> Tensor { + check!(TensorCheck::dim_ops::("softmax", dim)); + + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor = tensor.exp(); + let tensor_tmp = tensor.clone().sum_dim(dim); + + tensor.div(tensor_tmp + 1) } /// Applies the log softmax function on the input tensor along the given dimension. @@ -40,37 +60,37 @@ pub fn softmax(tensor: Tensor, dim: usize) -> /// The dimension argument `dim` specifies the dimension along which the function will be computed. /// It must in the range of `0` and `D-1`. pub fn log_softmax(tensor: Tensor, dim: usize) -> Tensor { - check!(TensorCheck::dim_ops::("log softmax", dim)); + check!(TensorCheck::dim_ops::("log softmax", dim)); - let tensor = tensor.clone() - tensor.detach().max_dim(dim); - let tensor_tmp = tensor.clone().exp().sum_dim(dim).log(); + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor_tmp = tensor.clone().exp().sum_dim(dim).log(); - tensor.sub(tensor_tmp) + tensor.sub(tensor_tmp) } /// Applies the sigmoid function. pub fn sigmoid(tensor: Tensor) -> Tensor { - log_sigmoid(tensor).exp() + log_sigmoid(tensor).exp() } /// Applies the log sigmoid function. pub fn log_sigmoid(tensor: Tensor) -> Tensor { - match B::FloatElem::precision() { - Precision::Half => { - let tensor_full = tensor.to_full_precision(); - let tensor_tmp = tensor_full.neg().exp().add_scalar(1.0_f32).log().neg(); - Tensor::from_full_precision(tensor_tmp) - } - _ => tensor.neg().exp().add_scalar(1.0_f32).log().neg(), + match B::FloatElem::precision() { + Precision::Half => { + let tensor_full = tensor.to_full_precision(); + let tensor_tmp = tensor_full.neg().exp().add_scalar(1.0_f32).log().neg(); + Tensor::from_full_precision(tensor_tmp) } + _ => tensor.neg().exp().add_scalar(1.0_f32).log().neg(), + } } /// Applies the silu function pub fn silu(tensor: Tensor) -> Tensor { - tensor.clone().mul(sigmoid(tensor)) + tensor.clone().mul(sigmoid(tensor)) } /// Applies the tanh function pub fn tanh(tensor: Tensor) -> Tensor { - tensor.tanh() + tensor.tanh() } diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index 3a12b834e3..b80bd023f1 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -13,658 +13,658 @@ use burn_common::{reader::Reader, stub::Mutex}; use core::{fmt::Debug, ops::Range}; use crate::{ - backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind, + backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind, }; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] pub struct Tensor where - B: Backend, - K: TensorKind, + B: Backend, + K: TensorKind, { - pub(crate) primitive: K::Primitive, + pub(crate) primitive: K::Primitive, } impl Tensor where - B: Backend, - K: BasicOps, + B: Backend, + K: BasicOps, { - /// Converts the tensor into a primitive tensor. - pub fn into_primitive(self) -> K::Primitive { - self.primitive - } - - /// Converts from a primitive tensor into a tensor. - pub fn from_primitive(tensor: K::Primitive) -> Self { - Self::new(tensor) - } - - /// Create an empty tensor of the given shape. - pub fn empty>>(shape: S) -> Self { - Self::empty_device(shape, &B::Device::default()) - } - - /// Create an empty tensor of the given shape. - pub fn empty_device>>(shape: S, device: &B::Device) -> Self { - Self::new(K::empty(shape.into(), device)) - } - - /// Returns the dimensions of the current tensor. - /// - /// Equivalent to `tensor.shape().dims`. - pub fn dims(&self) -> [usize; D] { - Self::shape(self).dims - } - - /// Returns the shape of the current tensor. - pub fn shape(&self) -> Shape { - K::shape(&self.primitive) - } - - /// Reshape the tensor to have the given shape. - /// - /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]` - /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12]. - /// - /// A `0` in the shape instructs to keep the current dimension from the original tensor, - /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4]. - /// This is useful when reshaping tensors with unknown dimensions and combining with `-1` - /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor - /// with [1, 3, 4] dimensions to [1, 12]. - /// - /// # Arguments - /// - `shape`: The new shape of the tensor. - /// - /// # Panics - /// - If the tensor contains more than one `-1` in the shape. - /// - If the tensor contains values that are not positive (other than -1). - /// - If the shape does not match the number of elements of the original shape. - /// - /// # Example - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let tensor = Tensor::::ones([2, 3, 4]); - /// // Given a 3D tensor with dimensions (2, 3, 4), reshape it to (2, 12) - /// let reshaped_tensor: Tensor:: = tensor.reshape([2, -1]); - /// // The resulting tensor will have dimensions (2, 12). - /// println!("{:?}", reshaped_tensor.shape()); - /// } - /// ``` - pub fn reshape>(self, shape: S) -> Tensor { - // Convert reshape args to shape - let shape = shape.into_shape(&self); - Tensor::new(K::reshape::(self.primitive, shape)) - } - - /// Transpose the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - pub fn transpose(self) -> Tensor { - Tensor::new(K::transpose(self.primitive)) - } - - /// Swaps two dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { - Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) - } - - /// Flatten the tensor along a given range of dimensions. - /// - /// This function collapses the specified range of dimensions into a single dimension, - /// effectively flattening the tensor in that range. - /// - /// # Arguments - /// - /// - `start_dim`: The starting dimension of the range to be flattened. - /// - `end_dim`: The ending dimension of the range to be flattened (inclusive). - /// - /// # Type Parameters - /// - /// - `D2`: The resulting number of dimensions in the flattened tensor. - /// - /// # Returns - /// - /// A new `Tensor` instance with the specified range of dimensions flattened. - /// - /// # Example - /// - /// ```rust - /// - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 4])); - /// - /// // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2: - /// let flattened_tensor: Tensor:: = tensor.flatten(1, 2); - /// - /// // The resulting tensor will have dimensions (2, 12). - /// println!("{:?}", flattened_tensor.shape()); - /// } - /// - /// ``` - pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { - check!(TensorCheck::flatten::(start_dim, end_dim)); - - let current_dims = self.shape().dims; - let mut new_dims: [usize; D2] = [0; D2]; - let mut flatten_dims = 1; - - for i in current_dims[start_dim..=end_dim].iter() { - flatten_dims *= i; - } - - new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]); - new_dims[start_dim] = flatten_dims; - new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]); - - Tensor::new(K::reshape::(self.primitive, new_dims.into())) - } - - /// Squeeze the tensor along the given dimension, removing the specified dimension - /// of size one, and effectively reducing the rank of the tensor by one. - /// - /// # Arguments - /// - /// - `dim`: The dimension to be squeezed. - /// - /// # Type Parameters - /// - /// - 'D2': The resulting number of dimensions in the squeezed tensor. - /// - /// # Returns - /// - /// A new `Tensor` instance with the specified dimenension removed. - /// - /// # Example - /// - /// ```rust - /// - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 1, 4])); - /// - /// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1 - /// let squeezed_tensor: Tensor:: = tensor.squeeze(1); - /// - /// // Resulting tensor will have dimensions (2, 4) - /// println!("{:?}", squeezed_tensor.shape()); - /// } - /// ``` - pub fn squeeze(self, dim: usize) -> Tensor { - check!(TensorCheck::squeeze::(dim, &self.shape().dims)); - - let current_dims = self.shape().dims; - let mut new_dims: [usize; D2] = [0; D2]; - - new_dims[..dim].copy_from_slice(¤t_dims[..dim]); - new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]); - - Tensor::new(K::reshape::(self.primitive, new_dims.into())) - } - - /// Unsqueeze the current tensor. Create new dimensions to fit the given size. - /// - /// If the output size is higher than the current tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([3, 3])); - /// let tensor = tensor.unsqueeze::<4>(); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [1, 1, 3, 3] } - /// } - /// ``` - pub fn unsqueeze(self) -> Tensor { - check!(TensorCheck::unsqueeze::()); - - let mut dims = [1; D2]; - let num_ones = D2 - D; - let shape = self.shape(); - - dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]); - - let shape = Shape::new(dims); - self.reshape(shape) - } - - /// Returns a tensor containing the elements selected from the given ranges. - /// - /// # Panics - /// - /// If a range exceeds the number of elements on a dimension. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); - /// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]); - /// println!("{:?}", tensor_slices.dims()); // [1, 3, 2] - /// - /// } - /// ``` - pub fn slice(self, ranges: [core::ops::Range; D2]) -> Self { - check!(TensorCheck::slice(&self.shape(), &ranges)); - Self::new(K::slice(self.primitive, ranges)) - } - - /// Returns a copy of the current tensor with the selected elements changed to the new ones at - /// the selected indices. - /// - /// # Panics - /// - /// - If a range exceeds the number of elements on a dimension. - /// - If the given values don't match the given ranges. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let tensor = Tensor::::ones([2, 3, 3]); - /// let values = Tensor::::zeros([1, 1, 1]); - /// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values); - /// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3] - /// } - /// ``` - pub fn slice_assign( - self, - ranges: [core::ops::Range; D2], - values: Self, - ) -> Self { - check!(TensorCheck::slice_assign( - &self.shape(), - &values.shape(), - &ranges - )); - Self::new(K::slice_assign(self.primitive, ranges, values.primitive)) - } - - /// Returns the device of the current tensor. - pub fn device(&self) -> B::Device { - K::device(&self.primitive) - } - - /// Returns a new tensor on the given device. - pub fn to_device(self, device: &B::Device) -> Self { - Self::new(K::to_device(self.primitive, device)) - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Returns the data of the current tensor. - pub async fn into_data(self) -> Data { - K::into_data(self.primitive).read().await - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Returns the data of the current tensor. - pub fn into_data(self) -> Data { - K::into_data(self.primitive).read() - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Returns the data of the current tensor. - pub async fn to_data(&self) -> Data { - K::into_data(self.primitive.clone()).read().await - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Returns the data of the current tensor without taking ownership. - pub fn to_data(&self) -> Data { - Self::into_data(self.clone()) - } - - /// Create a tensor from the given data. - pub fn from_data(data: T) -> Self - where - T: Into>, - { - Self::from_data_device(data, &B::Device::default()) - } - - /// Create a tensor from the given data on the given device. - pub fn from_data_device(data: T, device: &B::Device) -> Self - where - T: Into>, - { - Self::new(K::from_data(data.into(), device)) - } - - /// Repeat the tensor along the given dimension. - /// - /// # Panics - /// - /// If the selected dimension more than one item. - pub fn repeat(self, dim: usize, times: usize) -> Self { - Self::new(K::repeat(self.primitive, dim, times)) - } - - /// Applies element wise equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); - K::equal(self.primitive, other.primitive) - } - - /// Concatenates all tensors into a new one along the given dimension. - /// - /// # Panics - /// - /// If all tensors don't have the same shape. - pub fn cat(tensors: Vec, dim: usize) -> Self { - check!(TensorCheck::cat(&tensors, dim)); - - Self::new(K::cat( - tensors.into_iter().map(|vector| vector.primitive).collect(), - dim, - )) - } - - /// Iterate over slices of tensors alongside a given dimension. - /// - /// # Panics - /// - /// Given dimension is less than tensor rank. - /// - /// # Returns - /// - /// A tensor iterator. - pub fn iter_dim(self, dim: usize) -> DimIter { - check!(TensorCheck::dim_ops::("iter_dim", dim)); - DimIter::new(self, dim) - } + /// Converts the tensor into a primitive tensor. + pub fn into_primitive(self) -> K::Primitive { + self.primitive + } + + /// Converts from a primitive tensor into a tensor. + pub fn from_primitive(tensor: K::Primitive) -> Self { + Self::new(tensor) + } + + /// Create an empty tensor of the given shape. + pub fn empty>>(shape: S) -> Self { + Self::empty_device(shape, &B::Device::default()) + } + + /// Create an empty tensor of the given shape. + pub fn empty_device>>(shape: S, device: &B::Device) -> Self { + Self::new(K::empty(shape.into(), device)) + } + + /// Returns the dimensions of the current tensor. + /// + /// Equivalent to `tensor.shape().dims`. + pub fn dims(&self) -> [usize; D] { + Self::shape(self).dims + } + + /// Returns the shape of the current tensor. + pub fn shape(&self) -> Shape { + K::shape(&self.primitive) + } + + /// Reshape the tensor to have the given shape. + /// + /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]` + /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12]. + /// + /// A `0` in the shape instructs to keep the current dimension from the original tensor, + /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4]. + /// This is useful when reshaping tensors with unknown dimensions and combining with `-1` + /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor + /// with [1, 3, 4] dimensions to [1, 12]. + /// + /// # Arguments + /// - `shape`: The new shape of the tensor. + /// + /// # Panics + /// - If the tensor contains more than one `-1` in the shape. + /// - If the tensor contains values that are not positive (other than -1). + /// - If the shape does not match the number of elements of the original shape. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let tensor = Tensor::::ones([2, 3, 4]); + /// // Given a 3D tensor with dimensions (2, 3, 4), reshape it to (2, 12) + /// let reshaped_tensor: Tensor:: = tensor.reshape([2, -1]); + /// // The resulting tensor will have dimensions (2, 12). + /// println!("{:?}", reshaped_tensor.shape()); + /// } + /// ``` + pub fn reshape>(self, shape: S) -> Tensor { + // Convert reshape args to shape + let shape = shape.into_shape(&self); + Tensor::new(K::reshape::(self.primitive, shape)) + } + + /// Transpose the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + pub fn transpose(self) -> Tensor { + Tensor::new(K::transpose(self.primitive)) + } + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { + Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) + } + + /// Flatten the tensor along a given range of dimensions. + /// + /// This function collapses the specified range of dimensions into a single dimension, + /// effectively flattening the tensor in that range. + /// + /// # Arguments + /// + /// - `start_dim`: The starting dimension of the range to be flattened. + /// - `end_dim`: The ending dimension of the range to be flattened (inclusive). + /// + /// # Type Parameters + /// + /// - `D2`: The resulting number of dimensions in the flattened tensor. + /// + /// # Returns + /// + /// A new `Tensor` instance with the specified range of dimensions flattened. + /// + /// # Example + /// + /// ```rust + /// + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 4])); + /// + /// // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2: + /// let flattened_tensor: Tensor:: = tensor.flatten(1, 2); + /// + /// // The resulting tensor will have dimensions (2, 12). + /// println!("{:?}", flattened_tensor.shape()); + /// } + /// + /// ``` + pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { + check!(TensorCheck::flatten::(start_dim, end_dim)); + + let current_dims = self.shape().dims; + let mut new_dims: [usize; D2] = [0; D2]; + let mut flatten_dims = 1; + + for i in current_dims[start_dim..=end_dim].iter() { + flatten_dims *= i; + } + + new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]); + new_dims[start_dim] = flatten_dims; + new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]); + + Tensor::new(K::reshape::(self.primitive, new_dims.into())) + } + + /// Squeeze the tensor along the given dimension, removing the specified dimension + /// of size one, and effectively reducing the rank of the tensor by one. + /// + /// # Arguments + /// + /// - `dim`: The dimension to be squeezed. + /// + /// # Type Parameters + /// + /// - 'D2': The resulting number of dimensions in the squeezed tensor. + /// + /// # Returns + /// + /// A new `Tensor` instance with the specified dimenension removed. + /// + /// # Example + /// + /// ```rust + /// + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 1, 4])); + /// + /// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1 + /// let squeezed_tensor: Tensor:: = tensor.squeeze(1); + /// + /// // Resulting tensor will have dimensions (2, 4) + /// println!("{:?}", squeezed_tensor.shape()); + /// } + /// ``` + pub fn squeeze(self, dim: usize) -> Tensor { + check!(TensorCheck::squeeze::(dim, &self.shape().dims)); + + let current_dims = self.shape().dims; + let mut new_dims: [usize; D2] = [0; D2]; + + new_dims[..dim].copy_from_slice(¤t_dims[..dim]); + new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]); + + Tensor::new(K::reshape::(self.primitive, new_dims.into())) + } + + /// Unsqueeze the current tensor. Create new dimensions to fit the given size. + /// + /// If the output size is higher than the current tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([3, 3])); + /// let tensor = tensor.unsqueeze::<4>(); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [1, 1, 3, 3] } + /// } + /// ``` + pub fn unsqueeze(self) -> Tensor { + check!(TensorCheck::unsqueeze::()); + + let mut dims = [1; D2]; + let num_ones = D2 - D; + let shape = self.shape(); + + dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]); + + let shape = Shape::new(dims); + self.reshape(shape) + } + + /// Returns a tensor containing the elements selected from the given ranges. + /// + /// # Panics + /// + /// If a range exceeds the number of elements on a dimension. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]); + /// println!("{:?}", tensor_slices.dims()); // [1, 3, 2] + /// + /// } + /// ``` + pub fn slice(self, ranges: [core::ops::Range; D2]) -> Self { + check!(TensorCheck::slice(&self.shape(), &ranges)); + Self::new(K::slice(self.primitive, ranges)) + } + + /// Returns a copy of the current tensor with the selected elements changed to the new ones at + /// the selected indices. + /// + /// # Panics + /// + /// - If a range exceeds the number of elements on a dimension. + /// - If the given values don't match the given ranges. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let tensor = Tensor::::ones([2, 3, 3]); + /// let values = Tensor::::zeros([1, 1, 1]); + /// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values); + /// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3] + /// } + /// ``` + pub fn slice_assign( + self, + ranges: [core::ops::Range; D2], + values: Self, + ) -> Self { + check!(TensorCheck::slice_assign( + &self.shape(), + &values.shape(), + &ranges + )); + Self::new(K::slice_assign(self.primitive, ranges, values.primitive)) + } + + /// Returns the device of the current tensor. + pub fn device(&self) -> B::Device { + K::device(&self.primitive) + } + + /// Returns a new tensor on the given device. + pub fn to_device(self, device: &B::Device) -> Self { + Self::new(K::to_device(self.primitive, device)) + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Returns the data of the current tensor. + pub async fn into_data(self) -> Data { + K::into_data(self.primitive).read().await + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Returns the data of the current tensor. + pub fn into_data(self) -> Data { + K::into_data(self.primitive).read() + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Returns the data of the current tensor. + pub async fn to_data(&self) -> Data { + K::into_data(self.primitive.clone()).read().await + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Returns the data of the current tensor without taking ownership. + pub fn to_data(&self) -> Data { + Self::into_data(self.clone()) + } + + /// Create a tensor from the given data. + pub fn from_data(data: T) -> Self + where + T: Into>, + { + Self::from_data_device(data, &B::Device::default()) + } + + /// Create a tensor from the given data on the given device. + pub fn from_data_device(data: T, device: &B::Device) -> Self + where + T: Into>, + { + Self::new(K::from_data(data.into(), device)) + } + + /// Repeat the tensor along the given dimension. + /// + /// # Panics + /// + /// If the selected dimension more than one item. + pub fn repeat(self, dim: usize, times: usize) -> Self { + Self::new(K::repeat(self.primitive, dim, times)) + } + + /// Applies element wise equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); + K::equal(self.primitive, other.primitive) + } + + /// Concatenates all tensors into a new one along the given dimension. + /// + /// # Panics + /// + /// If all tensors don't have the same shape. + pub fn cat(tensors: Vec, dim: usize) -> Self { + check!(TensorCheck::cat(&tensors, dim)); + + Self::new(K::cat( + tensors.into_iter().map(|vector| vector.primitive).collect(), + dim, + )) + } + + /// Iterate over slices of tensors alongside a given dimension. + /// + /// # Panics + /// + /// Given dimension is less than tensor rank. + /// + /// # Returns + /// + /// A tensor iterator. + pub fn iter_dim(self, dim: usize) -> DimIter { + check!(TensorCheck::dim_ops::("iter_dim", dim)); + DimIter::new(self, dim) + } } /// Iterator given by (Tensor::iter_dim). pub struct DimIter where - B: Backend, - K: BasicOps, + B: Backend, + K: BasicOps, { - counter: usize, - dim: usize, - end_idx: usize, - ranges: [Range; D], - tensor: Tensor, + counter: usize, + dim: usize, + end_idx: usize, + ranges: [Range; D], + tensor: Tensor, } impl> Iterator for DimIter { - type Item = Tensor; - - fn next(&mut self) -> Option { - let res = if self.counter < self.end_idx { - let mut ranges = self.ranges.clone(); - ranges[self.dim] = self.counter..(self.counter + 1); - let slice = self.tensor.clone().slice(ranges); - Some(slice) - } else { - None - }; - self.counter += 1; - res - } + type Item = Tensor; + + fn next(&mut self) -> Option { + let res = if self.counter < self.end_idx { + let mut ranges = self.ranges.clone(); + ranges[self.dim] = self.counter..(self.counter + 1); + let slice = self.tensor.clone().slice(ranges); + Some(slice) + } else { + None + }; + self.counter += 1; + res + } } impl> DimIter { - fn new(tensor: Tensor, dim: usize) -> Self { - let dims = tensor.dims(); - let ranges = dims - .iter() - .map(|&dim| 0..dim) - .collect::>>(); - let ranges: [Range; D] = ranges.try_into().unwrap(); - Self { - end_idx: dims[dim], - ranges, - counter: 0, - dim, - tensor, - } - } + fn new(tensor: Tensor, dim: usize) -> Self { + let dims = tensor.dims(); + let ranges = dims + .iter() + .map(|&dim| 0..dim) + .collect::>>(); + let ranges: [Range; D] = ranges.try_into().unwrap(); + Self { + end_idx: dims[dim], + ranges, + counter: 0, + dim, + tensor, + } + } } impl Tensor where - B: Backend, - K: BasicOps, - >::Elem: Debug, + B: Backend, + K: BasicOps, + >::Elem: Debug, { - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - #[inline] - fn push_newline_indent(acc: &mut String, indent: usize) { - acc.push('\n'); - for _ in 0..indent { - acc.push(' '); - } - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn fmt_inner_tensor( - &self, - acc: &mut String, - depth: usize, - multi_index: &mut [usize], - range: (usize, usize), - ) { - let (start, end) = range; - for i in start..end { - if i > 0 { - acc.push_str(", "); - } - multi_index[depth] = i; - let range: [core::ops::Range; D] = - core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1); - - let elem = &self.clone().slice(range).into_data().value[0]; - acc.push_str(&format!("{elem:?}")); - } - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn fmt_outer_tensor( - &self, - acc: &mut String, - depth: usize, - multi_index: &mut [usize], - print_options: &PrintOptions, - summarize: bool, - range: (usize, usize), - ) { - let (start, end) = range; - for i in start..end { - if i > start { - acc.push(','); - Self::push_newline_indent(acc, depth + 1); - } - acc.push('['); - multi_index[depth] = i; - self.display_recursive(acc, depth + 1, multi_index, print_options, summarize); - acc.push(']'); - } - } - - /// Recursively formats the tensor data for display and appends it to the provided accumulator string. - /// - /// This function is designed to work with tensors of any dimensionality. - /// It traverses the tensor dimensions recursively, converting the elements - /// to strings and appending them to the accumulator string with the - /// appropriate formatting. - /// - /// # Arguments - /// - /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output. - /// * `depth` - The current depth of the tensor dimensions being processed. - /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn display_recursive( - &self, - acc: &mut String, - depth: usize, - multi_index: &mut [usize], - print_options: &PrintOptions, - summarize: bool, - ) { - let edge_items = print_options.edge_items; - - if depth == 0 { - acc.push('['); - } - - if depth == self.dims().len() - 1 { - // if we are at the innermost dimension, just push its elements into the accumulator - if summarize && self.dims()[depth] > 2 * edge_items { - // print the starting `edge_items` elements - self.fmt_inner_tensor(acc, depth, multi_index, (0, edge_items)); - acc.push_str(", ..."); - // print the last `edge_items` elements - self.fmt_inner_tensor( - acc, - depth, - multi_index, - (self.dims()[depth] - edge_items, self.dims()[depth]), - ); - } else { - // print all the elements - self.fmt_inner_tensor(acc, depth, multi_index, (0, self.dims()[depth])); - } - } else { - // otherwise, iterate through the current dimension and recursively display the inner tensors - if summarize && self.dims()[depth] > 2 * edge_items { - self.fmt_outer_tensor( - acc, - depth, - multi_index, - print_options, - summarize, - (0, edge_items), - ); - - acc.push(','); - Self::push_newline_indent(acc, depth + 1); - acc.push_str("..."); - Self::push_newline_indent(acc, depth + 1); - - self.fmt_outer_tensor( - acc, - depth, - multi_index, - print_options, - summarize, - (self.dims()[depth] - edge_items, self.dims()[depth]), - ); - } else { - self.fmt_outer_tensor( - acc, - depth, - multi_index, - print_options, - summarize, - (0, self.dims()[depth]), - ); - } - } - - if depth == 0 { - acc.push(']'); - } - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + #[inline] + fn push_newline_indent(acc: &mut String, indent: usize) { + acc.push('\n'); + for _ in 0..indent { + acc.push(' '); + } + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn fmt_inner_tensor( + &self, + acc: &mut String, + depth: usize, + multi_index: &mut [usize], + range: (usize, usize), + ) { + let (start, end) = range; + for i in start..end { + if i > 0 { + acc.push_str(", "); + } + multi_index[depth] = i; + let range: [core::ops::Range; D] = + core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1); + + let elem = &self.clone().slice(range).into_data().value[0]; + acc.push_str(&format!("{elem:?}")); + } + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn fmt_outer_tensor( + &self, + acc: &mut String, + depth: usize, + multi_index: &mut [usize], + print_options: &PrintOptions, + summarize: bool, + range: (usize, usize), + ) { + let (start, end) = range; + for i in start..end { + if i > start { + acc.push(','); + Self::push_newline_indent(acc, depth + 1); + } + acc.push('['); + multi_index[depth] = i; + self.display_recursive(acc, depth + 1, multi_index, print_options, summarize); + acc.push(']'); + } + } + + /// Recursively formats the tensor data for display and appends it to the provided accumulator string. + /// + /// This function is designed to work with tensors of any dimensionality. + /// It traverses the tensor dimensions recursively, converting the elements + /// to strings and appending them to the accumulator string with the + /// appropriate formatting. + /// + /// # Arguments + /// + /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output. + /// * `depth` - The current depth of the tensor dimensions being processed. + /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension. + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn display_recursive( + &self, + acc: &mut String, + depth: usize, + multi_index: &mut [usize], + print_options: &PrintOptions, + summarize: bool, + ) { + let edge_items = print_options.edge_items; + + if depth == 0 { + acc.push('['); + } + + if depth == self.dims().len() - 1 { + // if we are at the innermost dimension, just push its elements into the accumulator + if summarize && self.dims()[depth] > 2 * edge_items { + // print the starting `edge_items` elements + self.fmt_inner_tensor(acc, depth, multi_index, (0, edge_items)); + acc.push_str(", ..."); + // print the last `edge_items` elements + self.fmt_inner_tensor( + acc, + depth, + multi_index, + (self.dims()[depth] - edge_items, self.dims()[depth]), + ); + } else { + // print all the elements + self.fmt_inner_tensor(acc, depth, multi_index, (0, self.dims()[depth])); + } + } else { + // otherwise, iterate through the current dimension and recursively display the inner tensors + if summarize && self.dims()[depth] > 2 * edge_items { + self.fmt_outer_tensor( + acc, + depth, + multi_index, + print_options, + summarize, + (0, edge_items), + ); + + acc.push(','); + Self::push_newline_indent(acc, depth + 1); + acc.push_str("..."); + Self::push_newline_indent(acc, depth + 1); + + self.fmt_outer_tensor( + acc, + depth, + multi_index, + print_options, + summarize, + (self.dims()[depth] - edge_items, self.dims()[depth]), + ); + } else { + self.fmt_outer_tensor( + acc, + depth, + multi_index, + print_options, + summarize, + (0, self.dims()[depth]), + ); + } + } + + if depth == 0 { + acc.push(']'); + } + } } /// Options for Tensor pretty printing pub struct PrintOptions { - /// number of elements to start summarizing tensor - pub threshold: usize, - /// number of starting elements and ending elements to display - pub edge_items: usize, + /// number of elements to start summarizing tensor + pub threshold: usize, + /// number of starting elements and ending elements to display + pub edge_items: usize, } static PRINT_OPTS: Mutex = Mutex::new(PrintOptions::const_default()); impl PrintOptions { - // We cannot use the default trait as it's not const. - const fn const_default() -> Self { - Self { - threshold: 1000, - edge_items: 3, - } + // We cannot use the default trait as it's not const. + const fn const_default() -> Self { + Self { + threshold: 1000, + edge_items: 3, } + } } /// Set print options pub fn set_print_options(options: PrintOptions) { - *PRINT_OPTS.lock().unwrap() = options + *PRINT_OPTS.lock().unwrap() = options } /// Pretty print tensors impl core::fmt::Display for Tensor where - B: Backend, - B::IntElem: core::fmt::Display, - K: BasicOps, - >::Elem: Debug, + B: Backend, + B::IntElem: core::fmt::Display, + K: BasicOps, + >::Elem: Debug, { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - writeln!(f, "Tensor {{")?; + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + writeln!(f, "Tensor {{")?; - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - { - let po = PRINT_OPTS.lock().unwrap(); - let mut acc = String::new(); - let mut multi_index = vec![0; D]; - let summarize = self.shape().num_elements() > po.threshold; + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + { + let po = PRINT_OPTS.lock().unwrap(); + let mut acc = String::new(); + let mut multi_index = vec![0; D]; + let summarize = self.shape().num_elements() > po.threshold; - self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize); + self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize); - writeln!(f, " data:")?; - write!(f, "{acc}")?; - writeln!(f, ",")?; - } - - writeln!(f, " shape: {:?},", self.dims())?; - writeln!(f, " device: {:?},", self.device())?; - writeln!(f, " backend: {:?},", B::name())?; - writeln!(f, " kind: {:?},", K::name())?; - writeln!(f, " dtype: {:?},", K::elem_type_name())?; - write!(f, "}}") + writeln!(f, " data:")?; + write!(f, "{acc}")?; + writeln!(f, ",")?; } + + writeln!(f, " shape: {:?},", self.dims())?; + writeln!(f, " device: {:?},", self.device())?; + writeln!(f, " backend: {:?},", B::name())?; + writeln!(f, " kind: {:?},", K::name())?; + writeln!(f, " dtype: {:?},", K::elem_type_name())?; + write!(f, "}}") + } } /// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g. @@ -680,10 +680,10 @@ where pub struct T; impl core::ops::BitXor for Tensor { - type Output = Self; - fn bitxor(self, _: T) -> Self::Output { - self.transpose() - } + type Output = Self; + fn bitxor(self, _: T) -> Self::Output { + self.transpose() + } } /// Trait that list all operations that can be applied on all tensors. @@ -692,660 +692,646 @@ impl core::ops::BitXor for Tensor { /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). pub trait BasicOps: TensorKind { - /// The type of the tensor elements. - type Elem: 'static; - - /// Creates an empty tensor with the given shape. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The empty tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function, - /// which is more high-level and designed for public use. - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive; - - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function, - /// which is more high-level and designed for public use. - fn shape(tensor: &Self::Primitive) -> Shape; - - /// Reshapes the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `shape` - The new shape of the tensor. - /// - /// # Returns - /// - /// The reshaped tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function, - /// which is more high-level and designed for public use. - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive; - - /// Transposes a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn transpose(tensor: Self::Primitive) -> Self::Primitive; - - /// Swaps two dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive; - - /// Select tensor elements corresponding for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges of the elements to select. - /// - /// # Returns - /// - /// The selected elements. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function, - /// which is more high-level and designed for public use. - fn slice( - tensor: Self::Primitive, - range: [Range; D2], - ) -> Self::Primitive; - - /// Assigns the given value to the tensor elements corresponding for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges of the elements to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the assigned values. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function, - /// which is more high-level and designed for public use. - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive; - - /// Returns the device on which the tensor is allocated. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device on which the tensor is allocated. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function, - /// which is more high-level and designed for public use. - fn device(tensor: &Self::Primitive) -> B::Device; - - /// Moves the tensor to the given device. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `device` - The device on which the tensor will be moved. - /// - /// # Returns - /// - /// The tensor on the given device. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function, - /// which is more high-level and designed for public use. - fn to_device( - tensor: Self::Primitive, - device: &B::Device, - ) -> Self::Primitive; - - /// Extracts the data from the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function, - /// which is more high-level and designed for public use. - fn into_data(tensor: Self::Primitive) -> Reader>; - - /// Creates a tensor from the given data. - /// - /// # Arguments - /// - /// * `data` - The data of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function, - /// which is more high-level and designed for public use. - fn from_data( - data: Data, - device: &B::Device, - ) -> Self::Primitive; - - /// Repeat the tensor along the given dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension along which the tensor will be repeated. - /// * `times` - The number of times the tensor will be repeated. - /// - /// # Returns - /// - /// The repeated tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function, - /// which is more high-level and designed for public use. - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive; - - /// Concatenates the given tensors along the given dimension. - /// - /// # Arguments - /// - /// * `vectors` - The tensors to concatenate. - /// * `dim` - The dimension along which the tensors will be concatenated. - /// - /// # Returns - /// - /// The concatenated tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function, - /// which is more high-level and designed for public use. - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; - - /// Equates the given tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The tensor of booleans indicating whether the corresponding elements are equal. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function, - /// which is more high-level and designed for public use. - fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Returns the name of the element type. - fn elem_type_name() -> &'static str { - core::any::type_name::() - } + /// The type of the tensor elements. + type Elem: 'static; + + /// Creates an empty tensor with the given shape. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The empty tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function, + /// which is more high-level and designed for public use. + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive; + + /// Returns the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function, + /// which is more high-level and designed for public use. + fn shape(tensor: &Self::Primitive) -> Shape; + + /// Reshapes the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `shape` - The new shape of the tensor. + /// + /// # Returns + /// + /// The reshaped tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function, + /// which is more high-level and designed for public use. + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive; + + /// Transposes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn transpose(tensor: Self::Primitive) -> Self::Primitive; + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive; + + /// Select tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges of the elements to select. + /// + /// # Returns + /// + /// The selected elements. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function, + /// which is more high-level and designed for public use. + fn slice( + tensor: Self::Primitive, + range: [Range; D2], + ) -> Self::Primitive; + + /// Assigns the given value to the tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges of the elements to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the assigned values. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function, + /// which is more high-level and designed for public use. + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive; + + /// Returns the device on which the tensor is allocated. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device on which the tensor is allocated. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function, + /// which is more high-level and designed for public use. + fn device(tensor: &Self::Primitive) -> B::Device; + + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device on which the tensor will be moved. + /// + /// # Returns + /// + /// The tensor on the given device. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function, + /// which is more high-level and designed for public use. + fn to_device( + tensor: Self::Primitive, + device: &B::Device, + ) -> Self::Primitive; + + /// Extracts the data from the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function, + /// which is more high-level and designed for public use. + fn into_data(tensor: Self::Primitive) -> Reader>; + + /// Creates a tensor from the given data. + /// + /// # Arguments + /// + /// * `data` - The data of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function, + /// which is more high-level and designed for public use. + fn from_data(data: Data, device: &B::Device) + -> Self::Primitive; + + /// Repeat the tensor along the given dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension along which the tensor will be repeated. + /// * `times` - The number of times the tensor will be repeated. + /// + /// # Returns + /// + /// The repeated tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function, + /// which is more high-level and designed for public use. + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive; + + /// Concatenates the given tensors along the given dimension. + /// + /// # Arguments + /// + /// * `vectors` - The tensors to concatenate. + /// * `dim` - The dimension along which the tensors will be concatenated. + /// + /// # Returns + /// + /// The concatenated tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function, + /// which is more high-level and designed for public use. + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; + + /// Equates the given tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The tensor of booleans indicating whether the corresponding elements are equal. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function, + /// which is more high-level and designed for public use. + fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor; + + /// Returns the name of the element type. + fn elem_type_name() -> &'static str { + core::any::type_name::() + } } impl BasicOps for Float { - type Elem = B::FloatElem; - - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { - B::empty(shape, device) - } - fn shape(tensor: &Self::Primitive) -> Shape { - B::shape(tensor) - } - - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::reshape(tensor, shape) - } - - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - B::transpose(tensor) - } - - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive { - check!(TensorCheck::swap_dims::(dim1, dim2)); - B::swap_dims(tensor, dim1, dim2) - } - - fn slice( - tensor: Self::Primitive, - ranges: [Range; D2], - ) -> Self::Primitive { - B::slice(tensor, ranges) - } - - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - B::slice_assign(tensor, ranges, value) - } - - fn device(tensor: &Self::Primitive) -> ::Device { - B::device(tensor) - } - - fn to_device( - tensor: Self::Primitive, - device: &::Device, - ) -> Self::Primitive { - B::to_device(tensor, device) - } - - fn into_data(tensor: Self::Primitive) -> Reader> { - B::into_data(tensor) - } - - fn from_data( - data: Data, - device: &B::Device, - ) -> Self::Primitive { - B::from_data(data, device) - } - - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive { - B::repeat(tensor, dim, times) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::cat(vectors, dim) - } - - fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::equal(lhs, rhs)) - } + type Elem = B::FloatElem; + + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { + B::empty(shape, device) + } + fn shape(tensor: &Self::Primitive) -> Shape { + B::shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + B::reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + B::transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + check!(TensorCheck::swap_dims::(dim1, dim2)); + B::swap_dims(tensor, dim1, dim2) + } + + fn slice( + tensor: Self::Primitive, + ranges: [Range; D2], + ) -> Self::Primitive { + B::slice(tensor, ranges) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + B::slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + B::device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + B::to_device(tensor, device) + } + + fn into_data(tensor: Self::Primitive) -> Reader> { + B::into_data(tensor) + } + + fn from_data( + data: Data, + device: &B::Device, + ) -> Self::Primitive { + B::from_data(data, device) + } + + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + B::repeat(tensor, dim, times) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + B::cat(vectors, dim) + } + + fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { + Tensor::new(B::equal(lhs, rhs)) + } } impl BasicOps for Int { - type Elem = B::IntElem; - - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { - B::int_empty(shape, device) - } - fn shape(tensor: &Self::Primitive) -> Shape { - B::int_shape(tensor) - } - - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::int_reshape(tensor, shape) - } - - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - B::int_transpose(tensor) - } - - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive { - check!(TensorCheck::swap_dims::(dim1, dim2)); - B::int_swap_dims(tensor, dim1, dim2) - } - - fn slice( - tensor: Self::Primitive, - ranges: [Range; D2], - ) -> Self::Primitive { - B::int_slice(tensor, ranges) - } - - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - B::int_slice_assign(tensor, ranges, value) - } - - fn device(tensor: &Self::Primitive) -> ::Device { - B::int_device(tensor) - } - - fn to_device( - tensor: Self::Primitive, - device: &::Device, - ) -> Self::Primitive { - B::int_to_device(tensor, device) - } - - fn into_data(tensor: Self::Primitive) -> Reader> { - B::int_into_data(tensor) - } - - fn from_data( - data: Data, - device: &B::Device, - ) -> Self::Primitive { - B::int_from_data(data, device) - } - - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive { - B::int_repeat(tensor, dim, times) - } - - fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_equal(lhs, rhs)) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::int_cat(vectors, dim) - } + type Elem = B::IntElem; + + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { + B::int_empty(shape, device) + } + fn shape(tensor: &Self::Primitive) -> Shape { + B::int_shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + B::int_reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + B::int_transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + check!(TensorCheck::swap_dims::(dim1, dim2)); + B::int_swap_dims(tensor, dim1, dim2) + } + + fn slice( + tensor: Self::Primitive, + ranges: [Range; D2], + ) -> Self::Primitive { + B::int_slice(tensor, ranges) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + B::int_slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + B::int_device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + B::int_to_device(tensor, device) + } + + fn into_data(tensor: Self::Primitive) -> Reader> { + B::int_into_data(tensor) + } + + fn from_data( + data: Data, + device: &B::Device, + ) -> Self::Primitive { + B::int_from_data(data, device) + } + + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + B::int_repeat(tensor, dim, times) + } + + fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { + Tensor::new(B::int_equal(lhs, rhs)) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + B::int_cat(vectors, dim) + } } impl BasicOps for Bool { - type Elem = bool; - - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { - B::bool_empty(shape, device) - } - fn shape(tensor: &Self::Primitive) -> Shape { - B::bool_shape(tensor) - } - - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::bool_reshape(tensor, shape) - } - - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - B::bool_transpose(tensor) - } - - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive { - check!(TensorCheck::swap_dims::(dim1, dim2)); - B::bool_swap_dims(tensor, dim1, dim2) - } - - fn slice( - tensor: Self::Primitive, - ranges: [Range; D2], - ) -> Self::Primitive { - B::bool_slice(tensor, ranges) - } - - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - B::bool_slice_assign(tensor, ranges, value) - } - - fn device(tensor: &Self::Primitive) -> ::Device { - B::bool_device(tensor) - } - - fn to_device( - tensor: Self::Primitive, - device: &::Device, - ) -> Self::Primitive { - B::bool_to_device(tensor, device) - } - - fn into_data(tensor: Self::Primitive) -> Reader> { - B::bool_into_data(tensor) - } - - fn from_data( - data: Data, - device: &B::Device, - ) -> Self::Primitive { - B::bool_from_data(data, device) - } - - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive { - B::bool_repeat(tensor, dim, times) - } - - fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::bool_equal(lhs, rhs)) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::bool_cat(vectors, dim) - } + type Elem = bool; + + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { + B::bool_empty(shape, device) + } + fn shape(tensor: &Self::Primitive) -> Shape { + B::bool_shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + B::bool_reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + B::bool_transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + check!(TensorCheck::swap_dims::(dim1, dim2)); + B::bool_swap_dims(tensor, dim1, dim2) + } + + fn slice( + tensor: Self::Primitive, + ranges: [Range; D2], + ) -> Self::Primitive { + B::bool_slice(tensor, ranges) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + B::bool_slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + B::bool_device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + B::bool_to_device(tensor, device) + } + + fn into_data(tensor: Self::Primitive) -> Reader> { + B::bool_into_data(tensor) + } + + fn from_data( + data: Data, + device: &B::Device, + ) -> Self::Primitive { + B::bool_from_data(data, device) + } + + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + B::bool_repeat(tensor, dim, times) + } + + fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { + Tensor::new(B::bool_equal(lhs, rhs)) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + B::bool_cat(vectors, dim) + } } /// Trait used for reshape arguments. pub trait ReshapeArgs { - /// Converts to a shape. - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape; + /// Converts to a shape. + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape; } impl ReshapeArgs for Shape { - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape { - check!(TensorCheck::reshape_args_usize(&self, &tensor.shape())); - - self - } + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape { + check!(TensorCheck::reshape_args_usize(&self, &tensor.shape())); + + self + } } impl ReshapeArgs for [usize; D2] { - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape { - let shape = Shape::from(self); + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape { + let shape = Shape::from(self); - check!(TensorCheck::reshape_args_usize(&shape, &tensor.shape())); + check!(TensorCheck::reshape_args_usize(&shape, &tensor.shape())); - shape - } + shape + } } impl ReshapeArgs for [i32; D2] { - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape { - // Validate the reshape arguments - check!(TensorCheck::reshape_args_i32(&self)); - - // Temporary shape - let mut new_shape: [i32; D2] = [1; D2]; - - // We need to find the index of the 0 dimension and - // replace it with the actual dimension value. - for (i, &s) in self.iter().enumerate() { - if s != 0 { - new_shape[i] = s; - } else { - new_shape[i] = tensor.dims()[i] as i32; - } + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape { + // Validate the reshape arguments + check!(TensorCheck::reshape_args_i32(&self)); + + // Temporary shape + let mut new_shape: [i32; D2] = [1; D2]; + + // We need to find the index of the 0 dimension and + // replace it with the actual dimension value. + for (i, &s) in self.iter().enumerate() { + if s != 0 { + new_shape[i] = s; + } else { + new_shape[i] = tensor.dims()[i] as i32; + } + } + + // Find the index of the inferred dimension (-1) + let infer_index = new_shape.iter().position(|x| x == &-1); + + // Handle the case where the dimension is inferred (via -1) + if let Some(index) = infer_index { + // Handle the case where the dimension is inferred + let mut product = 1; + for (i, &s) in new_shape.iter().enumerate() { + if i != index { + product *= s; } - - // Find the index of the inferred dimension (-1) - let infer_index = new_shape.iter().position(|x| x == &-1); - - // Handle the case where the dimension is inferred (via -1) - if let Some(index) = infer_index { - // Handle the case where the dimension is inferred - let mut product = 1; - for (i, &s) in new_shape.iter().enumerate() { - if i != index { - product *= s; - } - } - let product_current = tensor.shape().num_elements() as i32; - - new_shape[index] = product_current / product; - - // Check if the reshape is valid - if product_current % product != 0 { - panic!( - "Cannot reshape tensor of shape {:?} to shape {:?}", - tensor.shape(), - new_shape - ); - } - }; - - // Convert each element to usize - let new_shape: [usize; D2] = new_shape.map(|x| x as usize); - - Shape::from(new_shape) - } + } + let product_current = tensor.shape().num_elements() as i32; + + new_shape[index] = product_current / product; + + // Check if the reshape is valid + if product_current % product != 0 { + panic!( + "Cannot reshape tensor of shape {:?} to shape {:?}", + tensor.shape(), + new_shape + ); + } + }; + + // Convert each element to usize + let new_shape: [usize; D2] = new_shape.map(|x| x as usize); + + Shape::from(new_shape) + } } diff --git a/burn-tensor/src/tensor/api/bool.rs b/burn-tensor/src/tensor/api/bool.rs index e7d7e97460..4047e64bac 100644 --- a/burn-tensor/src/tensor/api/bool.rs +++ b/burn-tensor/src/tensor/api/bool.rs @@ -2,30 +2,30 @@ use crate::{backend::Backend, Bool, Data, Int, Tensor}; impl Tensor where - B: Backend, + B: Backend, { - /// Create a boolean tensor from data. - pub fn from_bool(data: Data) -> Self { - Self::new(B::bool_from_data(data, &B::Device::default())) - } + /// Create a boolean tensor from data. + pub fn from_bool(data: Data) -> Self { + Self::new(B::bool_from_data(data, &B::Device::default())) + } - /// Create a boolean tensor from data on the given device. - pub fn from_bool_device(data: Data, device: &B::Device) -> Self { - Self::new(B::bool_from_data(data, device)) - } + /// Create a boolean tensor from data on the given device. + pub fn from_bool_device(data: Data, device: &B::Device) -> Self { + Self::new(B::bool_from_data(data, device)) + } - /// Convert the bool tensor into an int tensor. - pub fn int(self) -> Tensor { - Tensor::new(B::bool_into_int(self.primitive)) - } + /// Convert the bool tensor into an int tensor. + pub fn int(self) -> Tensor { + Tensor::new(B::bool_into_int(self.primitive)) + } - /// Convert the bool tensor into an float tensor. - pub fn float(self) -> Tensor { - Tensor::new(B::bool_into_float(self.primitive)) - } + /// Convert the bool tensor into an float tensor. + pub fn float(self) -> Tensor { + Tensor::new(B::bool_into_float(self.primitive)) + } - /// Inverses boolean values. - pub fn bool_not(self) -> Self { - Tensor::new(B::bool_not(self.primitive)) - } + /// Inverses boolean values. + pub fn bool_not(self) -> Self { + Tensor::new(B::bool_not(self.primitive)) + } } diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index da237ee1ae..cf064f8f4d 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -33,311 +33,303 @@ use core::ops::Range; /// implementation might re-implement the same checks, which may result in uncessary code /// duplication. Maybe a combination of both strategies could help to cover all usecases. pub(crate) enum TensorCheck { - Ok, - Failed(FailedTensorCheck), + Ok, + Failed(FailedTensorCheck), } impl TensorCheck { - /// Checks device and shape compatibility for element wise binary operations. - pub(crate) fn binary_ops_ew>( - ops: &str, - lhs: &Tensor, - rhs: &Tensor, - ) -> Self { - Self::Ok - .binary_ops_device(ops, &lhs.device(), &rhs.device()) - .binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape()) - } - - pub(crate) fn into_scalar(shape: &Shape) -> Self { - let mut check = Self::Ok; - - if shape.num_elements() != 1 { - check = check.register( - "Into Scalar", - TensorError::new("Only tensors with 1 element can be converted into scalar.") - .details(format!( - "Current tensor has {} elements", - shape.num_elements() - )), - ); - } - - check - } - - pub(crate) fn dim_ops(ops: &str, dim: usize) -> Self { - let mut check = Self::Ok; - - if dim >= D { - check = check.register( - ops, - TensorError::new("Given dimension is higher than the tensor rank.") - .details(format!("Tensor rank: '{D}', given dimension: '{dim}'.")), - ); - } - - check - } - - pub(crate) fn reshape_args_usize( - original: &Shape, - target: &Shape, - ) -> Self { - let mut check = Self::Ok; - - if original.num_elements() != target.num_elements() { - check = check.register("Reshape", TensorError::new( - "The given shape doesn't have the same number of elements as the current tensor.", - ) - .details(format!( - "Current shape: {:?}, target shape: {:?}.", - original.dims, target.dims - ))); - } - - check - } - - pub(crate) fn reshape_args_i32(target: &[i32; D]) -> Self { - let mut check = Self::Ok; - - if target.iter().any(|&dim| dim < -1) { - check = check.register( - "Reshape", - TensorError::new( - "The given shape cannot contain negative dimensions (other than -1).", - ) - .details(format!("Target shape: {:?}.", target)), - ); - } - - if target.iter().filter(|&x| x == &-1).count() > 1 { - check = check.register( - "Reshape", - TensorError::new("The given shape cannot contain more than one -1.") - .details(format!("Target shape: {:?}.", target)), - ); - } - - check - } - - pub(crate) fn flatten( - start_dim: usize, - end_dim: usize, - ) -> Self { - let mut check = Self::Ok; - - if start_dim > end_dim { - check = check.register( - "Flatten", - TensorError::new(format!( - "The start dim ({start_dim}) must be smaller than the end dim ({end_dim})" - )), - ); - } - - if D2 > D1 { - check = check.register( - "Flatten", - TensorError::new(format!("Result dim ({D2}) must be smaller than ({D1})")), - ); - } - - if D1 < end_dim + 1 { - check = check.register( - "Flatten", - TensorError::new(format!( - "The end dim ({end_dim}) must be greater than the tensor dim ({D2})" - )), - ); - } - - if D2 < D1 - (end_dim - start_dim) { - check = check.register( + /// Checks device and shape compatibility for element wise binary operations. + pub(crate) fn binary_ops_ew>( + ops: &str, + lhs: &Tensor, + rhs: &Tensor, + ) -> Self { + Self::Ok + .binary_ops_device(ops, &lhs.device(), &rhs.device()) + .binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape()) + } + + pub(crate) fn into_scalar(shape: &Shape) -> Self { + let mut check = Self::Ok; + + if shape.num_elements() != 1 { + check = check.register( + "Into Scalar", + TensorError::new("Only tensors with 1 element can be converted into scalar.").details( + format!("Current tensor has {} elements", shape.num_elements()), + ), + ); + } + + check + } + + pub(crate) fn dim_ops(ops: &str, dim: usize) -> Self { + let mut check = Self::Ok; + + if dim >= D { + check = check.register( + ops, + TensorError::new("Given dimension is higher than the tensor rank.") + .details(format!("Tensor rank: '{D}', given dimension: '{dim}'.")), + ); + } + + check + } + + pub(crate) fn reshape_args_usize( + original: &Shape, + target: &Shape, + ) -> Self { + let mut check = Self::Ok; + + if original.num_elements() != target.num_elements() { + check = check.register( + "Reshape", + TensorError::new( + "The given shape doesn't have the same number of elements as the current tensor.", + ) + .details(format!( + "Current shape: {:?}, target shape: {:?}.", + original.dims, target.dims + )), + ); + } + + check + } + + pub(crate) fn reshape_args_i32(target: &[i32; D]) -> Self { + let mut check = Self::Ok; + + if target.iter().any(|&dim| dim < -1) { + check = check.register( + "Reshape", + TensorError::new("The given shape cannot contain negative dimensions (other than -1).") + .details(format!("Target shape: {:?}.", target)), + ); + } + + if target.iter().filter(|&x| x == &-1).count() > 1 { + check = check.register( + "Reshape", + TensorError::new("The given shape cannot contain more than one -1.") + .details(format!("Target shape: {:?}.", target)), + ); + } + + check + } + + pub(crate) fn flatten( + start_dim: usize, + end_dim: usize, + ) -> Self { + let mut check = Self::Ok; + + if start_dim > end_dim { + check = check.register( + "Flatten", + TensorError::new(format!( + "The start dim ({start_dim}) must be smaller than the end dim ({end_dim})" + )), + ); + } + + if D2 > D1 { + check = check.register( + "Flatten", + TensorError::new(format!("Result dim ({D2}) must be smaller than ({D1})")), + ); + } + + if D1 < end_dim + 1 { + check = check.register( + "Flatten", + TensorError::new(format!( + "The end dim ({end_dim}) must be greater than the tensor dim ({D2})" + )), + ); + } + + if D2 < D1 - (end_dim - start_dim) { + check = check.register( "Flatten", TensorError::new(format!( "The destination dimension ({D2}) must be large enough to accommodate the flattening operation." )), ); - } - - check } - pub(crate) fn squeeze(dim: usize, tensor_dims: &[usize]) -> Self { - let mut check = Self::Ok; - // This should actually be to check that the dimension to squeeze - // has a size of 1 - if tensor_dims[dim] != 1 { - check = check.register( - "Squeeze", - TensorError::new(format!( - "Can't squeeze dimension {} because its size is not 1", - dim - )), - ); - } - - check - } - - pub(crate) fn unsqueeze() -> Self { - let mut check = Self::Ok; - if D2 < D1 { - check = check.register( - "Unsqueeze", - TensorError::new(format!( - "Can't unsqueeze smaller tensor, got dim {D2}, expected > {D1}" - )), - ); - } - - check - } - - pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { - let mut check = Self::Ok; - - if dim1 > D || dim2 > D { - check = check.register( - "Swap Dims", - TensorError::new("The swap dimensions must be smaller than the tensor dimension") - .details(format!( - "Swap dims ({dim1}, {dim2}) on tensor with ({D}) dimensions." - )), - ); - } - - check - } - - pub(crate) fn matmul( - lhs: &Tensor, - rhs: &Tensor, - ) -> Self { - let mut check = Self::Ok; - - check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); - - if D < 2 { - return check; - } - - let shape_lhs = lhs.shape(); - let shape_rhs = rhs.shape(); - - let dim_lhs = shape_lhs.dims[D - 1]; - let dim_rhs = shape_rhs.dims[D - 2]; + check + } - if dim_lhs != dim_rhs { - check = check.register( - "Matmul", - TensorError::new(format!( - "The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}." - )) - .details(format!( - "Lhs shape {:?}, rhs shape {:?}.", - shape_lhs.dims, shape_rhs.dims - )), - ); - } - - check - } - - pub(crate) fn cat>( - tensors: &[Tensor], - dim: usize, - ) -> Self { - let mut check = Self::Ok; - - if dim >= D { - check = check.register( - "Cat", - TensorError::new( - "Can't concatenate tensors on a dim that exceeds the tensors dimension", - ) - .details(format!( - "Trying to concatenate tensors with {D} dimensions on axis {dim}." - )), - ); - } - - if tensors.is_empty() { - return check.register( - "Cat", - TensorError::new("Can't concatenate an empty list of tensors."), - ); - } - - let mut shape_reference = tensors.get(0).unwrap().shape(); - shape_reference.dims[dim] = 1; // We want to check every dims except the one where the - // concatenation happens. - - for tensor in tensors { - let mut shape = tensor.shape(); - shape.dims[dim] = 1; // Ignore the concatenate dim. - - if shape_reference != shape { - return check.register( - "Cat", - TensorError::new("Can't concatenate tensors with different shapes, except for the provided dimension").details( - format!( - "Provided dimension ({}), tensors shapes: {:?}", - dim, - tensors.iter().map(Tensor::shape).collect::>() - ), - ), - ); - } - } - - check + pub(crate) fn squeeze(dim: usize, tensor_dims: &[usize]) -> Self { + let mut check = Self::Ok; + // This should actually be to check that the dimension to squeeze + // has a size of 1 + if tensor_dims[dim] != 1 { + check = check.register( + "Squeeze", + TensorError::new(format!( + "Can't squeeze dimension {} because its size is not 1", + dim + )), + ); } - pub(crate) fn slice( - shape: &Shape, - ranges: &[Range; D2], - ) -> Self { - let mut check = Self::Ok; - let n_dims_tensor = D1; - let n_dims_ranges = D2; + check + } + + pub(crate) fn unsqueeze() -> Self { + let mut check = Self::Ok; + if D2 < D1 { + check = check.register( + "Unsqueeze", + TensorError::new(format!( + "Can't unsqueeze smaller tensor, got dim {D2}, expected > {D1}" + )), + ); + } + + check + } + + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { + let mut check = Self::Ok; + + if dim1 > D || dim2 > D { + check = check.register( + "Swap Dims", + TensorError::new("The swap dimensions must be smaller than the tensor dimension").details( + format!("Swap dims ({dim1}, {dim2}) on tensor with ({D}) dimensions."), + ), + ); + } + + check + } + + pub(crate) fn matmul(lhs: &Tensor, rhs: &Tensor) -> Self { + let mut check = Self::Ok; + + check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); + + if D < 2 { + return check; + } + + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); + + let dim_lhs = shape_lhs.dims[D - 1]; + let dim_rhs = shape_rhs.dims[D - 2]; + + if dim_lhs != dim_rhs { + check = check.register( + "Matmul", + TensorError::new(format!( + "The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}." + )) + .details(format!( + "Lhs shape {:?}, rhs shape {:?}.", + shape_lhs.dims, shape_rhs.dims + )), + ); + } + + check + } + + pub(crate) fn cat>( + tensors: &[Tensor], + dim: usize, + ) -> Self { + let mut check = Self::Ok; + + if dim >= D { + check = check.register( + "Cat", + TensorError::new("Can't concatenate tensors on a dim that exceeds the tensors dimension") + .details(format!( + "Trying to concatenate tensors with {D} dimensions on axis {dim}." + )), + ); + } + + if tensors.is_empty() { + return check.register( + "Cat", + TensorError::new("Can't concatenate an empty list of tensors."), + ); + } + + let mut shape_reference = tensors.get(0).unwrap().shape(); + shape_reference.dims[dim] = 1; // We want to check every dims except the one where the + // concatenation happens. + + for tensor in tensors { + let mut shape = tensor.shape(); + shape.dims[dim] = 1; // Ignore the concatenate dim. + + if shape_reference != shape { + return check.register( + "Cat", + TensorError::new( + "Can't concatenate tensors with different shapes, except for the provided dimension", + ) + .details(format!( + "Provided dimension ({}), tensors shapes: {:?}", + dim, + tensors.iter().map(Tensor::shape).collect::>() + )), + ); + } + } + + check + } - if n_dims_tensor < n_dims_ranges { - check = check.register("Slice", + pub(crate) fn slice( + shape: &Shape, + ranges: &[Range; D2], + ) -> Self { + let mut check = Self::Ok; + let n_dims_tensor = D1; + let n_dims_ranges = D2; + + if n_dims_tensor < n_dims_ranges { + check = check.register("Slice", TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.") .details( format!( "The ranges array must be smaller or equal to the tensor number of dimensions. \ Tensor number of dimensions: {n_dims_tensor}, ranges array length {n_dims_ranges}." ))); - } + } - for i in 0..usize::min(D1, D2) { - let d_tensor = shape.dims[i]; - let range = ranges.get(i).unwrap(); + for i in 0..usize::min(D1, D2) { + let d_tensor = shape.dims[i]; + let range = ranges.get(i).unwrap(); - if range.end > d_tensor { - check = check.register( - "Slice", - TensorError::new("The provided ranges array has a range that exceeds the current tensor size.") - .details(format!( - "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ + if range.end > d_tensor { + check = check.register( + "Slice", + TensorError::new( + "The provided ranges array has a range that exceeds the current tensor size.", + ) + .details(format!( + "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ Tensor shape {:?}, provided ranges {:?}.", - range.start, - range.end, - d_tensor, - i, - shape.dims, - ranges, - ))); - } + range.start, range.end, d_tensor, i, shape.dims, ranges, + )), + ); + } - if range.start >= range.end { - check = check.register( + if range.start >= range.end { + check = check.register( "Slice", TensorError::new("The provided range array has a range where the start index is bigger or equal to its end.") .details(format!( @@ -349,53 +341,53 @@ impl TensorCheck { shape.dims, ranges, ))); - } - } - - check - } - - pub(crate) fn slice_assign( - shape: &Shape, - shape_value: &Shape, - ranges: &[Range; D2], - ) -> Self { - let mut check = Self::Ok; - - if D1 < D2 { - check = check.register("Slice Assign", - TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.") - .details( - format!( - "The ranges array must be smaller or equal to the tensor number of dimensions. \ + } + } + + check + } + + pub(crate) fn slice_assign( + shape: &Shape, + shape_value: &Shape, + ranges: &[Range; D2], + ) -> Self { + let mut check = Self::Ok; + + if D1 < D2 { + check = check.register( + "Slice Assign", + TensorError::new( + "The provided ranges array has a higher number of dimensions than the current tensor.", + ) + .details(format!( + "The ranges array must be smaller or equal to the tensor number of dimensions. \ Tensor number of dimensions: {D1}, ranges array length {D2}." - ))); - } - - for i in 0..usize::min(D1, D2) { - let d_tensor = shape.dims[i]; - let d_tensor_value = shape_value.dims[i]; - let range = ranges.get(i).unwrap(); - - if range.end > d_tensor { - check = check.register( - "Range Assign", - TensorError::new("The provided ranges array has a range that exceeds the current tensor size.") - .details(format!( - "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ + )), + ); + } + + for i in 0..usize::min(D1, D2) { + let d_tensor = shape.dims[i]; + let d_tensor_value = shape_value.dims[i]; + let range = ranges.get(i).unwrap(); + + if range.end > d_tensor { + check = check.register( + "Range Assign", + TensorError::new( + "The provided ranges array has a range that exceeds the current tensor size.", + ) + .details(format!( + "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.", - range.start, - range.end, - d_tensor, - i, - shape.dims, - shape_value.dims, - ranges, - ))); - } + range.start, range.end, d_tensor, i, shape.dims, shape_value.dims, ranges, + )), + ); + } - if range.end - range.start != d_tensor_value { - check = check.register( + if range.end - range.start != d_tensor_value { + check = check.register( "Slice Assign", TensorError::new("The value tensor must match the amount of elements selected with the ranges array") .details(format!( @@ -409,10 +401,10 @@ impl TensorCheck { shape_value.dims, ranges, ))); - } + } - if range.start >= range.end { - check = check.register( + if range.start >= range.end { + check = check.register( "Slice Assign", TensorError::new("The provided ranges array has a range where the start index is bigger or equal to its end.") .details(format!( @@ -425,246 +417,240 @@ impl TensorCheck { shape_value.dims, ranges, ))); - } - } - - check - } - - pub(crate) fn gather( - dim: usize, - shape: &Shape, - shape_indices: &Shape, - ) -> Self { - Self::check_gather_scatter_indices(Self::Ok, "Gather", dim, shape, shape_indices) - } - - pub(crate) fn scatter( - dim: usize, - shape: &Shape, - shape_indices: &Shape, - shape_value: &Shape, - ) -> Self { - let ops = "Scatter"; - let mut check = - Self::check_gather_scatter_indices(Self::Ok, ops, dim, shape, shape_indices); - - if shape_indices != shape_value { - check = check.register( - ops, - TensorError::new( - "Indices tensor shape should be the same as the value tensor shape." - .to_string(), - ) - .details(format!( - "The shape differs: {:?} != {:?}", - shape_indices.dims, shape_value.dims - )), - ); - } - - check - } - - pub(crate) fn select(dim: usize) -> Self { - Self::check_select_basic::(Self::Ok, "select", dim) - } - - pub(crate) fn select_assign(dim: usize) -> Self { - Self::check_select_basic::(Self::Ok, "select_assign", dim) - } - - fn check_select_basic(mut check: Self, ops: &str, dim: usize) -> Self { - if dim > D { - check = check.register( - ops, - TensorError::new(format!( - "Can't index a tensor with ({D}) dimensions on axis ({dim})" - )), - ); - } - - check - } - fn check_gather_scatter_indices( - mut check: Self, - ops: &str, - dim: usize, - shape: &Shape, - shape_indices: &Shape, - ) -> Self { - if dim > D { - check = check.register( - ops, - TensorError::new(format!( - "Can't index a tensor with ({D}) dimensions on axis ({dim})" - )), - ); - } - - for i in 0..D { - if i == dim { - continue; - } - - let tensor_dim_i = shape.dims[i]; - let indices_dim_i = shape_indices.dims[i]; - - if tensor_dim_i != indices_dim_i { - check = check.register( - ops, - TensorError::new( - "The tensor shape should be the same as the index tensor shape." - .to_string(), - ) - .details(format!( - "The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}" - )), - ); - } - } - - check - } - - /// Checks aggregate dimension such as mean and sum. - pub(crate) fn aggregate_dim(ops: &str, dim: usize) -> Self { - let mut check = Self::Ok; + } + } + + check + } + + pub(crate) fn gather( + dim: usize, + shape: &Shape, + shape_indices: &Shape, + ) -> Self { + Self::check_gather_scatter_indices(Self::Ok, "Gather", dim, shape, shape_indices) + } + + pub(crate) fn scatter( + dim: usize, + shape: &Shape, + shape_indices: &Shape, + shape_value: &Shape, + ) -> Self { + let ops = "Scatter"; + let mut check = Self::check_gather_scatter_indices(Self::Ok, ops, dim, shape, shape_indices); + + if shape_indices != shape_value { + check = check.register( + ops, + TensorError::new( + "Indices tensor shape should be the same as the value tensor shape.".to_string(), + ) + .details(format!( + "The shape differs: {:?} != {:?}", + shape_indices.dims, shape_value.dims + )), + ); + } + + check + } + + pub(crate) fn select(dim: usize) -> Self { + Self::check_select_basic::(Self::Ok, "select", dim) + } + + pub(crate) fn select_assign(dim: usize) -> Self { + Self::check_select_basic::(Self::Ok, "select_assign", dim) + } + + fn check_select_basic(mut check: Self, ops: &str, dim: usize) -> Self { + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't index a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } + + check + } + fn check_gather_scatter_indices( + mut check: Self, + ops: &str, + dim: usize, + shape: &Shape, + shape_indices: &Shape, + ) -> Self { + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't index a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } + + for i in 0..D { + if i == dim { + continue; + } + + let tensor_dim_i = shape.dims[i]; + let indices_dim_i = shape_indices.dims[i]; + + if tensor_dim_i != indices_dim_i { + check = check.register( + ops, + TensorError::new( + "The tensor shape should be the same as the index tensor shape.".to_string(), + ) + .details(format!( + "The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}" + )), + ); + } + } + + check + } + + /// Checks aggregate dimension such as mean and sum. + pub(crate) fn aggregate_dim(ops: &str, dim: usize) -> Self { + let mut check = Self::Ok; + + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't aggregate a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } + + check + } + + /// The goal is to minimize the cost of checks when there are no error, but it's way less + /// important when an error occurred, crafting a comprehensive error message is more important + /// than optimizing string manipulation. + fn register(self, ops: &str, error: TensorError) -> Self { + let errors = match self { + Self::Ok => vec![error], + Self::Failed(mut failed) => { + failed.errors.push(error); + failed.errors + } + }; - if dim > D { - check = check.register( - ops, - TensorError::new(format!( - "Can't aggregate a tensor with ({D}) dimensions on axis ({dim})" - )), - ); + Self::Failed(FailedTensorCheck { + ops: ops.to_string(), + errors, + }) + } + + /// Checks if shapes are compatible for element wise operations supporting broadcasting. + pub(crate) fn binary_ops_ew_shape( + self, + ops: &str, + lhs: &Shape, + rhs: &Shape, + ) -> Self { + let mut check = self; + + for i in 0..D { + let d_lhs = lhs.dims[i]; + let d_rhs = rhs.dims[i]; + + if d_lhs != d_rhs { + let is_broadcast = d_lhs == 1 || d_rhs == 1; + + if is_broadcast { + continue; } - check - } - - /// The goal is to minimize the cost of checks when there are no error, but it's way less - /// important when an error occurred, crafting a comprehensive error message is more important - /// than optimizing string manipulation. - fn register(self, ops: &str, error: TensorError) -> Self { - let errors = match self { - Self::Ok => vec![error], - Self::Failed(mut failed) => { - failed.errors.push(error); - failed.errors - } - }; - - Self::Failed(FailedTensorCheck { - ops: ops.to_string(), - errors, - }) - } - - /// Checks if shapes are compatible for element wise operations supporting broadcasting. - pub(crate) fn binary_ops_ew_shape( - self, - ops: &str, - lhs: &Shape, - rhs: &Shape, - ) -> Self { - let mut check = self; - - for i in 0..D { - let d_lhs = lhs.dims[i]; - let d_rhs = rhs.dims[i]; - - if d_lhs != d_rhs { - let is_broadcast = d_lhs == 1 || d_rhs == 1; - - if is_broadcast { - continue; - } - - check = check.register(ops, - TensorError::new("The provided tensors have incompatible shapes.") - .details(format!( - "Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \ + check = check.register( + ops, + TensorError::new("The provided tensors have incompatible shapes.").details(format!( + "Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \ Lhs tensor shape {:?}, Rhs tensor shape {:?}.", - i, - d_lhs, - d_rhs, - lhs.dims, - rhs.dims, - ))); - } - } - - check - } - - /// Checks if tensor devices are equal. - fn binary_ops_device( - self, - ops: &str, - lhs: &Device, - rhs: &Device, - ) -> Self { - match lhs != rhs { - true => self.register( - ops, - TensorError::new("The provided tensors are not on the same device.").details( - format!("Lhs tensor device {lhs:?}, Rhs tensor device {rhs:?}.",), - ), - ), - false => self, - } - } + i, d_lhs, d_rhs, lhs.dims, rhs.dims, + )), + ); + } + } + + check + } + + /// Checks if tensor devices are equal. + fn binary_ops_device( + self, + ops: &str, + lhs: &Device, + rhs: &Device, + ) -> Self { + match lhs != rhs { + true => self.register( + ops, + TensorError::new("The provided tensors are not on the same device.").details(format!( + "Lhs tensor device {lhs:?}, Rhs tensor device {rhs:?}.", + )), + ), + false => self, + } + } } pub(crate) struct FailedTensorCheck { - ops: String, - errors: Vec, + ops: String, + errors: Vec, } impl FailedTensorCheck { - /// Format all the checks into a single message ready to be printed by a [panic](core::panic). - pub(crate) fn format(self) -> String { - self.errors.into_iter().enumerate().fold( - format!( - "=== Tensor Operation Error ===\n Operation: '{}'\n Reason:", - self.ops - ), - |accum, (number, error)| accum + error.format(number + 1).as_str(), - ) + "\n" - } + /// Format all the checks into a single message ready to be printed by a [panic](core::panic). + pub(crate) fn format(self) -> String { + self.errors.into_iter().enumerate().fold( + format!( + "=== Tensor Operation Error ===\n Operation: '{}'\n Reason:", + self.ops + ), + |accum, (number, error)| accum + error.format(number + 1).as_str(), + ) + "\n" + } } struct TensorError { - description: String, - details: Option, + description: String, + details: Option, } impl TensorError { - pub(crate) fn new>(description: S) -> Self { - TensorError { - description: description.into(), - details: None, - } - } - - pub(crate) fn details>(mut self, details: S) -> Self { - self.details = Some(details.into()); - self + pub(crate) fn new>(description: S) -> Self { + TensorError { + description: description.into(), + details: None, } + } - fn format(self, number: usize) -> String { - let mut message = format!("\n {number}. "); - message += self.description.as_str(); - message += " "; + pub(crate) fn details>(mut self, details: S) -> Self { + self.details = Some(details.into()); + self + } - if let Some(details) = self.details { - message += details.as_str(); - message += " "; - } + fn format(self, number: usize) -> String { + let mut message = format!("\n {number}. "); + message += self.description.as_str(); + message += " "; - message + if let Some(details) = self.details { + message += details.as_str(); + message += " "; } + + message + } } /// We use a macro for all checks, since the panic message file and line number will match the @@ -672,78 +658,78 @@ impl TensorError { /// and line number. #[macro_export(local_inner_macros)] macro_rules! check { - ($check:expr) => { - if let TensorCheck::Failed(check) = $check { - core::panic!("{}", check.format()); - } - }; + ($check:expr) => { + if let TensorCheck::Failed(check) = $check { + core::panic!("{}", check.format()); + } + }; } #[cfg(test)] mod tests { - use super::*; - - #[test] - #[should_panic] - fn reshape_invalid_shape() { - check!(TensorCheck::reshape_args_usize( - &Shape::new([2, 2]), - &Shape::new([1, 3]) - )); - } - - #[test] - fn reshape_valid_shape() { - check!(TensorCheck::reshape_args_usize( - &Shape::new([2, 2]), - &Shape::new([1, 4]) - )); - } - - #[test] - #[should_panic] - fn index_range_exceed_dimension() { - check!(TensorCheck::slice( - &Shape::new([3, 5, 7]), - &[0..2, 0..4, 1..8] - )); - } - - #[test] - #[should_panic] - fn index_range_exceed_number_of_dimensions() { - check!(TensorCheck::slice(&Shape::new([3, 5]), &[0..1, 0..1, 0..1])); - } - - #[test] - #[should_panic] - fn binary_ops_shapes_no_broadcast() { - check!(TensorCheck::binary_ops_ew_shape( - TensorCheck::Ok, - "TestOps", - &Shape::new([3, 5]), - &Shape::new([3, 6]) - )); - } - - #[test] - fn binary_ops_shapes_with_broadcast() { - check!(TensorCheck::binary_ops_ew_shape( - TensorCheck::Ok, - "Test", - &Shape::new([3, 5]), - &Shape::new([1, 5]) - )); - } - - #[test] - #[should_panic] - fn binary_ops_devices() { - check!(TensorCheck::binary_ops_device( - TensorCheck::Ok, - "Test", - &5, // We can pass anything that implements PartialEq as device - &8 - )); - } + use super::*; + + #[test] + #[should_panic] + fn reshape_invalid_shape() { + check!(TensorCheck::reshape_args_usize( + &Shape::new([2, 2]), + &Shape::new([1, 3]) + )); + } + + #[test] + fn reshape_valid_shape() { + check!(TensorCheck::reshape_args_usize( + &Shape::new([2, 2]), + &Shape::new([1, 4]) + )); + } + + #[test] + #[should_panic] + fn index_range_exceed_dimension() { + check!(TensorCheck::slice( + &Shape::new([3, 5, 7]), + &[0..2, 0..4, 1..8] + )); + } + + #[test] + #[should_panic] + fn index_range_exceed_number_of_dimensions() { + check!(TensorCheck::slice(&Shape::new([3, 5]), &[0..1, 0..1, 0..1])); + } + + #[test] + #[should_panic] + fn binary_ops_shapes_no_broadcast() { + check!(TensorCheck::binary_ops_ew_shape( + TensorCheck::Ok, + "TestOps", + &Shape::new([3, 5]), + &Shape::new([3, 6]) + )); + } + + #[test] + fn binary_ops_shapes_with_broadcast() { + check!(TensorCheck::binary_ops_ew_shape( + TensorCheck::Ok, + "Test", + &Shape::new([3, 5]), + &Shape::new([1, 5]) + )); + } + + #[test] + #[should_panic] + fn binary_ops_devices() { + check!(TensorCheck::binary_ops_device( + TensorCheck::Ok, + "Test", + &5, // We can pass anything that implements PartialEq as device + &8 + )); + } } diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index d2cededbab..4115b8538b 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -12,316 +12,316 @@ use crate::Tensor; impl Tensor where - B: Backend, + B: Backend, { - /// Executes an operation on the tensor and modifies its value. - /// - /// # Notes - /// - /// This won't necessary reuse the same tensor data/buffer, but it should if there is - /// no other reference pointing to the same tensor. - /// - /// Wrapping operations with inplace is not an optimization, it's mainly there if you - /// want to mutate a tensor by using owned operations. A plausible usage would be to - /// update the weights of a mutable model reference. - pub fn inplace Self>(&mut self, func: F) { - let mut tensor_owned = Tensor::empty([0; D]); - core::mem::swap(&mut tensor_owned, self); - - let mut tensor_new = func(tensor_owned); - core::mem::swap(&mut tensor_new, self); - } - - /// Applies element wise exponential operation. - /// - /// `y = e^x` - pub fn exp(self) -> Self { - Self::new(B::exp(self.primitive)) - } - - /// Applies element wise natural log operation *ln*. - /// - /// `y = log(x)` - pub fn log(self) -> Self { - Self::new(B::log(self.primitive)) - } - - /// Applies the natural logarithm of one plus the input tensor, element-wise. - /// - /// `y = log(x+1)` - pub fn log1p(self) -> Self { - Self::new(B::log1p(self.primitive)) - } - - /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise. - /// - /// `y = erf(x)` - pub fn erf(self) -> Self { - Self::new(B::erf(self.primitive)) - } - - /// Applies element wise power operation. - /// - /// `y = x^a` - pub fn powf(self, value: f32) -> Self { - Self::new(B::powf(self.primitive, value)) - } - - /// Applies element wise reciprocal operation. - pub fn recip(self) -> Self { - Self::new(B::recip(self.primitive)) - } - - /// Applies element wise root square operation. - pub fn sqrt(self) -> Self { - Self::new(B::sqrt(self.primitive)) - } - - /// Applies element wise cosine operation. - pub fn cos(self) -> Self { - Self::new(B::cos(self.primitive)) - } - - /// Applies element wise sine operation. - pub fn sin(self) -> Self { - Self::new(B::sin(self.primitive)) - } - - /// Applies element wise hyperbolic tangent operation. - pub fn tanh(self) -> Self { - Self::new(B::tanh(self.primitive)) - } - - /// Create a tensor from floats (f32). - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let _ = Tensor::::from_floats([1.0, 2.0]); - /// let _ = Tensor::::from_floats([[1.0, 2.0], [3.0, 4.0]]); - /// } - /// ``` - pub fn from_floats>>(floats: A) -> Self { - Self::from_data(floats.into().convert()) - } - - /// Returns a new tensor with the same shape and device as the current tensor and the data - /// casted to Integer. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let float_tensor = Tensor::::from_floats([1.0, 2.0]); - /// let int_tensor = float_tensor.int(); - /// } - /// ``` - pub fn int(self) -> Tensor { - Tensor::new(B::into_int(self.primitive)) - } - - /// Returns a new tensor with the same shape and device as the current tensor filled with zeros. - pub fn zeros_like(&self) -> Self { - Tensor::new(B::zeros(self.shape(), &self.device())) - } - - /// Returns a new tensor with the same shape and device as the current tensor filled with ones. - pub fn ones_like(&self) -> Self { - Tensor::new(B::ones(self.shape(), &self.device())) - } - - /// Returns a new tensor with the same shape and device as the current tensor filled random - /// values sampled from the given distribution. - pub fn random_like(&self, distribution: Distribution) -> Self { - Tensor::new(B::random(self.shape(), distribution, &self.device())) - } - - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let one_hot = Tensor::::one_hot(2, 10); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - /// ``` - pub fn one_hot(index: usize, num_classes: usize) -> Self { - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]))) - } - - /// Applies the matrix multiplication operation. - /// - /// `C = AB` - /// - /// # Panics - /// - /// If the two tensors dont' have a compatible shape. - pub fn matmul(self, other: Self) -> Self { - check!(TensorCheck::matmul(&self, &other)); - Self::new(B::matmul(self.primitive, other.primitive)) - } - - /// Calculate the variance along the given dimension. - pub fn var(self, dim: usize) -> Self { - stats::var(self, dim) - } - - /// Calculate the variance along the given dimension without applying the Bessel’s correction. - pub fn var_bias(self, dim: usize) -> Self { - stats::var_bias(self, dim) - } - - /// Calculate the variance along the given dimension and also returns the mean. - pub fn var_mean(self, dim: usize) -> (Self, Self) { - let mean = self.clone().mean_dim(dim); - let var = stats::var_with_mean(self, mean.clone(), dim); - (var, mean) - } - - /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean. - pub fn var_mean_bias(self, dim: usize) -> (Self, Self) { - let mean = self.clone().mean_dim(dim); - let var = stats::var_with_mean_bias(self, mean.clone(), dim); - (var, mean) - } - - /// Create a random tensor of the given shape where each element is sampled from the given - /// distribution. - pub fn random>>(shape: S, distribution: Distribution) -> Self { - let tensor = B::random(shape.into(), distribution, &B::Device::default()); - Self::new(tensor) - } - - /// Create a random tensor of the given shape on the given device where each element is - /// sampled from the given distribution. - pub fn random_device>>( - shape: S, - distribution: Distribution, - device: &B::Device, - ) -> Self { - let tensor = B::random(shape.into(), distribution, device); - Self::new(tensor) - } - /// Returns a tensor with full precision based on the selected backend. - pub fn to_full_precision(&self) -> Tensor { - Tensor::new(B::to_full_precision(&self.primitive)) - } - - /// Returns a tensor on the selected backend from a full precision tensor. - pub fn from_full_precision(tensor: Tensor) -> Self { - Self::new(B::from_full_precision(tensor.primitive)) - } - - /// Detach the current tensor from the autodiff graph. - /// This function does nothing when autodiff is not enabled. - /// This can be used in batchers or elsewhere to ensure that previous operations are not - /// considered in the autodiff graph. - pub fn detach(self) -> Self { - Self::new(B::detach(self.primitive)) - } - - /// Mark the tensor to keep gradients during the backward pass. - /// This function does nothing when autodiff is not enabled. - pub fn require_grad(self) -> Self { - self.set_require_grad(true) - } - - /// Returns true if the tensor requires gradients during the backward pass. - pub fn is_require_grad(&self) -> bool { - B::is_require_grad(&self.primitive) - } - - /// Mark the tensor as tracked or untracked depending on the require grad argument. - /// When tracked, the gradients will be available after the backward pass. - /// - /// This function does nothing when autodiff is not enabled. - pub fn set_require_grad(self, require_grad: bool) -> Self { - Self::new(B::set_require_grad(self.primitive, require_grad)) - } - - /// Applies the relu function to the tensor. - pub(crate) fn relu(self) -> Self { - Self::new(B::relu(self.primitive)) - } - - /// Calculate covaraince matrix between different entries alongside a given dimension. - /// - /// # Arguments - /// - /// * `size` - The size of the square matrix. - /// * `correction_factor` - Is usually 1 for samples and 0 for population. - pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor { - let n = self.dims()[dim]; - let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0); - centered - .clone() - .transpose() - .matmul(centered) - .div_scalar(n as f32 - correction_factor as f32) - } + /// Executes an operation on the tensor and modifies its value. + /// + /// # Notes + /// + /// This won't necessary reuse the same tensor data/buffer, but it should if there is + /// no other reference pointing to the same tensor. + /// + /// Wrapping operations with inplace is not an optimization, it's mainly there if you + /// want to mutate a tensor by using owned operations. A plausible usage would be to + /// update the weights of a mutable model reference. + pub fn inplace Self>(&mut self, func: F) { + let mut tensor_owned = Tensor::empty([0; D]); + core::mem::swap(&mut tensor_owned, self); + + let mut tensor_new = func(tensor_owned); + core::mem::swap(&mut tensor_new, self); + } + + /// Applies element wise exponential operation. + /// + /// `y = e^x` + pub fn exp(self) -> Self { + Self::new(B::exp(self.primitive)) + } + + /// Applies element wise natural log operation *ln*. + /// + /// `y = log(x)` + pub fn log(self) -> Self { + Self::new(B::log(self.primitive)) + } + + /// Applies the natural logarithm of one plus the input tensor, element-wise. + /// + /// `y = log(x+1)` + pub fn log1p(self) -> Self { + Self::new(B::log1p(self.primitive)) + } + + /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise. + /// + /// `y = erf(x)` + pub fn erf(self) -> Self { + Self::new(B::erf(self.primitive)) + } + + /// Applies element wise power operation. + /// + /// `y = x^a` + pub fn powf(self, value: f32) -> Self { + Self::new(B::powf(self.primitive, value)) + } + + /// Applies element wise reciprocal operation. + pub fn recip(self) -> Self { + Self::new(B::recip(self.primitive)) + } + + /// Applies element wise root square operation. + pub fn sqrt(self) -> Self { + Self::new(B::sqrt(self.primitive)) + } + + /// Applies element wise cosine operation. + pub fn cos(self) -> Self { + Self::new(B::cos(self.primitive)) + } + + /// Applies element wise sine operation. + pub fn sin(self) -> Self { + Self::new(B::sin(self.primitive)) + } + + /// Applies element wise hyperbolic tangent operation. + pub fn tanh(self) -> Self { + Self::new(B::tanh(self.primitive)) + } + + /// Create a tensor from floats (f32). + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let _ = Tensor::::from_floats([1.0, 2.0]); + /// let _ = Tensor::::from_floats([[1.0, 2.0], [3.0, 4.0]]); + /// } + /// ``` + pub fn from_floats>>(floats: A) -> Self { + Self::from_data(floats.into().convert()) + } + + /// Returns a new tensor with the same shape and device as the current tensor and the data + /// casted to Integer. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let float_tensor = Tensor::::from_floats([1.0, 2.0]); + /// let int_tensor = float_tensor.int(); + /// } + /// ``` + pub fn int(self) -> Tensor { + Tensor::new(B::into_int(self.primitive)) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled with zeros. + pub fn zeros_like(&self) -> Self { + Tensor::new(B::zeros(self.shape(), &self.device())) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled with ones. + pub fn ones_like(&self) -> Self { + Tensor::new(B::ones(self.shape(), &self.device())) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled random + /// values sampled from the given distribution. + pub fn random_like(&self, distribution: Distribution) -> Self { + Tensor::new(B::random(self.shape(), distribution, &self.device())) + } + + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let one_hot = Tensor::::one_hot(2, 10); + /// println!("{}", one_hot.to_data()); + /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + /// } + /// ``` + pub fn one_hot(index: usize, num_classes: usize) -> Self { + let mut dims = [1; D]; + dims[D - 1] = num_classes; + let shape = Shape::new(dims); + let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); + let tensor = Tensor::zeros(shape); + let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); + ranges[D - 1] = index..index + 1; + + tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]))) + } + + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors dont' have a compatible shape. + pub fn matmul(self, other: Self) -> Self { + check!(TensorCheck::matmul(&self, &other)); + Self::new(B::matmul(self.primitive, other.primitive)) + } + + /// Calculate the variance along the given dimension. + pub fn var(self, dim: usize) -> Self { + stats::var(self, dim) + } + + /// Calculate the variance along the given dimension without applying the Bessel’s correction. + pub fn var_bias(self, dim: usize) -> Self { + stats::var_bias(self, dim) + } + + /// Calculate the variance along the given dimension and also returns the mean. + pub fn var_mean(self, dim: usize) -> (Self, Self) { + let mean = self.clone().mean_dim(dim); + let var = stats::var_with_mean(self, mean.clone(), dim); + (var, mean) + } + + /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean. + pub fn var_mean_bias(self, dim: usize) -> (Self, Self) { + let mean = self.clone().mean_dim(dim); + let var = stats::var_with_mean_bias(self, mean.clone(), dim); + (var, mean) + } + + /// Create a random tensor of the given shape where each element is sampled from the given + /// distribution. + pub fn random>>(shape: S, distribution: Distribution) -> Self { + let tensor = B::random(shape.into(), distribution, &B::Device::default()); + Self::new(tensor) + } + + /// Create a random tensor of the given shape on the given device where each element is + /// sampled from the given distribution. + pub fn random_device>>( + shape: S, + distribution: Distribution, + device: &B::Device, + ) -> Self { + let tensor = B::random(shape.into(), distribution, device); + Self::new(tensor) + } + /// Returns a tensor with full precision based on the selected backend. + pub fn to_full_precision(&self) -> Tensor { + Tensor::new(B::to_full_precision(&self.primitive)) + } + + /// Returns a tensor on the selected backend from a full precision tensor. + pub fn from_full_precision(tensor: Tensor) -> Self { + Self::new(B::from_full_precision(tensor.primitive)) + } + + /// Detach the current tensor from the autodiff graph. + /// This function does nothing when autodiff is not enabled. + /// This can be used in batchers or elsewhere to ensure that previous operations are not + /// considered in the autodiff graph. + pub fn detach(self) -> Self { + Self::new(B::detach(self.primitive)) + } + + /// Mark the tensor to keep gradients during the backward pass. + /// This function does nothing when autodiff is not enabled. + pub fn require_grad(self) -> Self { + self.set_require_grad(true) + } + + /// Returns true if the tensor requires gradients during the backward pass. + pub fn is_require_grad(&self) -> bool { + B::is_require_grad(&self.primitive) + } + + /// Mark the tensor as tracked or untracked depending on the require grad argument. + /// When tracked, the gradients will be available after the backward pass. + /// + /// This function does nothing when autodiff is not enabled. + pub fn set_require_grad(self, require_grad: bool) -> Self { + Self::new(B::set_require_grad(self.primitive, require_grad)) + } + + /// Applies the relu function to the tensor. + pub(crate) fn relu(self) -> Self { + Self::new(B::relu(self.primitive)) + } + + /// Calculate covaraince matrix between different entries alongside a given dimension. + /// + /// # Arguments + /// + /// * `size` - The size of the square matrix. + /// * `correction_factor` - Is usually 1 for samples and 0 for population. + pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor { + let n = self.dims()[dim]; + let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0); + centered + .clone() + .transpose() + .matmul(centered) + .div_scalar(n as f32 - correction_factor as f32) + } } impl Tensor { - /// Backward pass of the tensor. - pub fn backward(&self) -> B::Gradients { - B::backward::(self.primitive.clone()) - } - - /// Get the gradients of a tensor if it exist. - /// - /// Returns a new reference to the same tensor. Therefore the same grad tensor can - /// be accessed multiple times. If you only need to get the gradients one time, - /// consider using [grad_remove](Tensor::grad_remove) for better performance. - pub fn grad(&self, grads: &B::Gradients) -> Option> { - B::grad(&self.primitive, grads).map(Tensor::new) - } - - /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result. - pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option> { - B::grad_remove(&self.primitive, grads).map(Tensor::new) - } - - /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided - /// gradient. - pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor) { - B::grad_replace(&self.primitive, grads, grad.primitive); - } - - /// Returns the inner tensor without the autodiff information. - pub fn inner(self) -> Tensor { - Tensor::new(B::inner(self.primitive)) - } - - /// Convert a tensor to the autodiff backend. - /// - /// # Arguments - /// - /// * `inner` - The tensor to convert. - /// - /// # Returns - /// - /// The tensor converted to the autodiff backend. - pub fn from_inner(inner: Tensor) -> Self { - Self::new(B::from_inner(inner.primitive)) - } + /// Backward pass of the tensor. + pub fn backward(&self) -> B::Gradients { + B::backward::(self.primitive.clone()) + } + + /// Get the gradients of a tensor if it exist. + /// + /// Returns a new reference to the same tensor. Therefore the same grad tensor can + /// be accessed multiple times. If you only need to get the gradients one time, + /// consider using [grad_remove](Tensor::grad_remove) for better performance. + pub fn grad(&self, grads: &B::Gradients) -> Option> { + B::grad(&self.primitive, grads).map(Tensor::new) + } + + /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result. + pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option> { + B::grad_remove(&self.primitive, grads).map(Tensor::new) + } + + /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided + /// gradient. + pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor) { + B::grad_replace(&self.primitive, grads, grad.primitive); + } + + /// Returns the inner tensor without the autodiff information. + pub fn inner(self) -> Tensor { + Tensor::new(B::inner(self.primitive)) + } + + /// Convert a tensor to the autodiff backend. + /// + /// # Arguments + /// + /// * `inner` - The tensor to convert. + /// + /// # Returns + /// + /// The tensor converted to the autodiff backend. + pub fn from_inner(inner: Tensor) -> Self { + Self::new(B::from_inner(inner.primitive)) + } } diff --git a/burn-tensor/src/tensor/api/int.rs b/burn-tensor/src/tensor/api/int.rs index 395db8c280..f141c40ca8 100644 --- a/burn-tensor/src/tensor/api/int.rs +++ b/burn-tensor/src/tensor/api/int.rs @@ -3,84 +3,84 @@ use core::ops::Range; impl Tensor where - B: Backend, + B: Backend, { - /// Returns a new integer tensor on the default device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - pub fn arange(range: Range) -> Self { - Tensor::new(B::arange(range, &B::Device::default())) - } + /// Returns a new integer tensor on the default device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + pub fn arange(range: Range) -> Self { + Tensor::new(B::arange(range, &B::Device::default())) + } - /// Returns a new integer tensor on the default device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - /// * `step` - The step between each value. - pub fn arange_step(range: Range, step: usize) -> Self { - Tensor::new(B::arange_step(range, step, &B::Device::default())) - } + /// Returns a new integer tensor on the default device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `step` - The step between each value. + pub fn arange_step(range: Range, step: usize) -> Self { + Tensor::new(B::arange_step(range, step, &B::Device::default())) + } - /// Returns a new integer tensor on the specified device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - /// * `device` - The device to create the tensor on. - pub fn arange_device(range: Range, device: &B::Device) -> Self { - Tensor::new(B::arange(range, device)) - } + /// Returns a new integer tensor on the specified device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `device` - The device to create the tensor on. + pub fn arange_device(range: Range, device: &B::Device) -> Self { + Tensor::new(B::arange(range, device)) + } - /// Returns a new integer tensor on the specified device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - /// * `step` - The step between each value. - pub fn arange_step_device(range: Range, step: usize, device: &B::Device) -> Self { - Tensor::new(B::arange_step(range, step, device)) - } + /// Returns a new integer tensor on the specified device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `step` - The step between each value. + pub fn arange_step_device(range: Range, step: usize, device: &B::Device) -> Self { + Tensor::new(B::arange_step(range, step, device)) + } } impl Tensor where - B: Backend, + B: Backend, { - /// Create a tensor from integers (i32). - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example() { - /// let _x: Tensor = Tensor::from_ints([1, 2]); - /// let _y: Tensor = Tensor::from_ints([[1, 2], [3, 4]]); - /// } - /// ``` - pub fn from_ints>>(ints: A) -> Self { - Self::from_data(ints.into().convert()) - } + /// Create a tensor from integers (i32). + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Int}; + /// + /// fn example() { + /// let _x: Tensor = Tensor::from_ints([1, 2]); + /// let _y: Tensor = Tensor::from_ints([[1, 2], [3, 4]]); + /// } + /// ``` + pub fn from_ints>>(ints: A) -> Self { + Self::from_data(ints.into().convert()) + } - /// Returns a new tensor with the same shape and device as the current tensor and the data - /// casted to Float. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor}; - /// - /// fn example() { - /// let int_tensor = Tensor::::arange(0..5); - /// let float_tensor = int_tensor.float(); - /// } - /// ``` - pub fn float(self) -> Tensor { - Tensor::new(B::int_into_float(self.primitive)) - } + /// Returns a new tensor with the same shape and device as the current tensor and the data + /// casted to Float. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Int, Tensor}; + /// + /// fn example() { + /// let int_tensor = Tensor::::arange(0..5); + /// let float_tensor = int_tensor.float(); + /// } + /// ``` + pub fn float(self) -> Tensor { + Tensor::new(B::int_into_float(self.primitive)) + } } diff --git a/burn-tensor/src/tensor/api/kind.rs b/burn-tensor/src/tensor/api/kind.rs index 208aa26ef4..d7e7cc9eb5 100644 --- a/burn-tensor/src/tensor/api/kind.rs +++ b/burn-tensor/src/tensor/api/kind.rs @@ -14,30 +14,30 @@ pub struct Bool; /// A type-level representation of the kind of a tensor. pub trait TensorKind: Clone + core::fmt::Debug { - /// The primitive type of the tensor. - type Primitive: Clone + core::fmt::Debug; + /// The primitive type of the tensor. + type Primitive: Clone + core::fmt::Debug; - /// The name of the tensor kind. - fn name() -> &'static str; + /// The name of the tensor kind. + fn name() -> &'static str; } impl TensorKind for Float { - type Primitive = B::TensorPrimitive; - fn name() -> &'static str { - "Float" - } + type Primitive = B::TensorPrimitive; + fn name() -> &'static str { + "Float" + } } impl TensorKind for Int { - type Primitive = B::IntTensorPrimitive; - fn name() -> &'static str { - "Int" - } + type Primitive = B::IntTensorPrimitive; + fn name() -> &'static str { + "Int" + } } impl TensorKind for Bool { - type Primitive = B::BoolTensorPrimitive; - fn name() -> &'static str { - "Bool" - } + type Primitive = B::BoolTensorPrimitive; + fn name() -> &'static str { + "Bool" + } } diff --git a/burn-tensor/src/tensor/api/numeric.rs b/burn-tensor/src/tensor/api/numeric.rs index fe3d064120..2ee70d39db 100644 --- a/burn-tensor/src/tensor/api/numeric.rs +++ b/burn-tensor/src/tensor/api/numeric.rs @@ -1,486 +1,486 @@ use crate::{ - backend::Backend, check, check::TensorCheck, BasicOps, Bool, Element, ElementConversion, Float, - Int, Shape, Tensor, TensorKind, + backend::Backend, check, check::TensorCheck, BasicOps, Bool, Element, ElementConversion, Float, + Int, Shape, Tensor, TensorKind, }; impl Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Convert the tensor into a scalar. - /// - /// # Panics - /// - /// If the tensor doesn't have one element. - pub fn into_scalar(self) -> K::Elem { - check!(TensorCheck::into_scalar(&self.shape())); - let data = self.into_data(); - data.value[0] - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Convert the tensor into a scalar. - /// - /// # Panics - /// - /// If the tensor doesn't have one element. - pub async fn into_scalar(self) -> K::Elem { - check!(TensorCheck::into_scalar(&self.shape())); - let data = self.into_data().await; - data.value[0] - } - - /// Applies element wise addition operation. - /// - /// `y = x2 + x1` - #[allow(clippy::should_implement_trait)] - pub fn add(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Add", &self, &other)); - Self::new(K::add(self.primitive, other.primitive)) - } - - /// Applies element wise addition operation with a scalar. - /// - /// `y = x + s` - pub fn add_scalar(self, other: E) -> Self { - Self::new(K::add_scalar(self.primitive, other)) - } - - /// Applies element wise subtraction operation. - /// - /// `y = x2 - x1` - #[allow(clippy::should_implement_trait)] - pub fn sub(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Sub", &self, &other)); - Self::new(K::sub(self.primitive, other.primitive)) - } - - /// Applies element wise subtraction operation with a scalar. - /// - /// `y = x - s` - pub fn sub_scalar(self, other: E) -> Self { - Self::new(K::sub_scalar(self.primitive, other)) - } - - /// Applies element wise division operation. - /// - /// `y = x2 / x1` - #[allow(clippy::should_implement_trait)] - pub fn div(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Div", &self, &other)); - Self::new(K::div(self.primitive, other.primitive)) - } - - /// Applies element wise division operation with a scalar. - /// - /// `y = x / s` - pub fn div_scalar(self, other: E) -> Self { - Self::new(K::div_scalar(self.primitive, other)) - } - /// - /// Applies element wise multiplication operation. - /// - /// `y = x2 * x1` - #[allow(clippy::should_implement_trait)] - pub fn mul(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Mul", &self, &other)); - Self::new(K::mul(self.primitive, other.primitive)) - } - - /// Applies element wise multiplication operation with a scalar. - /// - /// `y = x * s` - pub fn mul_scalar(self, other: E) -> Self { - Self::new(K::mul_scalar(self.primitive, other)) - } - - /// Switch sign of each element in the tensor. - /// - /// `y = -x` - #[allow(clippy::should_implement_trait)] - pub fn neg(self) -> Self { - Self::new(K::neg(self.primitive)) - } - - /// Create a tensor of the given shape where each element is zero. - pub fn zeros>>(shape: S) -> Self { - Self::zeros_device(shape, &B::Device::default()) - } - - /// Create a tensor of the given shape where each element is zero. - pub fn zeros_device>>(shape: S, device: &B::Device) -> Self { - Self::new(K::zeros(shape.into(), device)) - } - - /// Create a tensor of the given shape where each element is one. - pub fn ones>>(shape: S) -> Self { - Self::ones_device(shape, &B::Device::default()) - } - - /// Create a tensor of the given shape where each element is one. - pub fn ones_device>>(shape: S, device: &B::Device) -> Self { - Self::new(K::ones(shape.into(), device)) - } - - /// Create a tensor of the given shape where each element is equal to the provided value. - pub fn full>, E: ElementConversion>(shape: S, fill_value: E) -> Self { - Self::full_device(shape, fill_value, &B::Device::default()) - } - - /// Create a tensor of the given shape where each element is equal to the provided value. - pub fn full_device>, E: ElementConversion>( - shape: S, - fill_value: E, - device: &B::Device, - ) -> Self { - Self::new(K::full(shape.into(), fill_value, device)) - } - - /// Aggregate all elements in the tensor with the mean operation. - pub fn mean(self) -> Tensor { - Tensor::new(K::mean(self.primitive)) - } - - /// Aggregate all elements in the tensor with the sum operation. - pub fn sum(self) -> Tensor { - Tensor::new(K::sum(self.primitive)) - } - - /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation. - pub fn mean_dim(self, dim: usize) -> Self { - check!(TensorCheck::aggregate_dim::("Mean", dim)); - Self::new(K::mean_dim(self.primitive, dim)) - } - - /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation. - pub fn sum_dim(self, dim: usize) -> Self { - check!(TensorCheck::aggregate_dim::("Sum", dim)); - Self::new(K::sum_dim(self.primitive, dim)) - } - - /// Applies element wise equal comparison and returns a boolean tensor. - pub fn equal_elem(self, other: E) -> Tensor { - K::equal_elem::(self.primitive, other.elem()) - } - - /// Applies element wise greater comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn greater(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Greater", &self, &other)); - K::greater(self.primitive, other.primitive) - } - - /// Applies element wise greater-equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn greater_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other)); - K::greater_equal(self.primitive, other.primitive) - } - - /// Applies element wise lower comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn lower(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Lower", &self, &other)); - K::lower(self.primitive, other.primitive) - } - - /// Applies element wise lower-equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn lower_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other)); - K::lower_equal(self.primitive, other.primitive) - } - - /// Applies element wise greater comparison and returns a boolean tensor. - pub fn greater_elem(self, other: E) -> Tensor { - K::greater_elem(self.primitive, other.elem()) - } - - /// Applies element wise greater-equal comparison and returns a boolean tensor. - pub fn greater_equal_elem(self, other: E) -> Tensor { - K::greater_equal_elem(self.primitive, other.elem()) - } - - /// Applies element wise lower comparison and returns a boolean tensor. - pub fn lower_elem(self, other: E) -> Tensor { - K::lower_elem(self.primitive, other.elem()) - } - - /// Applies element wise lower-equal comparison and returns a boolean tensor. - pub fn lower_equal_elem(self, other: E) -> Tensor { - K::lower_equal_elem(self.primitive, other.elem()) - } - - /// Update the given tensor with the value tensor where the mask is true. - /// - /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of - /// a scalar. - pub fn mask_where(self, mask: Tensor, value: Self) -> Self { - Self::new(K::mask_where(self.primitive, mask, value.primitive)) - } - - /// Update the given tensor with the value where the mask is true. - /// - /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of - /// a tensor. - pub fn mask_fill(self, mask: Tensor, value: E) -> Self { - Self::new(K::mask_fill(self.primitive, mask, value.elem())) - } - - /// Gather tensor elements corresponding to the given indices from the specified dim. - /// - /// Example using a 3D tensor: - /// - /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0` - /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1` - /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2` - /// - /// # Notes - /// - /// The index tensor should have the same shape as the original tensor except for the dim - /// specified. - pub fn gather(self, dim: usize, indices: Tensor) -> Self { - check!(TensorCheck::gather::( - dim, - &self.shape(), - &indices.shape() - )); - - Self::new(K::gather(dim, self.primitive, indices)) - } - - /// Assign the gathered elements corresponding to the given indices along the specified dimension - /// from the value tensor to the original tensor using sum reduction. - /// - /// Example using a 3D tensor: - /// - /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0` - /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1` - /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2` - /// - /// # Notes - /// - /// The index tensor should have the same shape as the original tensor except for the specified - /// dimension. The value and index tensors should have the same shape. - /// - /// Other references to the input tensor will not be modified by this operation. - pub fn scatter(self, dim: usize, indices: Tensor, values: Self) -> Self { - check!(TensorCheck::scatter::( - dim, - &self.shape(), - &indices.shape(), - &values.shape() - )); - - Self::new(K::scatter(dim, self.primitive, indices, values.primitive)) - } - - /// Select the tensor elements along the given dimension corresponding to the given indices. - /// - /// Example using a 3D tensor: - /// - /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0` - /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1` - /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2` - pub fn select(self, dim: usize, indices: Tensor) -> Self { - check!(TensorCheck::select::(dim)); - Self::new(K::select(self.primitive, dim, indices)) - } - - /// Assign the selected elements along the given dimension corresponding to the given indices - /// from the value tensor to the original tensor using sum reduction. - /// - /// Example using a 3D tensor: - /// - /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0` - /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1` - /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2` - pub fn select_assign( - self, - dim: usize, - indices: Tensor, - values: Tensor, - ) -> Self { - check!(TensorCheck::select_assign::(dim)); - - Self::new(K::select_assign( - self.primitive, - dim, - indices, - values.primitive, - )) - } - - /// Applies the argmax function along the given dimension and returns an integer tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); - /// let tensor = tensor.argmax(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } - /// } - /// ``` - pub fn argmax(self, dim: usize) -> Tensor { - Tensor::new(K::argmax(self.primitive, dim)) - } - - /// Find the maximum value. - pub fn max(self) -> Tensor { - Tensor::new(K::max(self.primitive)) - } - - /// Find the maximum value along the given dimension. - pub fn max_dim(self, dim: usize) -> Tensor { - check!(TensorCheck::aggregate_dim::("Max", dim)); - - Tensor::new(K::max_dim(self.primitive, dim)) - } - - /// Find the maximum value along the given dimension. - /// - /// Also returns the indices. - pub fn max_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { - check!(TensorCheck::aggregate_dim::("Max", dim)); - - let (tensor, index) = K::max_dim_with_indices(self.primitive, dim); - - let tensor = Tensor::new(tensor); - let index = Tensor::new(index); - - (tensor, index) - } - - /// Applies the argmin function along the given dimension and returns an integer tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); - /// let tensor = tensor.argmin(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } - /// } - /// ``` - pub fn argmin(self, dim: usize) -> Tensor { - Tensor::new(K::argmin(self.primitive, dim)) - } - - /// Find the minimum value. - pub fn min(self) -> Tensor { - Tensor::new(K::min(self.primitive)) - } - - /// Find the minimum value along the given dimension. - pub fn min_dim(self, dim: usize) -> Tensor { - check!(TensorCheck::aggregate_dim::("Min", dim)); - Tensor::new(K::min_dim(self.primitive, dim)) - } - - /// Find the minimum value along the given dimension. - /// - /// Also returns the indices. - pub fn min_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { - check!(TensorCheck::aggregate_dim::("Min", dim)); - - let (tensor, index) = K::min_dim_with_indices(self.primitive, dim); - - let tensor = Tensor::new(tensor); - let index = Tensor::new(index); - - (tensor, index) - } - - /// Clamp the tensor between the given min and max values. - /// - /// # Arguments - /// - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped between the given min and max values. - pub fn clamp(self, min: E, max: E) -> Self { - Self::new(K::clamp(self.primitive, min.elem(), max.elem())) - } - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped under the given min value. - pub fn clamp_min(self, min: E) -> Self { - Self::new(K::clamp_min(self.primitive, min.elem())) - } - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped over the given max value. - /// - pub fn clamp_max(self, max: E) -> Self { - Self::new(K::clamp_max(self.primitive, max.elem())) - } - - /// Apply element wise absolute value operation - pub fn abs(self) -> Self { - Self::new(K::abs(self.primitive)) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Convert the tensor into a scalar. + /// + /// # Panics + /// + /// If the tensor doesn't have one element. + pub fn into_scalar(self) -> K::Elem { + check!(TensorCheck::into_scalar(&self.shape())); + let data = self.into_data(); + data.value[0] + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Convert the tensor into a scalar. + /// + /// # Panics + /// + /// If the tensor doesn't have one element. + pub async fn into_scalar(self) -> K::Elem { + check!(TensorCheck::into_scalar(&self.shape())); + let data = self.into_data().await; + data.value[0] + } + + /// Applies element wise addition operation. + /// + /// `y = x2 + x1` + #[allow(clippy::should_implement_trait)] + pub fn add(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Add", &self, &other)); + Self::new(K::add(self.primitive, other.primitive)) + } + + /// Applies element wise addition operation with a scalar. + /// + /// `y = x + s` + pub fn add_scalar(self, other: E) -> Self { + Self::new(K::add_scalar(self.primitive, other)) + } + + /// Applies element wise subtraction operation. + /// + /// `y = x2 - x1` + #[allow(clippy::should_implement_trait)] + pub fn sub(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Sub", &self, &other)); + Self::new(K::sub(self.primitive, other.primitive)) + } + + /// Applies element wise subtraction operation with a scalar. + /// + /// `y = x - s` + pub fn sub_scalar(self, other: E) -> Self { + Self::new(K::sub_scalar(self.primitive, other)) + } + + /// Applies element wise division operation. + /// + /// `y = x2 / x1` + #[allow(clippy::should_implement_trait)] + pub fn div(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Div", &self, &other)); + Self::new(K::div(self.primitive, other.primitive)) + } + + /// Applies element wise division operation with a scalar. + /// + /// `y = x / s` + pub fn div_scalar(self, other: E) -> Self { + Self::new(K::div_scalar(self.primitive, other)) + } + /// + /// Applies element wise multiplication operation. + /// + /// `y = x2 * x1` + #[allow(clippy::should_implement_trait)] + pub fn mul(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Mul", &self, &other)); + Self::new(K::mul(self.primitive, other.primitive)) + } + + /// Applies element wise multiplication operation with a scalar. + /// + /// `y = x * s` + pub fn mul_scalar(self, other: E) -> Self { + Self::new(K::mul_scalar(self.primitive, other)) + } + + /// Switch sign of each element in the tensor. + /// + /// `y = -x` + #[allow(clippy::should_implement_trait)] + pub fn neg(self) -> Self { + Self::new(K::neg(self.primitive)) + } + + /// Create a tensor of the given shape where each element is zero. + pub fn zeros>>(shape: S) -> Self { + Self::zeros_device(shape, &B::Device::default()) + } + + /// Create a tensor of the given shape where each element is zero. + pub fn zeros_device>>(shape: S, device: &B::Device) -> Self { + Self::new(K::zeros(shape.into(), device)) + } + + /// Create a tensor of the given shape where each element is one. + pub fn ones>>(shape: S) -> Self { + Self::ones_device(shape, &B::Device::default()) + } + + /// Create a tensor of the given shape where each element is one. + pub fn ones_device>>(shape: S, device: &B::Device) -> Self { + Self::new(K::ones(shape.into(), device)) + } + + /// Create a tensor of the given shape where each element is equal to the provided value. + pub fn full>, E: ElementConversion>(shape: S, fill_value: E) -> Self { + Self::full_device(shape, fill_value, &B::Device::default()) + } + + /// Create a tensor of the given shape where each element is equal to the provided value. + pub fn full_device>, E: ElementConversion>( + shape: S, + fill_value: E, + device: &B::Device, + ) -> Self { + Self::new(K::full(shape.into(), fill_value, device)) + } + + /// Aggregate all elements in the tensor with the mean operation. + pub fn mean(self) -> Tensor { + Tensor::new(K::mean(self.primitive)) + } + + /// Aggregate all elements in the tensor with the sum operation. + pub fn sum(self) -> Tensor { + Tensor::new(K::sum(self.primitive)) + } + + /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation. + pub fn mean_dim(self, dim: usize) -> Self { + check!(TensorCheck::aggregate_dim::("Mean", dim)); + Self::new(K::mean_dim(self.primitive, dim)) + } + + /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation. + pub fn sum_dim(self, dim: usize) -> Self { + check!(TensorCheck::aggregate_dim::("Sum", dim)); + Self::new(K::sum_dim(self.primitive, dim)) + } + + /// Applies element wise equal comparison and returns a boolean tensor. + pub fn equal_elem(self, other: E) -> Tensor { + K::equal_elem::(self.primitive, other.elem()) + } + + /// Applies element wise greater comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn greater(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Greater", &self, &other)); + K::greater(self.primitive, other.primitive) + } + + /// Applies element wise greater-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn greater_equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other)); + K::greater_equal(self.primitive, other.primitive) + } + + /// Applies element wise lower comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn lower(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Lower", &self, &other)); + K::lower(self.primitive, other.primitive) + } + + /// Applies element wise lower-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn lower_equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other)); + K::lower_equal(self.primitive, other.primitive) + } + + /// Applies element wise greater comparison and returns a boolean tensor. + pub fn greater_elem(self, other: E) -> Tensor { + K::greater_elem(self.primitive, other.elem()) + } + + /// Applies element wise greater-equal comparison and returns a boolean tensor. + pub fn greater_equal_elem(self, other: E) -> Tensor { + K::greater_equal_elem(self.primitive, other.elem()) + } + + /// Applies element wise lower comparison and returns a boolean tensor. + pub fn lower_elem(self, other: E) -> Tensor { + K::lower_elem(self.primitive, other.elem()) + } + + /// Applies element wise lower-equal comparison and returns a boolean tensor. + pub fn lower_equal_elem(self, other: E) -> Tensor { + K::lower_equal_elem(self.primitive, other.elem()) + } + + /// Update the given tensor with the value tensor where the mask is true. + /// + /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of + /// a scalar. + pub fn mask_where(self, mask: Tensor, value: Self) -> Self { + Self::new(K::mask_where(self.primitive, mask, value.primitive)) + } + + /// Update the given tensor with the value where the mask is true. + /// + /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of + /// a tensor. + pub fn mask_fill(self, mask: Tensor, value: E) -> Self { + Self::new(K::mask_fill(self.primitive, mask, value.elem())) + } + + /// Gather tensor elements corresponding to the given indices from the specified dim. + /// + /// Example using a 3D tensor: + /// + /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0` + /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1` + /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2` + /// + /// # Notes + /// + /// The index tensor should have the same shape as the original tensor except for the dim + /// specified. + pub fn gather(self, dim: usize, indices: Tensor) -> Self { + check!(TensorCheck::gather::( + dim, + &self.shape(), + &indices.shape() + )); + + Self::new(K::gather(dim, self.primitive, indices)) + } + + /// Assign the gathered elements corresponding to the given indices along the specified dimension + /// from the value tensor to the original tensor using sum reduction. + /// + /// Example using a 3D tensor: + /// + /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0` + /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1` + /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2` + /// + /// # Notes + /// + /// The index tensor should have the same shape as the original tensor except for the specified + /// dimension. The value and index tensors should have the same shape. + /// + /// Other references to the input tensor will not be modified by this operation. + pub fn scatter(self, dim: usize, indices: Tensor, values: Self) -> Self { + check!(TensorCheck::scatter::( + dim, + &self.shape(), + &indices.shape(), + &values.shape() + )); + + Self::new(K::scatter(dim, self.primitive, indices, values.primitive)) + } + + /// Select the tensor elements along the given dimension corresponding to the given indices. + /// + /// Example using a 3D tensor: + /// + /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0` + /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1` + /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2` + pub fn select(self, dim: usize, indices: Tensor) -> Self { + check!(TensorCheck::select::(dim)); + Self::new(K::select(self.primitive, dim, indices)) + } + + /// Assign the selected elements along the given dimension corresponding to the given indices + /// from the value tensor to the original tensor using sum reduction. + /// + /// Example using a 3D tensor: + /// + /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0` + /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1` + /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2` + pub fn select_assign( + self, + dim: usize, + indices: Tensor, + values: Tensor, + ) -> Self { + check!(TensorCheck::select_assign::(dim)); + + Self::new(K::select_assign( + self.primitive, + dim, + indices, + values.primitive, + )) + } + + /// Applies the argmax function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmax(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } + /// ``` + pub fn argmax(self, dim: usize) -> Tensor { + Tensor::new(K::argmax(self.primitive, dim)) + } + + /// Find the maximum value. + pub fn max(self) -> Tensor { + Tensor::new(K::max(self.primitive)) + } + + /// Find the maximum value along the given dimension. + pub fn max_dim(self, dim: usize) -> Tensor { + check!(TensorCheck::aggregate_dim::("Max", dim)); + + Tensor::new(K::max_dim(self.primitive, dim)) + } + + /// Find the maximum value along the given dimension. + /// + /// Also returns the indices. + pub fn max_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { + check!(TensorCheck::aggregate_dim::("Max", dim)); + + let (tensor, index) = K::max_dim_with_indices(self.primitive, dim); + + let tensor = Tensor::new(tensor); + let index = Tensor::new(index); + + (tensor, index) + } + + /// Applies the argmin function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmin(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } + /// ``` + pub fn argmin(self, dim: usize) -> Tensor { + Tensor::new(K::argmin(self.primitive, dim)) + } + + /// Find the minimum value. + pub fn min(self) -> Tensor { + Tensor::new(K::min(self.primitive)) + } + + /// Find the minimum value along the given dimension. + pub fn min_dim(self, dim: usize) -> Tensor { + check!(TensorCheck::aggregate_dim::("Min", dim)); + Tensor::new(K::min_dim(self.primitive, dim)) + } + + /// Find the minimum value along the given dimension. + /// + /// Also returns the indices. + pub fn min_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { + check!(TensorCheck::aggregate_dim::("Min", dim)); + + let (tensor, index) = K::min_dim_with_indices(self.primitive, dim); + + let tensor = Tensor::new(tensor); + let index = Tensor::new(index); + + (tensor, index) + } + + /// Clamp the tensor between the given min and max values. + /// + /// # Arguments + /// + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped between the given min and max values. + pub fn clamp(self, min: E, max: E) -> Self { + Self::new(K::clamp(self.primitive, min.elem(), max.elem())) + } + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped under the given min value. + pub fn clamp_min(self, min: E) -> Self { + Self::new(K::clamp_min(self.primitive, min.elem())) + } + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped over the given max value. + /// + pub fn clamp_max(self, max: E) -> Self { + Self::new(K::clamp_max(self.primitive, max.elem())) + } + + /// Apply element wise absolute value operation + pub fn abs(self) -> Self { + Self::new(K::abs(self.primitive)) + } } impl Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - /// Create diagonal matrix. - /// - /// # Arguments - /// - /// * `size` - The size of the square matrix. - pub fn diagonal(size: usize) -> Self { - let indices = Tensor::::arange(0..size).unsqueeze(); - let ones = K::ones([1, size].into(), &B::Device::default()); - let zeros = K::zeros([size, size].into(), &B::Device::default()); - Self::new(K::scatter(0, zeros, indices, ones)) - } + /// Create diagonal matrix. + /// + /// # Arguments + /// + /// * `size` - The size of the square matrix. + pub fn diagonal(size: usize) -> Self { + let indices = Tensor::::arange(0..size).unsqueeze(); + let ones = K::ones([1, size].into(), &B::Device::default()); + let zeros = K::zeros([size, size].into(), &B::Device::default()); + Self::new(K::scatter(0, zeros, indices, ones)) + } } /// Trait that list all operations that can be applied on all numerical tensors. @@ -490,1647 +490,1623 @@ where /// This is an internal trait, use the public API provided by [tensor struct](Tensor). pub trait Numeric: BasicOps where - Self::Elem: Element, + Self::Elem: Element, { - /// Adds two tensors together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The sum of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function, - /// which is more high-level and designed for public use. - fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Adds a scalar to a tensor element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The sum of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function, - /// which is more high-level and designed for public use. - fn add_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Subtracts two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The difference of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function, - /// which is more high-level and designed for public use. - fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Subtracts a scalar from a tensor element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The difference of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function, - /// which is more high-level and designed for public use. - fn sub_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Divides two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The quotient of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function, - /// which is more high-level and designed for public use. - fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Divides a tensor by a scalar element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The quotient of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function, - /// which is more high-level and designed for public use. - fn div_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Multiplies two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The product of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function, - /// which is more high-level and designed for public use. - fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Multiplies a tensor by a scalar element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The product of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function, - /// which is more high-level and designed for public use. - fn mul_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Negates a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to negate. - /// - /// # Returns - /// - /// The negated tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function, - /// which is more high-level and designed for public use. - fn neg(tensor: Self::Primitive) -> Self::Primitive; - - /// Creates a tensor filled with zeros. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor filled with zeros. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function, - /// which is more high-level and designed for public use. - fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive; - - /// Creates a tensor filled with ones. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor filled with ones. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function, - /// which is more high-level and designed for public use. - fn ones(shape: Shape, device: &B::Device) -> Self::Primitive; - - /// Creates a tensor filled with elements equal to the given value. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `fill_value` - The value with which to fill the tensor - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor filled with elements equal to the given value - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function, - /// which is more high-level and designed for public use. - fn full( - shape: Shape, - fill_value: E, - device: &B::Device, - ) -> Self::Primitive; - - /// Sums all the elements of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// - /// # Returns - /// - /// The sum of all the elements of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function, - /// which is more high-level and designed for public use. - fn sum(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Sums all the elements of the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// * `dim` - The dimension along which to sum. - /// - /// # Returns - /// - /// The sum of all the elements of the tensor along the specified dimension. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function, - /// which is more high-level and designed for public use. - fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Computes the mean of all the elements of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// - /// # Returns - /// - /// The mean of all the elements of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function, - /// which is more high-level and designed for public use. - fn mean(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Computes the mean of all the elements of the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// * `dim` - The dimension along which to compute the mean. - /// - /// # Returns - /// - /// The mean of all the elements of the tensor along the specified dimension. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For computing the mean of all the elements of a tensor along a dimension, users should prefer - /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use. - fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Element-wise equality between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding elements of the input tensors are equal, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem) function, - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; - - /// Element-wise greater than comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is greater than the corresponding element - /// of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function, - /// which is more high-level and designed for public use. - fn greater( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Element-wise greater than comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is greater than the right hand side - /// scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than comparison between a tensor and a scalar, users should prefer - /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use. - fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) - -> Tensor; - - /// Element-wise greater than or equal comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is greater than or equal to the - /// corresponding element of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than or equal comparison between two tensors, users should prefer - /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use. - fn greater_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Element-wise greater than or equal comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is greater than or equal to the right - /// hand side scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer - /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use. - fn greater_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor; - - /// Element-wise less than comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is less than the corresponding element of - /// the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function, - /// which is more high-level and designed for public use. - fn lower( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Element-wise less than comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is less than the right hand side scalar, - /// and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than comparison between a tensor and a scalar, users should prefer - /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use. - fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; - - /// Element-wise less than or equal comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is less than or equal to the corresponding - /// element of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than or equal comparison between two tensors, users should prefer - /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use. - fn lower_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Element-wise less than or equal comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is less than or equal to the right hand - /// side scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer - /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use. - fn lower_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor; - - /// Selects elements from a tensor based on a boolean mask. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true. - /// * `mask` - The boolean mask to use for selecting elements. - /// * `source` - The tensor to select elements from when the corresponding element of the mask is false. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensors, where each element is taken from the - /// corresponding element of the left hand side tensor if the corresponding element of the mask - /// is true, and from the corresponding element of the right hand side tensor otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For selecting elements from a tensor based on a boolean mask, users should prefer the - /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use. - fn mask_where( - tensor: Self::Primitive, - mask: Tensor, - source: Self::Primitive, - ) -> Self::Primitive; - - /// Fills elements of a tensor based on a boolean mask. - /// - /// # Arguments - /// - /// * `tensor` - The tensor where will be overwritten with the value - /// when the corresponding element of the mask is true. - /// * `mask` - The boolean mask to use for filling elements. - /// * `value` - The value to fill elements with when the corresponding element of the mask is true. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensors, where each element is taken from the - /// corresponding element unmodified if the corresponding element of the mask is false, and - /// filled with the value otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For filling elements of a tensor based on a boolean mask, users should prefer the - /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use. - fn mask_fill( - tensor: Self::Primitive, - mask: Tensor, - value: Self::Elem, - ) -> Self::Primitive; - - /// Gathers elements from a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to gather elements. - /// * `tensor` - The tensor to gather elements from. - /// * `indices` - The indices of the elements to gather. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For gathering elements from a tensor along an axis, users should prefer the - /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use. - fn gather( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - ) -> Self::Primitive; - - /// Scatters elements into a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to scatter elements. - /// * `tensor` - The tensor to scatter elements into. - /// * `indices` - The indices of the elements to scatter. - /// * `values` - The values to scatter into the tensor. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis, - /// except for the elements at the specified indices, which are taken from the corresponding - /// element of the values tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function, - /// which is more high-level and designed for public use. - fn scatter( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive; - - /// Select tensor elements along the given dimension corresponding for the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select elements from. - /// * `dim` - The axis along which to select elements. - /// * `indices` - The indices of the elements to select. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For selecting elements from a tensor along an axis, users should prefer the - /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use. - fn select( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - ) -> Self::Primitive; - - /// Assign the selected elements along the given dimension corresponding to the given indices - /// from the value tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to assign elements to. - /// * `dim` - The axis along which to assign elements. - /// * `indices` - The indices of the elements to assign. - /// * `values` - The values to assign to the tensor. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis, - /// except for the elements at the specified indices, which are taken from the corresponding - /// element of the values tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For assigning elements to a tensor along an axis, users should prefer the - /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use. - fn select_assign( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive; - - /// Gets the indices of the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the indices of the maximum elements. - /// * `tensor` - The tensor to get the indices of the maximum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the index of the - /// maximum element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use. - fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; - - /// Gets the indices of the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the indices of the minimum elements. - /// * `tensor` - The tensor to get the indices of the minimum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the index of the - /// minimum element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use. - fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A single-element tensor containing the maximum element of the input tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use. - fn max(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements from. - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the maximum element - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use. - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements from. - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape - /// as the input tensor, where each element is the index of the maximum element of the input tensor - /// at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use. - fn max_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, B::IntTensorPrimitive); - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// - /// # Returns - /// - /// A single-element tensor containing the minimum element of the input tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use. - fn min(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// * `dim` - The axis along which to get the minimum elements. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the minimum element - /// of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use. - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Gets the minimum elements and indices of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor and corresponding indices, where - /// each element is the minimum element of the input tensor at the corresponding index - /// along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use. - fn min_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, B::IntTensorPrimitive); - - /// Clamp the tensor between the given min and max values. - /// - /// # Arguments - /// - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped between the given min and max values. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor between the given min and max values, users should prefer the - /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use. - fn clamp( - tensor: Self::Primitive, - min: Self::Elem, - max: Self::Elem, - ) -> Self::Primitive; - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped under the given min value. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor under a minimum value, users should prefer the - /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use. - fn clamp_min(tensor: Self::Primitive, min: Self::Elem) - -> Self::Primitive; - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped over the given max value. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor over a maximum value, users should prefer the - /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use. - fn clamp_max(tensor: Self::Primitive, max: Self::Elem) - -> Self::Primitive; - - /// Calculate absolute value on all elements of a tensor - /// - /// # Arguments - /// - /// * `tensor` - The tensor to apply abs to. - /// - /// # Returns - /// - /// A tensor with absolute values. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function, - /// which is more high-level and designed for public use. - fn abs(tensor: Self::Primitive) -> Self::Primitive; + /// Adds two tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The sum of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function, + /// which is more high-level and designed for public use. + fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Adds a scalar to a tensor element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The sum of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function, + /// which is more high-level and designed for public use. + fn add_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The difference of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function, + /// which is more high-level and designed for public use. + fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Subtracts a scalar from a tensor element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The difference of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function, + /// which is more high-level and designed for public use. + fn sub_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Divides two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The quotient of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function, + /// which is more high-level and designed for public use. + fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Divides a tensor by a scalar element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The quotient of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function, + /// which is more high-level and designed for public use. + fn div_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Multiplies two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The product of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function, + /// which is more high-level and designed for public use. + fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Multiplies a tensor by a scalar element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The product of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function, + /// which is more high-level and designed for public use. + fn mul_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Negates a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to negate. + /// + /// # Returns + /// + /// The negated tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function, + /// which is more high-level and designed for public use. + fn neg(tensor: Self::Primitive) -> Self::Primitive; + + /// Creates a tensor filled with zeros. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor filled with zeros. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function, + /// which is more high-level and designed for public use. + fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive; + + /// Creates a tensor filled with ones. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor filled with ones. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function, + /// which is more high-level and designed for public use. + fn ones(shape: Shape, device: &B::Device) -> Self::Primitive; + + /// Creates a tensor filled with elements equal to the given value. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `fill_value` - The value with which to fill the tensor + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor filled with elements equal to the given value + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function, + /// which is more high-level and designed for public use. + fn full( + shape: Shape, + fill_value: E, + device: &B::Device, + ) -> Self::Primitive; + + /// Sums all the elements of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// + /// # Returns + /// + /// The sum of all the elements of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function, + /// which is more high-level and designed for public use. + fn sum(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Sums all the elements of the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// The sum of all the elements of the tensor along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function, + /// which is more high-level and designed for public use. + fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Computes the mean of all the elements of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// + /// # Returns + /// + /// The mean of all the elements of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function, + /// which is more high-level and designed for public use. + fn mean(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Computes the mean of all the elements of the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// * `dim` - The dimension along which to compute the mean. + /// + /// # Returns + /// + /// The mean of all the elements of the tensor along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For computing the mean of all the elements of a tensor along a dimension, users should prefer + /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use. + fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Element-wise equality between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding elements of the input tensors are equal, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem) function, + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; + + /// Element-wise greater than comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is greater than the corresponding element + /// of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function, + /// which is more high-level and designed for public use. + fn greater( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Element-wise greater than comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is greater than the right hand side + /// scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than comparison between a tensor and a scalar, users should prefer + /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use. + fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; + + /// Element-wise greater than or equal comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is greater than or equal to the + /// corresponding element of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than or equal comparison between two tensors, users should prefer + /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use. + fn greater_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Element-wise greater than or equal comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is greater than or equal to the right + /// hand side scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer + /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use. + fn greater_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor; + + /// Element-wise less than comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is less than the corresponding element of + /// the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function, + /// which is more high-level and designed for public use. + fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor; + + /// Element-wise less than comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is less than the right hand side scalar, + /// and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than comparison between a tensor and a scalar, users should prefer + /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use. + fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; + + /// Element-wise less than or equal comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is less than or equal to the corresponding + /// element of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than or equal comparison between two tensors, users should prefer + /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use. + fn lower_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Element-wise less than or equal comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is less than or equal to the right hand + /// side scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer + /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use. + fn lower_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor; + + /// Selects elements from a tensor based on a boolean mask. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true. + /// * `mask` - The boolean mask to use for selecting elements. + /// * `source` - The tensor to select elements from when the corresponding element of the mask is false. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensors, where each element is taken from the + /// corresponding element of the left hand side tensor if the corresponding element of the mask + /// is true, and from the corresponding element of the right hand side tensor otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For selecting elements from a tensor based on a boolean mask, users should prefer the + /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use. + fn mask_where( + tensor: Self::Primitive, + mask: Tensor, + source: Self::Primitive, + ) -> Self::Primitive; + + /// Fills elements of a tensor based on a boolean mask. + /// + /// # Arguments + /// + /// * `tensor` - The tensor where will be overwritten with the value + /// when the corresponding element of the mask is true. + /// * `mask` - The boolean mask to use for filling elements. + /// * `value` - The value to fill elements with when the corresponding element of the mask is true. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensors, where each element is taken from the + /// corresponding element unmodified if the corresponding element of the mask is false, and + /// filled with the value otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For filling elements of a tensor based on a boolean mask, users should prefer the + /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use. + fn mask_fill( + tensor: Self::Primitive, + mask: Tensor, + value: Self::Elem, + ) -> Self::Primitive; + + /// Gathers elements from a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to gather elements. + /// * `tensor` - The tensor to gather elements from. + /// * `indices` - The indices of the elements to gather. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For gathering elements from a tensor along an axis, users should prefer the + /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use. + fn gather( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + ) -> Self::Primitive; + + /// Scatters elements into a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to scatter elements. + /// * `tensor` - The tensor to scatter elements into. + /// * `indices` - The indices of the elements to scatter. + /// * `values` - The values to scatter into the tensor. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis, + /// except for the elements at the specified indices, which are taken from the corresponding + /// element of the values tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function, + /// which is more high-level and designed for public use. + fn scatter( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive; + + /// Select tensor elements along the given dimension corresponding for the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select elements from. + /// * `dim` - The axis along which to select elements. + /// * `indices` - The indices of the elements to select. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For selecting elements from a tensor along an axis, users should prefer the + /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use. + fn select( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + ) -> Self::Primitive; + + /// Assign the selected elements along the given dimension corresponding to the given indices + /// from the value tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to assign elements to. + /// * `dim` - The axis along which to assign elements. + /// * `indices` - The indices of the elements to assign. + /// * `values` - The values to assign to the tensor. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis, + /// except for the elements at the specified indices, which are taken from the corresponding + /// element of the values tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For assigning elements to a tensor along an axis, users should prefer the + /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use. + fn select_assign( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive; + + /// Gets the indices of the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the indices of the maximum elements. + /// * `tensor` - The tensor to get the indices of the maximum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the index of the + /// maximum element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use. + fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; + + /// Gets the indices of the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the indices of the minimum elements. + /// * `tensor` - The tensor to get the indices of the minimum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the index of the + /// minimum element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use. + fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A single-element tensor containing the maximum element of the input tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use. + fn max(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements from. + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the maximum element + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use. + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements from. + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape + /// as the input tensor, where each element is the index of the maximum element of the input tensor + /// at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use. + fn max_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, B::IntTensorPrimitive); + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// + /// # Returns + /// + /// A single-element tensor containing the minimum element of the input tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use. + fn min(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// * `dim` - The axis along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the minimum element + /// of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use. + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Gets the minimum elements and indices of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor and corresponding indices, where + /// each element is the minimum element of the input tensor at the corresponding index + /// along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use. + fn min_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, B::IntTensorPrimitive); + + /// Clamp the tensor between the given min and max values. + /// + /// # Arguments + /// + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped between the given min and max values. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor between the given min and max values, users should prefer the + /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use. + fn clamp( + tensor: Self::Primitive, + min: Self::Elem, + max: Self::Elem, + ) -> Self::Primitive; + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped under the given min value. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor under a minimum value, users should prefer the + /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use. + fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive; + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped over the given max value. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor over a maximum value, users should prefer the + /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use. + fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive; + + /// Calculate absolute value on all elements of a tensor + /// + /// # Arguments + /// + /// * `tensor` - The tensor to apply abs to. + /// + /// # Returns + /// + /// A tensor with absolute values. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function, + /// which is more high-level and designed for public use. + fn abs(tensor: Self::Primitive) -> Self::Primitive; } impl Numeric for Int { - fn add( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_add(lhs, rhs) - } - fn add_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_add_scalar(lhs, rhs.elem()) - } - fn sub( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_sub(lhs, rhs) - } - fn sub_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_sub_scalar(lhs, rhs.elem()) - } - fn div( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_div(lhs, rhs) - } - fn div_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_div_scalar(lhs, rhs.elem()) - } - fn mul( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_mul(lhs, rhs) - } - fn mul_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_mul_scalar(lhs, rhs.elem()) - } - fn neg(tensor: Self::Primitive) -> Self::Primitive { - B::int_neg(tensor) - } - fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { - B::int_zeros(shape, device) - } - fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { - B::int_ones(shape, device) - } - fn full( - shape: Shape, - fill_value: E, - device: &B::Device, - ) -> Self::Primitive { - B::int_full(shape, fill_value.elem(), device) - } - fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_sum(tensor) - } - fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_sum_dim(tensor, dim) - } - fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_mean(tensor) - } - fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_mean_dim(tensor, dim) - } - - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::int_equal_elem(lhs, rhs)) - } - fn greater( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_greater(lhs, rhs)) - } - - fn greater_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::int_greater_elem(lhs, rhs)) - } - - fn greater_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_greater_equal(lhs, rhs)) - } - - fn greater_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::int_greater_equal_elem(lhs, rhs)) - } - - fn lower( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_lower(lhs, rhs)) - } - - fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::int_lower_elem(lhs, rhs)) - } - - fn lower_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_lower_equal(lhs, rhs)) - } - - fn lower_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::int_lower_equal_elem(lhs, rhs)) - } - - fn mask_where( - tensor: Self::Primitive, - mask: Tensor, - source: Self::Primitive, - ) -> Self::Primitive { - B::int_mask_where(tensor, mask.primitive, source) - } - - fn mask_fill( - tensor: Self::Primitive, - mask: Tensor, - value: Self::Elem, - ) -> Self::Primitive { - B::int_mask_fill(tensor, mask.primitive, value) - } - - fn select( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - ) -> Self::Primitive { - B::int_select(tensor, dim, indices.primitive) - } - - fn select_assign( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::int_select_assign(tensor, dim, indices.primitive, values) - } - fn gather( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - ) -> Self::Primitive { - B::int_gather(dim, tensor, indices.primitive) - } - - fn scatter( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::int_scatter(dim, tensor, indices.primitive, values) - } - - fn argmax( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::int_argmax(tensor, dim) - } - - fn argmin( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::int_argmin(tensor, dim) - } - - fn max(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_max(tensor) - } - - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_max_dim(tensor, dim) - } - - fn max_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::int_max_dim_with_indices(tensor, dim) - } - - fn min(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_min(tensor) - } - - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_min_dim(tensor, dim) - } - - fn min_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::int_min_dim_with_indices(tensor, dim) - } - - fn clamp( - tensor: Self::Primitive, - min: B::IntElem, - max: B::IntElem, - ) -> Self::Primitive { - B::int_clamp(tensor, min, max) - } - - fn clamp_min( - tensor: Self::Primitive, - min: B::IntElem, - ) -> Self::Primitive { - B::int_clamp_min(tensor, min) - } - - fn clamp_max( - tensor: Self::Primitive, - max: B::IntElem, - ) -> Self::Primitive { - B::int_clamp_max(tensor, max) - } - - fn abs(tensor: Self::Primitive) -> Self::Primitive { - B::int_abs(tensor) - } + fn add( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_add(lhs, rhs) + } + fn add_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_add_scalar(lhs, rhs.elem()) + } + fn sub( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_sub(lhs, rhs) + } + fn sub_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_sub_scalar(lhs, rhs.elem()) + } + fn div( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_div(lhs, rhs) + } + fn div_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_div_scalar(lhs, rhs.elem()) + } + fn mul( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_mul(lhs, rhs) + } + fn mul_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_mul_scalar(lhs, rhs.elem()) + } + fn neg(tensor: Self::Primitive) -> Self::Primitive { + B::int_neg(tensor) + } + fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { + B::int_zeros(shape, device) + } + fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { + B::int_ones(shape, device) + } + fn full( + shape: Shape, + fill_value: E, + device: &B::Device, + ) -> Self::Primitive { + B::int_full(shape, fill_value.elem(), device) + } + fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_sum(tensor) + } + fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_sum_dim(tensor, dim) + } + fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_mean(tensor) + } + fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_mean_dim(tensor, dim) + } + + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::int_equal_elem(lhs, rhs)) + } + fn greater( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_greater(lhs, rhs)) + } + + fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::int_greater_elem(lhs, rhs)) + } + + fn greater_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_greater_equal(lhs, rhs)) + } + + fn greater_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::int_greater_equal_elem(lhs, rhs)) + } + + fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { + Tensor::new(B::int_lower(lhs, rhs)) + } + + fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::int_lower_elem(lhs, rhs)) + } + + fn lower_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_lower_equal(lhs, rhs)) + } + + fn lower_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::int_lower_equal_elem(lhs, rhs)) + } + + fn mask_where( + tensor: Self::Primitive, + mask: Tensor, + source: Self::Primitive, + ) -> Self::Primitive { + B::int_mask_where(tensor, mask.primitive, source) + } + + fn mask_fill( + tensor: Self::Primitive, + mask: Tensor, + value: Self::Elem, + ) -> Self::Primitive { + B::int_mask_fill(tensor, mask.primitive, value) + } + + fn select( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + ) -> Self::Primitive { + B::int_select(tensor, dim, indices.primitive) + } + + fn select_assign( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::int_select_assign(tensor, dim, indices.primitive, values) + } + fn gather( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + ) -> Self::Primitive { + B::int_gather(dim, tensor, indices.primitive) + } + + fn scatter( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::int_scatter(dim, tensor, indices.primitive, values) + } + + fn argmax( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::int_argmax(tensor, dim) + } + + fn argmin( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::int_argmin(tensor, dim) + } + + fn max(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_max(tensor) + } + + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_max_dim(tensor, dim) + } + + fn max_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::int_max_dim_with_indices(tensor, dim) + } + + fn min(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_min(tensor) + } + + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_min_dim(tensor, dim) + } + + fn min_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::int_min_dim_with_indices(tensor, dim) + } + + fn clamp( + tensor: Self::Primitive, + min: B::IntElem, + max: B::IntElem, + ) -> Self::Primitive { + B::int_clamp(tensor, min, max) + } + + fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive { + B::int_clamp_min(tensor, min) + } + + fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive { + B::int_clamp_max(tensor, max) + } + + fn abs(tensor: Self::Primitive) -> Self::Primitive { + B::int_abs(tensor) + } } impl Numeric for Float { - fn add( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::add(lhs, rhs) - } - fn add_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::add_scalar(lhs, rhs.elem()) - } - fn sub( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::sub(lhs, rhs) - } - fn sub_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sub_scalar(lhs, rhs.elem()) - } - fn div( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::div(lhs, rhs) - } - fn div_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::div_scalar(lhs, rhs.elem()) - } - fn mul( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::mul(lhs, rhs) - } - fn mul_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::mul_scalar(lhs, rhs.elem()) - } - fn neg(tensor: Self::Primitive) -> Self::Primitive { - B::neg(tensor) - } - fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { - B::zeros(shape, device) - } - fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { - B::ones(shape, device) - } - fn full( - shape: Shape, - fill_value: E, - device: &B::Device, - ) -> Self::Primitive { - B::full(shape, fill_value.elem(), device) - } - fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { - B::sum(tensor) - } - fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::sum_dim(tensor, dim) - } - fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { - B::mean(tensor) - } - fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::mean_dim(tensor, dim) - } - - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::equal_elem(lhs, rhs)) - } - fn greater( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::greater(lhs, rhs)) - } - - fn greater_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::greater_elem(lhs, rhs)) - } - - fn greater_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::greater_equal(lhs, rhs)) - } - - fn greater_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::greater_equal_elem(lhs, rhs)) - } - - fn lower( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::lower(lhs, rhs)) - } - - fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::lower_elem(lhs, rhs)) - } - - fn lower_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::lower_equal(lhs, rhs)) - } - - fn lower_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::lower_equal_elem(lhs, rhs)) - } - - fn mask_where( - tensor: Self::Primitive, - mask: Tensor, - source: Self::Primitive, - ) -> Self::Primitive { - B::mask_where(tensor, mask.primitive, source) - } - - fn mask_fill( - tensor: Self::Primitive, - mask: Tensor, - value: Self::Elem, - ) -> Self::Primitive { - B::mask_fill(tensor, mask.primitive, value) - } - - fn select( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - ) -> Self::Primitive { - B::select(tensor, dim, indices.primitive) - } - - fn select_assign( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::select_assign(tensor, dim, indices.primitive, values) - } - - fn gather( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - ) -> Self::Primitive { - B::gather(dim, tensor, indices.primitive) - } - - fn scatter( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::scatter(dim, tensor, indices.primitive, values) - } - - fn argmax( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::argmax(tensor, dim) - } - - fn argmin( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::argmin(tensor, dim) - } - - fn max(tensor: Self::Primitive) -> Self::Primitive<1> { - B::max(tensor) - } - - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::max_dim(tensor, dim) - } - - fn max_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::max_dim_with_indices(tensor, dim) - } - - fn min(tensor: Self::Primitive) -> Self::Primitive<1> { - B::min(tensor) - } - - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::min_dim(tensor, dim) - } - - fn min_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::min_dim_with_indices(tensor, dim) - } - - fn clamp( - tensor: Self::Primitive, - min: B::FloatElem, - max: B::FloatElem, - ) -> Self::Primitive { - B::clamp(tensor, min, max) - } - - fn clamp_min( - tensor: Self::Primitive, - min: B::FloatElem, - ) -> Self::Primitive { - B::clamp_min(tensor, min) - } - - fn clamp_max( - tensor: Self::Primitive, - max: B::FloatElem, - ) -> Self::Primitive { - B::clamp_max(tensor, max) - } - - fn abs(tensor: Self::Primitive) -> Self::Primitive { - B::abs(tensor) - } + fn add( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::add(lhs, rhs) + } + fn add_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::add_scalar(lhs, rhs.elem()) + } + fn sub( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::sub(lhs, rhs) + } + fn sub_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::sub_scalar(lhs, rhs.elem()) + } + fn div( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::div(lhs, rhs) + } + fn div_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::div_scalar(lhs, rhs.elem()) + } + fn mul( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::mul(lhs, rhs) + } + fn mul_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::mul_scalar(lhs, rhs.elem()) + } + fn neg(tensor: Self::Primitive) -> Self::Primitive { + B::neg(tensor) + } + fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { + B::zeros(shape, device) + } + fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { + B::ones(shape, device) + } + fn full( + shape: Shape, + fill_value: E, + device: &B::Device, + ) -> Self::Primitive { + B::full(shape, fill_value.elem(), device) + } + fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { + B::sum(tensor) + } + fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::sum_dim(tensor, dim) + } + fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { + B::mean(tensor) + } + fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::mean_dim(tensor, dim) + } + + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::equal_elem(lhs, rhs)) + } + fn greater( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::greater(lhs, rhs)) + } + + fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::greater_elem(lhs, rhs)) + } + + fn greater_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::greater_equal(lhs, rhs)) + } + + fn greater_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::greater_equal_elem(lhs, rhs)) + } + + fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { + Tensor::new(B::lower(lhs, rhs)) + } + + fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::lower_elem(lhs, rhs)) + } + + fn lower_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::lower_equal(lhs, rhs)) + } + + fn lower_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::lower_equal_elem(lhs, rhs)) + } + + fn mask_where( + tensor: Self::Primitive, + mask: Tensor, + source: Self::Primitive, + ) -> Self::Primitive { + B::mask_where(tensor, mask.primitive, source) + } + + fn mask_fill( + tensor: Self::Primitive, + mask: Tensor, + value: Self::Elem, + ) -> Self::Primitive { + B::mask_fill(tensor, mask.primitive, value) + } + + fn select( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + ) -> Self::Primitive { + B::select(tensor, dim, indices.primitive) + } + + fn select_assign( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::select_assign(tensor, dim, indices.primitive, values) + } + + fn gather( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + ) -> Self::Primitive { + B::gather(dim, tensor, indices.primitive) + } + + fn scatter( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::scatter(dim, tensor, indices.primitive, values) + } + + fn argmax( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::argmax(tensor, dim) + } + + fn argmin( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::argmin(tensor, dim) + } + + fn max(tensor: Self::Primitive) -> Self::Primitive<1> { + B::max(tensor) + } + + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::max_dim(tensor, dim) + } + + fn max_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::max_dim_with_indices(tensor, dim) + } + + fn min(tensor: Self::Primitive) -> Self::Primitive<1> { + B::min(tensor) + } + + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::min_dim(tensor, dim) + } + + fn min_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::min_dim_with_indices(tensor, dim) + } + + fn clamp( + tensor: Self::Primitive, + min: B::FloatElem, + max: B::FloatElem, + ) -> Self::Primitive { + B::clamp(tensor, min, max) + } + + fn clamp_min( + tensor: Self::Primitive, + min: B::FloatElem, + ) -> Self::Primitive { + B::clamp_min(tensor, min) + } + + fn clamp_max( + tensor: Self::Primitive, + max: B::FloatElem, + ) -> Self::Primitive { + B::clamp_max(tensor, max) + } + + fn abs(tensor: Self::Primitive) -> Self::Primitive { + B::abs(tensor) + } } impl core::ops::Add for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn add(self, rhs: Tensor) -> Self { - Self::add(self, rhs) - } + fn add(self, rhs: Tensor) -> Self { + Self::add(self, rhs) + } } impl core::ops::Add for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn add(self, other: E) -> Self { - Tensor::add_scalar(self, other) - } + fn add(self, other: E) -> Self { + Tensor::add_scalar(self, other) + } } impl core::ops::Sub> for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn sub(self, rhs: Tensor) -> Self { - Tensor::sub(self, rhs) - } + fn sub(self, rhs: Tensor) -> Self { + Tensor::sub(self, rhs) + } } impl core::ops::Sub for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn sub(self, other: E) -> Self { - Tensor::sub_scalar(self, other) - } + fn sub(self, other: E) -> Self { + Tensor::sub_scalar(self, other) + } } impl core::ops::Div> for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn div(self, rhs: Tensor) -> Self { - Tensor::div(self, rhs) - } + fn div(self, rhs: Tensor) -> Self { + Tensor::div(self, rhs) + } } impl core::ops::Div for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn div(self, other: E) -> Self { - Tensor::div_scalar(self, other) - } + fn div(self, other: E) -> Self { + Tensor::div_scalar(self, other) + } } impl core::ops::Mul> for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn mul(self, rhs: Tensor) -> Self { - Tensor::mul(self, rhs) - } + fn mul(self, rhs: Tensor) -> Self { + Tensor::mul(self, rhs) + } } impl core::ops::Mul for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn mul(self, other: E) -> Self { - Tensor::mul_scalar(self, other) - } + fn mul(self, other: E) -> Self { + Tensor::mul_scalar(self, other) + } } impl core::ops::Neg for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn neg(self) -> Self { - Tensor::neg(self) - } + fn neg(self) -> Self { + Tensor::neg(self) + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 131855cce8..7e6141ea63 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -50,145 +50,144 @@ use crate::tensor::Element; /// Most of the documentation for each function can be found on the user API [tensor struct](crate::Tensor). /// For modules, public functions are often created, which can be used by `burn-core` modules. pub trait Backend: - TensorOps - + BoolTensorOps - + IntTensorOps - + ModuleOps - + ActivationOps - + Clone - + Sized - + Default - + Send - + Sync - + core::fmt::Debug - + 'static + TensorOps + + BoolTensorOps + + IntTensorOps + + ModuleOps + + ActivationOps + + Clone + + Sized + + Default + + Send + + Sync + + core::fmt::Debug + + 'static { - /// Device type. - type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; + /// Device type. + type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; - /// Pointer to another backend that have a full precision float element type - type FullPrecisionBackend: Backend; - /// Full precision float element type. - type FullPrecisionElem: Element; + /// Pointer to another backend that have a full precision float element type + type FullPrecisionBackend: Backend; + /// Full precision float element type. + type FullPrecisionElem: Element; - /// Tensor primitive to be used for all float operations. - type TensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; - /// Float element type. - type FloatElem: Element; + /// Tensor primitive to be used for all float operations. + type TensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + /// Float element type. + type FloatElem: Element; - /// Tensor primitive to be used for all int operations. - type IntTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; - /// Int element type. - type IntElem: Element; + /// Tensor primitive to be used for all int operations. + type IntTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + /// Int element type. + type IntElem: Element; - /// Tensor primitive to be used for all bool operations. - type BoolTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + /// Tensor primitive to be used for all bool operations. + type BoolTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; - /// If autodiff is enabled. - fn ad_enabled() -> bool { - false - } + /// If autodiff is enabled. + fn ad_enabled() -> bool { + false + } - /// Name of the backend. - fn name() -> String; + /// Name of the backend. + fn name() -> String; - /// Seed the backend. - fn seed(seed: u64); + /// Seed the backend. + fn seed(seed: u64); - /// Sync the backend, ensure that all computation are finished. - fn sync(_device: &Self::Device) {} + /// Sync the backend, ensure that all computation are finished. + fn sync(_device: &Self::Device) {} } /// Trait that allows a backend to support autodiff. pub trait AutodiffBackend: Backend { - /// The inner backend type. - type InnerBackend: Backend< - Device = Self::Device, - FloatElem = Self::FloatElem, - IntElem = Self::IntElem, - FullPrecisionElem = Self::FullPrecisionElem, - >; - - /// Gradients type. - type Gradients: Send + Sync; - - /// Backward pass. - /// - /// # Arguments - /// - /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed. - /// - /// # Returns - /// - /// The gradients. - fn backward(tensor: FloatTensor) -> Self::Gradients; - - /// Returns the gradients of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to extract the gradients from. - /// - /// # Returns - /// - /// An optional tensor containing the gradient. - fn grad( - tensor: &FloatTensor, - grads: &Self::Gradients, - ) -> Option>; - - /// Pops the gradients of a tensor and returns them. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to pop the gradients from. - /// * `grads` - The gradients. - /// - /// # Returns - /// - /// An optional tensor containing the given gradients. - fn grad_remove( - tensor: &FloatTensor, - grads: &mut Self::Gradients, - ) -> Option>; - - /// Replace the gradients of a tensor with the one provided. - /// - /// If no gradient existed for the provided tensor, register it. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to pop the gradients from. - /// * `grads` - The gradients. - /// * `grad` - The updated grad tensor. - fn grad_replace( - tensor: &FloatTensor, - grads: &mut Self::Gradients, - grad: FloatTensor, - ); - - /// Returns the tensor with inner backend type. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the inner backend tensor for. - /// - /// # Returns - /// - /// The inner backend tensor. - fn inner(tensor: FloatTensor) -> FloatTensor; - - /// Converts the inner backend tensor to the autodiff backend tensor. - /// - /// # Arguments - /// - /// * `tensor` - The inner backend tensor to convert. - /// - /// - /// # Returns - /// - /// The autodiff backend tensor. - fn from_inner( - tensor: FloatTensor, - ) -> FloatTensor; + /// The inner backend type. + type InnerBackend: Backend< + Device = Self::Device, + FloatElem = Self::FloatElem, + IntElem = Self::IntElem, + FullPrecisionElem = Self::FullPrecisionElem, + >; + + /// Gradients type. + type Gradients: Send + Sync; + + /// Backward pass. + /// + /// # Arguments + /// + /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed. + /// + /// # Returns + /// + /// The gradients. + fn backward(tensor: FloatTensor) -> Self::Gradients; + + /// Returns the gradients of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to extract the gradients from. + /// + /// # Returns + /// + /// An optional tensor containing the gradient. + fn grad( + tensor: &FloatTensor, + grads: &Self::Gradients, + ) -> Option>; + + /// Pops the gradients of a tensor and returns them. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to pop the gradients from. + /// * `grads` - The gradients. + /// + /// # Returns + /// + /// An optional tensor containing the given gradients. + fn grad_remove( + tensor: &FloatTensor, + grads: &mut Self::Gradients, + ) -> Option>; + + /// Replace the gradients of a tensor with the one provided. + /// + /// If no gradient existed for the provided tensor, register it. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to pop the gradients from. + /// * `grads` - The gradients. + /// * `grad` - The updated grad tensor. + fn grad_replace( + tensor: &FloatTensor, + grads: &mut Self::Gradients, + grad: FloatTensor, + ); + + /// Returns the tensor with inner backend type. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the inner backend tensor for. + /// + /// # Returns + /// + /// The inner backend tensor. + fn inner(tensor: FloatTensor) -> FloatTensor; + + /// Converts the inner backend tensor to the autodiff backend tensor. + /// + /// # Arguments + /// + /// * `tensor` - The inner backend tensor to convert. + /// + /// + /// # Returns + /// + /// The autodiff backend tensor. + fn from_inner(tensor: FloatTensor) + -> FloatTensor; } diff --git a/burn-tensor/src/tensor/container.rs b/burn-tensor/src/tensor/container.rs index 76bbde6151..7432d4ee70 100644 --- a/burn-tensor/src/tensor/container.rs +++ b/burn-tensor/src/tensor/container.rs @@ -12,79 +12,80 @@ use crate::{backend::Backend, Tensor}; /// Contains tensor of arbitrary dimension. #[derive(Debug)] pub struct TensorContainer { - tensors: HashMap>, + tensors: HashMap>, } impl Default for TensorContainer where - ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, + ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } type TensorPrimitive = ::TensorPrimitive; impl TensorContainer where - ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, + ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, { - /// Create an empty container. - pub fn new() -> Self { - Self { - tensors: HashMap::new(), - } + /// Create an empty container. + pub fn new() -> Self { + Self { + tensors: HashMap::new(), } + } - /// Get a tensor with the given ID. - pub fn get(&self, id: &ID) -> Option> - where - B: Backend, - { - let grad = match self.tensors.get(id) { - Some(grad) => grad, - None => return None, - }; + /// Get a tensor with the given ID. + pub fn get(&self, id: &ID) -> Option> + where + B: Backend, + { + let grad = match self.tensors.get(id) { + Some(grad) => grad, + None => return None, + }; - let tensor = grad - .downcast_ref::>() - .map(|primitive| Tensor::::from_primitive(primitive.clone())) - .unwrap(); + let tensor = grad + .downcast_ref::>() + .map(|primitive| Tensor::::from_primitive(primitive.clone())) + .unwrap(); - Some(tensor) - } + Some(tensor) + } - /// Register a new tensor for the given ID. - /// - /// # Notes - /// - /// If a tensor is already registered for the given ID, it will be replaced. - pub fn register(&mut self, id: ID, value: Tensor) - where - B: Backend, - { - self.tensors.insert(id, Box::new(value.into_primitive())); - } + /// Register a new tensor for the given ID. + /// + /// # Notes + /// + /// If a tensor is already registered for the given ID, it will be replaced. + pub fn register(&mut self, id: ID, value: Tensor) + where + B: Backend, + { + self.tensors.insert(id, Box::new(value.into_primitive())); + } - /// Remove a tensor for the given ID and returns it. - pub fn remove(&mut self, id: &ID) -> Option> - where - B: Backend, - { - self.tensors - .remove(id) - .map(|item| item.downcast::>().unwrap()) - .map(|primitive| Tensor::from_primitive(*primitive)) - } + /// Remove a tensor for the given ID and returns it. + pub fn remove(&mut self, id: &ID) -> Option> + where + B: Backend, + { + self + .tensors + .remove(id) + .map(|item| item.downcast::>().unwrap()) + .map(|primitive| Tensor::from_primitive(*primitive)) + } - /// The number of tensors registered. - pub fn len(&self) -> usize { - self.tensors.len() - } + /// The number of tensors registered. + pub fn len(&self) -> usize { + self.tensors.len() + } - /// If any tensor is contained. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } + /// If any tensor is contained. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } diff --git a/burn-tensor/src/tensor/data.rs b/burn-tensor/src/tensor/data.rs index dfbae95c3d..5cbf215199 100644 --- a/burn-tensor/src/tensor/data.rs +++ b/burn-tensor/src/tensor/data.rs @@ -9,523 +9,514 @@ use rand::{distributions::Standard, Rng, RngCore}; /// Data structure for serializing and deserializing tensor data. #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)] pub struct DataSerialize { - /// The values of the tensor. - pub value: Vec, - /// The shape of the tensor. - pub shape: Vec, + /// The values of the tensor. + pub value: Vec, + /// The shape of the tensor. + pub shape: Vec, } /// Data structure for tensors. #[derive(new, Debug, Clone, PartialEq, Eq)] pub struct Data { - /// The values of the tensor. - pub value: Vec, + /// The values of the tensor. + pub value: Vec, - /// The shape of the tensor. - pub shape: Shape, + /// The shape of the tensor. + pub shape: Shape, } /// Distribution for random value of a tensor. #[derive(Debug, Clone, Copy)] pub enum Distribution { - /// Uniform distribution from 0 (inclusive) to 1 (exclusive). - Default, + /// Uniform distribution from 0 (inclusive) to 1 (exclusive). + Default, - /// Bernoulli distribution with the given probability. - Bernoulli(f64), + /// Bernoulli distribution with the given probability. + Bernoulli(f64), - /// Uniform distribution. The range is inclusive. - Uniform(E, E), + /// Uniform distribution. The range is inclusive. + Uniform(E, E), - /// Normal distribution with the given mean and standard deviation. - Normal(f64, f64), + /// Normal distribution with the given mean and standard deviation. + Normal(f64, f64), } /// Distribution sampler for random value of a tensor. #[derive(new)] pub struct DistributionSampler<'a, E, R> where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, - R: RngCore, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, + R: RngCore, { - kind: DistributionSamplerKind, - rng: &'a mut R, + kind: DistributionSamplerKind, + rng: &'a mut R, } /// Distribution sampler kind for random value of a tensor. pub enum DistributionSamplerKind where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, { - /// Standard distribution. - Standard(rand::distributions::Standard), + /// Standard distribution. + Standard(rand::distributions::Standard), - /// Uniform distribution. - Uniform(rand::distributions::Uniform), + /// Uniform distribution. + Uniform(rand::distributions::Uniform), - /// Bernoulli distribution. - Bernoulli(rand::distributions::Bernoulli), + /// Bernoulli distribution. + Bernoulli(rand::distributions::Bernoulli), - /// Normal distribution. - Normal(rand_distr::Normal), + /// Normal distribution. + Normal(rand_distr::Normal), } impl<'a, E, R> DistributionSampler<'a, E, R> where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, - E: Element, - R: RngCore, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, + E: Element, + R: RngCore, { - /// Sames a random value from the distribution. - pub fn sample(&mut self) -> E { - match &self.kind { - DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution), - DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution), - DistributionSamplerKind::Bernoulli(distribution) => { - if self.rng.sample(distribution) { - 1.elem() - } else { - 0.elem() - } - } - DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(), + /// Sames a random value from the distribution. + pub fn sample(&mut self) -> E { + match &self.kind { + DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution), + DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution), + DistributionSamplerKind::Bernoulli(distribution) => { + if self.rng.sample(distribution) { + 1.elem() + } else { + 0.elem() } + } + DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(), } + } } impl Distribution where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, { - /// Creates a new distribution sampler. - /// - /// # Arguments - /// - /// * `rng` - The random number generator. - /// - /// # Returns - /// - /// The distribution sampler. - pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> { - let kind = match self { - Distribution::Default => { - DistributionSamplerKind::Standard(rand::distributions::Standard {}) - } - Distribution::Uniform(low, high) => { - DistributionSamplerKind::Uniform(rand::distributions::Uniform::new(low, high)) - } - Distribution::Bernoulli(prob) => DistributionSamplerKind::Bernoulli( - rand::distributions::Bernoulli::new(prob).unwrap(), - ), - Distribution::Normal(mean, std) => { - DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap()) - } - }; - - DistributionSampler::new(kind, rng) - } + /// Creates a new distribution sampler. + /// + /// # Arguments + /// + /// * `rng` - The random number generator. + /// + /// # Returns + /// + /// The distribution sampler. + pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> { + let kind = match self { + Distribution::Default => DistributionSamplerKind::Standard(rand::distributions::Standard {}), + Distribution::Uniform(low, high) => { + DistributionSamplerKind::Uniform(rand::distributions::Uniform::new(low, high)) + } + Distribution::Bernoulli(prob) => { + DistributionSamplerKind::Bernoulli(rand::distributions::Bernoulli::new(prob).unwrap()) + } + Distribution::Normal(mean, std) => { + DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap()) + } + }; + + DistributionSampler::new(kind, rng) + } } impl Distribution where - E: Element, + E: Element, { - /// Converts the distribution to a different element type. - /// - /// # Returns - /// - /// The converted distribution. - pub fn convert(self) -> Distribution { - match self { - Distribution::Default => Distribution::Default, - Distribution::Uniform(a, b) => { - Distribution::Uniform(EOther::from_elem(a), EOther::from_elem(b)) - } - Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob), - Distribution::Normal(mean, std) => Distribution::Normal(mean, std), - } + /// Converts the distribution to a different element type. + /// + /// # Returns + /// + /// The converted distribution. + pub fn convert(self) -> Distribution { + match self { + Distribution::Default => Distribution::Default, + Distribution::Uniform(a, b) => { + Distribution::Uniform(EOther::from_elem(a), EOther::from_elem(b)) + } + Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob), + Distribution::Normal(mean, std) => Distribution::Normal(mean, std), } + } } impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); + /// Converts the data to a different element type. + pub fn convert(self) -> Data { + let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - Data { - value, - shape: self.shape, - } + Data { + value, + shape: self.shape, } - - /// Asserts each value is within a given range. - /// - /// # Arguments - /// - /// * `range` - The range. - /// - /// # Panics - /// - /// If any value is not within the half-open range bounded inclusively below - /// and exclusively above (`start..end`). - pub fn assert_within_range(&self, range: core::ops::Range) { - let start = range.start.elem::(); - let end = range.end.elem::(); - - for elem in self.value.iter() { - let elem = elem.elem::(); - if elem < start || elem >= end { - panic!("Element ({elem:?}) is not within range {range:?}"); - } - } + } + + /// Asserts each value is within a given range. + /// + /// # Arguments + /// + /// * `range` - The range. + /// + /// # Panics + /// + /// If any value is not within the half-open range bounded inclusively below + /// and exclusively above (`start..end`). + pub fn assert_within_range(&self, range: core::ops::Range) { + let start = range.start.elem::(); + let end = range.end.elem::(); + + for elem in self.value.iter() { + let elem = elem.elem::(); + if elem < start || elem >= end { + panic!("Element ({elem:?}) is not within range {range:?}"); + } } + } } impl DataSerialize { - /// Converts the data to a different element type. - pub fn convert(self) -> DataSerialize { - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); + /// Converts the data to a different element type. + pub fn convert(self) -> DataSerialize { + let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - DataSerialize { - value, - shape: self.shape, - } + DataSerialize { + value, + shape: self.shape, } + } } impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| (a as i64).elem()).collect(); + /// Converts the data to a different element type. + pub fn convert(self) -> Data { + let value: Vec = self.value.into_iter().map(|a| (a as i64).elem()).collect(); - Data { - value, - shape: self.shape, - } + Data { + value, + shape: self.shape, } + } } impl Data { - /// Populates the data with random values. - pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(E::random(distribution, rng)); - } + /// Populates the data with random values. + pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); - Data::new(data, shape) + for _ in 0..num_elements { + data.push(E::random(distribution, rng)); } + + Data::new(data, shape) + } } impl Data where - E: Element, + E: Element, { - /// Populates the data with zeros. - pub fn zeros>>(shape: S) -> Data { - let shape = shape.into(); - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(0.elem()); - } - - Data::new(data, shape) + /// Populates the data with zeros. + pub fn zeros>>(shape: S) -> Data { + let shape = shape.into(); + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); + + for _ in 0..num_elements { + data.push(0.elem()); } + + Data::new(data, shape) + } } impl Data where - E: Element, + E: Element, { - /// Populates the data with ones. - pub fn ones(shape: Shape) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(1.elem()); - } + /// Populates the data with ones. + pub fn ones(shape: Shape) -> Data { + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); - Data::new(data, shape) + for _ in 0..num_elements { + data.push(1.elem()); } + + Data::new(data, shape) + } } impl Data where - E: Element, + E: Element, { - /// Populates the data with the given value - pub fn full(shape: Shape, fill_value: E) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - for _ in 0..num_elements { - data.push(fill_value) - } - - Data::new(data, shape) + /// Populates the data with the given value + pub fn full(shape: Shape, fill_value: E) -> Data { + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); + for _ in 0..num_elements { + data.push(fill_value) } + + Data::new(data, shape) + } } impl Data { - /// Serializes the data. - /// - /// # Returns - /// - /// The serialized data. - pub fn serialize(&self) -> DataSerialize { - DataSerialize { - value: self.value.clone(), - shape: self.shape.dims.to_vec(), - } + /// Serializes the data. + /// + /// # Returns + /// + /// The serialized data. + pub fn serialize(&self) -> DataSerialize { + DataSerialize { + value: self.value.clone(), + shape: self.shape.dims.to_vec(), } + } } impl + Clone + core::fmt::Debug + PartialEq, const D: usize> Data { - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `precision` - The precision of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq(&self, other: &Self, precision: usize) { - let tolerance = libm::pow(0.1, precision as f64); - - self.assert_approx_eq_diff(other, tolerance) + /// Asserts the data is approximately equal to another data. + /// + /// # Arguments + /// + /// * `other` - The other data. + /// * `precision` - The precision of the comparison. + /// + /// # Panics + /// + /// Panics if the data is not approximately equal. + #[track_caller] + pub fn assert_approx_eq(&self, other: &Self, precision: usize) { + let tolerance = libm::pow(0.1, precision as f64); + + self.assert_approx_eq_diff(other, tolerance) + } + + /// Asserts the data is approximately equal to another data. + /// + /// # Arguments + /// + /// * `other` - The other data. + /// * `tolerance` - The tolerance of the comparison. + /// + /// # Panics + /// + /// Panics if the data is not approximately equal. + #[track_caller] + pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) { + let mut message = String::new(); + if self.shape != other.shape { + message += format!( + "\n => Shape is different: {:?} != {:?}", + self.shape.dims, other.shape.dims + ) + .as_str(); } - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `tolerance` - The tolerance of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) { - let mut message = String::new(); - if self.shape != other.shape { - message += format!( - "\n => Shape is different: {:?} != {:?}", - self.shape.dims, other.shape.dims - ) - .as_str(); - } - - let iter = self.value.clone().into_iter().zip(other.value.clone()); + let iter = self.value.clone().into_iter().zip(other.value.clone()); - let mut num_diff = 0; - let max_num_diff = 5; + let mut num_diff = 0; + let max_num_diff = 5; - for (i, (a, b)) in iter.enumerate() { - let a: f64 = a.into(); - let b: f64 = b.into(); + for (i, (a, b)) in iter.enumerate() { + let a: f64 = a.into(); + let b: f64 = b.into(); - let err = libm::sqrt(libm::pow(a - b, 2.0)); + let err = libm::sqrt(libm::pow(a - b, 2.0)); - if err > tolerance { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}" - ) - .as_str(); - } - num_diff += 1; - } + if err > tolerance { + // Only print the first 5 different values. + if num_diff < max_num_diff { + message += + format!("\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}") + .as_str(); } + num_diff += 1; + } + } - if num_diff >= max_num_diff { - message += format!("\n{} more errors...", num_diff - 5).as_str(); - } + if num_diff >= max_num_diff { + message += format!("\n{} more errors...", num_diff - 5).as_str(); + } - if !message.is_empty() { - panic!("Tensors are not approx eq:{}", message); - } + if !message.is_empty() { + panic!("Tensors are not approx eq:{}", message); } + } } impl Data { - /// Converts the usize data to a different element type. - pub fn from_usize(self) -> Data { - let value: Vec = self - .value - .into_iter() - .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap()) - .collect(); - - Data { - value, - shape: self.shape, - } + /// Converts the usize data to a different element type. + pub fn from_usize(self) -> Data { + let value: Vec = self + .value + .into_iter() + .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap()) + .collect(); + + Data { + value, + shape: self.shape, } + } } impl From<&DataSerialize> for Data { - fn from(data: &DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value.clone(), Shape::new(dims)) - } + fn from(data: &DataSerialize) -> Self { + let mut dims = [0; D]; + dims[..D].copy_from_slice(&data.shape[..D]); + Data::new(data.value.clone(), Shape::new(dims)) + } } impl From> for Data { - fn from(data: DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value, Shape::new(dims)) - } + fn from(data: DataSerialize) -> Self { + let mut dims = [0; D]; + dims[..D].copy_from_slice(&data.shape[..D]); + Data::new(data.value, Shape::new(dims)) + } } impl From<[E; A]> for Data { - fn from(elems: [E; A]) -> Self { - let mut data = Vec::with_capacity(2 * A); - for elem in elems.into_iter() { - data.push(elem); - } - - Data::new(data, Shape::new([A])) + fn from(elems: [E; A]) -> Self { + let mut data = Vec::with_capacity(2 * A); + for elem in elems.into_iter() { + data.push(elem); } + + Data::new(data, Shape::new([A])) + } } impl From<&[E]> for Data { - fn from(elems: &[E]) -> Self { - let mut data = Vec::with_capacity(elems.len()); - for elem in elems.iter() { - data.push(*elem); - } - - Data::new(data, Shape::new([elems.len()])) + fn from(elems: &[E]) -> Self { + let mut data = Vec::with_capacity(elems.len()); + for elem in elems.iter() { + data.push(*elem); } + + Data::new(data, Shape::new([elems.len()])) + } } impl From<[[E; B]; A]> for Data { - fn from(elems: [[E; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B); - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - data.push(elem); - } - } - - Data::new(data, Shape::new([A, B])) + fn from(elems: [[E; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B); + for elem in elems.into_iter().take(A) { + for elem in elem.into_iter().take(B) { + data.push(elem); + } } + + Data::new(data, Shape::new([A, B])) + } } impl - From<[[[E; C]; B]; A]> for Data + From<[[[E; C]; B]; A]> for Data { - fn from(elems: [[[E; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - data.push(elem); - } - } - } + fn from(elems: [[[E; C]; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B * C); - Data::new(data, Shape::new([A, B, C])) + for elem in elems.into_iter().take(A) { + for elem in elem.into_iter().take(B) { + for elem in elem.into_iter().take(C) { + data.push(elem); + } + } } + + Data::new(data, Shape::new([A, B, C])) + } } -impl< - E: core::fmt::Debug + Copy, - const A: usize, - const B: usize, - const C: usize, - const D: usize, - > From<[[[[E; D]; C]; B]; A]> for Data +impl + From<[[[[E; D]; C]; B]; A]> for Data { - fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C * D); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - for elem in elem.into_iter().take(D) { - data.push(elem); - } - } - } - } + fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B * C * D); - Data::new(data, Shape::new([A, B, C, D])) + for elem in elems.into_iter().take(A) { + for elem in elem.into_iter().take(B) { + for elem in elem.into_iter().take(C) { + for elem in elem.into_iter().take(D) { + data.push(elem); + } + } + } } + + Data::new(data, Shape::new([A, B, C, D])) + } } impl core::fmt::Display for Data { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("{:?}", &self.value).as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(format!("{:?}", &self.value).as_str()) + } } #[cfg(test)] mod tests { - use super::*; - use rand::{rngs::StdRng, SeedableRng}; - - #[test] - fn should_have_right_num_elements() { - let shape = Shape::new([3, 5, 6]); - let num_elements = shape.num_elements(); - let data = - Data::::random(shape, Distribution::Default, &mut StdRng::from_entropy()); - - assert_eq!(num_elements, data.value.len()); - } - - #[test] - fn should_have_right_shape() { - let data = Data::from([[3.0, 5.0, 6.0]]); - assert_eq!(data.shape, Shape::new([1, 3])); - - let data = Data::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]); - assert_eq!(data.shape, Shape::new([2, 3])); - - let data = Data::from([3.0, 5.0, 6.0]); - assert_eq!(data.shape, Shape::new([3])); - } - - #[test] - fn should_assert_appox_eq_limit() { - let data1 = Data::::from([[3.0, 5.0, 6.0]]); - let data2 = Data::::from([[3.01, 5.0, 6.0]]); - - data1.assert_approx_eq(&data2, 2); - } - - #[test] - #[should_panic] - fn should_assert_appox_eq_above_limit() { - let data1 = Data::::from([[3.0, 5.0, 6.0]]); - let data2 = Data::::from([[3.011, 5.0, 6.0]]); - - data1.assert_approx_eq(&data2, 2); - } - - #[test] - #[should_panic] - fn should_assert_appox_eq_check_shape() { - let data1 = Data::::from([[3.0, 5.0, 6.0, 7.0]]); - let data2 = Data::::from([[3.0, 5.0, 6.0]]); - - data1.assert_approx_eq(&data2, 2); - } + use super::*; + use rand::{rngs::StdRng, SeedableRng}; + + #[test] + fn should_have_right_num_elements() { + let shape = Shape::new([3, 5, 6]); + let num_elements = shape.num_elements(); + let data = Data::::random(shape, Distribution::Default, &mut StdRng::from_entropy()); + + assert_eq!(num_elements, data.value.len()); + } + + #[test] + fn should_have_right_shape() { + let data = Data::from([[3.0, 5.0, 6.0]]); + assert_eq!(data.shape, Shape::new([1, 3])); + + let data = Data::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]); + assert_eq!(data.shape, Shape::new([2, 3])); + + let data = Data::from([3.0, 5.0, 6.0]); + assert_eq!(data.shape, Shape::new([3])); + } + + #[test] + fn should_assert_appox_eq_limit() { + let data1 = Data::::from([[3.0, 5.0, 6.0]]); + let data2 = Data::::from([[3.01, 5.0, 6.0]]); + + data1.assert_approx_eq(&data2, 2); + } + + #[test] + #[should_panic] + fn should_assert_appox_eq_above_limit() { + let data1 = Data::::from([[3.0, 5.0, 6.0]]); + let data2 = Data::::from([[3.011, 5.0, 6.0]]); + + data1.assert_approx_eq(&data2, 2); + } + + #[test] + #[should_panic] + fn should_assert_appox_eq_check_shape() { + let data1 = Data::::from([[3.0, 5.0, 6.0, 7.0]]); + let data2 = Data::::from([[3.0, 5.0, 6.0]]); + + data1.assert_approx_eq(&data2, 2); + } } diff --git a/burn-tensor/src/tensor/element.rs b/burn-tensor/src/tensor/element.rs index 2108a033b0..5bcd40dcc5 100644 --- a/burn-tensor/src/tensor/element.rs +++ b/burn-tensor/src/tensor/element.rs @@ -5,110 +5,110 @@ use rand::RngCore; /// Element trait for tensor. pub trait Element: - ToPrimitive - + ElementRandom - + ElementConversion - + ElementPrecision - + core::fmt::Debug - + core::fmt::Display - + Default - + Send - + Sync - + Copy - + 'static + ToPrimitive + + ElementRandom + + ElementConversion + + ElementPrecision + + core::fmt::Debug + + core::fmt::Display + + Default + + Send + + Sync + + Copy + + 'static { } /// Element conversion trait for tensor. pub trait ElementConversion { - /// Converts an element to another element. - /// - /// # Arguments - /// - /// * `elem` - The element to convert. - /// - /// # Returns - /// - /// The converted element. - fn from_elem(elem: E) -> Self; - - /// Converts and returns the converted element. - fn elem(self) -> E; + /// Converts an element to another element. + /// + /// # Arguments + /// + /// * `elem` - The element to convert. + /// + /// # Returns + /// + /// The converted element. + fn from_elem(elem: E) -> Self; + + /// Converts and returns the converted element. + fn elem(self) -> E; } /// Element trait for random value of a tensor. pub trait ElementRandom { - /// Returns a random value for the given distribution. - /// - /// # Arguments - /// - /// * `distribution` - The distribution to sample from. - /// * `rng` - The random number generator. - /// - /// # Returns - /// - /// The random value. - fn random(distribution: Distribution, rng: &mut R) -> Self - where - Self: Sized; + /// Returns a random value for the given distribution. + /// + /// # Arguments + /// + /// * `distribution` - The distribution to sample from. + /// * `rng` - The random number generator. + /// + /// # Returns + /// + /// The random value. + fn random(distribution: Distribution, rng: &mut R) -> Self + where + Self: Sized; } /// Element precision trait for tensor. #[derive(Clone, PartialEq, Eq, Copy, Debug)] pub enum Precision { - /// Double precision, e.g. f64. - Double, + /// Double precision, e.g. f64. + Double, - /// Full precision, e.g. f32. - Full, + /// Full precision, e.g. f32. + Full, - /// Half precision, e.g. f16. - Half, + /// Half precision, e.g. f16. + Half, - /// Other precision. - Other, + /// Other precision. + Other, } /// Element precision trait for tensor. pub trait ElementPrecision { - /// Returns the precision of the element. - fn precision() -> Precision; + /// Returns the precision of the element. + fn precision() -> Precision; } /// Macro to implement the element trait for a type. #[macro_export] macro_rules! make_element { - ( + ( ty $type:ident $precision:expr, convert $convert:expr, random $random:expr ) => { - impl Element for $type {} - - impl ElementConversion for $type { - fn from_elem(elem: E) -> Self { - #[allow(clippy::redundant_closure_call)] - $convert(&elem) - } - fn elem(self) -> E { - E::from_elem(self) - } - } - - impl ElementPrecision for $type { - fn precision() -> Precision { - $precision - } - } - - impl ElementRandom for $type { - fn random(distribution: Distribution, rng: &mut R) -> Self { - #[allow(clippy::redundant_closure_call)] - $random(distribution, rng) - } - } - }; + impl Element for $type {} + + impl ElementConversion for $type { + fn from_elem(elem: E) -> Self { + #[allow(clippy::redundant_closure_call)] + $convert(&elem) + } + fn elem(self) -> E { + E::from_elem(self) + } + } + + impl ElementPrecision for $type { + fn precision() -> Precision { + $precision + } + } + + impl ElementRandom for $type { + fn random(distribution: Distribution, rng: &mut R) -> Self { + #[allow(clippy::redundant_closure_call)] + $random(distribution, rng) + } + } + }; } make_element!( diff --git a/burn-tensor/src/tensor/loss/mod.rs b/burn-tensor/src/tensor/loss/mod.rs index 339c4071f3..535427a601 100644 --- a/burn-tensor/src/tensor/loss/mod.rs +++ b/burn-tensor/src/tensor/loss/mod.rs @@ -12,12 +12,12 @@ use crate::{activation, Tensor}; /// /// The log softmax cross entropy. pub fn cross_entropy_with_logits( - logits: Tensor, - target_probs: Tensor, + logits: Tensor, + target_probs: Tensor, ) -> Tensor { - let tensor = activation::log_softmax(logits, D - 1); - let tensor = tensor.mul(target_probs); - let tensor = tensor.sum_dim(D - 1); + let tensor = activation::log_softmax(logits, D - 1); + let tensor = tensor.mul(target_probs); + let tensor = tensor.sum_dim(D - 1); - tensor.mean().neg() + tensor.mean().neg() } diff --git a/burn-tensor/src/tensor/module.rs b/burn-tensor/src/tensor/module.rs index 4ad956db56..be8ec8c906 100644 --- a/burn-tensor/src/tensor/module.rs +++ b/burn-tensor/src/tensor/module.rs @@ -1,221 +1,221 @@ use crate::{ - backend::Backend, - ops::{ConvOptions, ConvTransposeOptions, UnfoldOptions}, - Int, Tensor, + backend::Backend, + ops::{ConvOptions, ConvTransposeOptions, UnfoldOptions}, + Int, Tensor, }; /// Applies the [embedding module](crate::ops::ModuleOps::embedding). pub fn embedding(weights: Tensor, indices: Tensor) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::embedding(weights.primitive, indices.primitive)) + Tensor::new(B::embedding(weights.primitive, indices.primitive)) } /// Applies a [1D convolution](crate::ops::ModuleOps::conv2d). pub fn conv1d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvOptions<1>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvOptions<1>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv1d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv1d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [2D convolution](crate::ops::ModuleOps::conv2d). pub fn conv2d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvOptions<2>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvOptions<2>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv2d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv2d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d). pub fn conv_transpose1d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvTransposeOptions<1>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvTransposeOptions<1>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv_transpose1d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv_transpose1d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d). pub fn conv_transpose2d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvTransposeOptions<2>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvTransposeOptions<2>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv_transpose2d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv_transpose2d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d). pub fn unfold4d(x: Tensor, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::unfold4d(x.primitive, kernel_size, options)) + Tensor::new(B::unfold4d(x.primitive, kernel_size, options)) } /// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). pub fn max_pool1d( - x: Tensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::max_pool1d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )) + Tensor::new(B::max_pool1d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )) } /// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d). pub fn max_pool2d( - x: Tensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::max_pool2d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )) + Tensor::new(B::max_pool2d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )) } /// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d). pub fn avg_pool2d( - x: Tensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::avg_pool2d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )) + Tensor::new(B::avg_pool2d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )) } /// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d). pub fn avg_pool1d( - x: Tensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::avg_pool1d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )) + Tensor::new(B::avg_pool1d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )) } /// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). pub fn max_pool1d_with_indices( - x: Tensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> (Tensor, Tensor) where - B: Backend, + B: Backend, { - let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - (Tensor::new(output.output), Tensor::new(output.indices)) + (Tensor::new(output.output), Tensor::new(output.indices)) } /// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices). pub fn max_pool2d_with_indices( - x: Tensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (Tensor, Tensor) where - B: Backend, + B: Backend, { - let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - (Tensor::new(output.output), Tensor::new(output.indices)) + (Tensor::new(output.output), Tensor::new(output.indices)) } /// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d). pub fn adaptive_avg_pool2d(x: Tensor, output_size: [usize; 2]) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::adaptive_avg_pool2d(x.primitive, output_size)) + Tensor::new(B::adaptive_avg_pool2d(x.primitive, output_size)) } /// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d). pub fn adaptive_avg_pool1d(x: Tensor, output_size: usize) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::adaptive_avg_pool1d(x.primitive, output_size)) + Tensor::new(B::adaptive_avg_pool1d(x.primitive, output_size)) } diff --git a/burn-tensor/src/tensor/named/base.rs b/burn-tensor/src/tensor/named/base.rs index ab1f5aa9d8..a958050556 100644 --- a/burn-tensor/src/tensor/named/base.rs +++ b/burn-tensor/src/tensor/named/base.rs @@ -6,76 +6,76 @@ use crate::{Distribution, NamedDims, Shape, Tensor}; /// A tensor with named dimensions. #[derive(Debug, Clone)] pub struct NamedTensor> { - pub(crate) tensor: D::Tensor, + pub(crate) tensor: D::Tensor, } impl>, const D: usize> From> - for Tensor + for Tensor { - fn from(nt: NamedTensor) -> Self { - nt.tensor - } + fn from(nt: NamedTensor) -> Self { + nt.tensor + } } impl>, const D: usize> From> - for NamedTensor + for NamedTensor { - fn from(tensor: Tensor) -> Self { - Self::from_tensor(tensor) - } + fn from(tensor: Tensor) -> Self { + Self::from_tensor(tensor) + } } impl> core::fmt::Display for NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(&format!( - "NamedTensor[shape={:?}, dims={}]", - self.shape().dims, - ND::to_string(), - )) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(&format!( + "NamedTensor[shape={:?}, dims={}]", + self.shape().dims, + ND::to_string(), + )) + } } impl NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - /// Create a named tensor from a tensor. - pub fn from_tensor(tensor: Tensor) -> Self { - Self { tensor } - } + /// Create a named tensor from a tensor. + pub fn from_tensor(tensor: Tensor) -> Self { + Self { tensor } + } - /// Create a random named tensor of the given shape where each element is sampled from - /// the given distribution. - pub fn random>>(shape: S, distribution: Distribution) -> Self { - Self::from_tensor(Tensor::random(shape, distribution)) - } + /// Create a random named tensor of the given shape where each element is sampled from + /// the given distribution. + pub fn random>>(shape: S, distribution: Distribution) -> Self { + Self::from_tensor(Tensor::random(shape, distribution)) + } - /// Returns the shape of the current tensor. - pub fn shape(&self) -> Shape { - self.tensor.shape() - } + /// Returns the shape of the current tensor. + pub fn shape(&self) -> Shape { + self.tensor.shape() + } - /// Applies element wise multiplication operation. - /// - /// `y = x2 * x1` - #[allow(clippy::should_implement_trait)] - pub fn mul(self, rhs: Self) -> Self { - Self::from_tensor(self.tensor.mul(rhs.tensor)) - } + /// Applies element wise multiplication operation. + /// + /// `y = x2 * x1` + #[allow(clippy::should_implement_trait)] + pub fn mul(self, rhs: Self) -> Self { + Self::from_tensor(self.tensor.mul(rhs.tensor)) + } - /// Reshape the tensor to have the given shape. - /// - /// # Panics - /// - /// If the tensor can not be reshape to the given shape. - pub fn reshape(self, shape: S, _: ND2) -> NamedTensor - where - S: Into>, - ND2: NamedDims>, - { - NamedTensor::from_tensor(self.tensor.reshape(shape.into())) - } + /// Reshape the tensor to have the given shape. + /// + /// # Panics + /// + /// If the tensor can not be reshape to the given shape. + pub fn reshape(self, shape: S, _: ND2) -> NamedTensor + where + S: Into>, + ND2: NamedDims>, + { + NamedTensor::from_tensor(self.tensor.reshape(shape.into())) + } } diff --git a/burn-tensor/src/tensor/named/dims.rs b/burn-tensor/src/tensor/named/dims.rs index 6f5631db77..5b4c37908d 100644 --- a/burn-tensor/src/tensor/named/dims.rs +++ b/burn-tensor/src/tensor/named/dims.rs @@ -6,90 +6,90 @@ use crate::Tensor; /// Dimension trait. pub trait Dim: core::fmt::Debug { - /// Converts the dimension to a string. - fn to_string() -> String; + /// Converts the dimension to a string. + fn to_string() -> String; } /// Named dimensions trait. pub trait NamedDims: core::fmt::Debug { - /// Tensor type. - type Tensor; + /// Tensor type. + type Tensor; - /// Converts the named dimensions to a string. - fn to_string() -> String; + /// Converts the named dimensions to a string. + fn to_string() -> String; } /// Named dimension macro. #[macro_export] macro_rules! NamedDim { - ($name:ident) => { - #[derive(Debug, Clone)] - pub struct $name; - impl Dim for $name { - fn to_string() -> String { - stringify!($name).to_string() - } - } - }; + ($name:ident) => { + #[derive(Debug, Clone)] + pub struct $name; + impl Dim for $name { + fn to_string() -> String { + stringify!($name).to_string() + } + } + }; } impl NamedDims for (D1,) where - B: Backend, - D1: Dim, + B: Backend, + D1: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!("[{}]", D1::to_string()) - } + type Tensor = Tensor; + fn to_string() -> String { + format!("[{}]", D1::to_string()) + } } impl NamedDims for (D1, D2) where - B: Backend, - D1: Dim, - D2: Dim, + B: Backend, + D1: Dim, + D2: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!("[{}, {}]", D1::to_string(), D2::to_string()) - } + type Tensor = Tensor; + fn to_string() -> String { + format!("[{}, {}]", D1::to_string(), D2::to_string()) + } } impl NamedDims for (D1, D2, D3) where - B: Backend, - D1: Dim, - D2: Dim, - D3: Dim, + B: Backend, + D1: Dim, + D2: Dim, + D3: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!( - "[{}, {}, {}]", - D1::to_string(), - D2::to_string(), - D3::to_string() - ) - } + type Tensor = Tensor; + fn to_string() -> String { + format!( + "[{}, {}, {}]", + D1::to_string(), + D2::to_string(), + D3::to_string() + ) + } } impl NamedDims for (D1, D2, D3, D4) where - B: Backend, - D1: Dim, - D2: Dim, - D3: Dim, - D4: Dim, + B: Backend, + D1: Dim, + D2: Dim, + D3: Dim, + D4: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!( - "[{}, {}, {}, {}]", - D1::to_string(), - D2::to_string(), - D3::to_string(), - D4::to_string() - ) - } + type Tensor = Tensor; + fn to_string() -> String { + format!( + "[{}, {}, {}, {}]", + D1::to_string(), + D2::to_string(), + D3::to_string(), + D4::to_string() + ) + } } diff --git a/burn-tensor/src/tensor/named/matmul.rs b/burn-tensor/src/tensor/named/matmul.rs index ef8e9849d0..0a7df3534e 100644 --- a/burn-tensor/src/tensor/named/matmul.rs +++ b/burn-tensor/src/tensor/named/matmul.rs @@ -2,58 +2,58 @@ use crate::backend::Backend; use crate::{Dim, NamedDims, NamedTensor, Tensor}; pub trait Matmul { - fn matmul(self, rhs: Rhs) -> Out; + fn matmul(self, rhs: Rhs) -> Out; } impl NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - /// Applies the matrix multiplication operation. - /// - /// `C = AB` - /// - /// # Panics - /// - /// If the two tensors dont' have a compatible shape. - pub fn matmul( - self, - rhs: NamedTensor, - ) -> NamedTensor - where - NamedDimsRhs: NamedDims>, - NamedDimsOut: NamedDims>, - Self: Matmul, NamedTensor>, - { - Matmul::matmul(self, rhs) - } + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors dont' have a compatible shape. + pub fn matmul( + self, + rhs: NamedTensor, + ) -> NamedTensor + where + NamedDimsRhs: NamedDims>, + NamedDimsOut: NamedDims>, + Self: Matmul, NamedTensor>, + { + Matmul::matmul(self, rhs) + } } impl Matmul, NamedTensor> - for NamedTensor + for NamedTensor { - fn matmul(self, rhs: NamedTensor) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) - } + fn matmul(self, rhs: NamedTensor) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) + } } impl - Matmul, NamedTensor> - for NamedTensor + Matmul, NamedTensor> + for NamedTensor { - fn matmul(self, rhs: NamedTensor) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) - } + fn matmul(self, rhs: NamedTensor) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) + } } impl - Matmul, NamedTensor> - for NamedTensor + Matmul, NamedTensor> + for NamedTensor { - fn matmul( - self, - rhs: NamedTensor, - ) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) - } + fn matmul( + self, + rhs: NamedTensor, + ) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) + } } diff --git a/burn-tensor/src/tensor/named/swap_dims.rs b/burn-tensor/src/tensor/named/swap_dims.rs index 51d0b3b11b..1a2f4f18f5 100644 --- a/burn-tensor/src/tensor/named/swap_dims.rs +++ b/burn-tensor/src/tensor/named/swap_dims.rs @@ -2,53 +2,53 @@ use crate::backend::Backend; use crate::{Dim, NamedDims, NamedTensor, Tensor}; pub trait SwapDims { - fn swap_dims(self) -> N; + fn swap_dims(self) -> N; } impl NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - /// Swap two dimensions. - pub fn swap_dims(self) -> NamedTensor - where - ND2: NamedDims>, - Self: SwapDims, D1, D2>, - { - SwapDims::swap_dims(self) - } + /// Swap two dimensions. + pub fn swap_dims(self) -> NamedTensor + where + ND2: NamedDims>, + Self: SwapDims, D1, D2>, + { + SwapDims::swap_dims(self) + } } macro_rules! generate_permut { - (2 => $output:ty, ($dim1:expr, $dim2:expr)) => { - impl SwapDims, $dim1, $dim2> - for NamedTensor - { - fn swap_dims(self) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) - } - } - }; + (2 => $output:ty, ($dim1:expr, $dim2:expr)) => { + impl SwapDims, $dim1, $dim2> + for NamedTensor + { + fn swap_dims(self) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) + } + } + }; - (3 => $output:ty, ($dim1:expr, $dim2:expr)) => { - impl SwapDims, $dim1, $dim2> - for NamedTensor - { - fn swap_dims(self) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) - } - } - }; + (3 => $output:ty, ($dim1:expr, $dim2:expr)) => { + impl SwapDims, $dim1, $dim2> + for NamedTensor + { + fn swap_dims(self) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) + } + } + }; - (4 => $output:ty, ($dim1:expr, $dim2:expr)) => { - impl - SwapDims, $dim1, $dim2> for NamedTensor - { - fn swap_dims(self) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) - } - } - }; + (4 => $output:ty, ($dim1:expr, $dim2:expr)) => { + impl + SwapDims, $dim1, $dim2> for NamedTensor + { + fn swap_dims(self) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) + } + } + }; } generate_permut!(2 => (D2, D1), (0, 1)); diff --git a/burn-tensor/src/tensor/ops/activation.rs b/burn-tensor/src/tensor/ops/activation.rs index b0aef546d5..a2f53f7a5a 100644 --- a/burn-tensor/src/tensor/ops/activation.rs +++ b/burn-tensor/src/tensor/ops/activation.rs @@ -7,99 +7,99 @@ use super::FloatTensor; /// /// This trait let backend implementations override activation functions for better performance. pub trait ActivationOps { - /// Applies the ReLU activation function. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The output tensor. - fn relu(tensor: FloatTensor) -> FloatTensor { - let mask = B::lower_equal_elem(tensor.clone(), 0.elem()); - - B::mask_fill(tensor, mask, 0.elem()) - } - - /// Applies the ReLU activation function backward. - /// - /// # Arguments - /// - /// * `output` - The output tensor. - /// - /// # Returns - /// - /// The gradient. - fn relu_backward( - output: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - let mask = B::lower_equal_elem(output, 0.elem()); - - B::mask_fill(grad, mask, 0.elem()) - } - - /// Applies the Gelu activation function. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The output tensor. - fn gelu(tensor: FloatTensor) -> FloatTensor { - let x = B::div_scalar(tensor.clone(), SQRT_2.elem()); - let x = B::erf(x); - let x = B::add_scalar(x, 1i32.elem()); - let x = B::mul(tensor, x); - - B::div_scalar(x, 2i32.elem()) - } - - /// Applies the Gelu activation function backward. - /// - /// # Arguments - /// - /// * `x` - The tensor. - /// * `grad` - The gradient. - /// - /// # Returns - /// - /// The output tensor. - fn gelu_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - // Derivative of the approximate gelu implementation based on tanh. - - let constant_1 = 0.0356774; - let constant_2 = 0.797885; - let constant_3 = 0.0535161; - let constant_4 = 0.398942; - - let x3 = B::powf(x.clone(), 3.0); - - let c1 = B::mul_scalar(x3.clone(), constant_1.elem()); - let c2 = B::mul_scalar(x.clone(), constant_2.elem()); - let c3 = B::mul_scalar(x3, constant_3.elem()); - let c4 = B::mul_scalar(x, constant_4.elem()); - - let inner1 = B::add(c1, c2); - let inner2 = B::add(c3, c4); - - let tanh = B::tanh(inner1); - - let sech = B::powf(tanh.clone(), 2.0); - let sech = B::neg(sech); - let sech = B::add_scalar(sech, 1.elem()); - - let y1 = B::mul_scalar(tanh, 0.5.elem()); - let y2 = B::mul(inner2, sech); - let y2 = B::add_scalar(y2, 0.5.elem()); - let y = B::add(y1, y2); - - B::mul(y, grad) - } + /// Applies the ReLU activation function. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The output tensor. + fn relu(tensor: FloatTensor) -> FloatTensor { + let mask = B::lower_equal_elem(tensor.clone(), 0.elem()); + + B::mask_fill(tensor, mask, 0.elem()) + } + + /// Applies the ReLU activation function backward. + /// + /// # Arguments + /// + /// * `output` - The output tensor. + /// + /// # Returns + /// + /// The gradient. + fn relu_backward( + output: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + let mask = B::lower_equal_elem(output, 0.elem()); + + B::mask_fill(grad, mask, 0.elem()) + } + + /// Applies the Gelu activation function. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The output tensor. + fn gelu(tensor: FloatTensor) -> FloatTensor { + let x = B::div_scalar(tensor.clone(), SQRT_2.elem()); + let x = B::erf(x); + let x = B::add_scalar(x, 1i32.elem()); + let x = B::mul(tensor, x); + + B::div_scalar(x, 2i32.elem()) + } + + /// Applies the Gelu activation function backward. + /// + /// # Arguments + /// + /// * `x` - The tensor. + /// * `grad` - The gradient. + /// + /// # Returns + /// + /// The output tensor. + fn gelu_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + // Derivative of the approximate gelu implementation based on tanh. + + let constant_1 = 0.0356774; + let constant_2 = 0.797885; + let constant_3 = 0.0535161; + let constant_4 = 0.398942; + + let x3 = B::powf(x.clone(), 3.0); + + let c1 = B::mul_scalar(x3.clone(), constant_1.elem()); + let c2 = B::mul_scalar(x.clone(), constant_2.elem()); + let c3 = B::mul_scalar(x3, constant_3.elem()); + let c4 = B::mul_scalar(x, constant_4.elem()); + + let inner1 = B::add(c1, c2); + let inner2 = B::add(c3, c4); + + let tanh = B::tanh(inner1); + + let sech = B::powf(tanh.clone(), 2.0); + let sech = B::neg(sech); + let sech = B::add_scalar(sech, 1.elem()); + + let y1 = B::mul_scalar(tanh, 0.5.elem()); + let y2 = B::mul(inner2, sech); + let y2 = B::add_scalar(y2, 0.5.elem()); + let y = B::add(y1, y2); + + B::mul(y, grad) + } } diff --git a/burn-tensor/src/tensor/ops/bool_tensor.rs b/burn-tensor/src/tensor/ops/bool_tensor.rs index cf478f972a..fcf2a55d6a 100644 --- a/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -7,255 +7,254 @@ use core::ops::Range; /// Bool Tensor API for basic operations, see [tensor](crate::Tensor) /// for documentation on each function. pub trait BoolTensorOps { - /// Creates a new bool tensor. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The boolean tensor with the given shape. - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor; + /// Creates a new bool tensor. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The boolean tensor with the given shape. + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor; - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn bool_shape(tensor: &BoolTensor) -> Shape; + /// Returns the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn bool_shape(tensor: &BoolTensor) -> Shape; - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn bool_into_data(tensor: BoolTensor) -> Reader>; + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn bool_into_data(tensor: BoolTensor) -> Reader>; - /// Gets the data from the tensor. - /// - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// - /// # Returns - /// - /// The data cloned from the data structure. - fn bool_to_data(tensor: &BoolTensor) -> Reader> { - Self::bool_into_data(tensor.clone()) - } - - /// Creates a tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the data. - fn bool_from_data(data: Data, device: &Device) -> BoolTensor; + /// Gets the data from the tensor. + /// + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// + /// # Returns + /// + /// The data cloned from the data structure. + fn bool_to_data(tensor: &BoolTensor) -> Reader> { + Self::bool_into_data(tensor.clone()) + } - /// Converts bool tensor to int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The int tensor with the same data as the bool tensor. - fn bool_into_int(tensor: BoolTensor) -> IntTensor; + /// Creates a tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the data. + fn bool_from_data(data: Data, device: &Device) -> BoolTensor; - /// Converts bool tensor to float tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The float tensor with the same data as the bool tensor. - fn bool_into_float(tensor: BoolTensor) -> FloatTensor; + /// Converts bool tensor to int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The int tensor with the same data as the bool tensor. + fn bool_into_int(tensor: BoolTensor) -> IntTensor; - /// Gets the device of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device of the tensor. - fn bool_device(tensor: &BoolTensor) -> Device; + /// Converts bool tensor to float tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The float tensor with the same data as the bool tensor. + fn bool_into_float(tensor: BoolTensor) -> FloatTensor; - /// Moves the tensor to the device. - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor; + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn bool_device(tensor: &BoolTensor) -> Device; - /// Reshapes the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `shape` - The new shape. - /// - /// # Returns - /// - /// The tensor with the new shape. - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor; + /// Moves the tensor to the device. + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor; - /// Gets the values from the tensor for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges to get the values from. - /// - /// # Returns - /// - /// The tensor with the values for the given ranges. - fn bool_slice( - tensor: BoolTensor, - ranges: [Range; D2], - ) -> BoolTensor; + /// Reshapes the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `shape` - The new shape. + /// + /// # Returns + /// + /// The tensor with the new shape. + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor; - /// Sets the values in the tensor for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges to set the values for. - /// * `value` - The values to set. - /// - /// # Returns - /// - /// The tensor with the values set for the given ranges. - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [Range; D2], - value: BoolTensor, - ) -> BoolTensor; + /// Gets the values from the tensor for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges to get the values from. + /// + /// # Returns + /// + /// The tensor with the values for the given ranges. + fn bool_slice( + tensor: BoolTensor, + ranges: [Range; D2], + ) -> BoolTensor; - /// Repeats one dimension of the tensor a given number of times along that dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to repeat. - /// * `times` - The number of times to repeat the dimension. - /// - /// # Returns - /// - /// The tensor with the dimension repeated. - fn bool_repeat( - tensor: BoolTensor, - dim: usize, - times: usize, - ) -> BoolTensor { - let mut shape = Self::bool_shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } - shape.dims[dim] = times; + /// Sets the values in the tensor for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges to set the values for. + /// * `value` - The values to set. + /// + /// # Returns + /// + /// The tensor with the values set for the given ranges. + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [Range; D2], + value: BoolTensor, + ) -> BoolTensor; - let mut i = 0; - let ranges_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); + /// Repeats one dimension of the tensor a given number of times along that dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to repeat. + /// * `times` - The number of times to repeat the dimension. + /// + /// # Returns + /// + /// The tensor with the dimension repeated. + fn bool_repeat( + tensor: BoolTensor, + dim: usize, + times: usize, + ) -> BoolTensor { + let mut shape = Self::bool_shape(&tensor); + if shape.dims[dim] != 1 { + panic!("Can only repeat dimension with dim=1"); + } + shape.dims[dim] = times; - let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor)); - for i in 0..times { - let mut ranges = ranges_select_all.clone(); - ranges[dim] = i..i + 1; - tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone()); - } + let mut i = 0; + let ranges_select_all = [0; D].map(|_| { + let start = 0; + let end = shape.dims[i]; + i += 1; + start..end + }); - tensor_output + let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor)); + for i in 0..times { + let mut ranges = ranges_select_all.clone(); + ranges[dim] = i..i + 1; + tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone()); } - /// Concatenates the tensors along the given dimension. - /// - /// # Arguments - /// - /// * `tensors` - The tensors to concatenate. - /// * `dim` - The dimension to concatenate along. - /// - /// # Returns - /// - /// The tensor with the tensors concatenated along the given dimension. - fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor; + tensor_output + } - /// Equates the two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The tensor with the result of the equate. - fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) - -> BoolTensor; + /// Concatenates the tensors along the given dimension. + /// + /// # Arguments + /// + /// * `tensors` - The tensors to concatenate. + /// * `dim` - The dimension to concatenate along. + /// + /// # Returns + /// + /// The tensor with the tensors concatenated along the given dimension. + fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor; - /// Inverses boolean values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The tensor with the result of the negation. - fn bool_not(tensor: BoolTensor) -> BoolTensor; + /// Equates the two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The tensor with the result of the equate. + fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor; - /// Transposes a bool tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn bool_transpose(tensor: BoolTensor) -> BoolTensor { - Self::bool_swap_dims(tensor, D - 2, D - 1) - } + /// Inverses boolean values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The tensor with the result of the negation. + fn bool_not(tensor: BoolTensor) -> BoolTensor; + + /// Transposes a bool tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn bool_transpose(tensor: BoolTensor) -> BoolTensor { + Self::bool_swap_dims(tensor, D - 2, D - 1) + } - /// Swaps two dimensions of a bool tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn bool_swap_dims( - tensor: BoolTensor, - dim1: usize, - dim2: usize, - ) -> BoolTensor; + /// Swaps two dimensions of a bool tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn bool_swap_dims( + tensor: BoolTensor, + dim1: usize, + dim2: usize, + ) -> BoolTensor; } diff --git a/burn-tensor/src/tensor/ops/int_tensor.rs b/burn-tensor/src/tensor/ops/int_tensor.rs index bf0b324c00..e61da7622b 100644 --- a/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/burn-tensor/src/tensor/ops/int_tensor.rs @@ -7,847 +7,844 @@ use core::ops::Range; /// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor) /// for documentation on each function. pub trait IntTensorOps { - /// Creates a new int tensor. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The integer tensor with the given shape. - fn int_empty(shape: Shape, device: &Device) -> IntTensor; - - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn int_shape(tensor: &IntTensor) -> Shape; - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn int_into_data(tensor: IntTensor) -> Reader, D>>; - - /// Gets the data from the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data cloned from the data structure. - fn int_to_data(tensor: &IntTensor) -> Reader, D>> { - Self::int_into_data(tensor.clone()) + /// Creates a new int tensor. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The integer tensor with the given shape. + fn int_empty(shape: Shape, device: &Device) -> IntTensor; + + /// Returns the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn int_shape(tensor: &IntTensor) -> Shape; + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn int_into_data(tensor: IntTensor) -> Reader, D>>; + + /// Gets the data from the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data cloned from the data structure. + fn int_to_data(tensor: &IntTensor) -> Reader, D>> { + Self::int_into_data(tensor.clone()) + } + + /// Creates a tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the data. + fn int_from_data( + data: Data, D>, + device: &Device, + ) -> IntTensor; + + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn int_device(tensor: &IntTensor) -> Device; + + /// Moves the tensor to the given device. + fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor; + + /// Reshapes the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `shape` - The new shape. + /// + /// # Returns + /// + /// The tensor with the new shape. + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor; + + /// Gets the element at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The elements at the given indices. + fn int_slice( + tensor: IntTensor, + indices: [Range; D2], + ) -> IntTensor; + + /// Sets the element at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The tensor with the element at the given indices set. + fn int_slice_assign( + tensor: IntTensor, + indices: [Range; D2], + value: IntTensor, + ) -> IntTensor; + + /// Converts int tensor to float tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The int tensor with the same data as the float tensor. + fn int_into_float(tensor: IntTensor) -> FloatTensor; + + /// Fills the tensor with values from the source tensor if the mask is true at the given + /// indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `mask` - The mask. + /// * `source` - The source tensor. + /// + /// # Returns + /// + /// The tensor with the values filled. + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + source: IntTensor, + ) -> IntTensor; + + /// Fills the tensor with the given value if the mask is true at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `mask` - The mask. + /// * `value` - The value. + /// + /// # Returns + /// + /// The tensor with the values filled. + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor; + + /// Gather elements from the tensor at the given indices. + /// + /// # Arguments + /// + /// * `dim` - The dimension to gather from. + /// * `tensor` - The tensor. + /// * `indices` - The indices. + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor; + + /// Scatter a given value to the tensor at the given indices. + /// + /// # Arguments + /// + /// * `dim` - The dimension to scatter to. + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// * `value` - The value. + /// + /// # Returns + /// + /// The tensor with the values scattered. + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor; + + /// Select tensor elements along the given dimension corresponding to the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The tensor with the selected elements. + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor; + + /// Assign the selected elements along the given dimension corresponding to the given indices + /// to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices. + /// * `value` - The value. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor; + + /// Repeats the tensor along the given dimension the given number of times. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to repeat. + /// * `times` - The number of times to repeat. + /// + /// # Returns + /// + /// The tensor with the given dimension repeated the given number of times. + fn int_repeat( + tensor: IntTensor, + dim: usize, + times: usize, + ) -> IntTensor { + let mut shape = Self::int_shape(&tensor); + if shape.dims[dim] != 1 { + panic!("Can only repeat dimension with dim=1"); } - - /// Creates a tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the data. - fn int_from_data( - data: Data, D>, - device: &Device, - ) -> IntTensor; - - /// Gets the device of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device of the tensor. - fn int_device(tensor: &IntTensor) -> Device; - - /// Moves the tensor to the given device. - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor; - - /// Reshapes the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `shape` - The new shape. - /// - /// # Returns - /// - /// The tensor with the new shape. - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor; - - /// Gets the element at the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `indices` - The indices. - /// - /// # Returns - /// - /// The elements at the given indices. - fn int_slice( - tensor: IntTensor, - indices: [Range; D2], - ) -> IntTensor; - - /// Sets the element at the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `indices` - The indices. - /// - /// # Returns - /// - /// The tensor with the element at the given indices set. - fn int_slice_assign( - tensor: IntTensor, - indices: [Range; D2], - value: IntTensor, - ) -> IntTensor; - - /// Converts int tensor to float tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The int tensor with the same data as the float tensor. - fn int_into_float(tensor: IntTensor) -> FloatTensor; - - /// Fills the tensor with values from the source tensor if the mask is true at the given - /// indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `mask` - The mask. - /// * `source` - The source tensor. - /// - /// # Returns - /// - /// The tensor with the values filled. - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - source: IntTensor, - ) -> IntTensor; - - /// Fills the tensor with the given value if the mask is true at the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `mask` - The mask. - /// * `value` - The value. - /// - /// # Returns - /// - /// The tensor with the values filled. - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor; - - /// Gather elements from the tensor at the given indices. - /// - /// # Arguments - /// - /// * `dim` - The dimension to gather from. - /// * `tensor` - The tensor. - /// * `indices` - The indices. - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor; - - /// Scatter a given value to the tensor at the given indices. - /// - /// # Arguments - /// - /// * `dim` - The dimension to scatter to. - /// * `tensor` - The tensor. - /// * `indices` - The indices. - /// * `value` - The value. - /// - /// # Returns - /// - /// The tensor with the values scattered. - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor; - - /// Select tensor elements along the given dimension corresponding to the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices. - /// - /// # Returns - /// - /// The tensor with the selected elements. - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor; - - /// Assign the selected elements along the given dimension corresponding to the given indices - /// to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices. - /// * `value` - The value. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor; - - /// Repeats the tensor along the given dimension the given number of times. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to repeat. - /// * `times` - The number of times to repeat. - /// - /// # Returns - /// - /// The tensor with the given dimension repeated the given number of times. - fn int_repeat( - tensor: IntTensor, - dim: usize, - times: usize, - ) -> IntTensor { - let mut shape = Self::int_shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } - shape.dims[dim] = times; - - let mut i = 0; - let indices_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); - - let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor)); - for i in 0..times { - let mut indices = indices_select_all.clone(); - indices[dim] = i..i + 1; - tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone()); - } - - tensor_output - } - - /// Concatenates the given tensors along the given dimension. - /// - /// # Arguments - /// - /// * `tensors` - The tensors. - /// * `dim` - The dimension to concatenate along. - /// - /// # Returns - /// - /// The concatenated tensor. - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor; - - /// Elementwise equality comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; - - /// Elementwise equality comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; - - /// Elementwise greater than comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; - - /// Elementwise greater than comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; - - /// Elementwise greater than or equal comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor; - - /// Elementwise greater than or equal comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor; - - /// Elementwise less than comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; - - /// Elementwise less than comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; - - /// Elementwise less than or equal comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor; - - /// Elementwise less than or equal comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor; - - // ==== NUMERIC ==== // - - /// Elementwise addition. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the addition. - fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise addition with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the addition. - fn int_add_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn int_clamp_min(tensor: IntTensor, min: IntElem) -> IntTensor { - let mask = Self::int_lower_elem(tensor.clone(), min); - Self::int_mask_fill(tensor, mask, min) - } - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn int_clamp_max(tensor: IntTensor, max: IntElem) -> IntTensor { - let mask = Self::int_greater_elem(tensor.clone(), max); - Self::int_mask_fill(tensor, mask, max) - } - - /// Clamps a tensor between a minimum and maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn int_clamp( - tensor: IntTensor, - min: IntElem, - max: IntElem, - ) -> IntTensor { - Self::int_clamp_min(Self::int_clamp_max(tensor, max), min) - } - - /// Elementwise subtraction. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the subtraction. - fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise subtraction with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the subtraction. - fn int_sub_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Elementwise multiplication. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the multiplication. - fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise multiplication with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the multiplication. - fn int_mul_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Elementwise division. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the division. - fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise division with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the division. - fn int_div_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Elementwise negation. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to negate. - /// - /// # Returns - /// - /// The negated tensor. - fn int_neg(tensor: IntTensor) -> IntTensor { - Self::int_mul_scalar(tensor, (-1.0).elem::>()) - } - - /// Creates a tensor of zeros. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor of zeros. - fn int_zeros(shape: Shape, device: &Device) -> IntTensor; - - /// Creates a tensor of ones. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor of ones. - fn int_ones(shape: Shape, device: &Device) -> IntTensor; - - /// Creates a tensor filled with given value. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `fill_value` - The value with which to fill the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor filled with given value - fn int_full( - shape: Shape, - fill_value: IntElem, - device: &Device, - ) -> IntTensor { - Self::int_add_scalar(Self::int_zeros(shape, device), fill_value) - } - - /// Sums all elements in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// - /// # Returns - /// - /// The sum of all elements in the tensor. - fn int_sum(tensor: IntTensor) -> IntTensor; - - /// Sums all elements in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// * `dim` - The dimension to sum along. - /// - /// # Returns - /// - /// The sum of all elements in the tensor along the dimension. - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Computes the mean of all elements in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// - /// # Returns - /// - /// The mean of all elements in the tensor. - fn int_mean(tensor: IntTensor) -> IntTensor { - let num_elems = B::int_shape(&tensor).num_elements(); - B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem()) - } - - /// Computes the mean of all elements in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// - /// # Returns - /// - /// The mean of all elements in the tensor along the dimension. - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Gets the indices of the maximum elements along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum indices of. - /// * `dim` - The dimension to get the maximum indices along. - /// - /// # Returns - /// - /// The indices of the maximum elements along the dimension. - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Gets the indices of the minimum elements along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum indices of. - /// * `dim` - The dimension to get the minimum indices along. - /// - /// # Returns - /// - /// The indices of the minimum elements along the dimension. - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Gets the maximum element in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum element of. - /// - /// # Returns - /// - /// The maximum element in the tensor. - fn int_max(tensor: IntTensor) -> IntTensor { - let shape = B::int_shape(&tensor); - let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); - - B::int_max_dim(tensor, 0) - } - - /// Gets the maximum element in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum element of. - /// * `dim` - The dimension to get the maximum element along. - /// - /// # Returns - /// - /// The maximum element in the tensor along the dimension. - fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let index = B::int_argmax(tensor.clone(), dim); - - B::int_gather(D - 1, tensor, index) - } - - /// Gets the maximum elements and corresponding indices along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements and indices of. - /// * `dim` - The dimension to get the maximum elements and indices along. - /// - /// # Returns - /// - /// The maximum elements and corresponding indices along the dimension. - fn int_max_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - let index = B::int_argmax(tensor.clone(), dim); - let values = B::int_gather(D - 1, tensor, index.clone()); - - (values, index) - } - - /// Gets the minimum element in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum element of. - /// - /// # Returns - /// - /// The minimum element in the tensor. - fn int_min(tensor: IntTensor) -> IntTensor { - let shape = B::int_shape(&tensor); - let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); - - B::int_min_dim(tensor, 0) - } - - /// Gets the minimum elements in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum element of. - /// * `dim` - The dimension to get the minimum element along. - /// - /// # Returns - /// - /// The minimum element in the tensor along the dimension. - fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let index = B::int_argmin(tensor.clone(), dim); - - B::int_gather(D - 1, tensor, index) - } - - /// Gets the minimum elements and corresponding indices along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements and indices of. - /// * `dim` - The dimension to get the minimum elements and indices along. - /// - /// # Returns - /// - /// The minimum elements and corresponding indices along the dimension. - fn int_min_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - let indices = B::int_argmin(tensor.clone(), dim); - let values = B::int_gather(D - 1, tensor, indices.clone()); - - (values, indices) - } - - /// Returns a new tensor with absolute values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take absolute value of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with absolute values. - fn int_abs(tensor: IntTensor) -> IntTensor; - - /// Transposes an int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn int_transpose(tensor: IntTensor) -> IntTensor { - Self::int_swap_dims(tensor, D - 2, D - 1) + shape.dims[dim] = times; + + let mut i = 0; + let indices_select_all = [0; D].map(|_| { + let start = 0; + let end = shape.dims[i]; + i += 1; + start..end + }); + + let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor)); + for i in 0..times { + let mut indices = indices_select_all.clone(); + indices[dim] = i..i + 1; + tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone()); } - /// Swaps two dimensions of an int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn int_swap_dims( - tensor: IntTensor, - dim1: usize, - dim2: usize, - ) -> IntTensor; + tensor_output + } + + /// Concatenates the given tensors along the given dimension. + /// + /// # Arguments + /// + /// * `tensors` - The tensors. + /// * `dim` - The dimension to concatenate along. + /// + /// # Returns + /// + /// The concatenated tensor. + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor; + + /// Elementwise equality comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; + + /// Elementwise equality comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; + + /// Elementwise greater than comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; + + /// Elementwise greater than comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; + + /// Elementwise greater than or equal comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor; + + /// Elementwise greater than or equal comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor; + + /// Elementwise less than comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; + + /// Elementwise less than comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; + + /// Elementwise less than or equal comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor; + + /// Elementwise less than or equal comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor; + + // ==== NUMERIC ==== // + + /// Elementwise addition. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the addition. + fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise addition with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the addition. + fn int_add_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn int_clamp_min(tensor: IntTensor, min: IntElem) -> IntTensor { + let mask = Self::int_lower_elem(tensor.clone(), min); + Self::int_mask_fill(tensor, mask, min) + } + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn int_clamp_max(tensor: IntTensor, max: IntElem) -> IntTensor { + let mask = Self::int_greater_elem(tensor.clone(), max); + Self::int_mask_fill(tensor, mask, max) + } + + /// Clamps a tensor between a minimum and maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn int_clamp( + tensor: IntTensor, + min: IntElem, + max: IntElem, + ) -> IntTensor { + Self::int_clamp_min(Self::int_clamp_max(tensor, max), min) + } + + /// Elementwise subtraction. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the subtraction. + fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise subtraction with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the subtraction. + fn int_sub_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Elementwise multiplication. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the multiplication. + fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise multiplication with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the multiplication. + fn int_mul_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Elementwise division. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the division. + fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise division with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the division. + fn int_div_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Elementwise negation. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to negate. + /// + /// # Returns + /// + /// The negated tensor. + fn int_neg(tensor: IntTensor) -> IntTensor { + Self::int_mul_scalar(tensor, (-1.0).elem::>()) + } + + /// Creates a tensor of zeros. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor of zeros. + fn int_zeros(shape: Shape, device: &Device) -> IntTensor; + + /// Creates a tensor of ones. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor of ones. + fn int_ones(shape: Shape, device: &Device) -> IntTensor; + + /// Creates a tensor filled with given value. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `fill_value` - The value with which to fill the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor filled with given value + fn int_full( + shape: Shape, + fill_value: IntElem, + device: &Device, + ) -> IntTensor { + Self::int_add_scalar(Self::int_zeros(shape, device), fill_value) + } + + /// Sums all elements in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// + /// # Returns + /// + /// The sum of all elements in the tensor. + fn int_sum(tensor: IntTensor) -> IntTensor; + + /// Sums all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension to sum along. + /// + /// # Returns + /// + /// The sum of all elements in the tensor along the dimension. + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Computes the mean of all elements in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// + /// # Returns + /// + /// The mean of all elements in the tensor. + fn int_mean(tensor: IntTensor) -> IntTensor { + let num_elems = B::int_shape(&tensor).num_elements(); + B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem()) + } + + /// Computes the mean of all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// + /// # Returns + /// + /// The mean of all elements in the tensor along the dimension. + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Gets the indices of the maximum elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum indices of. + /// * `dim` - The dimension to get the maximum indices along. + /// + /// # Returns + /// + /// The indices of the maximum elements along the dimension. + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Gets the indices of the minimum elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum indices of. + /// * `dim` - The dimension to get the minimum indices along. + /// + /// # Returns + /// + /// The indices of the minimum elements along the dimension. + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Gets the maximum element in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum element of. + /// + /// # Returns + /// + /// The maximum element in the tensor. + fn int_max(tensor: IntTensor) -> IntTensor { + let shape = B::int_shape(&tensor); + let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); + + B::int_max_dim(tensor, 0) + } + + /// Gets the maximum element in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum element of. + /// * `dim` - The dimension to get the maximum element along. + /// + /// # Returns + /// + /// The maximum element in the tensor along the dimension. + fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let index = B::int_argmax(tensor.clone(), dim); + + B::int_gather(D - 1, tensor, index) + } + + /// Gets the maximum elements and corresponding indices along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements and indices of. + /// * `dim` - The dimension to get the maximum elements and indices along. + /// + /// # Returns + /// + /// The maximum elements and corresponding indices along the dimension. + fn int_max_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + let index = B::int_argmax(tensor.clone(), dim); + let values = B::int_gather(D - 1, tensor, index.clone()); + + (values, index) + } + + /// Gets the minimum element in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum element of. + /// + /// # Returns + /// + /// The minimum element in the tensor. + fn int_min(tensor: IntTensor) -> IntTensor { + let shape = B::int_shape(&tensor); + let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); + + B::int_min_dim(tensor, 0) + } + + /// Gets the minimum elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum element of. + /// * `dim` - The dimension to get the minimum element along. + /// + /// # Returns + /// + /// The minimum element in the tensor along the dimension. + fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let index = B::int_argmin(tensor.clone(), dim); + + B::int_gather(D - 1, tensor, index) + } + + /// Gets the minimum elements and corresponding indices along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements and indices of. + /// * `dim` - The dimension to get the minimum elements and indices along. + /// + /// # Returns + /// + /// The minimum elements and corresponding indices along the dimension. + fn int_min_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + let indices = B::int_argmin(tensor.clone(), dim); + let values = B::int_gather(D - 1, tensor, indices.clone()); + + (values, indices) + } + + /// Returns a new tensor with absolute values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take absolute value of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with absolute values. + fn int_abs(tensor: IntTensor) -> IntTensor; + + /// Transposes an int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn int_transpose(tensor: IntTensor) -> IntTensor { + Self::int_swap_dims(tensor, D - 2, D - 1) + } + + /// Swaps two dimensions of an int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn int_swap_dims( + tensor: IntTensor, + dim1: usize, + dim2: usize, + ) -> IntTensor; } diff --git a/burn-tensor/src/tensor/ops/modules/base.rs b/burn-tensor/src/tensor/ops/modules/base.rs index 4290fbca23..636c073bc4 100644 --- a/burn-tensor/src/tensor/ops/modules/base.rs +++ b/burn-tensor/src/tensor/ops/modules/base.rs @@ -1,441 +1,434 @@ use super::{conv, pool, unfold::unfold4d_using_conv2d}; use crate::{ - backend::Backend, - ops::{FloatTensor, IntTensor}, - Shape, + backend::Backend, + ops::{FloatTensor, IntTensor}, + Shape, }; /// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). #[derive(new)] pub struct Conv2dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, - /// Weights gradient. - pub weights_grad: FloatTensor, + /// Weights gradient. + pub weights_grad: FloatTensor, - /// Bias gradient. - pub bias_grad: Option>, + /// Bias gradient. + pub bias_grad: Option>, } /// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). #[derive(new)] pub struct MaxPool1dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, } /// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices). #[derive(new)] pub struct MaxPool1dWithIndices { - /// The output tensor. - pub output: FloatTensor, + /// The output tensor. + pub output: FloatTensor, - /// The indices tensor. - pub indices: IntTensor, + /// The indices tensor. + pub indices: IntTensor, } /// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d). #[derive(new)] pub struct MaxPool2dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, } /// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices). #[derive(new)] pub struct MaxPool2dWithIndices { - /// The output tensor. - pub output: FloatTensor, + /// The output tensor. + pub output: FloatTensor, - /// The indices tensor. - pub indices: IntTensor, + /// The indices tensor. + pub indices: IntTensor, } /// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d). #[derive(new)] pub struct Conv1dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, - /// Weights gradient. - pub weights_grad: FloatTensor, + /// Weights gradient. + pub weights_grad: FloatTensor, - /// Bias gradient. - pub bias_grad: Option>, + /// Bias gradient. + pub bias_grad: Option>, } /// Convolution options. #[derive(new, Debug, Clone, Hash)] pub struct ConvOptions { - /// Stride. - pub stride: [usize; N], + /// Stride. + pub stride: [usize; N], - /// Padding. - pub padding: [usize; N], + /// Padding. + pub padding: [usize; N], - /// Dilation. - pub dilation: [usize; N], + /// Dilation. + pub dilation: [usize; N], - /// Groups. - pub groups: usize, + /// Groups. + pub groups: usize, } /// Transposed convolution options. #[derive(new, Debug, Clone, Hash)] pub struct ConvTransposeOptions { - /// Stride. - pub stride: [usize; N], + /// Stride. + pub stride: [usize; N], - /// Padding. - pub padding: [usize; N], + /// Padding. + pub padding: [usize; N], - /// Padding out. - pub padding_out: [usize; N], + /// Padding out. + pub padding_out: [usize; N], - /// Dilation. - pub dilation: [usize; N], + /// Dilation. + pub dilation: [usize; N], - /// Groups. - pub groups: usize, + /// Groups. + pub groups: usize, } /// Unfold operation options. #[derive(new, Debug, Clone)] pub struct UnfoldOptions { - /// The number of positions to slide over the input tensor in each dimension. - /// A stride of `[1, 1]` will slide the kernel one pixel at a time. - pub stride: [usize; 2], + /// The number of positions to slide over the input tensor in each dimension. + /// A stride of `[1, 1]` will slide the kernel one pixel at a time. + pub stride: [usize; 2], - /// The number of zero-padding pixels added to each side of the input tensor in each dimension. - pub padding: [usize; 2], + /// The number of zero-padding pixels added to each side of the input tensor in each dimension. + pub padding: [usize; 2], - /// The spacing between the blocks (patches) in the original input tensor. - pub dilation: [usize; 2], + /// The spacing between the blocks (patches) in the original input tensor. + pub dilation: [usize; 2], } /// Module operations trait. pub trait ModuleOps { - /// Embedding operation. - /// - /// # Arguments - /// - /// * `weights` - The embedding weights. - /// * `indices` - The indices tensor. - /// - /// # Returns - /// - /// The output tensor. - fn embedding(weights: FloatTensor, indices: IntTensor) -> FloatTensor { - let [batch_size, seq_length] = B::int_shape(&indices).dims; - let [_, d_model] = B::shape(&weights).dims; - - let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); - let output = B::select(weights, 0, indices); - - B::reshape(output, Shape::new([batch_size, seq_length, d_model])) - } - - /// Embedding backward operation. - /// - /// # Arguments - /// - /// * `weights` - The embedding weights. - /// * `output_grad` - The output gradient. - /// * `indices` - The indices tensor. - /// - /// # Returns - /// - /// The gradient. - fn embedding_backward( - weights: FloatTensor, - output_grad: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - let [batch_size, seq_length] = B::int_shape(&indices).dims; - let [n_embeddings, d_model] = B::shape(&weights).dims; - let device = B::device(&weights); - - let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); - let output_grad = B::reshape(output_grad, Shape::new([batch_size * seq_length, d_model])); - let grad = B::zeros(Shape::new([n_embeddings, d_model]), &device); - - B::select_assign(grad, 0, indices, output_grad) - } - /// One dimensional convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, length]`, - /// weight: `[channels_out, channels_in, kernel_size]`, - /// bias: `[channels_out]`, - fn conv1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> FloatTensor { - conv::conv1d_from_conv2d::(x, weight, bias, options) - } - /// Backward pass for the [conv1d](ModuleOps::conv1d) operation. - fn conv1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<1>, - ) -> Conv1dBackward { - conv::conv1d_backward(x, weight, bias, output_grad, options) - } - /// Two dimensional convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, height, width]`, - /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, - /// bias: `[channels_out]`, - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor; - /// Backward pass for the [conv2d](ModuleOps::conv2d) operation. - fn conv2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<2>, - ) -> Conv2dBackward { - conv::conv2d_backward(x, weight, bias, output_grad, options) - } - /// One dimensional transposed convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, length]`, - /// weight: `[channels_in, channels_out, length]`, - /// bias: `[channels_out]`, - fn conv_transpose1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> FloatTensor { - conv::conv_transpose1d_from_conv_transpose2d::(x, weight, bias, options) - } - /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation. - fn conv_transpose1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<1>, - ) -> Conv1dBackward { - conv::conv_transpose1d_backward(x, weight, bias, output_grad, options) - } - /// Two dimensional transposed convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, height, width]`, - /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`, - /// bias: `[channels_out]`, - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor; - - /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation. - fn conv_transpose2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<2>, - ) -> Conv2dBackward { - conv::conv_transpose2d_backward(x, weight, bias, output_grad, options) - } - - /// Four-dimensional unfolding. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, height, width]`, - /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, - fn unfold4d( - x: FloatTensor, - kernel_size: [usize; 2], - options: UnfoldOptions, - ) -> FloatTensor { - unfold4d_using_conv2d::(x, kernel_size, options) - } - - /// One dimensional avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, length], - fn avg_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - pool::avg_pool1d_from_2d::(x, kernel_size, stride, padding, count_include_pad) - } - /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation. - fn avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - pool::avg_pool1d_backward_from_2d::( - x, - grad, - kernel_size, - stride, - padding, - count_include_pad, - ) - } - /// Two dimensional avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor; - /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation. - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor; - /// Two dimensional adaptive avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor; - /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation. - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor; - /// One dimensional adaptive avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, length], - fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { - pool::adaptive_avg_pool1d_from_2d::(x, output_size) - } - /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation. - fn adaptive_avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - pool::adaptive_avg_pool1d_backward_from_2d::(x, grad) - } - /// One dimensional max pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, length], - fn max_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> FloatTensor { - pool::max_pool1d_from_2d::(x, kernel_size, stride, padding, dilation) - } - - /// One dimensional max pooling with indices. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn max_pool1d_with_indices( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices { - pool::max_pool1d_with_indices_from_2d::(x, kernel_size, stride, padding, dilation) - } - /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation. - fn max_pool1d_with_indices_backward( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool1dBackward { - pool::max_pool1d_with_indices_backward_from_2d::( - x, - kernel_size, - stride, - padding, - dilation, - output_grad, - indices, - ) - } - - /// Two dimensional max pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor; - - /// Two dimensional max pooling with indices. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices; - /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation. - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward; + /// Embedding operation. + /// + /// # Arguments + /// + /// * `weights` - The embedding weights. + /// * `indices` - The indices tensor. + /// + /// # Returns + /// + /// The output tensor. + fn embedding(weights: FloatTensor, indices: IntTensor) -> FloatTensor { + let [batch_size, seq_length] = B::int_shape(&indices).dims; + let [_, d_model] = B::shape(&weights).dims; + + let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); + let output = B::select(weights, 0, indices); + + B::reshape(output, Shape::new([batch_size, seq_length, d_model])) + } + + /// Embedding backward operation. + /// + /// # Arguments + /// + /// * `weights` - The embedding weights. + /// * `output_grad` - The output gradient. + /// * `indices` - The indices tensor. + /// + /// # Returns + /// + /// The gradient. + fn embedding_backward( + weights: FloatTensor, + output_grad: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + let [batch_size, seq_length] = B::int_shape(&indices).dims; + let [n_embeddings, d_model] = B::shape(&weights).dims; + let device = B::device(&weights); + + let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); + let output_grad = B::reshape(output_grad, Shape::new([batch_size * seq_length, d_model])); + let grad = B::zeros(Shape::new([n_embeddings, d_model]), &device); + + B::select_assign(grad, 0, indices, output_grad) + } + /// One dimensional convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, length]`, + /// weight: `[channels_out, channels_in, kernel_size]`, + /// bias: `[channels_out]`, + fn conv1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> FloatTensor { + conv::conv1d_from_conv2d::(x, weight, bias, options) + } + /// Backward pass for the [conv1d](ModuleOps::conv1d) operation. + fn conv1d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<1>, + ) -> Conv1dBackward { + conv::conv1d_backward(x, weight, bias, output_grad, options) + } + /// Two dimensional convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, height, width]`, + /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, + /// bias: `[channels_out]`, + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor; + /// Backward pass for the [conv2d](ModuleOps::conv2d) operation. + fn conv2d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<2>, + ) -> Conv2dBackward { + conv::conv2d_backward(x, weight, bias, output_grad, options) + } + /// One dimensional transposed convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, length]`, + /// weight: `[channels_in, channels_out, length]`, + /// bias: `[channels_out]`, + fn conv_transpose1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + conv::conv_transpose1d_from_conv_transpose2d::(x, weight, bias, options) + } + /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation. + fn conv_transpose1d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, + ) -> Conv1dBackward { + conv::conv_transpose1d_backward(x, weight, bias, output_grad, options) + } + /// Two dimensional transposed convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, height, width]`, + /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`, + /// bias: `[channels_out]`, + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor; + + /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation. + fn conv_transpose2d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, + ) -> Conv2dBackward { + conv::conv_transpose2d_backward(x, weight, bias, output_grad, options) + } + + /// Four-dimensional unfolding. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, height, width]`, + /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, + fn unfold4d( + x: FloatTensor, + kernel_size: [usize; 2], + options: UnfoldOptions, + ) -> FloatTensor { + unfold4d_using_conv2d::(x, kernel_size, options) + } + + /// One dimensional avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, length], + fn avg_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + pool::avg_pool1d_from_2d::(x, kernel_size, stride, padding, count_include_pad) + } + /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation. + fn avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + pool::avg_pool1d_backward_from_2d::(x, grad, kernel_size, stride, padding, count_include_pad) + } + /// Two dimensional avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor; + /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation. + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor; + /// Two dimensional adaptive avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor; + /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation. + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor; + /// One dimensional adaptive avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, length], + fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { + pool::adaptive_avg_pool1d_from_2d::(x, output_size) + } + /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation. + fn adaptive_avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + pool::adaptive_avg_pool1d_backward_from_2d::(x, grad) + } + /// One dimensional max pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, length], + fn max_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> FloatTensor { + pool::max_pool1d_from_2d::(x, kernel_size, stride, padding, dilation) + } + + /// One dimensional max pooling with indices. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool1d_with_indices( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices { + pool::max_pool1d_with_indices_from_2d::(x, kernel_size, stride, padding, dilation) + } + /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation. + fn max_pool1d_with_indices_backward( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool1dBackward { + pool::max_pool1d_with_indices_backward_from_2d::( + x, + kernel_size, + stride, + padding, + dilation, + output_grad, + indices, + ) + } + + /// Two dimensional max pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor; + + /// Two dimensional max pooling with indices. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices; + /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation. + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward; } diff --git a/burn-tensor/src/tensor/ops/modules/conv.rs b/burn-tensor/src/tensor/ops/modules/conv.rs index 666068f566..d97969a9b9 100644 --- a/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/burn-tensor/src/tensor/ops/modules/conv.rs @@ -5,763 +5,757 @@ use libm::ceilf; /// Calculate the expected padding size required when applying a convolution. pub fn calculate_conv_padding( - kernel_size: usize, - stride: usize, - size_in: usize, - size_out: usize, + kernel_size: usize, + stride: usize, + size_in: usize, + size_out: usize, ) -> usize { - let kernel_size = kernel_size as f32; - let stride = stride as f32; - let size_in = size_in as f32; - let size_out = size_out as f32; + let kernel_size = kernel_size as f32; + let stride = stride as f32; + let size_in = size_in as f32; + let size_out = size_out as f32; - let padding = stride * (size_out - 1.) - size_in + kernel_size; - let padding = ceilf(padding / 2.); + let padding = stride * (size_out - 1.) - size_in + kernel_size; + let padding = ceilf(padding / 2.); - padding as usize + padding as usize } /// Calculate the expected output size when doing a convolution operation. pub fn calculate_conv_output_size( - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - size_in: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + size_in: usize, ) -> usize { - (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 } /// Calculate the expected output size when doing a transposed convolution operation. pub fn calculate_conv_transpose_output_size( - kernel_size: usize, - stride: usize, - padding: usize, - padding_out: usize, - dilation: usize, - size_in: usize, + kernel_size: usize, + stride: usize, + padding: usize, + padding_out: usize, + dilation: usize, + size_in: usize, ) -> usize { - (size_in - 1) * stride + dilation * (kernel_size - 1) + padding_out - 2 * padding + 1 + (size_in - 1) * stride + dilation * (kernel_size - 1) + padding_out - 2 * padding + 1 } /// Calculate the expected output size when doing a pooling operation. pub fn calculate_pool_output_size( - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - size_in: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + size_in: usize, ) -> usize { - ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 + ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 } /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions. pub(crate) fn conv1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<1>, ) -> Conv1dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _, length_in] = B::shape(&x).dims; - let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims; - let [_, _, kernel_size] = weight_shape.dims; - - let padding_out = calculate_padding_out( - kernel_size, - options.stride[0], - options.padding[0], - options.dilation[0], - length_in, - length_out, - ); - - let x_grad = B::conv_transpose1d( - output_grad.clone(), - weight, - None, - ConvTransposeOptions::new( - options.stride, - options.padding, - [padding_out], - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => conv1d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), - false => conv1d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv1dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _, length_in] = B::shape(&x).dims; + let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims; + let [_, _, kernel_size] = weight_shape.dims; + + let padding_out = calculate_padding_out( + kernel_size, + options.stride[0], + options.padding[0], + options.dilation[0], + length_in, + length_out, + ); + + let x_grad = B::conv_transpose1d( + output_grad.clone(), + weight, + None, + ConvTransposeOptions::new( + options.stride, + options.padding, + [padding_out], + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => conv1d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), + false => conv1d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv1dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass using convolutions. pub(crate) fn conv2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<2>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<2>, ) -> Conv2dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims; - let [_, _, height_out, width_out] = B::shape(&output_grad).dims; - let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims; - - let padding_1_out = calculate_padding_out( - kernel_size_1, - options.stride[0], - options.padding[0], - options.dilation[0], - height_in, - height_out, - ); - let padding_2_out = calculate_padding_out( - kernel_size_2, - options.stride[1], - options.padding[1], - options.dilation[1], - width_in, - width_out, - ); - - let x_grad = B::conv_transpose2d( - output_grad.clone(), - weight, - None, - ConvTransposeOptions::new( - options.stride, - options.padding, - [padding_1_out, padding_2_out], - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => conv2d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), - false => conv2d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv2dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape( - grad, - Shape::new([channels_out, batch_size * height_out * width_out]), - ); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims; + let [_, _, height_out, width_out] = B::shape(&output_grad).dims; + let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims; + + let padding_1_out = calculate_padding_out( + kernel_size_1, + options.stride[0], + options.padding[0], + options.dilation[0], + height_in, + height_out, + ); + let padding_2_out = calculate_padding_out( + kernel_size_2, + options.stride[1], + options.padding[1], + options.dilation[1], + width_in, + width_out, + ); + + let x_grad = B::conv_transpose2d( + output_grad.clone(), + weight, + None, + ConvTransposeOptions::new( + options.stride, + options.padding, + [padding_1_out, padding_2_out], + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => conv2d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), + false => conv2d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv2dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape( + grad, + Shape::new([channels_out, batch_size * height_out * width_out]), + ); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass using convolutions. pub(crate) fn conv_transpose2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<2>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, ) -> Conv2dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _channels_in, _, _] = B::shape(&x).dims; - let [_, channels_out, height_out, width_out] = B::shape(&output_grad).dims; - - let x_grad = B::conv2d( - output_grad.clone(), - weight, - None, - ConvOptions::new( - options.stride, - options.padding, - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => conv_transpose2d_weight_grad_no_groups::( - x, - output_grad.clone(), - weight_shape, - options, - ), - false => conv_transpose2d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv2dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape( - grad, - Shape::new([channels_out, batch_size * height_out * width_out]), - ); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _channels_in, _, _] = B::shape(&x).dims; + let [_, channels_out, height_out, width_out] = B::shape(&output_grad).dims; + + let x_grad = B::conv2d( + output_grad.clone(), + weight, + None, + ConvOptions::new( + options.stride, + options.padding, + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => { + conv_transpose2d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options) + } + false => conv_transpose2d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv2dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape( + grad, + Shape::new([channels_out, batch_size * height_out * width_out]), + ); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass using convolutions. pub(crate) fn conv_transpose1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, ) -> Conv1dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _channels_in, _] = B::shape(&x).dims; - let [_, channels_out, length_out] = B::shape(&output_grad).dims; - - let x_grad = B::conv1d( - output_grad.clone(), - weight, - None, - ConvOptions::new( - options.stride, - options.padding, - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => conv_transpose1d_weight_grad_no_groups::( - x, - output_grad.clone(), - weight_shape, - options, - ), - false => conv_transpose1d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv1dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _channels_in, _] = B::shape(&x).dims; + let [_, channels_out, length_out] = B::shape(&output_grad).dims; + + let x_grad = B::conv1d( + output_grad.clone(), + weight, + None, + ConvOptions::new( + options.stride, + options.padding, + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => { + conv_transpose1d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options) + } + false => conv_transpose1d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv1dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Execute a 1D convolution using a 2D convolution. pub(crate) fn conv1d_from_conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, ) -> FloatTensor { - let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims; - let [batch_size, channels_in, length_in] = B::shape(&x).dims; - - let weight = B::reshape( - weight, - Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]), - ); - let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); - - let tensor = B::conv2d( - x, - weight, - bias, - ConvOptions::new( - [options.stride[0], 1], - [options.padding[0], 0], - [options.dilation[0], 1], - options.groups, - ), - ); - let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; - B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) + let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims; + let [batch_size, channels_in, length_in] = B::shape(&x).dims; + + let weight = B::reshape( + weight, + Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]), + ); + let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); + + let tensor = B::conv2d( + x, + weight, + bias, + ConvOptions::new( + [options.stride[0], 1], + [options.padding[0], 0], + [options.dilation[0], 1], + options.groups, + ), + ); + let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; + B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } /// Execute a 1D transposed convolution using a 2D transposed convolution. pub(crate) fn conv_transpose1d_from_conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, ) -> FloatTensor { - let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims; - let [batch_size, _channels_in, length_in] = B::shape(&x).dims; - - let weight = B::reshape( - weight, - Shape::new([channels_in, channels_out, kernel_size, 1]), - ); - let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); - - let tensor = B::conv_transpose2d( - x, - weight, - bias, - ConvTransposeOptions::new( - [options.stride[0], 1], - [options.padding[0], 0], - [options.padding_out[0], 0], - [options.dilation[0], 1], - options.groups, - ), - ); - let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; - B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) + let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims; + let [batch_size, _channels_in, length_in] = B::shape(&x).dims; + + let weight = B::reshape( + weight, + Shape::new([channels_in, channels_out, kernel_size, 1]), + ); + let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); + + let tensor = B::conv_transpose2d( + x, + weight, + bias, + ConvTransposeOptions::new( + [options.stride[0], 1], + [options.padding[0], 0], + [options.padding_out[0], 0], + [options.dilation[0], 1], + options.groups, + ), + ); + let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; + B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } fn conv1d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvOptions<1>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<1>, ) -> FloatTensor { - let [channels_out, increment_ci, kernel_size] = B::shape(&weight_grad).dims; - let increment_co = channels_out / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv1d( - x, - grad, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - weight_grad = B::slice_assign( - weight_grad, - [start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size], - weight_grad_tmp, - ); - } + let [channels_out, increment_ci, kernel_size] = B::shape(&weight_grad).dims; + let increment_co = channels_out / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv1d( + x, + grad, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + weight_grad = B::slice_assign( + weight_grad, + [start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size], + weight_grad_tmp, + ); + } - weight_grad + weight_grad } fn conv2d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvOptions<2>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<2>, ) -> FloatTensor { - let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; - let increment_co = channels_out / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv2d( - x, - grad, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - weight_grad = B::slice_assign( - weight_grad, - [ - start_idx_co..end_idx_co, - 0..increment_ci, - 0..kernel_size_1, - 0..kernel_size_2, - ], - weight_grad_tmp, - ); - } + let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; + let increment_co = channels_out / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv2d( + x, + grad, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + weight_grad = B::slice_assign( + weight_grad, + [ + start_idx_co..end_idx_co, + 0..increment_ci, + 0..kernel_size_1, + 0..kernel_size_2, + ], + weight_grad_tmp, + ); + } - weight_grad + weight_grad } fn conv_transpose2d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvTransposeOptions<2>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, ) -> FloatTensor { - let [channels_in, increment_co, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; - let increment_ci = channels_in / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv2d( - grad, - x, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::shape(&weight_grad_tmp).dims; - - if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { - weight_grad_tmp = B::slice( - weight_grad_tmp, - [ - 0..increment_ci, - 0..increment_co, - 0..kernel_size_1, - 0..kernel_size_2, - ], - ); - } - - weight_grad = B::slice_assign( - weight_grad, - [ - start_idx_ci..end_idx_ci, - 0..increment_co, - 0..kernel_size_1, - 0..kernel_size_2, - ], - weight_grad_tmp, - ); + let [channels_in, increment_co, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; + let increment_ci = channels_in / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv2d( + grad, + x, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::shape(&weight_grad_tmp).dims; + + if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { + weight_grad_tmp = B::slice( + weight_grad_tmp, + [ + 0..increment_ci, + 0..increment_co, + 0..kernel_size_1, + 0..kernel_size_2, + ], + ); } - weight_grad + weight_grad = B::slice_assign( + weight_grad, + [ + start_idx_ci..end_idx_ci, + 0..increment_co, + 0..kernel_size_1, + 0..kernel_size_2, + ], + weight_grad_tmp, + ); + } + + weight_grad } fn conv1d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<3>, - options: ConvOptions<1>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<3>, + options: ConvOptions<1>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv1d( - x_swapped, - output_grad_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv1d( + x_swapped, + output_grad_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + if B::shape(&weight_grad) != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + ], ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - if B::shape(&weight_grad) != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - ], - ); - } - weight_grad + } + weight_grad } fn conv_transpose1d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvTransposeOptions<1>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, ) -> FloatTensor { - let [channels_in, increment_co, kernel_size] = B::shape(&weight_grad).dims; - let increment_ci = channels_in / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv1d( - grad, - x, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - let [_, _, kernel_size_tmp] = B::shape(&weight_grad_tmp).dims; - - if kernel_size_tmp != kernel_size { - weight_grad_tmp = B::slice( - weight_grad_tmp, - [0..increment_ci, 0..increment_co, 0..kernel_size], - ); - } - - weight_grad = B::slice_assign( - weight_grad, - [start_idx_ci..end_idx_ci, 0..increment_co, 0..kernel_size], - weight_grad_tmp, - ); + let [channels_in, increment_co, kernel_size] = B::shape(&weight_grad).dims; + let increment_ci = channels_in / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv1d( + grad, + x, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + let [_, _, kernel_size_tmp] = B::shape(&weight_grad_tmp).dims; + + if kernel_size_tmp != kernel_size { + weight_grad_tmp = B::slice( + weight_grad_tmp, + [0..increment_ci, 0..increment_co, 0..kernel_size], + ); } - weight_grad + weight_grad = B::slice_assign( + weight_grad, + [start_idx_ci..end_idx_ci, 0..increment_co, 0..kernel_size], + weight_grad_tmp, + ); + } + + weight_grad } fn conv2d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<4>, - options: ConvOptions<2>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<4>, + options: ConvOptions<2>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv2d( - x_swapped, - output_grad_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv2d( + x_swapped, + output_grad_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + if B::shape(&weight_grad) != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + 0..weight_shape.dims[3], + ], ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - if B::shape(&weight_grad) != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - 0..weight_shape.dims[3], - ], - ); - } - weight_grad + } + weight_grad } fn conv_transpose1d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<3>, - options: ConvTransposeOptions<1>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<3>, + options: ConvTransposeOptions<1>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv1d( - output_grad_swapped, - x_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv1d( + output_grad_swapped, + x_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + let grad_shape = B::shape(&weight_grad); + + if grad_shape != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + ], ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - let grad_shape = B::shape(&weight_grad); - - if grad_shape != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - ], - ); - } - weight_grad + } + weight_grad } fn conv_transpose2d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<4>, - options: ConvTransposeOptions<2>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<4>, + options: ConvTransposeOptions<2>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv2d( - output_grad_swapped, - x_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv2d( + output_grad_swapped, + x_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + let grad_shape = B::shape(&weight_grad); + + if grad_shape != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + 0..weight_shape.dims[3], + ], ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - let grad_shape = B::shape(&weight_grad); - - if grad_shape != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - 0..weight_shape.dims[3], - ], - ); - } - weight_grad + } + weight_grad } fn calculate_padding_out( - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - size_in: usize, - size_out: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + size_in: usize, + size_out: usize, ) -> usize { - if stride <= 1 { - return 0; - } - - let out = 1 + libm::ceil( - (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64, - ) as usize; - i64::max(0, out as i64 - size_out as i64) as usize + if stride <= 1 { + return 0; + } + + let out = 1 + + libm::ceil((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64) + as usize; + i64::max(0, out as i64 - size_out as i64) as usize } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_calculate_output_size_1() { - let kernel_size = 3; - let stride = 1; - let padding = 1; - let size_in = 3; - let dilation = 1; - - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, 3); - } - - #[test] - fn test_calculate_output_size_2() { - let kernel_size = 5; - let stride = 2; - let padding = 3; - let size_in = 27; - let dilation = 1; - - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, 15); - } - - #[test] - fn test_calculate_output_size_3() { - let kernel_size = 5; - let stride = 2; - let padding = 3; - let size_in = 27; - let dilation = 2; - - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, 13); - } - - #[test] - fn test_calculate_same_padding_1() { - let kernel_size = 3; - let stride = 1; - let size_in = 3; - let dilation = 1; - - let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_in, size_out, "Expected size"); - } - - #[test] - fn test_calculate_same_padding_2() { - let kernel_size = 3; - let stride = 2; - let size_in = 7; - let dilation = 1; - - let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_in, size_out, "Expected size"); - } - - #[test] - fn test_calculate_output_padding_1() { - let kernel_size = 3; - let stride = 2; - let size_in = 7; - let size_out = 10; - let dilation = 1; - - let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out); - let size_out_expected = - calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, size_out_expected, "Expected size"); - } + use super::*; + + #[test] + fn test_calculate_output_size_1() { + let kernel_size = 3; + let stride = 1; + let padding = 1; + let size_in = 3; + let dilation = 1; + + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, 3); + } + + #[test] + fn test_calculate_output_size_2() { + let kernel_size = 5; + let stride = 2; + let padding = 3; + let size_in = 27; + let dilation = 1; + + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, 15); + } + + #[test] + fn test_calculate_output_size_3() { + let kernel_size = 5; + let stride = 2; + let padding = 3; + let size_in = 27; + let dilation = 2; + + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, 13); + } + + #[test] + fn test_calculate_same_padding_1() { + let kernel_size = 3; + let stride = 1; + let size_in = 3; + let dilation = 1; + + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_in, size_out, "Expected size"); + } + + #[test] + fn test_calculate_same_padding_2() { + let kernel_size = 3; + let stride = 2; + let size_in = 7; + let dilation = 1; + + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_in, size_out, "Expected size"); + } + + #[test] + fn test_calculate_output_padding_1() { + let kernel_size = 3; + let stride = 2; + let size_in = 7; + let size_out = 10; + let dilation = 1; + + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out); + let size_out_expected = + calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, size_out_expected, "Expected size"); + } } diff --git a/burn-tensor/src/tensor/ops/modules/pool.rs b/burn-tensor/src/tensor/ops/modules/pool.rs index 4d096ae4a5..0b3c830687 100644 --- a/burn-tensor/src/tensor/ops/modules/pool.rs +++ b/burn-tensor/src/tensor/ops/modules/pool.rs @@ -1,167 +1,167 @@ use crate::{ - backend::Backend, - ops::{FloatTensor, IntTensor}, - Shape, + backend::Backend, + ops::{FloatTensor, IntTensor}, + Shape, }; use super::{MaxPool1dBackward, MaxPool1dWithIndices}; pub(crate) fn avg_pool1d_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, ) -> FloatTensor { - let [batch_size, channels, length] = B::shape(&x).dims; + let [batch_size, channels, length] = B::shape(&x).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); - let x = B::avg_pool2d( - x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - count_include_pad, - ); + let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); + let x = B::avg_pool2d( + x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + count_include_pad, + ); - let [batch_size, channels, length, _] = B::shape(&x).dims; + let [batch_size, channels, length, _] = B::shape(&x).dims; - B::reshape(x, Shape::from([batch_size, channels, length])) + B::reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn avg_pool1d_backward_from_2d( - x: FloatTensor, - grad: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, + x: FloatTensor, + grad: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, ) -> FloatTensor { - let [batch_size, channels, length_in] = B::shape(&x).dims; - let [_, _, length_out] = B::shape(&grad).dims; - - let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); - let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); - - let grad_x = B::avg_pool2d_backward( - x, - grad_x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - count_include_pad, - ); - - B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) + let [batch_size, channels, length_in] = B::shape(&x).dims; + let [_, _, length_out] = B::shape(&grad).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); + let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); + + let grad_x = B::avg_pool2d_backward( + x, + grad_x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + count_include_pad, + ); + + B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) } pub(crate) fn adaptive_avg_pool1d_from_2d( - x: FloatTensor, - output_size: usize, + x: FloatTensor, + output_size: usize, ) -> FloatTensor { - let [batch_size, channels, length] = B::shape(&x).dims; + let [batch_size, channels, length] = B::shape(&x).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); - let x = B::adaptive_avg_pool2d(x, [output_size, 1]); + let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); + let x = B::adaptive_avg_pool2d(x, [output_size, 1]); - let [batch_size, channels, length, _] = B::shape(&x).dims; + let [batch_size, channels, length, _] = B::shape(&x).dims; - B::reshape(x, Shape::from([batch_size, channels, length])) + B::reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn adaptive_avg_pool1d_backward_from_2d( - x: FloatTensor, - grad: FloatTensor, + x: FloatTensor, + grad: FloatTensor, ) -> FloatTensor { - let [batch_size, channels, length_in] = B::shape(&x).dims; - let [_, _, length_out] = B::shape(&grad).dims; + let [batch_size, channels, length_in] = B::shape(&x).dims; + let [_, _, length_out] = B::shape(&grad).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); - let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); + let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); + let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); - let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x); + let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x); - B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) + B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) } pub(crate) fn max_pool1d_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> FloatTensor { - let [batch_size, channels, length] = B::shape(&x).dims; + let [batch_size, channels, length] = B::shape(&x).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); - let x = B::max_pool2d( - x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - [dilation, 1], - ); + let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); + let x = B::max_pool2d( + x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + [dilation, 1], + ); - let [batch_size, channels, length, _] = B::shape(&x).dims; + let [batch_size, channels, length, _] = B::shape(&x).dims; - B::reshape(x, Shape::from([batch_size, channels, length])) + B::reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn max_pool1d_with_indices_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> MaxPool1dWithIndices { - let [batch_size, channels, length] = B::shape(&x).dims; - - let x = B::reshape(x, Shape::from([batch_size, channels, 1, length])); - let x = B::max_pool2d_with_indices( - x, - [1, kernel_size], - [1, stride], - [0, padding], - [1, dilation], - ); - let [batch_size, channels, _, length] = B::shape(&x.output).dims; - let output = B::reshape(x.output, Shape::from([batch_size, channels, length])); - let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); - MaxPool1dWithIndices::new(output, indices) + let [batch_size, channels, length] = B::shape(&x).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, 1, length])); + let x = B::max_pool2d_with_indices( + x, + [1, kernel_size], + [1, stride], + [0, padding], + [1, dilation], + ); + let [batch_size, channels, _, length] = B::shape(&x.output).dims; + let output = B::reshape(x.output, Shape::from([batch_size, channels, length])); + let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); + MaxPool1dWithIndices::new(output, indices) } pub(crate) fn max_pool1d_with_indices_backward_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: FloatTensor, - indices: IntTensor, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: FloatTensor, + indices: IntTensor, ) -> MaxPool1dBackward { - let [batch_size, channels, length_in] = B::shape(&x).dims; - let [_, _, length_out] = B::shape(&output_grad).dims; - - let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); - let grad_x = B::reshape( - output_grad, - Shape::from([batch_size, channels, length_out, 1]), - ); - let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1])); - - let grad_x = B::max_pool2d_with_indices_backward( - x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - [dilation, 1], - grad_x, - indices, - ) - .x_grad; - - MaxPool1dBackward::new(B::reshape( - grad_x, - Shape::from([batch_size, channels, length_in]), - )) + let [batch_size, channels, length_in] = B::shape(&x).dims; + let [_, _, length_out] = B::shape(&output_grad).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); + let grad_x = B::reshape( + output_grad, + Shape::from([batch_size, channels, length_out, 1]), + ); + let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1])); + + let grad_x = B::max_pool2d_with_indices_backward( + x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + [dilation, 1], + grad_x, + indices, + ) + .x_grad; + + MaxPool1dBackward::new(B::reshape( + grad_x, + Shape::from([batch_size, channels, length_in]), + )) } diff --git a/burn-tensor/src/tensor/ops/modules/unfold.rs b/burn-tensor/src/tensor/ops/modules/unfold.rs index 65f47562e8..b542bfe1a5 100644 --- a/burn-tensor/src/tensor/ops/modules/unfold.rs +++ b/burn-tensor/src/tensor/ops/modules/unfold.rs @@ -15,72 +15,71 @@ use super::{ConvOptions, UnfoldOptions}; /// the convolution operation's mechanism as it moves across the input tensor, picking up the desired /// values in the pattern of the unfolding operation. pub(crate) fn create_unfolding_weight( - in_channels: usize, - kernel_size: [usize; 2], - device: &B::Device, + in_channels: usize, + kernel_size: [usize; 2], + device: &B::Device, ) -> FloatTensor { - let shape = Shape::new([ - in_channels * kernel_size[0] * kernel_size[1], - in_channels, - kernel_size[0], - kernel_size[1], - ]); + let shape = Shape::new([ + in_channels * kernel_size[0] * kernel_size[1], + in_channels, + kernel_size[0], + kernel_size[1], + ]); - let mut strides = [0; 4]; - let mut current = 1; - shape - .dims - .iter() - .enumerate() - .rev() - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - }); + let mut strides = [0; 4]; + let mut current = 1; + shape + .dims + .iter() + .enumerate() + .rev() + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); - let num_elements = shape.num_elements(); + let num_elements = shape.num_elements(); - let mut weight: Vec = vec![0.0.elem(); num_elements]; + let mut weight: Vec = vec![0.0.elem(); num_elements]; - for k in 0..in_channels { - for i in 0..kernel_size[0] { - for j in 0..kernel_size[1] { - let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j; - let index = - output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3]; + for k in 0..in_channels { + for i in 0..kernel_size[0] { + for j in 0..kernel_size[1] { + let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j; + let index = output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3]; - weight[index] = 1.elem(); - } - } + weight[index] = 1.elem(); + } } + } - B::from_data(Data::new(weight, shape), device) + B::from_data(Data::new(weight, shape), device) } /// Compute the unfold4d operation using the conv2d operations. pub(crate) fn unfold4d_using_conv2d( - x: FloatTensor, - kernel_size: [usize; 2], - options: UnfoldOptions, + x: FloatTensor, + kernel_size: [usize; 2], + options: UnfoldOptions, ) -> FloatTensor { - let [_batch_size, in_channels, _in_height, _in_width] = B::shape(&x).dims; - let weight = create_unfolding_weight::(in_channels, kernel_size, &B::device(&x)); - let unfolded = B::conv2d( - x, - weight, - None, - ConvOptions { - stride: options.stride, - padding: options.padding, - dilation: options.dilation, - groups: 1, - }, - ); + let [_batch_size, in_channels, _in_height, _in_width] = B::shape(&x).dims; + let weight = create_unfolding_weight::(in_channels, kernel_size, &B::device(&x)); + let unfolded = B::conv2d( + x, + weight, + None, + ConvOptions { + stride: options.stride, + padding: options.padding, + dilation: options.dilation, + groups: 1, + }, + ); - let [batch_size, channels_out, out_height, out_width] = B::shape(&unfolded).dims; + let [batch_size, channels_out, out_height, out_width] = B::shape(&unfolded).dims; - B::reshape( - unfolded, - Shape::new([batch_size, channels_out, out_height * out_width]), - ) + B::reshape( + unfolded, + Shape::new([batch_size, channels_out, out_height * out_width]), + ) } diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 63c081df40..056c6d92d9 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -6,1073 +6,1064 @@ use core::ops::Range; /// Operations on float tensors. pub trait TensorOps { - /// Creates a new tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given data. - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor; - - /// Creates a new tensor with random values. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `distribution` - The distribution to sample from. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given shape and random values. - fn random( - shape: Shape, - distribution: Distribution>, - device: &Device, - ) -> FloatTensor; - - /// Creates a new tensor with zeros. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given shape and zeros. - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::zeros(shape), device) + /// Creates a new tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given data. + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor; + + /// Creates a new tensor with random values. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `distribution` - The distribution to sample from. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given shape and random values. + fn random( + shape: Shape, + distribution: Distribution>, + device: &Device, + ) -> FloatTensor; + + /// Creates a new tensor with zeros. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given shape and zeros. + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::zeros(shape), device) + } + + /// Creates a new tensor with ones. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given shape and ones. + fn ones(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::ones(shape), device) + } + + /// Creates a tensor filled with given value. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `fill_value` - The value with which to fill the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor filled with given value + fn full( + shape: Shape, + fill_value: FloatElem, + device: &Device, + ) -> FloatTensor { + Self::add_scalar(Self::zeros(shape, device), fill_value) + } + + /// Gets the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn shape(tensor: &FloatTensor) -> Shape; + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn to_data(tensor: &FloatTensor) -> Reader, D>> { + Self::into_data(tensor.clone()) + } + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn into_data(tensor: FloatTensor) -> Reader, D>>; + + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn device(tensor: &FloatTensor) -> Device; + + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device to move the tensor to. + /// + /// # Returns + /// + /// The tensor on the given device. + fn to_device(tensor: FloatTensor, device: &Device) -> FloatTensor; + + /// Creates a new tensor with values from the given range. + /// + /// # Arguments + /// + /// * `range` - The range of values. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given values. + /// + /// # Remarks + /// + /// Uses `arange_step` with a step size of 1 under the hood. + fn arange(range: Range, device: &Device) -> IntTensor { + Self::arange_step(range, 1, device) + } + + /// Converts float tensor to int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The int tensor with the same data as the float tensor. + fn into_int(tensor: FloatTensor) -> IntTensor; + + /// Creates a new tensor with values from the given range with the given step size. + /// + /// # Arguments + /// + /// * `range` - The range of values. + /// * `step` - The step size. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given values. + fn arange_step(range: Range, step: usize, device: &Device) -> IntTensor { + let value = range + .step_by(step) + .map(|i| (i as i64).elem()) + .collect::>>(); + let shape = Shape::new([value.len()]); + let data = Data::new(value, shape); + B::int_from_data(data, device) + } + + /// Creates an empty tensor with the given shape. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The empty tensor with the given shape. + fn empty(shape: Shape, device: &Device) -> FloatTensor; + + /// Repeat the tensor along the given dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to repeat. + /// * `times` - The number of times to repeat the dimension. + /// + /// # Returns + /// + /// The tensor with the given dimension repeated. + fn repeat( + tensor: FloatTensor, + dim: usize, + times: usize, + ) -> FloatTensor { + let mut shape = B::shape(&tensor); + if shape.dims[dim] != 1 { + panic!("Can only repeat dimension with dim=1"); } - - /// Creates a new tensor with ones. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given shape and ones. - fn ones(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::ones(shape), device) - } - - /// Creates a tensor filled with given value. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `fill_value` - The value with which to fill the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor filled with given value - fn full( - shape: Shape, - fill_value: FloatElem, - device: &Device, - ) -> FloatTensor { - Self::add_scalar(Self::zeros(shape, device), fill_value) - } - - /// Gets the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn shape(tensor: &FloatTensor) -> Shape; - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn to_data(tensor: &FloatTensor) -> Reader, D>> { - Self::into_data(tensor.clone()) - } - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn into_data(tensor: FloatTensor) -> Reader, D>>; - - /// Gets the device of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device of the tensor. - fn device(tensor: &FloatTensor) -> Device; - - /// Moves the tensor to the given device. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `device` - The device to move the tensor to. - /// - /// # Returns - /// - /// The tensor on the given device. - fn to_device( - tensor: FloatTensor, - device: &Device, - ) -> FloatTensor; - - /// Creates a new tensor with values from the given range. - /// - /// # Arguments - /// - /// * `range` - The range of values. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given values. - /// - /// # Remarks - /// - /// Uses `arange_step` with a step size of 1 under the hood. - fn arange(range: Range, device: &Device) -> IntTensor { - Self::arange_step(range, 1, device) - } - - /// Converts float tensor to int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The int tensor with the same data as the float tensor. - fn into_int(tensor: FloatTensor) -> IntTensor; - - /// Creates a new tensor with values from the given range with the given step size. - /// - /// # Arguments - /// - /// * `range` - The range of values. - /// * `step` - The step size. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given values. - fn arange_step(range: Range, step: usize, device: &Device) -> IntTensor { - let value = range - .step_by(step) - .map(|i| (i as i64).elem()) - .collect::>>(); - let shape = Shape::new([value.len()]); - let data = Data::new(value, shape); - B::int_from_data(data, device) - } - - /// Creates an empty tensor with the given shape. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The empty tensor with the given shape. - fn empty(shape: Shape, device: &Device) -> FloatTensor; - - /// Repeat the tensor along the given dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to repeat. - /// * `times` - The number of times to repeat the dimension. - /// - /// # Returns - /// - /// The tensor with the given dimension repeated. - fn repeat( - tensor: FloatTensor, - dim: usize, - times: usize, - ) -> FloatTensor { - let mut shape = B::shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } - shape.dims[dim] = times; - - let mut i = 0; - let indices_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); - - let mut tensor_output = B::empty(shape, &B::device(&tensor)); - for i in 0..times { - let mut indices = indices_select_all.clone(); - indices[dim] = i..i + 1; - tensor_output = B::slice_assign(tensor_output, indices, tensor.clone()); - } - - tensor_output - } - - /// Adds two tensors together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of adding the two tensors together. - fn add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Adds a scalar to a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of adding the scalar to the tensor. - fn add_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - // Default implementation - let mask = Self::lower_elem(tensor.clone(), min); - B::mask_fill(tensor, mask, min) - } - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - // Default implementation - let mask = Self::greater_elem(tensor.clone(), max); - B::mask_fill(tensor, mask, max) - } - - /// Clamps a tensor between a minimum and maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - // Default implementation - Self::clamp_min(Self::clamp_max(tensor, max), min) - } - - /// Subtracts two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of subtracting the two tensors. - fn sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Subtracts a scalar from a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of subtracting the scalar from the tensor. - fn sub_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Multiplies two tensors together element-wise. - fn mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Multiplies a tensor by a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of multiplying the tensor by the scalar. - fn mul_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Divides two tensors element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of dividing the two tensors. - fn div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Divides a tensor by a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of dividing the tensor by the scalar. - fn div_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Multiplies two tensors together using matrix multiplication. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of multiplying the two tensors together using matrix multiplication. - fn matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Negates a tensor element-wise. - fn neg(tensor: FloatTensor) -> FloatTensor { - Self::mul_scalar(tensor, (-1.0_f32).elem::>()) - } - - /// Calculates the reciprocals elementwise - fn recip(tensor: FloatTensor) -> FloatTensor; - - /// Transposes a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn transpose(tensor: FloatTensor) -> FloatTensor { - Self::swap_dims(tensor, D - 2, D - 1) - } - - /// Swaps two dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor; - - /// Reshapes a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to reshape. - /// * `shape` - The new shape of the tensor. - /// - /// # Returns - /// - /// The tensor with the new shape. - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor; - - /// Gather elements from a tensor. - /// - /// # Arguments - /// - /// * `dim` - The dimension to gather from. - /// * `tensor` - The tensor to gather from. - /// * `indices` - The indices to gather. - /// - /// # Returns - /// - /// The gathered elements. - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor; - - /// Scatter elements into a tensor. - /// - /// # Arguments - /// - /// * `dim` - The dimension to scatter into. - /// * `tensor` - The tensor to scatter into. - /// * `indices` - The indices to scatter into. - /// * `value` - The value to scatter. - /// - /// # Returns - /// - /// The tensor with the scattered elements. - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor; - - /// Select tensor elements along the given dimension corresponding for the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices to select. - /// - /// # Returns - /// - /// The selected elements. - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor; - - /// Assign the selected elements along the given dimension corresponding for the given indices - /// to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor; - - /// Select tensor elements corresponding for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `ranges` - The ranges to select. - /// - /// # Returns - /// - /// The selected elements in a new tensor. - fn slice( - tensor: FloatTensor, - ranges: [Range; D2], - ) -> FloatTensor; - - /// Assign the selected elements corresponding for the given ranges to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `ranges` - The ranges to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn slice_assign( - tensor: FloatTensor, - ranges: [Range; D2], - value: FloatTensor, - ) -> FloatTensor; - - /// Update the given tensor with the value tensor where the mask is true. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `mask` - The boolean mask to select with. - /// * `value` - The value to assign to the selected elements from the value tensor. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor; - - /// Update the given tensor with the value where the mask is true. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `mask` - The boolean mask to select with. - /// * `value` - The value to assign to the selected elements. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor; - - /// Equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; - - /// Equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; - - /// Greater than comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; - - /// Greater than comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; - - /// Greater than or equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor; - - /// Greater than or equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor; - - /// Less than comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; - - /// Less than comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; - - /// Less than or equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor; - - /// Less than or equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor; - - /// Detaches a tensor from the computation graph. - fn detach(tensor: FloatTensor) -> FloatTensor { - // Should only be overridden by autodiff backends. - tensor - } - - /// Sets the `require_grad` flag of a tensor. - fn set_require_grad( - tensor: FloatTensor, - _require_grad: bool, - ) -> FloatTensor { - // Should only be overridden by autodiff backends. - tensor - } - - /// Returns the `require_grad` flag of a tensor. - fn is_require_grad(_tensor: &FloatTensor) -> bool { - // Should only be overridden by autodiff backends. - false - } - - /// Sum of all elements in a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// - /// # Returns - /// - /// A scalar tensor with the sum of all elements in `tensor`. - fn sum(tensor: FloatTensor) -> FloatTensor; - - /// Sum of all elements in a tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// * `dim` - The dimension along which to sum. - /// - /// # Returns - /// - /// A tensor with the sum of all elements in `tensor` along `dim`. - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; - - /// Mean of all elements in a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to mean. - /// - /// # Returns - /// - /// A scalar tensor with the mean of all elements in `tensor`. - fn mean(tensor: FloatTensor) -> FloatTensor { - let num_elems = B::shape(&tensor).num_elements(); - B::div_scalar(B::sum(tensor), (num_elems as i64).elem()) - } - - /// Mean of all elements in a tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to mean. - /// * `dim` - The dimension along which to mean. - /// - /// # Returns - /// - /// A tensor with the mean of all elements in `tensor` along `dim`. - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; - - /// Converts a tensor to full precision. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to convert. - /// - /// # Returns - /// - /// A tensor with the same values as `tensor` but with full precision. - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D>; - - /// Converts a tensor from full precision. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to convert. - /// - /// # Returns - /// - /// A tensor with the same values as `tensor` but with the precision of the backend. - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor; - - /// Returns a new tensor with exponential values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to exponentiate. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with exponential values. - fn exp(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with natural logarithm values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the logarithm of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with natural logarithm values. - fn log(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with logarithm values of (1 + Xi). - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the logarithm of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). - fn log1p(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with values raised to the power of `value`. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to exponentiate. - /// * `value` - The exponent. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with values raised to the power of `value`. - fn powf(tensor: FloatTensor, value: f32) -> FloatTensor; - - /// Returns a new tensor with square root values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the square root of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with square root values. - fn sqrt(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with absolute values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take absolute value of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with absolute values. - fn abs(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with cosine values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the cosine of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with cosine values. - fn cos(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with sine values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the sine of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with sine values. - fn sin(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with tangent values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the tangent of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with tangent values. - fn tanh(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with the error function values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the error function of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with error function values. - fn erf(tensor: FloatTensor) -> FloatTensor; - - /// Catcatenates tensors along a dimension. - /// - /// # Arguments - /// - /// * `tensors` - The tensors to catcatenate. - /// * `dim` - The dimension along which to catcatenate. - /// - /// # Returns - /// - /// A tensor with the catcatenated tensors along `dim`. - fn cat(tensors: Vec>, dim: usize) -> FloatTensor; - - /// Gets the indices of the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// * `dim` - The dimension along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the indices of the maximum elements of `tensor` along `dim`. - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor; - - /// Gets the indices of the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// * `dim` - The dimension along which to get the minimum elements. - /// - /// # Returns - /// - /// A tensor with the indices of the minimum elements of `tensor` along `dim`. - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor; - - /// Gets the maximum element of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// - /// # Returns - /// - /// A tensor with the maximum element of `tensor`. - fn max(tensor: FloatTensor) -> FloatTensor { - let shape = B::shape(&tensor); - let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); - - B::max_dim(tensor, 0) - } - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// * `dim` - The dimension along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the maximum elements of `tensor` along `dim`. - fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - let index = B::argmax(tensor.clone(), dim); - - B::gather(D - 1, tensor, index) - } - - /// Gets the maximum elements of a tensor along an axis and their indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// * `dim` - The dimension along which to get the maximum elements. - /// - /// # Returns - /// - /// A tuple with the maximum elements of `tensor` along `dim` and their indices. - fn max_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - let index = B::argmax(tensor.clone(), dim); - let values = B::gather(D - 1, tensor, index.clone()); - - (values, index) + shape.dims[dim] = times; + + let mut i = 0; + let indices_select_all = [0; D].map(|_| { + let start = 0; + let end = shape.dims[i]; + i += 1; + start..end + }); + + let mut tensor_output = B::empty(shape, &B::device(&tensor)); + for i in 0..times { + let mut indices = indices_select_all.clone(); + indices[dim] = i..i + 1; + tensor_output = B::slice_assign(tensor_output, indices, tensor.clone()); } - /// Gets the minimum element of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// - /// # Returns - /// - /// A tensor with the minimum element of `tensor`. - fn min(tensor: FloatTensor) -> FloatTensor { - let shape = B::shape(&tensor); - let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); - - B::min_dim(tensor, 0) - } - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// * `dim` - The dimension along which to get the minimum elements. - /// - /// # Returns - /// - /// A tensor with the minimum elements of `tensor` along `dim`. - fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - let index = B::argmin(tensor.clone(), dim); - - B::gather(D - 1, tensor, index) - } - - /// Gets the minimum elements of a tensor along an axis and their indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// * `dim` - The dimension along which to get the minimum elements. - /// - /// # Returns - /// - /// A tuple with the minimum elements of `tensor` along `dim` and their indices. - fn min_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - let index = B::argmin(tensor.clone(), dim); - let values = B::gather(D - 1, tensor, index.clone()); - - (values, index) - } + tensor_output + } + + /// Adds two tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Adds a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of adding the scalar to the tensor. + fn add_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn clamp_min(tensor: FloatTensor, min: FloatElem) -> FloatTensor { + // Default implementation + let mask = Self::lower_elem(tensor.clone(), min); + B::mask_fill(tensor, mask, min) + } + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn clamp_max(tensor: FloatTensor, max: FloatElem) -> FloatTensor { + // Default implementation + let mask = Self::greater_elem(tensor.clone(), max); + B::mask_fill(tensor, mask, max) + } + + /// Clamps a tensor between a minimum and maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + // Default implementation + Self::clamp_min(Self::clamp_max(tensor, max), min) + } + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Subtracts a scalar from a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of subtracting the scalar from the tensor. + fn sub_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Multiplies two tensors together element-wise. + fn mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Multiplies a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of multiplying the tensor by the scalar. + fn mul_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Divides two tensors element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Divides a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of dividing the tensor by the scalar. + fn div_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Multiplies two tensors together using matrix multiplication. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together using matrix multiplication. + fn matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Negates a tensor element-wise. + fn neg(tensor: FloatTensor) -> FloatTensor { + Self::mul_scalar(tensor, (-1.0_f32).elem::>()) + } + + /// Calculates the reciprocals elementwise + fn recip(tensor: FloatTensor) -> FloatTensor; + + /// Transposes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn transpose(tensor: FloatTensor) -> FloatTensor { + Self::swap_dims(tensor, D - 2, D - 1) + } + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor; + + /// Reshapes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reshape. + /// * `shape` - The new shape of the tensor. + /// + /// # Returns + /// + /// The tensor with the new shape. + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor; + + /// Gather elements from a tensor. + /// + /// # Arguments + /// + /// * `dim` - The dimension to gather from. + /// * `tensor` - The tensor to gather from. + /// * `indices` - The indices to gather. + /// + /// # Returns + /// + /// The gathered elements. + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor; + + /// Scatter elements into a tensor. + /// + /// # Arguments + /// + /// * `dim` - The dimension to scatter into. + /// * `tensor` - The tensor to scatter into. + /// * `indices` - The indices to scatter into. + /// * `value` - The value to scatter. + /// + /// # Returns + /// + /// The tensor with the scattered elements. + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor; + + /// Select tensor elements along the given dimension corresponding for the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices to select. + /// + /// # Returns + /// + /// The selected elements. + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor; + + /// Assign the selected elements along the given dimension corresponding for the given indices + /// to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor; + + /// Select tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `ranges` - The ranges to select. + /// + /// # Returns + /// + /// The selected elements in a new tensor. + fn slice( + tensor: FloatTensor, + ranges: [Range; D2], + ) -> FloatTensor; + + /// Assign the selected elements corresponding for the given ranges to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `ranges` - The ranges to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn slice_assign( + tensor: FloatTensor, + ranges: [Range; D2], + value: FloatTensor, + ) -> FloatTensor; + + /// Update the given tensor with the value tensor where the mask is true. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `mask` - The boolean mask to select with. + /// * `value` - The value to assign to the selected elements from the value tensor. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor; + + /// Update the given tensor with the value where the mask is true. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `mask` - The boolean mask to select with. + /// * `value` - The value to assign to the selected elements. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor; + + /// Equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; + + /// Equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; + + /// Greater than comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; + + /// Greater than comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; + + /// Greater than or equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor; + + /// Greater than or equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor; + + /// Less than comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; + + /// Less than comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; + + /// Less than or equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor; + + /// Less than or equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor; + + /// Detaches a tensor from the computation graph. + fn detach(tensor: FloatTensor) -> FloatTensor { + // Should only be overridden by autodiff backends. + tensor + } + + /// Sets the `require_grad` flag of a tensor. + fn set_require_grad( + tensor: FloatTensor, + _require_grad: bool, + ) -> FloatTensor { + // Should only be overridden by autodiff backends. + tensor + } + + /// Returns the `require_grad` flag of a tensor. + fn is_require_grad(_tensor: &FloatTensor) -> bool { + // Should only be overridden by autodiff backends. + false + } + + /// Sum of all elements in a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// + /// # Returns + /// + /// A scalar tensor with the sum of all elements in `tensor`. + fn sum(tensor: FloatTensor) -> FloatTensor; + + /// Sum of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// A tensor with the sum of all elements in `tensor` along `dim`. + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; + + /// Mean of all elements in a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to mean. + /// + /// # Returns + /// + /// A scalar tensor with the mean of all elements in `tensor`. + fn mean(tensor: FloatTensor) -> FloatTensor { + let num_elems = B::shape(&tensor).num_elements(); + B::div_scalar(B::sum(tensor), (num_elems as i64).elem()) + } + + /// Mean of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to mean. + /// * `dim` - The dimension along which to mean. + /// + /// # Returns + /// + /// A tensor with the mean of all elements in `tensor` along `dim`. + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; + + /// Converts a tensor to full precision. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to convert. + /// + /// # Returns + /// + /// A tensor with the same values as `tensor` but with full precision. + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D>; + + /// Converts a tensor from full precision. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to convert. + /// + /// # Returns + /// + /// A tensor with the same values as `tensor` but with the precision of the backend. + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor; + + /// Returns a new tensor with exponential values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to exponentiate. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with exponential values. + fn exp(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with natural logarithm values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the logarithm of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with natural logarithm values. + fn log(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with logarithm values of (1 + Xi). + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the logarithm of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). + fn log1p(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with values raised to the power of `value`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to exponentiate. + /// * `value` - The exponent. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with values raised to the power of `value`. + fn powf(tensor: FloatTensor, value: f32) -> FloatTensor; + + /// Returns a new tensor with square root values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the square root of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with square root values. + fn sqrt(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with absolute values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take absolute value of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with absolute values. + fn abs(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with cosine values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the cosine of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with cosine values. + fn cos(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with sine values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the sine of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with sine values. + fn sin(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with tangent values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the tangent of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with tangent values. + fn tanh(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with the error function values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the error function of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with error function values. + fn erf(tensor: FloatTensor) -> FloatTensor; + + /// Catcatenates tensors along a dimension. + /// + /// # Arguments + /// + /// * `tensors` - The tensors to catcatenate. + /// * `dim` - The dimension along which to catcatenate. + /// + /// # Returns + /// + /// A tensor with the catcatenated tensors along `dim`. + fn cat(tensors: Vec>, dim: usize) -> FloatTensor; + + /// Gets the indices of the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the indices of the maximum elements of `tensor` along `dim`. + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor; + + /// Gets the indices of the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the indices of the minimum elements of `tensor` along `dim`. + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor; + + /// Gets the maximum element of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// + /// # Returns + /// + /// A tensor with the maximum element of `tensor`. + fn max(tensor: FloatTensor) -> FloatTensor { + let shape = B::shape(&tensor); + let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); + + B::max_dim(tensor, 0) + } + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the maximum elements of `tensor` along `dim`. + fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let index = B::argmax(tensor.clone(), dim); + + B::gather(D - 1, tensor, index) + } + + /// Gets the maximum elements of a tensor along an axis and their indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tuple with the maximum elements of `tensor` along `dim` and their indices. + fn max_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + let index = B::argmax(tensor.clone(), dim); + let values = B::gather(D - 1, tensor, index.clone()); + + (values, index) + } + + /// Gets the minimum element of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// + /// # Returns + /// + /// A tensor with the minimum element of `tensor`. + fn min(tensor: FloatTensor) -> FloatTensor { + let shape = B::shape(&tensor); + let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); + + B::min_dim(tensor, 0) + } + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the minimum elements of `tensor` along `dim`. + fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let index = B::argmin(tensor.clone(), dim); + + B::gather(D - 1, tensor, index) + } + + /// Gets the minimum elements of a tensor along an axis and their indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tuple with the minimum elements of `tensor` along `dim` and their indices. + fn min_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + let index = B::argmin(tensor.clone(), dim); + let values = B::gather(D - 1, tensor, index.clone()); + + (values, index) + } } diff --git a/burn-tensor/src/tensor/shape.rs b/burn-tensor/src/tensor/shape.rs index 00c69501b6..b2bf9744c9 100644 --- a/burn-tensor/src/tensor/shape.rs +++ b/burn-tensor/src/tensor/shape.rs @@ -3,76 +3,76 @@ use alloc::vec::Vec; /// Shape of a tensor. #[derive(new, Debug, Clone, PartialEq, Eq)] pub struct Shape { - /// The dimensions of the tensor. - pub dims: [usize; D], + /// The dimensions of the tensor. + pub dims: [usize; D], } impl Shape { - /// Returns the total number of elements of a tensor having this shape - pub fn num_elements(&self) -> usize { - let mut num_elements = 1; - for i in 0..D { - num_elements *= self.dims[i]; - } - - num_elements + /// Returns the total number of elements of a tensor having this shape + pub fn num_elements(&self) -> usize { + let mut num_elements = 1; + for i in 0..D { + num_elements *= self.dims[i]; } + + num_elements + } } impl From<[usize; D]> for Shape { - fn from(dims: [usize; D]) -> Self { - Shape::new(dims) - } + fn from(dims: [usize; D]) -> Self { + Shape::new(dims) + } } impl From> for Shape { - fn from(shape: Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.into_iter().enumerate() { - dims[i] = dim as usize; - } - Self::new(dims) + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim as usize; } + Self::new(dims) + } } impl From> for Shape { - fn from(shape: Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.into_iter().enumerate() { - dims[i] = dim as usize; - } - Self::new(dims) + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim as usize; } + Self::new(dims) + } } impl From> for Shape { - fn from(shape: Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.into_iter().enumerate() { - dims[i] = dim; - } - Self::new(dims) + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim; } + Self::new(dims) + } } impl From<&Vec> for Shape { - fn from(shape: &Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.iter().enumerate() { - dims[i] = *dim; - } - Self::new(dims) + fn from(shape: &Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.iter().enumerate() { + dims[i] = *dim; } + Self::new(dims) + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn num_elements() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - assert_eq!(120, shape.num_elements()); - } + #[test] + fn num_elements() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + assert_eq!(120, shape.num_elements()); + } } diff --git a/burn-tensor/src/tensor/stats/mod.rs b/burn-tensor/src/tensor/stats/mod.rs index 0ad39dc69a..70dda6288a 100644 --- a/burn-tensor/src/tensor/stats/mod.rs +++ b/burn-tensor/src/tensor/stats/mod.rs @@ -1,38 +1,38 @@ use crate::{backend::Backend, Tensor}; pub fn var(tensor: Tensor, dim: usize) -> Tensor { - let mean = tensor.clone().mean_dim(dim); - var_with_mean(tensor, mean, dim) + let mean = tensor.clone().mean_dim(dim); + var_with_mean(tensor, mean, dim) } pub fn var_with_mean( - tensor: Tensor, - mean: Tensor, - dim: usize, + tensor: Tensor, + mean: Tensor, + dim: usize, ) -> Tensor { - let n = tensor.shape().dims[dim] - 1; - var_with_mean_n(tensor, mean, dim, n) + let n = tensor.shape().dims[dim] - 1; + var_with_mean_n(tensor, mean, dim, n) } pub fn var_bias(tensor: Tensor, dim: usize) -> Tensor { - let mean = tensor.clone().mean_dim(dim); - var_with_mean_bias(tensor, mean, dim) + let mean = tensor.clone().mean_dim(dim); + var_with_mean_bias(tensor, mean, dim) } pub fn var_with_mean_bias( - tensor: Tensor, - mean: Tensor, - dim: usize, + tensor: Tensor, + mean: Tensor, + dim: usize, ) -> Tensor { - let n = tensor.shape().dims[dim]; - var_with_mean_n(tensor, mean, dim, n) + let n = tensor.shape().dims[dim]; + var_with_mean_n(tensor, mean, dim, n) } pub fn var_with_mean_n( - tensor: Tensor, - mean: Tensor, - dim: usize, - n: usize, + tensor: Tensor, + mean: Tensor, + dim: usize, + n: usize, ) -> Tensor { - tensor.sub(mean).powf(2.0).sum_dim(dim).div_scalar(n as f32) + tensor.sub(mean).powf(2.0).sum_dim(dim).div_scalar(n as f32) } diff --git a/burn-tensor/src/tests/activation/gelu.rs b/burn-tensor/src/tests/activation/gelu.rs index a6dc2a617d..aad5288645 100644 --- a/burn-tensor/src/tests/activation/gelu.rs +++ b/burn-tensor/src/tests/activation/gelu.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(gelu)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_gelu() { - let data = Data::from([[ - 0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737, - ]]); - let tensor = Tensor::::from_data(data).clone().clone(); + #[test] + fn test_gelu() { + let data = Data::from([[ + 0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737, + ]]); + let tensor = Tensor::::from_data(data).clone().clone(); - let data_actual = activation::gelu(tensor).to_data(); + let data_actual = activation::gelu(tensor).to_data(); - let data_expected = Data::from([[ - 0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051, - ]]); - data_expected.assert_approx_eq(&data_actual, 2); // Low precision to allow approximation - // implementation using tanh - } + let data_expected = Data::from([[ + 0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051, + ]]); + data_expected.assert_approx_eq(&data_actual, 2); // Low precision to allow approximation + // implementation using tanh + } } diff --git a/burn-tensor/src/tests/activation/quiet_softmax.rs b/burn-tensor/src/tests/activation/quiet_softmax.rs new file mode 100644 index 0000000000..7a3733db70 --- /dev/null +++ b/burn-tensor/src/tests/activation/quiet_softmax.rs @@ -0,0 +1,16 @@ +#[burn_tensor_testgen::testgen(quiet_softmax)] +mod tests { + use super::*; + use burn_tensor::{activation, Data, Tensor}; + + #[test] + fn test_quiet_softmax_d2() { + let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = activation::quiet_softmax(tensor, 1).to_data(); + + let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); + data_actual.assert_approx_eq(&data_expected, 4); + } +} diff --git a/burn-tensor/src/tests/activation/relu.rs b/burn-tensor/src/tests/activation/relu.rs index b9e5ff6623..cbd23cccb4 100644 --- a/burn-tensor/src/tests/activation/relu.rs +++ b/burn-tensor/src/tests/activation/relu.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(relu)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_relu_d2() { - let data = Data::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_relu_d2() { + let data = Data::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::relu(tensor).to_data(); + let data_actual = activation::relu(tensor).to_data(); - let data_expected = Data::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/activation/sigmoid.rs b/burn-tensor/src/tests/activation/sigmoid.rs index 17d55ba7ca..54b889a49b 100644 --- a/burn-tensor/src/tests/activation/sigmoid.rs +++ b/burn-tensor/src/tests/activation/sigmoid.rs @@ -1,27 +1,27 @@ #[burn_tensor_testgen::testgen(sigmoid)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_sigmoid() { - let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_sigmoid() { + let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::sigmoid(tensor).to_data(); + let data_actual = activation::sigmoid(tensor).to_data(); - let data_expected = Data::from([[0.7311, 0.9991], [1.0, 0.0474]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[0.7311, 0.9991], [1.0, 0.0474]]); + data_actual.assert_approx_eq(&data_expected, 4); + } - #[test] - fn test_sigmoid_overflow() { - let data = Data::from([f32::MAX, f32::MIN]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_sigmoid_overflow() { + let data = Data::from([f32::MAX, f32::MIN]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::sigmoid(tensor).to_data(); + let data_actual = activation::sigmoid(tensor).to_data(); - let data_expected = Data::from([1.0, 0.0]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([1.0, 0.0]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/activation/silu.rs b/burn-tensor/src/tests/activation/silu.rs index f207bc6145..32728a6427 100644 --- a/burn-tensor/src/tests/activation/silu.rs +++ b/burn-tensor/src/tests/activation/silu.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(silu)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_silu() { - let data = Data::from([[1.0, 2.0], [3.0, 4.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_silu() { + let data = Data::from([[1.0, 2.0], [3.0, 4.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::silu(tensor).to_data(); + let data_actual = activation::silu(tensor).to_data(); - let data_expected = Data::from([[0.7311, 1.7616], [2.8577, 3.9281]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[0.7311, 1.7616], [2.8577, 3.9281]]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/activation/softmax.rs b/burn-tensor/src/tests/activation/softmax.rs index 7e04168bff..f3a761de90 100644 --- a/burn-tensor/src/tests/activation/softmax.rs +++ b/burn-tensor/src/tests/activation/softmax.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(softmax)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_softmax_d2() { - let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_softmax_d2() { + let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::softmax(tensor, 1).to_data(); + let data_actual = activation::softmax(tensor, 1).to_data(); - let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/activation/tanh_activation.rs b/burn-tensor/src/tests/activation/tanh_activation.rs index 1aaa9e3d0d..a3012b1ac4 100644 --- a/burn-tensor/src/tests/activation/tanh_activation.rs +++ b/burn-tensor/src/tests/activation/tanh_activation.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(tanh_activation)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_tanh() { - let data = Data::from([[1., 2.], [3., 4.]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_tanh() { + let data = Data::from([[1., 2.], [3., 4.]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::tanh(tensor).to_data(); + let data_actual = activation::tanh(tensor).to_data(); - let data_expected = Data::from([[0.7616, 0.9640], [0.9951, 0.9993]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[0.7616, 0.9640], [0.9951, 0.9993]]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/clone_invariance.rs b/burn-tensor/src/tests/clone_invariance.rs index a70973e4b3..879b07a55f 100644 --- a/burn-tensor/src/tests/clone_invariance.rs +++ b/burn-tensor/src/tests/clone_invariance.rs @@ -6,713 +6,713 @@ /// and use different kernels in such cases. We ensure that the results are consistent regardless /// of the approach and that the input tensors are not modified when cloned. mod tests { - use super::*; - use burn_tensor::activation::{ - gelu, log_sigmoid, log_softmax, relu, sigmoid, silu, softmax, tanh, - }; - use burn_tensor::{Data, Distribution, Tensor}; + use super::*; + use burn_tensor::activation::{ + gelu, log_sigmoid, log_softmax, relu, sigmoid, silu, softmax, tanh, + }; + use burn_tensor::{Data, Distribution, Tensor}; - pub trait CloneInvarianceTest { - type Args; + pub trait CloneInvarianceTest { + type Args; - fn args(&self) -> Self::Args; + fn args(&self) -> Self::Args; - fn run(&self, args: &Self::Args, inplace: bool) -> Data; + fn run(&self, args: &Self::Args, inplace: bool) -> Data; - fn check(&self) { - let args = self.args(); - let out = self.run(&args, false); - let out_inplace = self.run(&args, true); + fn check(&self) { + let args = self.args(); + let out = self.run(&args, false); + let out_inplace = self.run(&args, true); - out.assert_approx_eq(&out_inplace, 4); - } + out.assert_approx_eq(&out_inplace, 4); } - - macro_rules! clone_invariance_test { - (unary: $name:ident, ops_float: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = Data; - - fn args(&self) -> Self::Args { - TestTensor::random([32, 32], Distribution::Default) - .into_data() - .convert() - } - - fn run(&self, args: &Self::Args, inplace: bool) -> Data { - let lhs = TestTensor::from_data(args.clone().convert()); - - if inplace { - $ops(lhs).into_data().convert() - } else { - let out = $ops(lhs.clone()).into_data().convert(); - lhs.into_data().assert_approx_eq(args, 4); - out - } - } - } - - CloneInvarianceTest::<2>::check(&$name); - } - }; - - (binary: $name:ident, ops_float: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = (Data, Data); - - fn args(&self) -> Self::Args { - ( - TestTensor::random([32, 32], Distribution::Default) - .into_data() - .convert(), - // Avoid div by zero. - TestTensor::random([32, 32], Distribution::Uniform(1., 3.)) - .into_data() - .convert(), - ) - } - - fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { - let lhs = TestTensor::from_data(lhs_arg.clone().convert()); - let rhs = TestTensor::from_data(rhs_arg.clone().convert()); - - if inplace { - $ops(lhs, rhs).into_data().convert() - } else { - let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); - - lhs.into_data().assert_approx_eq(lhs_arg, 4); - rhs.into_data().assert_approx_eq(rhs_arg, 4); - - out - } - } - } - - CloneInvarianceTest::<2>::check(&$name); + } + + macro_rules! clone_invariance_test { + (unary: $name:ident, ops_float: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = Data; + + fn args(&self) -> Self::Args { + TestTensor::random([32, 32], Distribution::Default) + .into_data() + .convert() + } + + fn run(&self, args: &Self::Args, inplace: bool) -> Data { + let lhs = TestTensor::from_data(args.clone().convert()); + + if inplace { + $ops(lhs).into_data().convert() + } else { + let out = $ops(lhs.clone()).into_data().convert(); + lhs.into_data().assert_approx_eq(args, 4); + out } - }; - - (unary: $name:ident, ops_int: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = Data; - - fn args(&self) -> Self::Args { - TestTensor::random([32, 32], Distribution::Uniform(0.0, 50.0)) - .into_data() - .convert() - } - - fn run(&self, args: &Self::Args, inplace: bool) -> Data { - let lhs = TestTensorInt::from_data(args.clone().convert()); - - if inplace { - $ops(lhs).into_data().convert() - } else { - let out = $ops(lhs.clone()).into_data().convert(); - lhs.into_data().convert().assert_approx_eq(args, 4); - out - } - } - } - - CloneInvarianceTest::<2>::check(&$name); - } - }; - - (binary: $name:ident, ops_int: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = (Data, Data); - - fn args(&self) -> Self::Args { - ( - TestTensor::random([32, 32], Distribution::Uniform(0., 50.)) - .into_data() - .convert(), - // Avoid div by zero. - TestTensor::random([32, 32], Distribution::Uniform(1., 51.)) - .into_data() - .convert(), - ) - } - - fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { - let lhs = TestTensorInt::from_data(lhs_arg.clone().convert()); - let rhs = TestTensorInt::from_data(rhs_arg.clone().convert()); - - if inplace { - $ops(lhs, rhs).into_data().convert() - } else { - let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); - - lhs.into_data().convert().assert_approx_eq(lhs_arg, 4); - rhs.into_data().convert().assert_approx_eq(rhs_arg, 4); - - out - } - } - } - - CloneInvarianceTest::<2>::check(&$name); - } - }; - } + } + } - mod float { - use super::*; - - // Unary ops - clone_invariance_test!( - unary: AddScalar, - ops_float: |tensor: TestTensor<2>| tensor.add_scalar(2.0) - ); - clone_invariance_test!( - unary: SubScalar, - ops_float: |tensor: TestTensor<2>| tensor.sub_scalar(2.0) - ); - clone_invariance_test!( - unary: DivScalar, - ops_float: |tensor: TestTensor<2>| tensor.div_scalar(2.0) - ); - clone_invariance_test!( - unary: MulScalar, - ops_float: |tensor: TestTensor<2>| tensor.mul_scalar(2.0) - ); - clone_invariance_test!( - unary: PowScalar, - ops_float: |tensor: TestTensor<2>| tensor.powf(2.0) - ); - clone_invariance_test!( - unary: Sqrt, - ops_float: |tensor: TestTensor<2>| tensor.sqrt() - ); - clone_invariance_test!( - unary: Exp, - ops_float: |tensor: TestTensor<2>| tensor.exp() - ); - clone_invariance_test!( - unary: Neg, - ops_float: |tensor: TestTensor<2>| tensor.neg() - ); - clone_invariance_test!( - unary: MeanDim, - ops_float: |tensor: TestTensor<2>| tensor.mean_dim(1) - ); - clone_invariance_test!( - unary: SumDim, - ops_float: |tensor: TestTensor<2>| tensor.sum_dim(1) - ); - clone_invariance_test!( - unary: Sum, - ops_float: |tensor: TestTensor<2>| tensor.sum().unsqueeze() - ); - clone_invariance_test!( - unary: Mean, - ops_float: |tensor: TestTensor<2>| tensor.mean().unsqueeze() - ); - clone_invariance_test!( - unary: Clamp, - ops_float: |tensor: TestTensor<2>| tensor.clamp(-2., 2.) - ); - clone_invariance_test!( - unary: ClampMin, - ops_float: |tensor: TestTensor<2>| tensor.clamp_min(-2.) - ); - clone_invariance_test!( - unary: ClampMax, - ops_float: |tensor: TestTensor<2>| tensor.clamp_max(2.) - ); - clone_invariance_test!( - unary: Abs, - ops_float: |tensor: TestTensor<2>| tensor.abs() - ); - clone_invariance_test!( - unary: Cos, - ops_float: |tensor: TestTensor<2>| tensor.cos() - ); - clone_invariance_test!( - unary: Sin, - ops_float: |tensor: TestTensor<2>| tensor.sin() - ); - clone_invariance_test!( - unary: Log, - ops_float: |tensor: TestTensor<2>| tensor.log() - ); - clone_invariance_test!( - unary: Log1P, - ops_float: |tensor: TestTensor<2>| tensor.log1p() - ); - clone_invariance_test!( - unary: SwapDims, - ops_float: |tensor: TestTensor<2>| tensor.swap_dims(0, 1) - ); - clone_invariance_test!( - unary: Transpose, - ops_float: |tensor: TestTensor<2>| tensor.transpose() - ); - clone_invariance_test!( - unary: Slice, - ops_float: |tensor: TestTensor<2>| tensor.slice([0..12, 12..24]) - ); - clone_invariance_test!( - unary: Erf, - ops_float: |tensor: TestTensor<2>| tensor.erf() - ); - clone_invariance_test!( - unary: EqualElem, - ops_float: |tensor: TestTensor<2>| tensor.equal_elem(0.5) - ); - clone_invariance_test!( - unary: GreaterElem, - ops_float: |tensor: TestTensor<2>| tensor.greater_elem(0.5) - ); - clone_invariance_test!( - unary: GreaterEqualElem, - ops_float: |tensor: TestTensor<2>| tensor.greater_equal_elem(0.5) - ); - clone_invariance_test!( - unary: LowerElem, - ops_float: |tensor: TestTensor<2>| tensor.lower_elem(0.5) - ); - clone_invariance_test!( - unary: LowerEqualElem, - ops_float: |tensor: TestTensor<2>| tensor.lower_equal_elem(0.5) - ); - clone_invariance_test!( - unary: Argmax, - ops_float: |tensor: TestTensor<2>| tensor.argmax(0) - ); - clone_invariance_test!( - unary: Argmin, - ops_float: |tensor: TestTensor<2>| tensor.argmin(0) - ); - clone_invariance_test!( - unary: Max, - ops_float: |tensor: TestTensor<2>| tensor.max().unsqueeze() - ); - clone_invariance_test!( - unary: Min, - ops_float: |tensor: TestTensor<2>| tensor.min().unsqueeze() - ); - clone_invariance_test!( - unary: MaxDim, - ops_float: |tensor: TestTensor<2>| tensor.max_dim(1) - ); - clone_invariance_test!( - unary: MaxDimWithIndices, - ops_float: |tensor: TestTensor<2>| tensor.max_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDimWithIndices, - ops_float: |tensor: TestTensor<2>| tensor.min_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDim, - ops_float: |tensor: TestTensor<2>| tensor.min_dim(1) - ); - clone_invariance_test!( - unary: Repeat, - ops_float: |tensor: TestTensor<2>| { - tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) - } - ); - clone_invariance_test!( - unary: Reshape, - ops_float: |tensor: TestTensor<2>| { - let shape = tensor.shape(); - let new_shape = [shape.num_elements(), 1]; - tensor.reshape(new_shape) - } - ); - clone_invariance_test!( - unary: Gatter, - ops_float: |tensor: TestTensor<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.gather(0, indices) - } - ); - clone_invariance_test!( - unary: Select, - ops_float: |tensor: TestTensor<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - tensor.select(0, indices) - } - ); - clone_invariance_test!( - unary: MaskFill, - ops_float: |tensor: TestTensor<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_fill(mask, 77.0) - } - ); - - // Activation - clone_invariance_test!( - unary: Softmax, - ops_float: |tensor: TestTensor<2>| softmax(tensor, 1) - ); - clone_invariance_test!( - unary: LogSoftmax, - ops_float: |tensor: TestTensor<2>| log_softmax(tensor, 1) - ); - clone_invariance_test!( - unary: Sigmoid, - ops_float: |tensor: TestTensor<2>| sigmoid(tensor) - ); - clone_invariance_test!( - unary: LogSigmoid, - ops_float: |tensor: TestTensor<2>| log_sigmoid(tensor) - ); - clone_invariance_test!( - unary: Relu, - ops_float: |tensor: TestTensor<2>| relu(tensor) - ); - clone_invariance_test!( - unary: Gelu, - ops_float: |tensor: TestTensor<2>| gelu(tensor) - ); - clone_invariance_test!( - unary: Silu, - ops_float: |tensor: TestTensor<2>| silu(tensor) - ); - clone_invariance_test!( - unary: Tanh, - ops_float: |tensor: TestTensor<2>| tanh(tensor) - ); - - // Binary ops - clone_invariance_test!( - binary: Add, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.add(rhs) - ); - clone_invariance_test!( - binary: Sub, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.sub(rhs) - ); - clone_invariance_test!( - binary: Div, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.div(rhs) - ); - clone_invariance_test!( - binary: Mul, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.mul(rhs) - ); - clone_invariance_test!( - binary: Matmul, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.matmul(rhs) - ); - clone_invariance_test!( - binary: Equal, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.equal(rhs) - ); - clone_invariance_test!( - binary: Greater, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater(rhs) - ); - clone_invariance_test!( - binary: GreaterEqual, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater_equal(rhs) - ); - clone_invariance_test!( - binary: Lower, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower(rhs) - ); - clone_invariance_test!( - binary: LowerEqual, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower_equal(rhs) - ); - clone_invariance_test!( - binary: Cat, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| { - let lhs = lhs.reshape([1usize, 32, 32]); - let rhs = rhs.reshape([1usize, 32, 32]); - - TestTensor::cat(vec![lhs, rhs], 0).reshape([64, 32]) - } - ); - clone_invariance_test!( - binary: Scatter, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.scatter(0, indices, values) - } - ); - clone_invariance_test!( - binary: SliceAssign, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) - } - ); - clone_invariance_test!( - binary: MaskWhere, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_where(mask, values) - } - ); - clone_invariance_test!( - binary: SelectAssign, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - let values = values.select(0, indices.clone()); - tensor.select_assign(0, indices, values) - } - ); - } + CloneInvarianceTest::<2>::check(&$name); + } + }; - mod int { - use super::*; - - // Unary ops - clone_invariance_test!( - unary: AddScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.add_scalar(2.0) - ); - clone_invariance_test!( - unary: SubScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.sub_scalar(2.0) - ); - clone_invariance_test!( - unary: DivScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.div_scalar(2.0) - ); - clone_invariance_test!( - unary: MulScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.mul_scalar(2.0) - ); - clone_invariance_test!( - unary: Neg, - ops_int: |tensor: TestTensorInt<2>| tensor.neg() - ); - clone_invariance_test!( - unary: MeanDim, - ops_int: |tensor: TestTensorInt<2>| tensor.mean_dim(1) - ); - clone_invariance_test!( - unary: SumDim, - ops_int: |tensor: TestTensorInt<2>| tensor.sum_dim(1) - ); - clone_invariance_test!( - unary: Sum, - ops_int: |tensor: TestTensorInt<2>| tensor.sum().unsqueeze() - ); - clone_invariance_test!( - unary: Mean, - ops_int: |tensor: TestTensorInt<2>| tensor.mean().unsqueeze() - ); - clone_invariance_test!( - unary: Clamp, - ops_int: |tensor: TestTensorInt<2>| tensor.clamp(-2., 2.) - ); - clone_invariance_test!( - unary: ClampMin, - ops_int: |tensor: TestTensorInt<2>| tensor.clamp_min(-2.) - ); - clone_invariance_test!( - unary: ClampMax, - ops_int: |tensor: TestTensorInt<2>| tensor.clamp_max(2.) - ); - clone_invariance_test!( - unary: Abs, - ops_int: |tensor: TestTensorInt<2>| tensor.abs() - ); - clone_invariance_test!( - unary: SwapDims, - ops_int: |tensor: TestTensorInt<2>| tensor.swap_dims(0, 1) - ); - clone_invariance_test!( - unary: Transpose, - ops_int: |tensor: TestTensorInt<2>| tensor.transpose() - ); - clone_invariance_test!( - unary: Slice, - ops_int: |tensor: TestTensorInt<2>| tensor.slice([0..12, 12..24]) - ); - clone_invariance_test!( - unary: EqualElem, - ops_int: |tensor: TestTensorInt<2>| tensor.equal_elem(25) - ); - clone_invariance_test!( - unary: GreaterElem, - ops_int: |tensor: TestTensorInt<2>| tensor.greater_elem(25) - ); - clone_invariance_test!( - unary: GreaterEqualElem, - ops_int: |tensor: TestTensorInt<2>| tensor.greater_equal_elem(25) - ); - clone_invariance_test!( - unary: LowerElem, - ops_int: |tensor: TestTensorInt<2>| tensor.lower_elem(25) - ); - clone_invariance_test!( - unary: LowerEqualElem, - ops_int: |tensor: TestTensorInt<2>| tensor.lower_equal_elem(25) - ); - clone_invariance_test!( - unary: Argmax, - ops_int: |tensor: TestTensorInt<2>| tensor.argmax(0) - ); - clone_invariance_test!( - unary: Argmin, - ops_int: |tensor: TestTensorInt<2>| tensor.argmin(0) - ); - clone_invariance_test!( - unary: Max, - ops_int: |tensor: TestTensorInt<2>| tensor.max().unsqueeze() - ); - clone_invariance_test!( - unary: Min, - ops_int: |tensor: TestTensorInt<2>| tensor.min().unsqueeze() - ); - clone_invariance_test!( - unary: MaxDim, - ops_int: |tensor: TestTensorInt<2>| tensor.max_dim(1) - ); - clone_invariance_test!( - unary: MaxDimWithIndices, - ops_int: |tensor: TestTensorInt<2>| tensor.max_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDimWithIndices, - ops_int: |tensor: TestTensorInt<2>| tensor.min_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDim, - ops_int: |tensor: TestTensorInt<2>| tensor.min_dim(1) - ); - clone_invariance_test!( - unary: Repeat, - ops_int: |tensor: TestTensorInt<2>| { - tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) - } - ); - clone_invariance_test!( - unary: Reshape, - ops_int: |tensor: TestTensorInt<2>| { - let shape = tensor.shape(); - let new_shape = [shape.num_elements(), 1]; - tensor.reshape(new_shape) - } - ); - clone_invariance_test!( - unary: Gatter, - ops_int: |tensor: TestTensorInt<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.gather(0, indices) - } - ); - clone_invariance_test!( - unary: Select, - ops_int: |tensor: TestTensorInt<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - tensor.select(0, indices) - } - ); - clone_invariance_test!( - unary: MaskFill, - ops_int: |tensor: TestTensorInt<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_fill(mask, 77.0) + (binary: $name:ident, ops_float: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = (Data, Data); + + fn args(&self) -> Self::Args { + ( + TestTensor::random([32, 32], Distribution::Default) + .into_data() + .convert(), + // Avoid div by zero. + TestTensor::random([32, 32], Distribution::Uniform(1., 3.)) + .into_data() + .convert(), + ) + } + + fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { + let lhs = TestTensor::from_data(lhs_arg.clone().convert()); + let rhs = TestTensor::from_data(rhs_arg.clone().convert()); + + if inplace { + $ops(lhs, rhs).into_data().convert() + } else { + let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); + + lhs.into_data().assert_approx_eq(lhs_arg, 4); + rhs.into_data().assert_approx_eq(rhs_arg, 4); + + out } - ); - - // Binary ops - clone_invariance_test!( - binary: Add, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.add(rhs) - ); - clone_invariance_test!( - binary: Sub, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.sub(rhs) - ); - clone_invariance_test!( - binary: Div, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.div(rhs) - ); - clone_invariance_test!( - binary: Mul, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.mul(rhs) - ); - clone_invariance_test!( - binary: Equal, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.equal(rhs) - ); - clone_invariance_test!( - binary: Greater, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater(rhs) - ); - clone_invariance_test!( - binary: GreaterEqual, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater_equal(rhs) - ); - clone_invariance_test!( - binary: Lower, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower(rhs) - ); - clone_invariance_test!( - binary: LowerEqual, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower_equal(rhs) - ); - clone_invariance_test!( - binary: Cat, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| { - let lhs = lhs.reshape([1usize, 32, 32]); - let rhs = rhs.reshape([1usize, 32, 32]); - - TestTensorInt::cat(vec![lhs, rhs], 0).reshape([64, 32]) - } - ); - clone_invariance_test!( - binary: Scatter, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.scatter(0, indices, values) - } - ); - clone_invariance_test!( - binary: SliceAssign, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) - } - ); - clone_invariance_test!( - binary: MaskWhere, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_where(mask, values) + } + } + + CloneInvarianceTest::<2>::check(&$name); + } + }; + + (unary: $name:ident, ops_int: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = Data; + + fn args(&self) -> Self::Args { + TestTensor::random([32, 32], Distribution::Uniform(0.0, 50.0)) + .into_data() + .convert() + } + + fn run(&self, args: &Self::Args, inplace: bool) -> Data { + let lhs = TestTensorInt::from_data(args.clone().convert()); + + if inplace { + $ops(lhs).into_data().convert() + } else { + let out = $ops(lhs.clone()).into_data().convert(); + lhs.into_data().convert().assert_approx_eq(args, 4); + out } - ); - clone_invariance_test!( - binary: SelectAssign, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - let values = values.select(0, indices.clone()); - tensor.select_assign(0, indices, values) + } + } + + CloneInvarianceTest::<2>::check(&$name); + } + }; + + (binary: $name:ident, ops_int: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = (Data, Data); + + fn args(&self) -> Self::Args { + ( + TestTensor::random([32, 32], Distribution::Uniform(0., 50.)) + .into_data() + .convert(), + // Avoid div by zero. + TestTensor::random([32, 32], Distribution::Uniform(1., 51.)) + .into_data() + .convert(), + ) + } + + fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { + let lhs = TestTensorInt::from_data(lhs_arg.clone().convert()); + let rhs = TestTensorInt::from_data(rhs_arg.clone().convert()); + + if inplace { + $ops(lhs, rhs).into_data().convert() + } else { + let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); + + lhs.into_data().convert().assert_approx_eq(lhs_arg, 4); + rhs.into_data().convert().assert_approx_eq(rhs_arg, 4); + + out } - ); - } + } + } + + CloneInvarianceTest::<2>::check(&$name); + } + }; + } + + mod float { + use super::*; + + // Unary ops + clone_invariance_test!( + unary: AddScalar, + ops_float: |tensor: TestTensor<2>| tensor.add_scalar(2.0) + ); + clone_invariance_test!( + unary: SubScalar, + ops_float: |tensor: TestTensor<2>| tensor.sub_scalar(2.0) + ); + clone_invariance_test!( + unary: DivScalar, + ops_float: |tensor: TestTensor<2>| tensor.div_scalar(2.0) + ); + clone_invariance_test!( + unary: MulScalar, + ops_float: |tensor: TestTensor<2>| tensor.mul_scalar(2.0) + ); + clone_invariance_test!( + unary: PowScalar, + ops_float: |tensor: TestTensor<2>| tensor.powf(2.0) + ); + clone_invariance_test!( + unary: Sqrt, + ops_float: |tensor: TestTensor<2>| tensor.sqrt() + ); + clone_invariance_test!( + unary: Exp, + ops_float: |tensor: TestTensor<2>| tensor.exp() + ); + clone_invariance_test!( + unary: Neg, + ops_float: |tensor: TestTensor<2>| tensor.neg() + ); + clone_invariance_test!( + unary: MeanDim, + ops_float: |tensor: TestTensor<2>| tensor.mean_dim(1) + ); + clone_invariance_test!( + unary: SumDim, + ops_float: |tensor: TestTensor<2>| tensor.sum_dim(1) + ); + clone_invariance_test!( + unary: Sum, + ops_float: |tensor: TestTensor<2>| tensor.sum().unsqueeze() + ); + clone_invariance_test!( + unary: Mean, + ops_float: |tensor: TestTensor<2>| tensor.mean().unsqueeze() + ); + clone_invariance_test!( + unary: Clamp, + ops_float: |tensor: TestTensor<2>| tensor.clamp(-2., 2.) + ); + clone_invariance_test!( + unary: ClampMin, + ops_float: |tensor: TestTensor<2>| tensor.clamp_min(-2.) + ); + clone_invariance_test!( + unary: ClampMax, + ops_float: |tensor: TestTensor<2>| tensor.clamp_max(2.) + ); + clone_invariance_test!( + unary: Abs, + ops_float: |tensor: TestTensor<2>| tensor.abs() + ); + clone_invariance_test!( + unary: Cos, + ops_float: |tensor: TestTensor<2>| tensor.cos() + ); + clone_invariance_test!( + unary: Sin, + ops_float: |tensor: TestTensor<2>| tensor.sin() + ); + clone_invariance_test!( + unary: Log, + ops_float: |tensor: TestTensor<2>| tensor.log() + ); + clone_invariance_test!( + unary: Log1P, + ops_float: |tensor: TestTensor<2>| tensor.log1p() + ); + clone_invariance_test!( + unary: SwapDims, + ops_float: |tensor: TestTensor<2>| tensor.swap_dims(0, 1) + ); + clone_invariance_test!( + unary: Transpose, + ops_float: |tensor: TestTensor<2>| tensor.transpose() + ); + clone_invariance_test!( + unary: Slice, + ops_float: |tensor: TestTensor<2>| tensor.slice([0..12, 12..24]) + ); + clone_invariance_test!( + unary: Erf, + ops_float: |tensor: TestTensor<2>| tensor.erf() + ); + clone_invariance_test!( + unary: EqualElem, + ops_float: |tensor: TestTensor<2>| tensor.equal_elem(0.5) + ); + clone_invariance_test!( + unary: GreaterElem, + ops_float: |tensor: TestTensor<2>| tensor.greater_elem(0.5) + ); + clone_invariance_test!( + unary: GreaterEqualElem, + ops_float: |tensor: TestTensor<2>| tensor.greater_equal_elem(0.5) + ); + clone_invariance_test!( + unary: LowerElem, + ops_float: |tensor: TestTensor<2>| tensor.lower_elem(0.5) + ); + clone_invariance_test!( + unary: LowerEqualElem, + ops_float: |tensor: TestTensor<2>| tensor.lower_equal_elem(0.5) + ); + clone_invariance_test!( + unary: Argmax, + ops_float: |tensor: TestTensor<2>| tensor.argmax(0) + ); + clone_invariance_test!( + unary: Argmin, + ops_float: |tensor: TestTensor<2>| tensor.argmin(0) + ); + clone_invariance_test!( + unary: Max, + ops_float: |tensor: TestTensor<2>| tensor.max().unsqueeze() + ); + clone_invariance_test!( + unary: Min, + ops_float: |tensor: TestTensor<2>| tensor.min().unsqueeze() + ); + clone_invariance_test!( + unary: MaxDim, + ops_float: |tensor: TestTensor<2>| tensor.max_dim(1) + ); + clone_invariance_test!( + unary: MaxDimWithIndices, + ops_float: |tensor: TestTensor<2>| tensor.max_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDimWithIndices, + ops_float: |tensor: TestTensor<2>| tensor.min_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDim, + ops_float: |tensor: TestTensor<2>| tensor.min_dim(1) + ); + clone_invariance_test!( + unary: Repeat, + ops_float: |tensor: TestTensor<2>| { + tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) + } + ); + clone_invariance_test!( + unary: Reshape, + ops_float: |tensor: TestTensor<2>| { + let shape = tensor.shape(); + let new_shape = [shape.num_elements(), 1]; + tensor.reshape(new_shape) + } + ); + clone_invariance_test!( + unary: Gatter, + ops_float: |tensor: TestTensor<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.gather(0, indices) + } + ); + clone_invariance_test!( + unary: Select, + ops_float: |tensor: TestTensor<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + tensor.select(0, indices) + } + ); + clone_invariance_test!( + unary: MaskFill, + ops_float: |tensor: TestTensor<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_fill(mask, 77.0) + } + ); + + // Activation + clone_invariance_test!( + unary: Softmax, + ops_float: |tensor: TestTensor<2>| softmax(tensor, 1) + ); + clone_invariance_test!( + unary: LogSoftmax, + ops_float: |tensor: TestTensor<2>| log_softmax(tensor, 1) + ); + clone_invariance_test!( + unary: Sigmoid, + ops_float: |tensor: TestTensor<2>| sigmoid(tensor) + ); + clone_invariance_test!( + unary: LogSigmoid, + ops_float: |tensor: TestTensor<2>| log_sigmoid(tensor) + ); + clone_invariance_test!( + unary: Relu, + ops_float: |tensor: TestTensor<2>| relu(tensor) + ); + clone_invariance_test!( + unary: Gelu, + ops_float: |tensor: TestTensor<2>| gelu(tensor) + ); + clone_invariance_test!( + unary: Silu, + ops_float: |tensor: TestTensor<2>| silu(tensor) + ); + clone_invariance_test!( + unary: Tanh, + ops_float: |tensor: TestTensor<2>| tanh(tensor) + ); + + // Binary ops + clone_invariance_test!( + binary: Add, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.add(rhs) + ); + clone_invariance_test!( + binary: Sub, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.sub(rhs) + ); + clone_invariance_test!( + binary: Div, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.div(rhs) + ); + clone_invariance_test!( + binary: Mul, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.mul(rhs) + ); + clone_invariance_test!( + binary: Matmul, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.matmul(rhs) + ); + clone_invariance_test!( + binary: Equal, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.equal(rhs) + ); + clone_invariance_test!( + binary: Greater, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater(rhs) + ); + clone_invariance_test!( + binary: GreaterEqual, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater_equal(rhs) + ); + clone_invariance_test!( + binary: Lower, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower(rhs) + ); + clone_invariance_test!( + binary: LowerEqual, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower_equal(rhs) + ); + clone_invariance_test!( + binary: Cat, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| { + let lhs = lhs.reshape([1usize, 32, 32]); + let rhs = rhs.reshape([1usize, 32, 32]); + + TestTensor::cat(vec![lhs, rhs], 0).reshape([64, 32]) + } + ); + clone_invariance_test!( + binary: Scatter, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.scatter(0, indices, values) + } + ); + clone_invariance_test!( + binary: SliceAssign, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) + } + ); + clone_invariance_test!( + binary: MaskWhere, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_where(mask, values) + } + ); + clone_invariance_test!( + binary: SelectAssign, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + let values = values.select(0, indices.clone()); + tensor.select_assign(0, indices, values) + } + ); + } + + mod int { + use super::*; + + // Unary ops + clone_invariance_test!( + unary: AddScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.add_scalar(2.0) + ); + clone_invariance_test!( + unary: SubScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.sub_scalar(2.0) + ); + clone_invariance_test!( + unary: DivScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.div_scalar(2.0) + ); + clone_invariance_test!( + unary: MulScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.mul_scalar(2.0) + ); + clone_invariance_test!( + unary: Neg, + ops_int: |tensor: TestTensorInt<2>| tensor.neg() + ); + clone_invariance_test!( + unary: MeanDim, + ops_int: |tensor: TestTensorInt<2>| tensor.mean_dim(1) + ); + clone_invariance_test!( + unary: SumDim, + ops_int: |tensor: TestTensorInt<2>| tensor.sum_dim(1) + ); + clone_invariance_test!( + unary: Sum, + ops_int: |tensor: TestTensorInt<2>| tensor.sum().unsqueeze() + ); + clone_invariance_test!( + unary: Mean, + ops_int: |tensor: TestTensorInt<2>| tensor.mean().unsqueeze() + ); + clone_invariance_test!( + unary: Clamp, + ops_int: |tensor: TestTensorInt<2>| tensor.clamp(-2., 2.) + ); + clone_invariance_test!( + unary: ClampMin, + ops_int: |tensor: TestTensorInt<2>| tensor.clamp_min(-2.) + ); + clone_invariance_test!( + unary: ClampMax, + ops_int: |tensor: TestTensorInt<2>| tensor.clamp_max(2.) + ); + clone_invariance_test!( + unary: Abs, + ops_int: |tensor: TestTensorInt<2>| tensor.abs() + ); + clone_invariance_test!( + unary: SwapDims, + ops_int: |tensor: TestTensorInt<2>| tensor.swap_dims(0, 1) + ); + clone_invariance_test!( + unary: Transpose, + ops_int: |tensor: TestTensorInt<2>| tensor.transpose() + ); + clone_invariance_test!( + unary: Slice, + ops_int: |tensor: TestTensorInt<2>| tensor.slice([0..12, 12..24]) + ); + clone_invariance_test!( + unary: EqualElem, + ops_int: |tensor: TestTensorInt<2>| tensor.equal_elem(25) + ); + clone_invariance_test!( + unary: GreaterElem, + ops_int: |tensor: TestTensorInt<2>| tensor.greater_elem(25) + ); + clone_invariance_test!( + unary: GreaterEqualElem, + ops_int: |tensor: TestTensorInt<2>| tensor.greater_equal_elem(25) + ); + clone_invariance_test!( + unary: LowerElem, + ops_int: |tensor: TestTensorInt<2>| tensor.lower_elem(25) + ); + clone_invariance_test!( + unary: LowerEqualElem, + ops_int: |tensor: TestTensorInt<2>| tensor.lower_equal_elem(25) + ); + clone_invariance_test!( + unary: Argmax, + ops_int: |tensor: TestTensorInt<2>| tensor.argmax(0) + ); + clone_invariance_test!( + unary: Argmin, + ops_int: |tensor: TestTensorInt<2>| tensor.argmin(0) + ); + clone_invariance_test!( + unary: Max, + ops_int: |tensor: TestTensorInt<2>| tensor.max().unsqueeze() + ); + clone_invariance_test!( + unary: Min, + ops_int: |tensor: TestTensorInt<2>| tensor.min().unsqueeze() + ); + clone_invariance_test!( + unary: MaxDim, + ops_int: |tensor: TestTensorInt<2>| tensor.max_dim(1) + ); + clone_invariance_test!( + unary: MaxDimWithIndices, + ops_int: |tensor: TestTensorInt<2>| tensor.max_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDimWithIndices, + ops_int: |tensor: TestTensorInt<2>| tensor.min_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDim, + ops_int: |tensor: TestTensorInt<2>| tensor.min_dim(1) + ); + clone_invariance_test!( + unary: Repeat, + ops_int: |tensor: TestTensorInt<2>| { + tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) + } + ); + clone_invariance_test!( + unary: Reshape, + ops_int: |tensor: TestTensorInt<2>| { + let shape = tensor.shape(); + let new_shape = [shape.num_elements(), 1]; + tensor.reshape(new_shape) + } + ); + clone_invariance_test!( + unary: Gatter, + ops_int: |tensor: TestTensorInt<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.gather(0, indices) + } + ); + clone_invariance_test!( + unary: Select, + ops_int: |tensor: TestTensorInt<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + tensor.select(0, indices) + } + ); + clone_invariance_test!( + unary: MaskFill, + ops_int: |tensor: TestTensorInt<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_fill(mask, 77.0) + } + ); + + // Binary ops + clone_invariance_test!( + binary: Add, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.add(rhs) + ); + clone_invariance_test!( + binary: Sub, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.sub(rhs) + ); + clone_invariance_test!( + binary: Div, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.div(rhs) + ); + clone_invariance_test!( + binary: Mul, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.mul(rhs) + ); + clone_invariance_test!( + binary: Equal, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.equal(rhs) + ); + clone_invariance_test!( + binary: Greater, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater(rhs) + ); + clone_invariance_test!( + binary: GreaterEqual, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater_equal(rhs) + ); + clone_invariance_test!( + binary: Lower, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower(rhs) + ); + clone_invariance_test!( + binary: LowerEqual, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower_equal(rhs) + ); + clone_invariance_test!( + binary: Cat, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| { + let lhs = lhs.reshape([1usize, 32, 32]); + let rhs = rhs.reshape([1usize, 32, 32]); + + TestTensorInt::cat(vec![lhs, rhs], 0).reshape([64, 32]) + } + ); + clone_invariance_test!( + binary: Scatter, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.scatter(0, indices, values) + } + ); + clone_invariance_test!( + binary: SliceAssign, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) + } + ); + clone_invariance_test!( + binary: MaskWhere, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_where(mask, values) + } + ); + clone_invariance_test!( + binary: SelectAssign, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + let values = values.select(0, indices.clone()); + tensor.select_assign(0, indices, values) + } + ); + } } diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index 1d0cd93807..3108659006 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -7,79 +7,79 @@ mod stats; #[allow(missing_docs)] #[macro_export] macro_rules! testgen_all { - () => { - // test activation - burn_tensor::testgen_gelu!(); - burn_tensor::testgen_relu!(); - burn_tensor::testgen_softmax!(); - burn_tensor::testgen_sigmoid!(); - burn_tensor::testgen_silu!(); - burn_tensor::testgen_tanh_activation!(); + () => { + // test activation + burn_tensor::testgen_gelu!(); + burn_tensor::testgen_relu!(); + burn_tensor::testgen_softmax!(); + burn_tensor::testgen_sigmoid!(); + burn_tensor::testgen_silu!(); + burn_tensor::testgen_tanh_activation!(); - // test module - burn_tensor::testgen_module_forward!(); - burn_tensor::testgen_module_conv1d!(); - burn_tensor::testgen_module_conv2d!(); - burn_tensor::testgen_module_conv_transpose1d!(); - burn_tensor::testgen_module_conv_transpose2d!(); - burn_tensor::testgen_module_unfold4d!(); - burn_tensor::testgen_module_max_pool1d!(); - burn_tensor::testgen_module_max_pool2d!(); - burn_tensor::testgen_module_avg_pool1d!(); - burn_tensor::testgen_module_avg_pool2d!(); - burn_tensor::testgen_module_adaptive_avg_pool1d!(); - burn_tensor::testgen_module_adaptive_avg_pool2d!(); + // test module + burn_tensor::testgen_module_forward!(); + burn_tensor::testgen_module_conv1d!(); + burn_tensor::testgen_module_conv2d!(); + burn_tensor::testgen_module_conv_transpose1d!(); + burn_tensor::testgen_module_conv_transpose2d!(); + burn_tensor::testgen_module_unfold4d!(); + burn_tensor::testgen_module_max_pool1d!(); + burn_tensor::testgen_module_max_pool2d!(); + burn_tensor::testgen_module_avg_pool1d!(); + burn_tensor::testgen_module_avg_pool2d!(); + burn_tensor::testgen_module_adaptive_avg_pool1d!(); + burn_tensor::testgen_module_adaptive_avg_pool2d!(); - // test ops - burn_tensor::testgen_add!(); - burn_tensor::testgen_aggregation!(); - burn_tensor::testgen_arange!(); - burn_tensor::testgen_arange_step!(); - burn_tensor::testgen_arg!(); - burn_tensor::testgen_cast!(); - burn_tensor::testgen_cat!(); - burn_tensor::testgen_clamp!(); - burn_tensor::testgen_cos!(); - burn_tensor::testgen_create_like!(); - burn_tensor::testgen_div!(); - burn_tensor::testgen_erf!(); - burn_tensor::testgen_exp!(); - burn_tensor::testgen_flatten!(); - burn_tensor::testgen_full!(); - burn_tensor::testgen_gather_scatter!(); - burn_tensor::testgen_init!(); - burn_tensor::testgen_iter_dim!(); - burn_tensor::testgen_log!(); - burn_tensor::testgen_log1p!(); - burn_tensor::testgen_map_comparison!(); - burn_tensor::testgen_mask!(); - burn_tensor::testgen_matmul!(); - burn_tensor::testgen_maxmin!(); - burn_tensor::testgen_mul!(); - burn_tensor::testgen_neg!(); - burn_tensor::testgen_one_hot!(); - burn_tensor::testgen_powf!(); - burn_tensor::testgen_random!(); - burn_tensor::testgen_recip!(); - burn_tensor::testgen_repeat!(); - burn_tensor::testgen_reshape!(); - burn_tensor::testgen_select!(); - burn_tensor::testgen_sin!(); - burn_tensor::testgen_slice!(); - burn_tensor::testgen_sqrt!(); - burn_tensor::testgen_abs!(); - burn_tensor::testgen_squeeze!(); - burn_tensor::testgen_sub!(); - burn_tensor::testgen_tanh!(); - burn_tensor::testgen_transpose!(); + // test ops + burn_tensor::testgen_add!(); + burn_tensor::testgen_aggregation!(); + burn_tensor::testgen_arange!(); + burn_tensor::testgen_arange_step!(); + burn_tensor::testgen_arg!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_clamp!(); + burn_tensor::testgen_cos!(); + burn_tensor::testgen_create_like!(); + burn_tensor::testgen_div!(); + burn_tensor::testgen_erf!(); + burn_tensor::testgen_exp!(); + burn_tensor::testgen_flatten!(); + burn_tensor::testgen_full!(); + burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_init!(); + burn_tensor::testgen_iter_dim!(); + burn_tensor::testgen_log!(); + burn_tensor::testgen_log1p!(); + burn_tensor::testgen_map_comparison!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_matmul!(); + burn_tensor::testgen_maxmin!(); + burn_tensor::testgen_mul!(); + burn_tensor::testgen_neg!(); + burn_tensor::testgen_one_hot!(); + burn_tensor::testgen_powf!(); + burn_tensor::testgen_random!(); + burn_tensor::testgen_recip!(); + burn_tensor::testgen_repeat!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_select!(); + burn_tensor::testgen_sin!(); + burn_tensor::testgen_slice!(); + burn_tensor::testgen_sqrt!(); + burn_tensor::testgen_abs!(); + burn_tensor::testgen_squeeze!(); + burn_tensor::testgen_sub!(); + burn_tensor::testgen_tanh!(); + burn_tensor::testgen_transpose!(); - // test stats - burn_tensor::testgen_var!(); - burn_tensor::testgen_cov!(); - burn_tensor::testgen_diagonal!(); - burn_tensor::testgen_display!(); + // test stats + burn_tensor::testgen_var!(); + burn_tensor::testgen_cov!(); + burn_tensor::testgen_diagonal!(); + burn_tensor::testgen_display!(); - // test clone invariance - burn_tensor::testgen_clone_invariance!(); - }; + // test clone invariance + burn_tensor::testgen_clone_invariance!(); + }; } diff --git a/burn-tensor/src/tests/module/adaptive_avgpool1d.rs b/burn-tensor/src/tests/module/adaptive_avgpool1d.rs index f655ed8496..3173613d21 100644 --- a/burn-tensor/src/tests/module/adaptive_avgpool1d.rs +++ b/burn-tensor/src/tests/module/adaptive_avgpool1d.rs @@ -1,73 +1,73 @@ #[burn_tensor_testgen::testgen(module_adaptive_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_adaptive_avg_pool1d_simple() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 8, - length_out: 4, - }; + #[test] + fn test_adaptive_avg_pool1d_simple() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 8, + length_out: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5, 2.5, 4.5, 6.5], - [8.5, 10.5, 12.5, 14.5], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5, 2.5, 4.5, 6.5], + [8.5, 10.5, 12.5, 14.5], + ]])); + } - #[test] - fn test_adaptive_avg_pool1d_dyn_filter_size() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 7, - length_out: 3, - }; + #[test] + fn test_adaptive_avg_pool1d_dyn_filter_size() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 7, + length_out: 3, + }; - test.assert_output(TestTensor::from_floats([[ - [1.0, 3.0, 5.0], - [8.0, 10.0, 12.0], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [1.0, 3.0, 5.0], + [8.0, 10.0, 12.0], + ]])); + } - #[test] - fn test_adaptive_avg_pool1d_bigger_output() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 4, - length_out: 8, - }; + #[test] + fn test_adaptive_avg_pool1d_bigger_output() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 4, + length_out: 8, + }; - test.assert_output(TestTensor::from_floats([[ - [0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0], - [4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0], + [4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0], + ]])); + } - struct AdaptiveAvgPool1dTestCase { - batch_size: usize, - channels: usize, - length: usize, - length_out: usize, - } + struct AdaptiveAvgPool1dTestCase { + batch_size: usize, + channels: usize, + length: usize, + length_out: usize, + } - impl AdaptiveAvgPool1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = adaptive_avg_pool1d(x, self.length_out); + impl AdaptiveAvgPool1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = adaptive_avg_pool1d(x, self.length_out); - y.into_data().assert_approx_eq(&output.into_data(), 3); - } + y.into_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/adaptive_avgpool2d.rs b/burn-tensor/src/tests/module/adaptive_avgpool2d.rs index 2711948388..a484cdc9fa 100644 --- a/burn-tensor/src/tests/module/adaptive_avgpool2d.rs +++ b/burn-tensor/src/tests/module/adaptive_avgpool2d.rs @@ -1,103 +1,103 @@ #[burn_tensor_testgen::testgen(module_adaptive_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_adaptive_avg_pool2d_simple() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 8, - width: 6, - height_out: 4, - width_out: 4, - }; + #[test] + fn test_adaptive_avg_pool2d_simple() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 8, + width: 6, + height_out: 4, + width_out: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [3.5000, 4.5000, 6.5000, 7.5000], - [15.5000, 16.5000, 18.5000, 19.5000], - [27.5000, 28.5000, 30.5000, 31.5000], - [39.5000, 40.5000, 42.5000, 43.5000], - ], - [ - [51.5000, 52.5000, 54.5000, 55.5000], - [63.5000, 64.5000, 66.5000, 67.5000], - [75.5000, 76.5000, 78.5000, 79.5000], - [87.5000, 88.5000, 90.5000, 91.5000], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [3.5000, 4.5000, 6.5000, 7.5000], + [15.5000, 16.5000, 18.5000, 19.5000], + [27.5000, 28.5000, 30.5000, 31.5000], + [39.5000, 40.5000, 42.5000, 43.5000], + ], + [ + [51.5000, 52.5000, 54.5000, 55.5000], + [63.5000, 64.5000, 66.5000, 67.5000], + [75.5000, 76.5000, 78.5000, 79.5000], + [87.5000, 88.5000, 90.5000, 91.5000], + ], + ]])); + } - #[test] - fn test_adaptive_avg_pool2d_dyn_filter_size() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 5, - width: 7, - height_out: 3, - width_out: 2, - }; + #[test] + fn test_adaptive_avg_pool2d_dyn_filter_size() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 5, + width: 7, + height_out: 3, + width_out: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]], - [[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]], + [[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]], + ]])); + } - #[test] - fn test_adaptive_avg_pool2d_bigger_output() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 4, - width: 3, - height_out: 5, - width_out: 4, - }; + #[test] + fn test_adaptive_avg_pool2d_bigger_output() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 4, + width: 3, + height_out: 5, + width_out: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [0.0000, 0.5000, 1.5000, 2.0000], - [1.5000, 2.0000, 3.0000, 3.5000], - [4.5000, 5.0000, 6.0000, 6.5000], - [7.5000, 8.0000, 9.0000, 9.5000], - [9.0000, 9.5000, 10.5000, 11.0000], - ], - [ - [12.0000, 12.5000, 13.5000, 14.0000], - [13.5000, 14.0000, 15.0000, 15.5000], - [16.5000, 17.0000, 18.0000, 18.5000], - [19.5000, 20.0000, 21.0000, 21.5000], - [21.0000, 21.5000, 22.5000, 23.0000], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [0.0000, 0.5000, 1.5000, 2.0000], + [1.5000, 2.0000, 3.0000, 3.5000], + [4.5000, 5.0000, 6.0000, 6.5000], + [7.5000, 8.0000, 9.0000, 9.5000], + [9.0000, 9.5000, 10.5000, 11.0000], + ], + [ + [12.0000, 12.5000, 13.5000, 14.0000], + [13.5000, 14.0000, 15.0000, 15.5000], + [16.5000, 17.0000, 18.0000, 18.5000], + [19.5000, 20.0000, 21.0000, 21.5000], + [21.0000, 21.5000, 22.5000, 23.0000], + ], + ]])); + } - struct AdaptiveAvgPool2dTestCase { - batch_size: usize, - channels: usize, - height: usize, - width: usize, - height_out: usize, - width_out: usize, - } + struct AdaptiveAvgPool2dTestCase { + batch_size: usize, + channels: usize, + height: usize, + width: usize, + height_out: usize, + width_out: usize, + } - impl AdaptiveAvgPool2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]); + impl AdaptiveAvgPool2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/avgpool1d.rs b/burn-tensor/src/tests/module/avgpool1d.rs index ce1a95fd55..0706bfc8fc 100644 --- a/burn-tensor/src/tests/module/avgpool1d.rs +++ b/burn-tensor/src/tests/module/avgpool1d.rs @@ -1,88 +1,88 @@ #[burn_tensor_testgen::testgen(module_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool1d_simple() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 1, - kernel_size: 3, - padding: 0, - stride: 1, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_simple() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 1, + kernel_size: 3, + padding: 0, + stride: 1, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[1., 2., 3., 4.]]])); - } + test.assert_output(TestTensor::from_floats([[[1., 2., 3., 4.]]])); + } - #[test] - fn test_avg_pool1d_complex() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_complex() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[ - [0.3333, 2.0000, 4.0000], - [4.3333, 8.0000, 10.0000], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.3333, 2.0000, 4.0000], + [4.3333, 8.0000, 10.0000], + ]])); + } - #[test] - fn test_avg_pool1d_complex_dont_count_pad() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool1d_complex_dont_count_pad() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5000, 2.0000, 4.0000], - [6.5000, 8.0000, 10.0000], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5000, 2.0000, 4.0000], + [6.5000, 8.0000, 10.0000], + ]])); + } - struct AvgPool1dTestCase { - batch_size: usize, - channels: usize, - kernel_size: usize, - padding: usize, - stride: usize, - length: usize, - count_include_pad: bool, - } + struct AvgPool1dTestCase { + batch_size: usize, + channels: usize, + kernel_size: usize, + padding: usize, + stride: usize, + length: usize, + count_include_pad: bool, + } - impl AvgPool1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = avg_pool1d( - x, - self.kernel_size, - self.stride, - self.padding, - self.count_include_pad, - ); + impl AvgPool1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = avg_pool1d( + x, + self.kernel_size, + self.stride, + self.padding, + self.count_include_pad, + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/avgpool2d.rs b/burn-tensor/src/tests/module/avgpool2d.rs index 0207014326..ca9ffcf321 100644 --- a/burn-tensor/src/tests/module/avgpool2d.rs +++ b/burn-tensor/src/tests/module/avgpool2d.rs @@ -1,113 +1,113 @@ #[burn_tensor_testgen::testgen(module_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool2d_simple() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - height: 6, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_simple() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + height: 6, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [7., 8., 9., 10.], - [13., 14., 15., 16.], - [19., 20., 21., 22.], - [25., 26., 27., 28.], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [7., 8., 9., 10.], + [13., 14., 15., 16.], + [19., 20., 21., 22.], + [25., 26., 27., 28.], + ]]])); + } - #[test] - fn test_avg_pool2d_complex() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_complex() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [1.1667, 3.0000, 4.3333, 2.5000], - [3.2500, 7.5000, 9.5000, 5.2500], - [6.2500, 13.5000, 15.5000, 8.2500], - [5.1667, 11.0000, 12.3333, 6.5000], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [1.1667, 3.0000, 4.3333, 2.5000], + [3.2500, 7.5000, 9.5000, 5.2500], + [6.2500, 13.5000, 15.5000, 8.2500], + [5.1667, 11.0000, 12.3333, 6.5000], + ]]])); + } - #[test] - fn test_avg_pool2d_complex_dont_include_pad() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool2d_complex_dont_include_pad() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[[ - [3.5000, 4.5000, 6.5000, 7.5000], - [6.5000, 7.5000, 9.5000, 10.5000], - [12.5000, 13.5000, 15.5000, 16.5000], - [15.5000, 16.5000, 18.5000, 19.5000], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [3.5000, 4.5000, 6.5000, 7.5000], + [6.5000, 7.5000, 9.5000, 10.5000], + [12.5000, 13.5000, 15.5000, 16.5000], + [15.5000, 16.5000, 18.5000, 19.5000], + ]]])); + } - struct AvgPool2dTestCase { - batch_size: usize, - channels: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - height: usize, - width: usize, - count_include_pad: bool, - } + struct AvgPool2dTestCase { + batch_size: usize, + channels: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + height: usize, + width: usize, + count_include_pad: bool, + } - impl AvgPool2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = avg_pool2d( - x, - [self.kernel_size_1, self.kernel_size_2], - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - self.count_include_pad, - ); + impl AvgPool2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = avg_pool2d( + x, + [self.kernel_size_1, self.kernel_size_2], + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + self.count_include_pad, + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/conv1d.rs b/burn-tensor/src/tests/module/conv1d.rs index 662ba8a4bd..77bd82c10d 100644 --- a/burn-tensor/src/tests/module/conv1d.rs +++ b/burn-tensor/src/tests/module/conv1d.rs @@ -1,135 +1,135 @@ #[burn_tensor_testgen::testgen(module_conv1d)] mod tests { - use super::*; - use burn_tensor::module::conv1d; - use burn_tensor::ops::ConvOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv1d; + use burn_tensor::ops::ConvOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv1d_simple() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv1d_simple() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[43., 67., 82., 49.], [104., 176., 227., 158.]], - [[139., 187., 202., 113.], [392., 584., 635., 414.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[43., 67., 82., 49.], [104., 176., 227., 158.]], + [[139., 187., 202., 113.], [392., 584., 635., 414.]], + ])); + } - #[test] - fn test_conv1d_dilation() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 2, - groups: 1, - length: 4, - }; + #[test] + fn test_conv1d_dilation() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 2, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[62., 38.], [159., 111.]], - [[158., 102.], [447., 367.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[62., 38.], [159., 111.]], + [[158., 102.], [447., 367.]], + ])); + } - #[test] - fn test_conv1d_groups() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 2, - length: 4, - }; + #[test] + fn test_conv1d_groups() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 2, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[2., 5., 8., 3.], [42., 63., 75., 47.]], - [[26., 29., 32., 11.], [114., 159., 171., 103.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[2., 5., 8., 3.], [42., 63., 75., 47.]], + [[26., 29., 32., 11.], [114., 159., 171., 103.]], + ])); + } - #[test] - fn test_conv1d_complex() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 3, - channels_out: 4, - kernel_size: 3, - padding: 1, - stride: 2, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv1d_complex() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 3, + channels_out: 4, + kernel_size: 3, + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]], - [[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]], + [[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]], + ])); + } - struct Conv1dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size: usize, - padding: usize, - stride: usize, - dilation: usize, - groups: usize, - length: usize, - } + struct Conv1dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size: usize, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + length: usize, + } - impl Conv1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size, - ]); - let weight = TestTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv1d( - x, - weight, - Some(bias), - ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), - ); + impl Conv1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size, + ]); + let weight = TestTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv1d( + x, + weight, + Some(bias), + ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/conv2d.rs b/burn-tensor/src/tests/module/conv2d.rs index ba7292ea39..7d92170fdc 100644 --- a/burn-tensor/src/tests/module/conv2d.rs +++ b/burn-tensor/src/tests/module/conv2d.rs @@ -1,165 +1,165 @@ #[burn_tensor_testgen::testgen(module_conv2d)] mod tests { - use super::*; - use burn_tensor::module::conv2d; - use burn_tensor::ops::ConvOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv2d; + use burn_tensor::ops::ConvOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv2d_simple() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; + #[test] + fn test_conv2d_simple() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [1196., 1796., 1916., 1264.], - [1881., 2793., 2946., 1923.], - [2313., 3405., 3558., 2307.], - [1424., 2072., 2156., 1380.], - ], - [ - [2709., 4173., 4509., 3065.], - [4582., 7006., 7483., 5056.], - [5878., 8914., 9391., 6304.], - [4089., 6177., 6477., 4333.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [1196., 1796., 1916., 1264.], + [1881., 2793., 2946., 1923.], + [2313., 3405., 3558., 2307.], + [1424., 2072., 2156., 1380.], + ], + [ + [2709., 4173., 4509., 3065.], + [4582., 7006., 7483., 5056.], + [5878., 8914., 9391., 6304.], + [4089., 6177., 6477., 4333.], + ], + ]])); + } - #[test] - fn test_conv2d_groups() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 5, - width: 5, - }; + #[test] + fn test_conv2d_groups() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 5, + width: 5, + }; - test.assert_output(TestTensor::from_floats([[ - [[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]], - [ - [3724., 3841., 3958.], - [4309., 4426., 4543.], - [4894., 5011., 5128.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]], + [ + [3724., 3841., 3958.], + [4309., 4426., 4543.], + [4894., 5011., 5128.], + ], + ]])); + } - #[test] - fn test_conv2d_complex() { - let test = Conv2dTestCase { - batch_size: 2, - channels_in: 3, - channels_out: 4, - kernel_size_1: 3, - kernel_size_2: 2, - padding_1: 1, - padding_2: 2, - stride_1: 2, - stride_2: 3, - dilation_1: 1, - dilation_2: 2, - groups: 1, - height: 4, - width: 5, - }; + #[test] + fn test_conv2d_complex() { + let test = Conv2dTestCase { + batch_size: 2, + channels_in: 3, + channels_out: 4, + kernel_size_1: 3, + kernel_size_2: 2, + padding_1: 1, + padding_2: 2, + stride_1: 2, + stride_2: 3, + dilation_1: 1, + dilation_2: 2, + groups: 1, + height: 4, + width: 5, + }; - test.assert_output(TestTensor::from_floats([ - [ - [[1845., 3789., 1926.], [3210., 6465., 3228.]], - [[4276., 9082., 4789.], [8071., 16834., 8737.]], - [[6707., 14375., 7652.], [12932., 27203., 14246.]], - [[9138., 19668., 10515.], [17793., 37572., 19755.]], - ], - [ - [[5445., 10629., 5166.], [8070., 15645., 7548.]], - [[14356., 28882., 14509.], [22651., 45454., 22777.]], - [[23267., 47135., 23852.], [37232., 75263., 38006.]], - [[32178., 65388., 33195.], [51813., 105072., 53235.]], - ], - ])); - } + test.assert_output(TestTensor::from_floats([ + [ + [[1845., 3789., 1926.], [3210., 6465., 3228.]], + [[4276., 9082., 4789.], [8071., 16834., 8737.]], + [[6707., 14375., 7652.], [12932., 27203., 14246.]], + [[9138., 19668., 10515.], [17793., 37572., 19755.]], + ], + [ + [[5445., 10629., 5166.], [8070., 15645., 7548.]], + [[14356., 28882., 14509.], [22651., 45454., 22777.]], + [[23267., 47135., 23852.], [37232., 75263., 38006.]], + [[32178., 65388., 33195.], [51813., 105072., 53235.]], + ], + ])); + } - struct Conv2dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - dilation_1: usize, - dilation_2: usize, - groups: usize, - height: usize, - width: usize, - } + struct Conv2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + groups: usize, + height: usize, + width: usize, + } - impl Conv2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size_1, - self.kernel_size_2, - ]); - let weight = TestTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv2d( - x, - weight, - Some(bias), - ConvOptions::new( - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - [self.dilation_1, self.dilation_2], - self.groups, - ), - ); + impl Conv2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let weight = TestTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv2d( + x, + weight, + Some(bias), + ConvOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], + self.groups, + ), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/conv_transpose1d.rs b/burn-tensor/src/tests/module/conv_transpose1d.rs index 349e1f1cfd..d7b487869d 100644 --- a/burn-tensor/src/tests/module/conv_transpose1d.rs +++ b/burn-tensor/src/tests/module/conv_transpose1d.rs @@ -1,146 +1,146 @@ #[burn_tensor_testgen::testgen(module_conv_transpose1d)] mod tests { - use super::*; - use burn_tensor::module::conv_transpose1d; - use burn_tensor::ops::ConvTransposeOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv_transpose1d; + use burn_tensor::ops::ConvTransposeOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv_transpose1d_diff_channels() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 3, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv_transpose1d_diff_channels() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [270., 453., 516., 387.], - [352., 589., 679., 505.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [270., 453., 516., 387.], + [352., 589., 679., 505.], + ]])); + } - #[test] - fn test_conv_transpose1d_stride() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 1, - stride: 2, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv_transpose1d_stride() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 1, + stride: 2, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [28., 62., 36., 78., 44., 94., 52., 62.], - [41., 93., 55., 121., 69., 149., 83., 93.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [28., 62., 36., 78., 44., 94., 52., 62.], + [41., 93., 55., 121., 69., 149., 83., 93.], + ]])); + } - #[test] - fn test_conv_transpose1d_dilation() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 0, - stride: 1, - dilation: 2, - groups: 1, - length: 4, - }; + #[test] + fn test_conv_transpose1d_dilation() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 0, + stride: 1, + dilation: 2, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [30., 64., 78., 76., 94., 52.], - [49., 101., 127., 113., 143., 77.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [30., 64., 78., 76., 94., 52.], + [49., 101., 127., 113., 143., 77.], + ]])); + } - #[test] - fn test_conv_transpose1d_groups() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 2, - length: 4, - }; + #[test] + fn test_conv_transpose1d_groups() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 2, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [0., 1., 4., 7.], - [32., 59., 71., 59.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0., 1., 4., 7.], + [32., 59., 71., 59.], + ]])); + } - struct ConvTranspose1dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size: usize, - padding: usize, - padding_out: usize, - stride: usize, - dilation: usize, - groups: usize, - length: usize, - } + struct ConvTranspose1dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size: usize, + padding: usize, + padding_out: usize, + stride: usize, + dilation: usize, + groups: usize, + length: usize, + } - impl ConvTranspose1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); - let shape_weights = Shape::new([ - self.channels_in, - self.channels_out / self.groups, - self.kernel_size, - ]); - let weights = TestTensor::from_data( - TestTensorInt::arange(0..shape_weights.num_elements()) - .reshape(shape_weights) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv_transpose1d( - x, - weights, - Some(bias), - ConvTransposeOptions::new( - [self.stride], - [self.padding], - [self.padding_out], - [self.dilation], - self.groups, - ), - ); + impl ConvTranspose1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); + let shape_weights = Shape::new([ + self.channels_in, + self.channels_out / self.groups, + self.kernel_size, + ]); + let weights = TestTensor::from_data( + TestTensorInt::arange(0..shape_weights.num_elements()) + .reshape(shape_weights) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv_transpose1d( + x, + weights, + Some(bias), + ConvTransposeOptions::new( + [self.stride], + [self.padding], + [self.padding_out], + [self.dilation], + self.groups, + ), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/conv_transpose2d.rs b/burn-tensor/src/tests/module/conv_transpose2d.rs index d8b3a3e05d..4fb76daf44 100644 --- a/burn-tensor/src/tests/module/conv_transpose2d.rs +++ b/burn-tensor/src/tests/module/conv_transpose2d.rs @@ -1,336 +1,336 @@ #[burn_tensor_testgen::testgen(module_conv_transpose2d)] mod tests { - use super::*; - use burn_tensor::module::conv_transpose2d; - use burn_tensor::ops::ConvTransposeOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv_transpose2d; + use burn_tensor::ops::ConvTransposeOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv_transpose2d_simple_1() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 1, - channels_out: 1, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_simple_1() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 1, + channels_out: 1, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[[[5.0, 11.0], [23.0, 29.0]]]])); - } - #[test] - fn test_conv_transpose2d_simple_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 3, - channels_out: 3, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; + test.assert_output(TestTensor::from_floats([[[[5.0, 11.0], [23.0, 29.0]]]])); + } + #[test] + fn test_conv_transpose2d_simple_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [9855., 15207., 15738., 10797.], - [16290., 25119., 25956., 17793.], - [18486., 28467., 29304., 20061.], - [13593., 20913., 21498., 14703.], - ], - [ - [11854., 18286., 18979., 13012.], - [19612., 30223., 31303., 21439.], - [22456., 34543., 35623., 24355.], - [16456., 25288., 26035., 17782.], - ], - [ - [13853., 21365., 22220., 15227.], - [22934., 35327., 36650., 25085.], - [26426., 40619., 41942., 28649.], - [19319., 29663., 30572., 20861.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [9855., 15207., 15738., 10797.], + [16290., 25119., 25956., 17793.], + [18486., 28467., 29304., 20061.], + [13593., 20913., 21498., 14703.], + ], + [ + [11854., 18286., 18979., 13012.], + [19612., 30223., 31303., 21439.], + [22456., 34543., 35623., 24355.], + [16456., 25288., 26035., 17782.], + ], + [ + [13853., 21365., 22220., 15227.], + [22934., 35327., 36650., 25085.], + [26426., 40619., 41942., 28649.], + [19319., 29663., 30572., 20861.], + ], + ]])); + } - #[test] - fn test_conv_transpose2d_stride_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 1, - channels_out: 1, - kernel_size_1: 2, - kernel_size_2: 2, - padding_1: 0, - padding_2: 0, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 2, - stride_2: 2, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_stride_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 1, + channels_out: 1, + kernel_size_1: 2, + kernel_size_2: 2, + padding_1: 0, + padding_2: 0, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 2, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 2.0, 3.0], - [0.0, 2.0, 0.0, 3.0], - [4.0, 6.0, 6.0, 9.0], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 2.0, 3.0], + [0.0, 2.0, 0.0, 3.0], + [4.0, 6.0, 6.0, 9.0], + ]]])); + } - #[test] - fn test_conv_transpose2d_dilation_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 1, - padding_out_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 2, - dilation_2: 2, - groups: 1, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_dilation_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 1, + padding_out_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 2, + dilation_2: 2, + groups: 1, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [126., 116., 136., 124., 146.], - [108., 88., 114., 92., 120.], - [156., 140., 166., 148., 176.], - [126., 100., 132., 104., 138.], - [186., 164., 196., 172., 206.], - ], - [ - [217., 189., 227., 197., 237.], - [163., 125., 169., 129., 175.], - [247., 213., 257., 221., 267.], - [181., 137., 187., 141., 193.], - [277., 237., 287., 245., 297.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [126., 116., 136., 124., 146.], + [108., 88., 114., 92., 120.], + [156., 140., 166., 148., 176.], + [126., 100., 132., 104., 138.], + [186., 164., 196., 172., 206.], + ], + [ + [217., 189., 227., 197., 237.], + [163., 125., 169., 129., 175.], + [247., 213., 257., 221., 267.], + [181., 137., 187., 141., 193.], + [277., 237., 287., 245., 297.], + ], + ]])); + } - #[test] - fn test_conv_transpose2d_stride2_out_padding() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 1, - padding_out_2: 1, - stride_1: 2, - stride_2: 2, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; + #[test] + fn test_conv_transpose2d_stride2_out_padding() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 1, + padding_out_2: 1, + stride_1: 2, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [352., 728., 378., 780., 404., 832., 430., 452.], - [784., 1616., 836., 1720., 888., 1824., 940., 992.], - [456., 936., 482., 988., 508., 1040., 534., 564.], - [992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.], - [560., 1144., 586., 1196., 612., 1248., 638., 676.], - [1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.], - [664., 1352., 690., 1404., 716., 1456., 742., 788.], - [784., 1598., 816., 1662., 848., 1726., 880., 926.], - ], - [ - [497., 1035., 541., 1123., 585., 1211., 629., 651.], - [1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.], - [673., 1387., 717., 1475., 761., 1563., 805., 835.], - [1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.], - [849., 1739., 893., 1827., 937., 1915., 981., 1019.], - [1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.], - [1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.], - [1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [352., 728., 378., 780., 404., 832., 430., 452.], + [784., 1616., 836., 1720., 888., 1824., 940., 992.], + [456., 936., 482., 988., 508., 1040., 534., 564.], + [992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.], + [560., 1144., 586., 1196., 612., 1248., 638., 676.], + [1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.], + [664., 1352., 690., 1404., 716., 1456., 742., 788.], + [784., 1598., 816., 1662., 848., 1726., 880., 926.], + ], + [ + [497., 1035., 541., 1123., 585., 1211., 629., 651.], + [1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.], + [673., 1387., 717., 1475., 761., 1563., 805., 835.], + [1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.], + [849., 1739., 893., 1827., 937., 1915., 981., 1019.], + [1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.], + [1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.], + [1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.], + ], + ]])); + } - #[test] - fn test_conv_transpose2d_groups_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_groups_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [[5., 11.], [23., 29.]], - [[236., 258.], [302., 324.]], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [[5., 11.], [23., 29.]], + [[236., 258.], [302., 324.]], + ]])); + } - #[test] - fn test_conv_transpose2d_groups_different_channels() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 6, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_groups_different_channels() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 6, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00], - [0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01], - [6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01], - [1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01], - ], - [ - [1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01], - [1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01], - [2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01], - [3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01], - ], - [ - [2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01], - [3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01], - [4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01], - [5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01], - ], - [ - [1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02], - [2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02], - [3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02], - [2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02], - ], - [ - [1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02], - [3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02], - [4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02], - [2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02], - ], - [ - [1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02], - [4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02], - [4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02], - [3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00], + [0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01], + [6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01], + [1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01], + ], + [ + [1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01], + [1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01], + [2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01], + [3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01], + ], + [ + [2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01], + [3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01], + [4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01], + [5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01], + ], + [ + [1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02], + [2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02], + [3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02], + [2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02], + ], + [ + [1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02], + [3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02], + [4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02], + [2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02], + ], + [ + [1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02], + [4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02], + [4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02], + [3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02], + ], + ]])); + } - struct ConvTranspose2dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - padding_out_1: usize, - padding_out_2: usize, - stride_1: usize, - stride_2: usize, - dilation_1: usize, - dilation_2: usize, - groups: usize, - height: usize, - width: usize, - } + struct ConvTranspose2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + padding_out_1: usize, + padding_out_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + groups: usize, + height: usize, + width: usize, + } - impl ConvTranspose2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let shape_weights = Shape::new([ - self.channels_in, - self.channels_out / self.groups, - self.kernel_size_1, - self.kernel_size_2, - ]); - let weights = TestTensor::from_data( - TestTensorInt::arange(0..shape_weights.num_elements()) - .reshape(shape_weights) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv_transpose2d( - x, - weights, - Some(bias), - ConvTransposeOptions::new( - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - [self.padding_out_1, self.padding_out_2], - [self.dilation_1, self.dilation_2], - self.groups, - ), - ); + impl ConvTranspose2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_weights = Shape::new([ + self.channels_in, + self.channels_out / self.groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let weights = TestTensor::from_data( + TestTensorInt::arange(0..shape_weights.num_elements()) + .reshape(shape_weights) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv_transpose2d( + x, + weights, + Some(bias), + ConvTransposeOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.padding_out_1, self.padding_out_2], + [self.dilation_1, self.dilation_2], + self.groups, + ), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/module/forward.rs b/burn-tensor/src/tests/module/forward.rs index 7ff629140a..2ea81da5e4 100644 --- a/burn-tensor/src/tests/module/forward.rs +++ b/burn-tensor/src/tests/module/forward.rs @@ -1,20 +1,20 @@ #[burn_tensor_testgen::testgen(module_forward)] mod tests { - use super::*; - use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; + use super::*; + use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; - #[test] - fn test_embedding_forward() { - let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = Data::from([[0, 1], [1, 1]]); - let weights = Tensor::::from_data(weights); - let indices = Tensor::::from_data(indices); + #[test] + fn test_embedding_forward() { + let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = Data::from([[0, 1], [1, 1]]); + let weights = Tensor::::from_data(weights); + let indices = Tensor::::from_data(indices); - let output = embedding(weights, indices); - let expected = Data::from([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], - ]); - assert_eq!(output.to_data(), expected); - } + let output = embedding(weights, indices); + let expected = Data::from([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], + ]); + assert_eq!(output.to_data(), expected); + } } diff --git a/burn-tensor/src/tests/module/maxpool1d.rs b/burn-tensor/src/tests/module/maxpool1d.rs index 97c26129da..89cff90ec4 100644 --- a/burn-tensor/src/tests/module/maxpool1d.rs +++ b/burn-tensor/src/tests/module/maxpool1d.rs @@ -1,116 +1,116 @@ #[burn_tensor_testgen::testgen(module_max_pool1d)] mod tests { - use super::*; - use burn_tensor::module::{max_pool1d, max_pool1d_with_indices}; - use burn_tensor::{backend::Backend, Data, Tensor}; - - type IntElem = ::IntElem; - - #[test] - fn test_max_pool1d_simple() { - let kernel_size = 3; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[ - [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], - [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], - ]]); - let y = TestTensor::from_floats([[ - [0.9861, 0.5474, 0.4477, 0.8221], - [0.949, 0.949, 0.949, 0.789], - ]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_different_padding_stride_kernel() { - let kernel_size = 3; - let padding = 1; - let stride = 2; - let dilation = 1; - - let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]); - let y = TestTensor::from_floats([[[0.6309, 0.6998]]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_with_neg() { - let kernel_size = 3; - let padding = 1; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]); - let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_with_dilation() { - let kernel_size = 2; - let padding = 1; - let stride = 1; - let dilation = 2; - - let x = TestTensor::from_floats([[ - [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], - [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], - ]]); - let y = TestTensor::from_floats([[ - [0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548], - [0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537], - ]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_with_indices() { - let kernel_size = 2; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]); - let indices = Data::::from([[[1, 1, 3]]]); - let y = TestTensor::from_floats([[[0.6386, 0.6386, 0.5742]]]); - - let (output, output_indices) = - max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } - - #[test] - fn test_max_pool1d_complex() { - let kernel_size = 4; - let padding = 2; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]); - let indices = Data::::from([[[0, 2, 3, 3, 3, 3]]]); - let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]); - - let (output, output_indices) = - max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } + use super::*; + use burn_tensor::module::{max_pool1d, max_pool1d_with_indices}; + use burn_tensor::{backend::Backend, Data, Tensor}; + + type IntElem = ::IntElem; + + #[test] + fn test_max_pool1d_simple() { + let kernel_size = 3; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + ]]); + let y = TestTensor::from_floats([[ + [0.9861, 0.5474, 0.4477, 0.8221], + [0.949, 0.949, 0.949, 0.789], + ]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_different_padding_stride_kernel() { + let kernel_size = 3; + let padding = 1; + let stride = 2; + let dilation = 1; + + let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]); + let y = TestTensor::from_floats([[[0.6309, 0.6998]]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_neg() { + let kernel_size = 3; + let padding = 1; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]); + let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_dilation() { + let kernel_size = 2; + let padding = 1; + let stride = 1; + let dilation = 2; + + let x = TestTensor::from_floats([[ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + ]]); + let y = TestTensor::from_floats([[ + [0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548], + [0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537], + ]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_indices() { + let kernel_size = 2; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]); + let indices = Data::::from([[[1, 1, 3]]]); + let y = TestTensor::from_floats([[[0.6386, 0.6386, 0.5742]]]); + + let (output, output_indices) = + max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } + + #[test] + fn test_max_pool1d_complex() { + let kernel_size = 4; + let padding = 2; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]); + let indices = Data::::from([[[0, 2, 3, 3, 3, 3]]]); + let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]); + + let (output, output_indices) = + max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } } diff --git a/burn-tensor/src/tests/module/maxpool2d.rs b/burn-tensor/src/tests/module/maxpool2d.rs index 47a843fa93..d70f194cea 100644 --- a/burn-tensor/src/tests/module/maxpool2d.rs +++ b/burn-tensor/src/tests/module/maxpool2d.rs @@ -1,324 +1,324 @@ #[burn_tensor_testgen::testgen(module_max_pool2d)] mod tests { - use super::*; - use burn_tensor::module::{max_pool2d, max_pool2d_with_indices}; - use burn_tensor::{backend::Backend, Data, Tensor}; + use super::*; + use burn_tensor::module::{max_pool2d, max_pool2d_with_indices}; + use burn_tensor::{backend::Backend, Data, Tensor}; - type IntElem = ::IntElem; + type IntElem = ::IntElem; - #[test] - fn test_max_pool2d_simple() { - let batch_size = 2; - let channels_in = 2; - let kernel_size_1 = 3; - let kernel_size_2 = 3; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_simple() { + let batch_size = 2; + let channels_in = 2; + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([ - [ - [ - [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], - [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], - [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182], - [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392], - [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605], - [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068], - ], - [ - [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473], - [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947], - [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149], - [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426], - [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544], - [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572], - ], - ], - [ - [ - [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910], - [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546], - [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584], - [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441], - [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881], - [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467], - ], - [ - [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952], - [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454], - [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628], - [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527], - [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827], - [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421], - ], - ], - ]); - let y = TestTensor::from_floats([ - [ - [ - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], - [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], - [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], - [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], - ], - [ - [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947], - [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149], - [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149], - [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149], - [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025], - [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775], - ], - ], - [ - [ - [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], - [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], - [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546], - [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], - [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], - [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037], - ], - [ - [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378], - [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378], - [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445], - [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128], - [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128], - [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128], - ], - ], - ]); + let x = TestTensor::from_floats([ + [ + [ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182], + [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392], + [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605], + [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068], + ], + [ + [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473], + [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947], + [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149], + [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426], + [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544], + [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572], + ], + ], + [ + [ + [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910], + [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546], + [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584], + [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441], + [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881], + [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467], + ], + [ + [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952], + [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454], + [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628], + [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527], + [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827], + [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421], + ], + ], + ]); + let y = TestTensor::from_floats([ + [ + [ + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], + [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], + [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], + [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], + ], + [ + [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947], + [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149], + [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149], + [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149], + [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025], + [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775], + ], + ], + [ + [ + [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], + [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], + [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546], + [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], + [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], + [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037], + ], + [ + [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378], + [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378], + [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445], + [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128], + [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128], + [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128], + ], + ], + ]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - #[test] - fn test_max_pool2d_different_padding_stride_kernel() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 3; - let kernel_size_2 = 1; - let padding_1 = 1; - let padding_2 = 0; - let stride_1 = 1; - let stride_2 = 2; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_different_padding_stride_kernel() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 3; + let kernel_size_2 = 1; + let padding_1 = 1; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 2; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.6309, 0.6112, 0.6998], - [0.4708, 0.9161, 0.5402], - [0.4577, 0.7397, 0.9870], - [0.6380, 0.4352, 0.5884], - [0.6277, 0.5139, 0.4525], - [0.9333, 0.9846, 0.5006], - ]]]); - let y = TestTensor::from_floats([[[ - [0.6309, 0.6998], - [0.6309, 0.9870], - [0.6380, 0.9870], - [0.6380, 0.9870], - [0.9333, 0.5884], - [0.9333, 0.5006], - ]]]); + let x = TestTensor::from_floats([[[ + [0.6309, 0.6112, 0.6998], + [0.4708, 0.9161, 0.5402], + [0.4577, 0.7397, 0.9870], + [0.6380, 0.4352, 0.5884], + [0.6277, 0.5139, 0.4525], + [0.9333, 0.9846, 0.5006], + ]]]); + let y = TestTensor::from_floats([[[ + [0.6309, 0.6998], + [0.6309, 0.9870], + [0.6380, 0.9870], + [0.6380, 0.9870], + [0.9333, 0.5884], + [0.9333, 0.5006], + ]]]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - #[test] - fn test_max_pool2d_with_neg() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 3; - let kernel_size_2 = 3; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_with_neg() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.6309, 0.6112, 0.6998], - [0.4708, 0.9161, 0.5402], - [0.4577, 0.7397, 0.9870], - [0.6380, 0.4352, 0.5884], - [0.6277, 0.5139, 0.4525], - [0.9333, 0.9846, 0.5006], - ]]]) - .neg(); - let y = TestTensor::from_floats([[[ - [-0.4708, -0.4708, -0.5402], - [-0.4577, -0.4577, -0.5402], - [-0.4352, -0.4352, -0.4352], - [-0.4352, -0.4352, -0.4352], - [-0.4352, -0.4352, -0.4352], - [-0.5139, -0.4525, -0.4525], - ]]]); + let x = TestTensor::from_floats([[[ + [0.6309, 0.6112, 0.6998], + [0.4708, 0.9161, 0.5402], + [0.4577, 0.7397, 0.9870], + [0.6380, 0.4352, 0.5884], + [0.6277, 0.5139, 0.4525], + [0.9333, 0.9846, 0.5006], + ]]]) + .neg(); + let y = TestTensor::from_floats([[[ + [-0.4708, -0.4708, -0.5402], + [-0.4577, -0.4577, -0.5402], + [-0.4352, -0.4352, -0.4352], + [-0.4352, -0.4352, -0.4352], + [-0.4352, -0.4352, -0.4352], + [-0.5139, -0.4525, -0.4525], + ]]]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - #[test] - fn test_max_pool2d_with_dilation() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 0; - let padding_2 = 0; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 2; - let dilation_2 = 2; + #[test] + fn test_max_pool2d_with_dilation() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 0; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 2; + let dilation_2 = 2; - let x = TestTensor::from_floats([[[ - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], - [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], - [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], - [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], - ]]]); - let y = TestTensor::from_floats([[[ - [0.9861, 0.9861, 0.9540, 0.9490], - [0.9861, 0.9861, 0.9540, 0.9490], - [0.9540, 0.9540, 0.9540, 0.9490], - [0.9540, 0.9540, 0.9540, 0.9432], - ]]]); + let x = TestTensor::from_floats([[[ + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], + [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], + [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], + [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], + ]]]); + let y = TestTensor::from_floats([[[ + [0.9861, 0.9861, 0.9540, 0.9490], + [0.9861, 0.9861, 0.9540, 0.9490], + [0.9540, 0.9540, 0.9540, 0.9490], + [0.9540, 0.9540, 0.9540, 0.9432], + ]]]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - fn test_max_pool2d_with_indices() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; + fn test_max_pool2d_with_indices() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]); - let indices = Data::::from([[[ - [0, 1, 1, 3, 3], - [4, 4, 1, 7, 7], - [4, 9, 9, 7, 7], - [8, 9, 9, 14, 11], - [12, 12, 14, 14, 15], - ]]]); - let y = TestTensor::from_floats([[[ - [0.2479, 0.6386, 0.6386, 0.5742, 0.5742], - [0.7065, 0.7065, 0.6386, 0.8959, 0.8959], - [0.7065, 0.8602, 0.8602, 0.8959, 0.8959], - [0.5416, 0.8602, 0.8602, 0.8293, 0.1662], - [0.3358, 0.3358, 0.8293, 0.8293, 0.0990], - ]]]); + let x = TestTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]); + let indices = Data::::from([[[ + [0, 1, 1, 3, 3], + [4, 4, 1, 7, 7], + [4, 9, 9, 7, 7], + [8, 9, 9, 14, 11], + [12, 12, 14, 14, 15], + ]]]); + let y = TestTensor::from_floats([[[ + [0.2479, 0.6386, 0.6386, 0.5742, 0.5742], + [0.7065, 0.7065, 0.6386, 0.8959, 0.8959], + [0.7065, 0.8602, 0.8602, 0.8959, 0.8959], + [0.5416, 0.8602, 0.8602, 0.8293, 0.1662], + [0.3358, 0.3358, 0.8293, 0.8293, 0.0990], + ]]]); - let (output, output_indices) = max_pool2d_with_indices( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let (output, output_indices) = max_pool2d_with_indices( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } - #[test] - fn test_max_pool2d_complex() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 4; - let kernel_size_2 = 2; - let padding_1 = 2; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 2; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_complex() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 4; + let kernel_size_2 = 2; + let padding_1 = 2; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 2; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], - [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], - [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], - [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], - [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], - ]]]); - let indices = Data::::from([[[ - [5, 7, 3], - [5, 7, 3], - [5, 16, 3], - [5, 16, 8], - [15, 16, 24], - [15, 16, 24], - ]]]); - let y = TestTensor::from_floats([[[ - [0.9154, 0.9089, 0.8316], - [0.9154, 0.9089, 0.8316], - [0.9154, 0.9963, 0.8316], - [0.9154, 0.9963, 0.8016], - [0.4384, 0.9963, 0.688], - [0.4384, 0.9963, 0.688], - ]]]); - let (output, output_indices) = max_pool2d_with_indices( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let x = TestTensor::from_floats([[[ + [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], + [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], + [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], + [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], + [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], + ]]]); + let indices = Data::::from([[[ + [5, 7, 3], + [5, 7, 3], + [5, 16, 3], + [5, 16, 8], + [15, 16, 24], + [15, 16, 24], + ]]]); + let y = TestTensor::from_floats([[[ + [0.9154, 0.9089, 0.8316], + [0.9154, 0.9089, 0.8316], + [0.9154, 0.9963, 0.8316], + [0.9154, 0.9963, 0.8016], + [0.4384, 0.9963, 0.688], + [0.4384, 0.9963, 0.688], + ]]]); + let (output, output_indices) = max_pool2d_with_indices( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } } diff --git a/burn-tensor/src/tests/module/unfold4d.rs b/burn-tensor/src/tests/module/unfold4d.rs index afb1df24b8..0ead03612d 100644 --- a/burn-tensor/src/tests/module/unfold4d.rs +++ b/burn-tensor/src/tests/module/unfold4d.rs @@ -1,132 +1,132 @@ #[burn_tensor_testgen::testgen(module_unfold4d)] mod tests { - use super::*; - use burn_tensor::module::unfold4d; - use burn_tensor::ops::UnfoldOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::unfold4d; + use burn_tensor::ops::UnfoldOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_unfold4d_shape() { - let test = Unfold4dTestCase { - batch_size: 2, - channels_in: 5, - kernel_size: [2, 3], - padding: [0, 0], - stride: [1, 1], - dilation: [1, 1], - height: 3, - width: 4, - }; + #[test] + fn test_unfold4d_shape() { + let test = Unfold4dTestCase { + batch_size: 2, + channels_in: 5, + kernel_size: [2, 3], + padding: [0, 0], + stride: [1, 1], + dilation: [1, 1], + height: 3, + width: 4, + }; - test.assert_shape([2, 30, 4]); - } + test.assert_shape([2, 30, 4]); + } - #[test] - fn test_unfold4d_simple() { - let test = Unfold4dTestCase { - batch_size: 1, - channels_in: 2, - kernel_size: [2, 2], - padding: [0, 0], - stride: [1, 1], - dilation: [1, 1], - height: 4, - width: 4, - }; + #[test] + fn test_unfold4d_simple() { + let test = Unfold4dTestCase { + batch_size: 1, + channels_in: 2, + kernel_size: [2, 2], + padding: [0, 0], + stride: [1, 1], + dilation: [1, 1], + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_data([[ - [0., 1., 2., 4., 5., 6., 8., 9., 10.], - [1., 2., 3., 5., 6., 7., 9., 10., 11.], - [4., 5., 6., 8., 9., 10., 12., 13., 14.], - [5., 6., 7., 9., 10., 11., 13., 14., 15.], - [16., 17., 18., 20., 21., 22., 24., 25., 26.], - [17., 18., 19., 21., 22., 23., 25., 26., 27.], - [20., 21., 22., 24., 25., 26., 28., 29., 30.], - [21., 22., 23., 25., 26., 27., 29., 30., 31.], - ]])); - } + test.assert_output(TestTensor::from_data([[ + [0., 1., 2., 4., 5., 6., 8., 9., 10.], + [1., 2., 3., 5., 6., 7., 9., 10., 11.], + [4., 5., 6., 8., 9., 10., 12., 13., 14.], + [5., 6., 7., 9., 10., 11., 13., 14., 15.], + [16., 17., 18., 20., 21., 22., 24., 25., 26.], + [17., 18., 19., 21., 22., 23., 25., 26., 27.], + [20., 21., 22., 24., 25., 26., 28., 29., 30.], + [21., 22., 23., 25., 26., 27., 29., 30., 31.], + ]])); + } - #[test] - fn test_unfold4d_complex() { - let test = Unfold4dTestCase { - batch_size: 1, - channels_in: 2, - kernel_size: [2, 3], - padding: [0, 1], - stride: [1, 2], - dilation: [1, 2], - height: 3, - width: 4, - }; + #[test] + fn test_unfold4d_complex() { + let test = Unfold4dTestCase { + batch_size: 1, + channels_in: 2, + kernel_size: [2, 3], + padding: [0, 1], + stride: [1, 2], + dilation: [1, 2], + height: 3, + width: 4, + }; - test.assert_output(TestTensor::from_data([[ - [0., 0.], - [1., 5.], - [3., 7.], - [0., 0.], - [5., 9.], - [7., 11.], - [0., 0.], - [13., 17.], - [15., 19.], - [0., 0.], - [17., 21.], - [19., 23.], - ]])); - } + test.assert_output(TestTensor::from_data([[ + [0., 0.], + [1., 5.], + [3., 7.], + [0., 0.], + [5., 9.], + [7., 11.], + [0., 0.], + [13., 17.], + [15., 19.], + [0., 0.], + [17., 21.], + [19., 23.], + ]])); + } - struct Unfold4dTestCase { - batch_size: usize, - channels_in: usize, - kernel_size: [usize; 2], - padding: [usize; 2], - stride: [usize; 2], - dilation: [usize; 2], - height: usize, - width: usize, - } + struct Unfold4dTestCase { + batch_size: usize, + channels_in: usize, + kernel_size: [usize; 2], + padding: [usize; 2], + stride: [usize; 2], + dilation: [usize; 2], + height: usize, + width: usize, + } - impl Unfold4dTestCase { - fn assert_shape(self, expected_shape: [usize; 3]) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); + impl Unfold4dTestCase { + fn assert_shape(self, expected_shape: [usize; 3]) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); - let output = unfold4d( - x, - self.kernel_size, - UnfoldOptions::new(self.stride, self.padding, self.dilation), - ); + let output = unfold4d( + x, + self.kernel_size, + UnfoldOptions::new(self.stride, self.padding, self.dilation), + ); - assert_eq!( - output.shape().dims, - expected_shape, - "Expected shape doesn't match the actual shape" - ); - } + assert_eq!( + output.shape().dims, + expected_shape, + "Expected shape doesn't match the actual shape" + ); + } - fn assert_output(self, expected: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); + fn assert_output(self, expected: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); - let output = unfold4d( - x, - self.kernel_size, - UnfoldOptions::new(self.stride, self.padding, self.dilation), - ); + let output = unfold4d( + x, + self.kernel_size, + UnfoldOptions::new(self.stride, self.padding, self.dilation), + ); - output - .into_data() - .assert_approx_eq(&expected.into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&expected.into_data(), 3); } + } } diff --git a/burn-tensor/src/tests/ops/abs.rs b/burn-tensor/src/tests/ops/abs.rs index f87b87a6fb..ad34a4581a 100644 --- a/burn-tensor/src/tests/ops/abs.rs +++ b/burn-tensor/src/tests/ops/abs.rs @@ -1,27 +1,27 @@ #[burn_tensor_testgen::testgen(abs)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn should_support_abs_ops_float() { - let data = Data::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_abs_ops_float() { + let data = Data::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.abs().into_data(); + let data_actual = tensor.abs().into_data(); - let data_expected = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_abs_ops_int() { - let data = Data::from([[0, -1, 2], [3, 4, -5]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_abs_ops_int() { + let data = Data::from([[0, -1, 2], [3, 4, -5]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.abs().into_data(); + let data_actual = tensor.abs().into_data(); - let data_expected = Data::from([[0, 1, 2], [3, 4, 5]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0, 1, 2], [3, 4, 5]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/add.rs b/burn-tensor/src/tests/ops/add.rs index bd45b4376d..08a013f297 100644 --- a/burn-tensor/src/tests/ops/add.rs +++ b/burn-tensor/src/tests/ops/add.rs @@ -1,83 +1,83 @@ #[burn_tensor_testgen::testgen(add)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_add_d2() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_add_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_add_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor + scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_add_d2_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[6, 8, 10], [12, 14, 16]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_add_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[3, 5, 7], [6, 8, 10]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_add_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor + scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[2, 3, 4], [5, 6, 7]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_add_d2() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_add_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_add_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor + scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_add_d2_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[6, 8, 10], [12, 14, 16]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_add_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[3, 5, 7], [6, 8, 10]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_add_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor + scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[2, 3, 4], [5, 6, 7]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/aggregation.rs b/burn-tensor/src/tests/ops/aggregation.rs index 45b94a89bb..e234b795b2 100644 --- a/burn-tensor/src/tests/ops/aggregation.rs +++ b/burn-tensor/src/tests/ops/aggregation.rs @@ -1,125 +1,125 @@ #[burn_tensor_testgen::testgen(aggregation)] mod tests { - use super::*; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_should_mean() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_mean() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.mean().to_data(); + let data_actual = tensor.mean().to_data(); - data_actual.assert_approx_eq(&Data::from([15.0 / 6.0]), 3); - } + data_actual.assert_approx_eq(&Data::from([15.0 / 6.0]), 3); + } - #[test] - fn test_should_mean_int() { - let tensor = TestTensorInt::from_data([[2, 2, 2], [3, 4, 5]]); + #[test] + fn test_should_mean_int() { + let tensor = TestTensorInt::from_data([[2, 2, 2], [3, 4, 5]]); - let data_actual = tensor.mean().to_data(); + let data_actual = tensor.mean().to_data(); - assert_eq!(data_actual, Data::from([3])); - } + assert_eq!(data_actual, Data::from([3])); + } - #[test] - fn test_should_sum() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_sum() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.sum().to_data(); + let data_actual = tensor.sum().to_data(); - assert_eq!(data_actual, Data::from([15.0])); - } + assert_eq!(data_actual, Data::from([15.0])); + } - #[test] - fn test_should_sum_int() { - let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + #[test] + fn test_should_sum_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); - let data_actual = tensor.sum().to_data(); + let data_actual = tensor.sum().to_data(); - assert_eq!(data_actual, Data::from([15])); - } + assert_eq!(data_actual, Data::from([15])); + } - #[test] - fn test_should_mean_last_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_mean_last_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.mean_dim(1).to_data(); + let data_actual = tensor.mean_dim(1).to_data(); - data_actual.assert_approx_eq(&Data::from([[3.0 / 3.0], [12.0 / 3.0]]), 3); - } + data_actual.assert_approx_eq(&Data::from([[3.0 / 3.0], [12.0 / 3.0]]), 3); + } - #[test] - fn test_should_sum_last_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_sum_last_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.sum_dim(1).to_data(); + let data_actual = tensor.sum_dim(1).to_data(); - assert_eq!(data_actual, Data::from([[3.0], [12.0]])); - } + assert_eq!(data_actual, Data::from([[3.0], [12.0]])); + } - #[test] - fn test_should_mean_last_dim_int() { - let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + #[test] + fn test_should_mean_last_dim_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); - let data_actual = tensor.mean_dim(1).to_data(); + let data_actual = tensor.mean_dim(1).to_data(); - assert_eq!(data_actual, Data::from([[1], [4]])); - } + assert_eq!(data_actual, Data::from([[1], [4]])); + } - #[test] - fn test_should_sum_last_dim_int() { - let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + #[test] + fn test_should_sum_last_dim_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); - let data_actual = tensor.sum_dim(1).to_data(); + let data_actual = tensor.sum_dim(1).to_data(); - assert_eq!(data_actual, Data::from([[3], [12]])); - } + assert_eq!(data_actual, Data::from([[3], [12]])); + } - #[test] - fn test_should_sum_first_dim() { - let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); + #[test] + fn test_should_sum_first_dim() { + let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); - let data_actual = tensor.sum_dim(0).to_data(); + let data_actual = tensor.sum_dim(0).to_data(); - assert_eq!(data_actual, Data::from([[7.0, 3.0, 5.0]])); - } + assert_eq!(data_actual, Data::from([[7.0, 3.0, 5.0]])); + } - #[test] - fn test_should_mean_first_dim() { - let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); + #[test] + fn test_should_mean_first_dim() { + let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); - let data_actual = tensor.mean_dim(0).to_data(); + let data_actual = tensor.mean_dim(0).to_data(); - assert_eq!(data_actual, Data::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]])); - } + assert_eq!(data_actual, Data::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]])); + } - #[test] - fn test_should_sum_mid_dim_3d_non_contiguous_1() { - let tensor = TestTensor::from_data([ - [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], - [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], - ]); + #[test] + fn test_should_sum_mid_dim_3d_non_contiguous_1() { + let tensor = TestTensor::from_data([ + [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], + [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], + ]); - let data_actual = tensor.swap_dims(0, 2).sum_dim(1).into_data(); + let data_actual = tensor.swap_dims(0, 2).sum_dim(1).into_data(); - assert_eq!( - data_actual, - Data::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], Shape::new([3, 1, 2])) - ); - } + assert_eq!( + data_actual, + Data::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], Shape::new([3, 1, 2])) + ); + } - #[test] - fn test_should_sum_mid_dim_3d_non_contiguous_2() { - let tensor = TestTensor::from_data([ - [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], - [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], - ]); + #[test] + fn test_should_sum_mid_dim_3d_non_contiguous_2() { + let tensor = TestTensor::from_data([ + [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], + [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], + ]); - let data_actual = tensor.swap_dims(0, 1).sum_dim(1).into_data(); + let data_actual = tensor.swap_dims(0, 1).sum_dim(1).into_data(); - assert_eq!( - data_actual, - Data::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], Shape::new([2, 1, 3])) - ); - } + assert_eq!( + data_actual, + Data::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], Shape::new([2, 1, 3])) + ); + } } diff --git a/burn-tensor/src/tests/ops/arange.rs b/burn-tensor/src/tests/ops/arange.rs index 5abe68aff1..4552943183 100644 --- a/burn-tensor/src/tests/ops/arange.rs +++ b/burn-tensor/src/tests/ops/arange.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(arange)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn test_arange() { - let tensor = Tensor::::arange(2..5); - assert_eq!(tensor.into_data(), Data::from([2, 3, 4])); - } + #[test] + fn test_arange() { + let tensor = Tensor::::arange(2..5); + assert_eq!(tensor.into_data(), Data::from([2, 3, 4])); + } - #[test] - fn test_arange_device() { - let device = ::Device::default(); + #[test] + fn test_arange_device() { + let device = ::Device::default(); - let tensor = Tensor::::arange_device(2..5, &device); - assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4])); - assert_eq!(tensor.device(), device); - } + let tensor = Tensor::::arange_device(2..5, &device); + assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4])); + assert_eq!(tensor.device(), device); + } } diff --git a/burn-tensor/src/tests/ops/arange_step.rs b/burn-tensor/src/tests/ops/arange_step.rs index 127f234eca..0922ac5620 100644 --- a/burn-tensor/src/tests/ops/arange_step.rs +++ b/burn-tensor/src/tests/ops/arange_step.rs @@ -1,46 +1,46 @@ #[burn_tensor_testgen::testgen(arange_step)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_arange_step() { - // Test correct sequence of numbers when the range is 0..9 and the step is 1 - let tensor = Tensor::::arange_step(0..9, 1); - assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); - - // Test correct sequence of numbers when the range is 0..3 and the step is 2 - let tensor = Tensor::::arange_step(0..3, 2); - assert_eq!(tensor.into_data(), Data::from([0, 2])); - - // Test correct sequence of numbers when the range is 0..2 and the step is 5 - let tensor = Tensor::::arange_step(0..2, 5); - assert_eq!(tensor.into_data(), Data::from([0])); - } - - #[test] - fn test_arange_step_device() { - let device = ::Device::default(); - - // Test correct sequence of numbers when the range is 0..9 and the step is 1 - let tensor = Tensor::::arange_step_device(0..9, 1, &device); - assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); - - // Test correct sequence of numbers when the range is 0..3 and the step is 2 - let tensor = Tensor::::arange_step_device(0..3, 2, &device); - assert_eq!(tensor.into_data(), Data::from([0, 2])); - - // Test correct sequence of numbers when the range is 0..2 and the step is 5 - let tensor = Tensor::::arange_step_device(0..2, 5, &device); - assert_eq!(tensor.clone().into_data(), Data::from([0])); - assert_eq!(tensor.device(), device); - } - - #[test] - #[should_panic] - fn should_panic_when_step_is_zero() { - // Test that arange_step panics when the step is 0 - let _tensor = Tensor::::arange_step(0..3, 0); - } + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_arange_step() { + // Test correct sequence of numbers when the range is 0..9 and the step is 1 + let tensor = Tensor::::arange_step(0..9, 1); + assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); + + // Test correct sequence of numbers when the range is 0..3 and the step is 2 + let tensor = Tensor::::arange_step(0..3, 2); + assert_eq!(tensor.into_data(), Data::from([0, 2])); + + // Test correct sequence of numbers when the range is 0..2 and the step is 5 + let tensor = Tensor::::arange_step(0..2, 5); + assert_eq!(tensor.into_data(), Data::from([0])); + } + + #[test] + fn test_arange_step_device() { + let device = ::Device::default(); + + // Test correct sequence of numbers when the range is 0..9 and the step is 1 + let tensor = Tensor::::arange_step_device(0..9, 1, &device); + assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); + + // Test correct sequence of numbers when the range is 0..3 and the step is 2 + let tensor = Tensor::::arange_step_device(0..3, 2, &device); + assert_eq!(tensor.into_data(), Data::from([0, 2])); + + // Test correct sequence of numbers when the range is 0..2 and the step is 5 + let tensor = Tensor::::arange_step_device(0..2, 5, &device); + assert_eq!(tensor.clone().into_data(), Data::from([0])); + assert_eq!(tensor.device(), device); + } + + #[test] + #[should_panic] + fn should_panic_when_step_is_zero() { + // Test that arange_step panics when the step is 0 + let _tensor = Tensor::::arange_step(0..3, 0); + } } diff --git a/burn-tensor/src/tests/ops/arg.rs b/burn-tensor/src/tests/ops/arg.rs index fd6f282b76..c954ec5739 100644 --- a/burn-tensor/src/tests/ops/arg.rs +++ b/burn-tensor/src/tests/ops/arg.rs @@ -1,71 +1,71 @@ #[burn_tensor_testgen::testgen(arg)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn test_argmax_2d_dim0() { - let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmax_2d_dim0() { + let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmax(0); + let data_actual = tensor.argmax(0); - let data_expected = Data::from([[0, 0, 1]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 0, 1]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmin_2d_dim0() { - let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmin_2d_dim0() { + let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmin(0); + let data_actual = tensor.argmin(0); - let data_expected = Data::from([[0, 1, 0]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 1, 0]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmax_2d_dim0_int() { - let data = Data::from([[10, 11, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmax_2d_dim0_int() { + let data = Data::from([[10, 11, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmax(0); + let data_actual = tensor.argmax(0); - let data_expected = Data::from([[0, 0, 1]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 0, 1]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmin_2d_dim0_int() { - let data = Data::from([[10, 11, 2], [30, 4, 5]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmin_2d_dim0_int() { + let data = Data::from([[10, 11, 2], [30, 4, 5]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmin(0); + let data_actual = tensor.argmin(0); - let data_expected = Data::from([[0, 1, 0]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 1, 0]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmax_2d_dim1() { - let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmax_2d_dim1() { + let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmax(1); + let data_actual = tensor.argmax(1); - let data_expected = Data::from([[1], [2]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[1], [2]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmin_2d_dim1() { - let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmin_2d_dim1() { + let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmin(1); + let data_actual = tensor.argmin(1); - let data_expected = Data::from([[2], [1]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[2], [1]]); + assert_eq!(data_expected, data_actual.to_data()); + } } diff --git a/burn-tensor/src/tests/ops/cast.rs b/burn-tensor/src/tests/ops/cast.rs index e6ac40ae74..057901273f 100644 --- a/burn-tensor/src/tests/ops/cast.rs +++ b/burn-tensor/src/tests/ops/cast.rs @@ -1,43 +1,43 @@ #[burn_tensor_testgen::testgen(cast)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; - - #[test] - fn cast_float_to_int() { - let tensor = Tensor::::from_data([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]); - - let actual = tensor.int().into_data(); - let expected = Data::from([[1, 2, 3], [4, 5, 6]]); - assert_eq!(expected, actual); - } - - #[test] - fn cast_int_to_float_tensor() { - let tensor = Tensor::::from_data([[1, 2, 3], [4, 5, 6]]); - - let actual = tensor.float().into_data(); - let expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - assert_eq!(expected, actual); - } - - #[test] - fn cast_bool_to_int_tensor() { - let tensor = - Tensor::::from_data([[true, false, true], [false, false, true]]); - - let actual = tensor.int().into_data(); - let expected = Data::from([[1, 0, 1], [0, 0, 1]]); - assert_eq!(expected, actual); - } - - #[test] - fn cast_bool_to_float_tensor() { - let tensor = - Tensor::::from_data([[true, false, true], [false, false, true]]); - - let actual = tensor.float().into_data(); - let expected = Data::from([[1., 0., 1.], [0., 0., 1.]]); - assert_eq!(expected, actual); - } + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; + + #[test] + fn cast_float_to_int() { + let tensor = Tensor::::from_data([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]); + + let actual = tensor.int().into_data(); + let expected = Data::from([[1, 2, 3], [4, 5, 6]]); + assert_eq!(expected, actual); + } + + #[test] + fn cast_int_to_float_tensor() { + let tensor = Tensor::::from_data([[1, 2, 3], [4, 5, 6]]); + + let actual = tensor.float().into_data(); + let expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + assert_eq!(expected, actual); + } + + #[test] + fn cast_bool_to_int_tensor() { + let tensor = + Tensor::::from_data([[true, false, true], [false, false, true]]); + + let actual = tensor.int().into_data(); + let expected = Data::from([[1, 0, 1], [0, 0, 1]]); + assert_eq!(expected, actual); + } + + #[test] + fn cast_bool_to_float_tensor() { + let tensor = + Tensor::::from_data([[true, false, true], [false, false, true]]); + + let actual = tensor.float().into_data(); + let expected = Data::from([[1., 0., 1.], [0., 0., 1.]]); + assert_eq!(expected, actual); + } } diff --git a/burn-tensor/src/tests/ops/cat.rs b/burn-tensor/src/tests/ops/cat.rs index f01311b2d6..519600b619 100644 --- a/burn-tensor/src/tests/ops/cat.rs +++ b/burn-tensor/src/tests/ops/cat.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(cat)] mod tests { - use super::*; - use alloc::vec::Vec; - use burn_tensor::{Bool, Data, Int, Tensor}; - #[test] - fn should_support_cat_ops_2d_dim0() { - let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); - let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); - - let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_cat_ops_int() { - let tensor_1 = Tensor::::from_data([[1, 2, 3]]); - let tensor_2 = Tensor::::from_data([[4, 5, 6]]); - - let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[1, 2, 3], [4, 5, 6]]); - assert_eq!(&data_actual, &data_expected); - } - - #[test] - fn should_support_cat_ops_bool() { - let tensor_1 = Tensor::::from_data([[false, true, true]]); - let tensor_2 = Tensor::::from_data([[true, true, false]]); - - let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[false, true, true], [true, true, false]]); - assert_eq!(&data_actual, &data_expected); - } - - #[test] - fn should_support_cat_ops_2d_dim1() { - let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); - let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); - - let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 1).into_data(); - - let data_expected = Data::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_cat_ops_3d() { - let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); - let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); - - let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - #[should_panic] - fn should_panic_when_dimensions_are_not_the_same() { - let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]); - let tensor_2 = TestTensor::from_data([[4.0, 5.0]]); - - TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - } - - #[test] - #[should_panic] - fn should_panic_when_list_of_vectors_is_empty() { - let tensor: Vec> = vec![]; - TestTensor::cat(tensor, 0).into_data(); - } - - #[test] - #[should_panic] - fn should_panic_when_cat_exceeds_dimension() { - let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); - let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); - - TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data(); - } + use super::*; + use alloc::vec::Vec; + use burn_tensor::{Bool, Data, Int, Tensor}; + #[test] + fn should_support_cat_ops_2d_dim0() { + let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); + let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); + + let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_cat_ops_int() { + let tensor_1 = Tensor::::from_data([[1, 2, 3]]); + let tensor_2 = Tensor::::from_data([[4, 5, 6]]); + + let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[1, 2, 3], [4, 5, 6]]); + assert_eq!(&data_actual, &data_expected); + } + + #[test] + fn should_support_cat_ops_bool() { + let tensor_1 = Tensor::::from_data([[false, true, true]]); + let tensor_2 = Tensor::::from_data([[true, true, false]]); + + let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[false, true, true], [true, true, false]]); + assert_eq!(&data_actual, &data_expected); + } + + #[test] + fn should_support_cat_ops_2d_dim1() { + let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); + let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); + + let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 1).into_data(); + + let data_expected = Data::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_cat_ops_3d() { + let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); + let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); + + let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + #[should_panic] + fn should_panic_when_dimensions_are_not_the_same() { + let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]); + let tensor_2 = TestTensor::from_data([[4.0, 5.0]]); + + TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + } + + #[test] + #[should_panic] + fn should_panic_when_list_of_vectors_is_empty() { + let tensor: Vec> = vec![]; + TestTensor::cat(tensor, 0).into_data(); + } + + #[test] + #[should_panic] + fn should_panic_when_cat_exceeds_dimension() { + let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); + let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); + + TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data(); + } } diff --git a/burn-tensor/src/tests/ops/clamp.rs b/burn-tensor/src/tests/ops/clamp.rs index 2acda4c0dd..6c8ddd7b85 100644 --- a/burn-tensor/src/tests/ops/clamp.rs +++ b/burn-tensor/src/tests/ops/clamp.rs @@ -1,60 +1,60 @@ #[burn_tensor_testgen::testgen(clamp)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn clamp_min() { - // test float tensor - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.clamp_min(2.0).into_data(); - - let data_expected = Data::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]); - assert_eq!(data_expected, data_actual); - - // test int tensor - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp_min(2).into_data(); - let data_expected = Data::from([[2, 2, 2], [3, 4, 5]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn clamp_max() { - // test float tensor - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.clamp_max(2.0).into_data(); - - let data_expected = Data::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]); - assert_eq!(data_expected, data_actual); - - // test int tensor - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp_max(4).into_data(); - let data_expected = Data::from([[0, 1, 2], [3, 4, 4]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn clamp_min_max() { - // test float tensor - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp(1.0, 4.0).into_data(); - let data_expected = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]); - assert_eq!(data_expected, data_actual); - - // test int tensor - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp(1, 4).into_data(); - let data_expected = Data::from([[1, 1, 2], [3, 4, 4]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn clamp_min() { + // test float tensor + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.clamp_min(2.0).into_data(); + + let data_expected = Data::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]); + assert_eq!(data_expected, data_actual); + + // test int tensor + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp_min(2).into_data(); + let data_expected = Data::from([[2, 2, 2], [3, 4, 5]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn clamp_max() { + // test float tensor + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.clamp_max(2.0).into_data(); + + let data_expected = Data::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]); + assert_eq!(data_expected, data_actual); + + // test int tensor + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp_max(4).into_data(); + let data_expected = Data::from([[0, 1, 2], [3, 4, 4]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn clamp_min_max() { + // test float tensor + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp(1.0, 4.0).into_data(); + let data_expected = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]); + assert_eq!(data_expected, data_actual); + + // test int tensor + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp(1, 4).into_data(); + let data_expected = Data::from([[1, 1, 2], [3, 4, 4]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/cos.rs b/burn-tensor/src/tests/ops/cos.rs index c099b1e4e9..b193cd991a 100644 --- a/burn-tensor/src/tests/ops/cos.rs +++ b/burn-tensor/src/tests/ops/cos.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(cos)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_cos_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_cos_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.cos().into_data(); + let data_actual = tensor.cos().into_data(); - let data_expected = Data::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/create_like.rs b/burn-tensor/src/tests/ops/create_like.rs index ea54aeabc6..80374043d6 100644 --- a/burn-tensor/src/tests/ops/create_like.rs +++ b/burn-tensor/src/tests/ops/create_like.rs @@ -1,52 +1,49 @@ #[burn_tensor_testgen::testgen(create_like)] mod tests { - use super::*; - use burn_tensor::{Data, Distribution, Tensor}; + use super::*; + use burn_tensor::{Data, Distribution, Tensor}; - #[test] - fn should_support_zeros_like() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_zeros_like() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.zeros_like().into_data(); + let data_actual = tensor.zeros_like().into_data(); - let data_expected = - Data::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]); + let data_expected = Data::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_ones_like() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_ones_like() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.ones_like().into_data(); + let data_actual = tensor.ones_like().into_data(); - let data_expected = - Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); + let data_expected = Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_randoms_like() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_randoms_like() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor - .random_like(Distribution::Uniform(0.99999, 1.)) - .into_data(); + let data_actual = tensor + .random_like(Distribution::Uniform(0.99999, 1.)) + .into_data(); - let data_expected = - Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); + let data_expected = Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/div.rs b/burn-tensor/src/tests/ops/div.rs index fb3e91f070..a7f17fe19b 100644 --- a/burn-tensor/src/tests/ops/div.rs +++ b/burn-tensor/src/tests/ops/div.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(div)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn should_support_div_ops() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 / tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_div_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 / tensor_2).into_data(); - - let data_expected = Data::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_div_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor / scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_div_ops_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[1, 1, 2], [1, 1, 2]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 / tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 1, 1], [3, 4, 2]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_div_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[1, 1, 2], [3, 4, 5]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 / tensor_2).into_data(); - - let data_expected = Data::from([[0, 1, 1], [0, 0, 0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_div_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor / scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 0, 1], [1, 2, 2]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn should_support_div_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 / tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_div_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 / tensor_2).into_data(); + + let data_expected = Data::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_div_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor / scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_div_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[1, 1, 2], [1, 1, 2]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 / tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 1, 1], [3, 4, 2]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_div_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[1, 1, 2], [3, 4, 5]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 / tensor_2).into_data(); + + let data_expected = Data::from([[0, 1, 1], [0, 0, 0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_div_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor / scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 0, 1], [1, 2, 2]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/erf.rs b/burn-tensor/src/tests/ops/erf.rs index 14afb899d9..aeac9ac496 100644 --- a/burn-tensor/src/tests/ops/erf.rs +++ b/burn-tensor/src/tests/ops/erf.rs @@ -1,30 +1,30 @@ #[burn_tensor_testgen::testgen(erf)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_erf_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_erf_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.erf().into_data(); + let data_actual = tensor.erf().into_data(); - let data_expected = Data::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_erf_ops_with_negative_number() { - let data = Data::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_erf_ops_with_negative_number() { + let data = Data::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.erf().into_data(); + let data_actual = tensor.erf().into_data(); - let data_expected = Data::from([ - [-0.06312324, -0.048490416, -0.10016122], - [1.0000, 1.0000, 1.0000], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([ + [-0.06312324, -0.048490416, -0.10016122], + [1.0000, 1.0000, 1.0000], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/exp.rs b/burn-tensor/src/tests/ops/exp.rs index 278b7b6a44..f0c1d203dc 100644 --- a/burn-tensor/src/tests/ops/exp.rs +++ b/burn-tensor/src/tests/ops/exp.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(exp)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_exp_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_exp_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.exp().into_data(); + let data_actual = tensor.exp().into_data(); - let data_expected = Data::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/flatten.rs b/burn-tensor/src/tests/ops/flatten.rs index 7876bfac3d..65f05477ac 100644 --- a/burn-tensor/src/tests/ops/flatten.rs +++ b/burn-tensor/src/tests/ops/flatten.rs @@ -1,58 +1,58 @@ #[burn_tensor_testgen::testgen(flatten)] mod tests { - use super::*; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Shape, Tensor}; - /// Test if the function can successfully flatten a 4D tensor to a 1D tensor. - #[test] - fn should_flatten_to_1d() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(0, 3); - let expected_shape = Shape::new([120]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten a 4D tensor to a 1D tensor. + #[test] + fn should_flatten_to_1d() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(0, 3); + let expected_shape = Shape::new([120]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function can successfully flatten the middle dimensions of a 4D tensor. - #[test] - fn should_flatten_middle() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(1, 2); - let expected_shape = Shape::new([2, 12, 5]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten the middle dimensions of a 4D tensor. + #[test] + fn should_flatten_middle() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(1, 2); + let expected_shape = Shape::new([2, 12, 5]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function can successfully flatten the first dimensions of a 4D tensor. - #[test] - fn should_flatten_begin() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(0, 2); - let expected_shape = Shape::new([24, 5]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten the first dimensions of a 4D tensor. + #[test] + fn should_flatten_begin() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(0, 2); + let expected_shape = Shape::new([24, 5]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function can successfully flatten the last dimensions of a 4D tensor. - #[test] - fn should_flatten_end() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(1, 3); - let expected_shape = Shape::new([2, 60]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten the last dimensions of a 4D tensor. + #[test] + fn should_flatten_end() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(1, 3); + let expected_shape = Shape::new([2, 60]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function panics when the start dimension is greater than the end dimension. - #[test] - #[should_panic] - fn should_flatten_panic() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(2, 0); - } + /// Test if the function panics when the start dimension is greater than the end dimension. + #[test] + #[should_panic] + fn should_flatten_panic() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(2, 0); + } - #[test] - #[should_panic] - fn not_enough_destination_dimension() { - let tensor = Tensor::::ones(Shape::new([1, 5, 15])); - let flattened_tensor: Tensor = tensor.flatten(1, 2); - let expected_shape = Shape::new([75]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + #[test] + #[should_panic] + fn not_enough_destination_dimension() { + let tensor = Tensor::::ones(Shape::new([1, 5, 15])); + let flattened_tensor: Tensor = tensor.flatten(1, 2); + let expected_shape = Shape::new([75]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } } diff --git a/burn-tensor/src/tests/ops/full.rs b/burn-tensor/src/tests/ops/full.rs index c1de8e8592..d2e4a7abc2 100644 --- a/burn-tensor/src/tests/ops/full.rs +++ b/burn-tensor/src/tests/ops/full.rs @@ -1,25 +1,25 @@ #[burn_tensor_testgen::testgen(full)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Shape, Tensor}; - #[test] - fn test_data_full() { - let data_actual = Data::full([2, 3].into(), 2.0); - let data_expected = Data::from([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]); - assert_eq!(data_expected, data_actual); - } + #[test] + fn test_data_full() { + let data_actual = Data::full([2, 3].into(), 2.0); + let data_expected = Data::from([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn test_tensor_full() { - // Test full with f32 - let tensor = Tensor::::full([2, 3], 2.1); - let data_expected = Data::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]); - assert_eq!(data_expected, tensor.into_data()); + #[test] + fn test_tensor_full() { + // Test full with f32 + let tensor = Tensor::::full([2, 3], 2.1); + let data_expected = Data::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]); + assert_eq!(data_expected, tensor.into_data()); - // Test full with Int - let int_tensor = Tensor::::full([2, 2], 2); - let data_expected = Data::from([[2, 2], [2, 2]]); - assert_eq!(data_expected, int_tensor.into_data()); - } + // Test full with Int + let int_tensor = Tensor::::full([2, 2], 2); + let data_expected = Data::from([[2, 2], [2, 2]]); + assert_eq!(data_expected, int_tensor.into_data()); + } } diff --git a/burn-tensor/src/tests/ops/gather_scatter.rs b/burn-tensor/src/tests/ops/gather_scatter.rs index 7b9abc7820..38f0710d34 100644 --- a/burn-tensor/src/tests/ops/gather_scatter.rs +++ b/burn-tensor/src/tests/ops/gather_scatter.rs @@ -1,177 +1,177 @@ #[burn_tensor_testgen::testgen(gather_scatter)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_gather_1d_dim0() { - let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]); - let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); + #[test] + fn should_gather_1d_dim0() { + let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]); + let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); - let output = tensor.gather(0, indices); + let output = tensor.gather(0, indices); - assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); - } - - #[test] - fn should_gather_1d_dim0_int() { - let tensor = TestTensorInt::from_ints([5, 6, 7]); - let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); - - let output = tensor.gather(0, indices); - - assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); - } + assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); + } + + #[test] + fn should_gather_1d_dim0_int() { + let tensor = TestTensorInt::from_ints([5, 6, 7]); + let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); + + let output = tensor.gather(0, indices); + + assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); + } - #[test] - fn should_gather_2d_dim0() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]); - - let output = tensor.gather(0, indices); - - assert_eq!( - output.into_data(), - Data::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]) - ); - } - - #[test] - fn should_gather_2d_dim1() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]); - - let output = tensor.gather(1, indices); - - assert_eq!( - output.into_data(), - Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]) - ); - } - - #[test] - fn should_gather_3d_dim1() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); - - let output = tensor.gather(1, indices); - - assert_eq!( - output.into_data(), - Data::from([ - [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], - [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]] - ]) - ); - } - - #[test] - fn should_gather_2d_only_1dim() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]); - - let output = tensor.gather(1, indices); - - assert_eq!(output.into_data(), Data::from([[1.0], [5.0]])); - } - - #[test] - fn should_scatter_1d() { - let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); - let values = TestTensor::from_floats([5.0, 4.0, 3.0]); - let indices = TestTensorInt::from_ints([1, 0, 2]); - - let output = tensor.scatter(0, indices, values); - - assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); - } - - #[test] - fn should_scatter_1d_int() { - let tensor = TestTensorInt::from_ints([0, 0, 0]); - let values = TestTensorInt::from_ints([5, 4, 3]); - let indices = TestTensorInt::from_ints([1, 0, 2]); - - let output = tensor.scatter(0, indices, values); - - assert_eq!(output.into_data(), Data::from([4, 5, 3])); - } - - #[test] - fn should_scatter_2d_dim0() { - let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); - let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]); - - let output = tensor.scatter(0, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]) - ); - } - - #[test] - fn should_scatter_2d_dim1() { - let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); - let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]); - - let output = tensor.scatter(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]) - ); - } - - #[test] - fn should_scatter_3d_dim1() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - let values = TestTensor::from_floats([ - [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], - ]); - let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); - - let output = tensor.scatter(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([ - [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], - [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]] - ]) - ); - } - - #[test] - fn should_scatter_2d_dim1_diff_shape() { - let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); - let values = TestTensor::from_floats([[1.0], [4.0]]); - let indices = TestTensorInt::from_ints([[1], [2]]); - - let output = tensor.scatter(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]) - ); - } - - #[test] - #[should_panic] - fn scatter_should_panic_on_mismatch_of_shapes() { - let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); - let values = TestTensor::from_floats([5.0, 4.0]); - let indices = TestTensorInt::from_ints([1, 0, 2]); - - tensor.scatter(0, indices, values); - } + #[test] + fn should_gather_2d_dim0() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]); + + let output = tensor.gather(0, indices); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]) + ); + } + + #[test] + fn should_gather_2d_dim1() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]); + + let output = tensor.gather(1, indices); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_gather_3d_dim1() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); + + let output = tensor.gather(1, indices); + + assert_eq!( + output.into_data(), + Data::from([ + [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], + [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]] + ]) + ); + } + + #[test] + fn should_gather_2d_only_1dim() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]); + + let output = tensor.gather(1, indices); + + assert_eq!(output.into_data(), Data::from([[1.0], [5.0]])); + } + + #[test] + fn should_scatter_1d() { + let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); + let values = TestTensor::from_floats([5.0, 4.0, 3.0]); + let indices = TestTensorInt::from_ints([1, 0, 2]); + + let output = tensor.scatter(0, indices, values); + + assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); + } + + #[test] + fn should_scatter_1d_int() { + let tensor = TestTensorInt::from_ints([0, 0, 0]); + let values = TestTensorInt::from_ints([5, 4, 3]); + let indices = TestTensorInt::from_ints([1, 0, 2]); + + let output = tensor.scatter(0, indices, values); + + assert_eq!(output.into_data(), Data::from([4, 5, 3])); + } + + #[test] + fn should_scatter_2d_dim0() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]); + + let output = tensor.scatter(0, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]) + ); + } + + #[test] + fn should_scatter_2d_dim1() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]); + + let output = tensor.scatter(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_scatter_3d_dim1() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + let values = TestTensor::from_floats([ + [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ]); + let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); + + let output = tensor.scatter(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([ + [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], + [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]] + ]) + ); + } + + #[test] + fn should_scatter_2d_dim1_diff_shape() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0], [4.0]]); + let indices = TestTensorInt::from_ints([[1], [2]]); + + let output = tensor.scatter(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]) + ); + } + + #[test] + #[should_panic] + fn scatter_should_panic_on_mismatch_of_shapes() { + let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); + let values = TestTensor::from_floats([5.0, 4.0]); + let indices = TestTensorInt::from_ints([1, 0, 2]); + + tensor.scatter(0, indices, values); + } } diff --git a/burn-tensor/src/tests/ops/init.rs b/burn-tensor/src/tests/ops/init.rs index 7a89527cd3..599d0ef95a 100644 --- a/burn-tensor/src/tests/ops/init.rs +++ b/burn-tensor/src/tests/ops/init.rs @@ -1,58 +1,58 @@ #[burn_tensor_testgen::testgen(init)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; - #[test] - fn should_support_float_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } + #[test] + fn should_support_float_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } - #[test] - fn should_support_int_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } + #[test] + fn should_support_int_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } - #[test] - fn should_support_float_zeros() { - let shape = [2, 2]; - let tensor = Tensor::::zeros(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[0., 0.], [0., 0.]])) - } + #[test] + fn should_support_float_zeros() { + let shape = [2, 2]; + let tensor = Tensor::::zeros(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[0., 0.], [0., 0.]])) + } - #[test] - fn should_support_int_zeros() { - let shape = [2, 2]; - let tensor = Tensor::::zeros(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[0, 0], [0, 0]])) - } + #[test] + fn should_support_int_zeros() { + let shape = [2, 2]; + let tensor = Tensor::::zeros(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[0, 0], [0, 0]])) + } - #[test] - fn should_support_float_ones() { - let shape = [2, 2]; - let tensor = Tensor::::ones(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[1., 1.], [1., 1.]])) - } + #[test] + fn should_support_float_ones() { + let shape = [2, 2]; + let tensor = Tensor::::ones(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[1., 1.], [1., 1.]])) + } - #[test] - fn should_support_int_ones() { - let shape = [2, 2]; - let tensor = Tensor::::ones(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[1, 1], [1, 1]])) - } + #[test] + fn should_support_int_ones() { + let shape = [2, 2]; + let tensor = Tensor::::ones(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[1, 1], [1, 1]])) + } - #[test] - fn should_support_bool_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } + #[test] + fn should_support_bool_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } } diff --git a/burn-tensor/src/tests/ops/iter_dim.rs b/burn-tensor/src/tests/ops/iter_dim.rs index 08d581a339..d12c9ed787 100644 --- a/burn-tensor/src/tests/ops/iter_dim.rs +++ b/burn-tensor/src/tests/ops/iter_dim.rs @@ -1,46 +1,46 @@ #[burn_tensor_testgen::testgen(iter_dim)] mod test { - use super::*; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn test_1d_iter_last_item() { - let data = [1, 2, 3, 4]; - let tensor = Tensor::::from_ints(data); - assert_eq!( - Tensor::::from_ints([4]).into_data(), - tensor.iter_dim(0).last().unwrap().into_data() - ) - } + #[test] + fn test_1d_iter_last_item() { + let data = [1, 2, 3, 4]; + let tensor = Tensor::::from_ints(data); + assert_eq!( + Tensor::::from_ints([4]).into_data(), + tensor.iter_dim(0).last().unwrap().into_data() + ) + } - #[test] - #[should_panic] - fn test_too_high_dimension() { - Tensor::::zeros([10]).iter_dim(1); - } + #[test] + #[should_panic] + fn test_too_high_dimension() { + Tensor::::zeros([10]).iter_dim(1); + } - #[test] - fn test_transposed() { - let data = [ - [1., 2., 3., 1., 2.], - [4., 5., 6., 1., 2.], - [7., 8., 9., 1., 2.], - ]; - let tensor = Tensor::::from_floats(data); - let lhs = tensor.clone().slice([1..2, 0..5]); - let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap(); - assert_eq!(lhs.into_data().value, rhs.into_data().value); - } + #[test] + fn test_transposed() { + let data = [ + [1., 2., 3., 1., 2.], + [4., 5., 6., 1., 2.], + [7., 8., 9., 1., 2.], + ]; + let tensor = Tensor::::from_floats(data); + let lhs = tensor.clone().slice([1..2, 0..5]); + let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap(); + assert_eq!(lhs.into_data().value, rhs.into_data().value); + } - fn test_iteration_over_low_dim() { - let data = [[ - [1., 2., 3., 1., 2.], - [4., 5., 6., 1., 2.], - [7., 8., 9., 1., 2.], - ]; 5]; - let tensor = Tensor::::from_floats(data); - let lhs = tensor.iter_dim(2).nth(1).unwrap(); - let rhs = Data::from([2., 5., 8.]); - assert_eq!(lhs.into_data().value, rhs.value); - } + fn test_iteration_over_low_dim() { + let data = [[ + [1., 2., 3., 1., 2.], + [4., 5., 6., 1., 2.], + [7., 8., 9., 1., 2.], + ]; 5]; + let tensor = Tensor::::from_floats(data); + let lhs = tensor.iter_dim(2).nth(1).unwrap(); + let rhs = Data::from([2., 5., 8.]); + assert_eq!(lhs.into_data().value, rhs.value); + } } diff --git a/burn-tensor/src/tests/ops/log.rs b/burn-tensor/src/tests/ops/log.rs index f71387317a..4532643487 100644 --- a/burn-tensor/src/tests/ops/log.rs +++ b/burn-tensor/src/tests/ops/log.rs @@ -1,19 +1,19 @@ #[burn_tensor_testgen::testgen(log)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_log_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_log_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.log().into_data(); + let data_actual = tensor.log().into_data(); - let data_expected = Data::from([ - [-f32::INFINITY, 0.0, core::f32::consts::LN_2], - [1.0986, 1.3862, 1.6094], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([ + [-f32::INFINITY, 0.0, core::f32::consts::LN_2], + [1.0986, 1.3862, 1.6094], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/log1p.rs b/burn-tensor/src/tests/ops/log1p.rs index 346e01b2e3..fe6f2b5b4e 100644 --- a/burn-tensor/src/tests/ops/log1p.rs +++ b/burn-tensor/src/tests/ops/log1p.rs @@ -1,19 +1,19 @@ #[burn_tensor_testgen::testgen(log1p)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_exp_log1p() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_exp_log1p() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.log1p().into_data(); + let data_actual = tensor.log1p().into_data(); - let data_expected = Data::from([ - [0.0, core::f32::consts::LN_2, 1.0986], - [1.3862, 1.6094, 1.7917], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([ + [0.0, core::f32::consts::LN_2, 1.0986], + [1.3862, 1.6094, 1.7917], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/map_comparison.rs b/burn-tensor/src/tests/ops/map_comparison.rs index 76eee5dc02..0906c2bc5b 100644 --- a/burn-tensor/src/tests/ops/map_comparison.rs +++ b/burn-tensor/src/tests/ops/map_comparison.rs @@ -1,308 +1,308 @@ #[burn_tensor_testgen::testgen(map_comparison)] mod tests { - use super::*; - use burn_tensor::{ - backend::Backend, BasicOps, Bool, Data, Element, Float, Int, Numeric, Tensor, TensorKind, - }; - - type IntElem = ::IntElem; - type FloatElem = ::FloatElem; - - #[test] - fn test_equal() { - equal::() - } - - #[test] - fn test_int_equal() { - equal::() - } - - #[test] - fn test_equal_elem() { - equal_elem::() - } - - #[test] - fn test_int_equal_elem() { - equal_elem::() - } - - #[test] - fn test_greater_elem() { - greater_elem::() - } - - #[test] - fn test_int_greater_elem() { - greater_elem::() - } - - #[test] - fn test_greater_equal_elem() { - greater_equal_elem::() - } - - #[test] - fn test_int_greater_equal_elem() { - greater_equal_elem::() - } - - #[test] - fn test_greater() { - greater::() - } - - #[test] - fn test_int_greater() { - greater::() - } - - #[test] - fn test_greater_equal() { - greater_equal::() - } - - #[test] - fn test_int_greater_equal() { - greater_equal::() - } - - #[test] - fn test_lower_elem() { - lower_elem::() - } - - #[test] - fn test_int_lower_elem() { - lower_elem::() - } - - #[test] - fn test_lower_equal_elem() { - lower_equal_elem::() - } - - #[test] - fn test_int_lower_equal_elem() { - lower_equal_elem::() - } - - #[test] - fn test_lower() { - lower::() - } - - #[test] - fn test_int_lower() { - lower::() - } - - #[test] - fn test_lower_equal() { - lower_equal::() - } - - #[test] - fn test_int_lower_equal() { - lower_equal::() - } - - fn equal() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.equal(tensor_2); - - let data_expected = Data::from([[false, true, false], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn equal_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().equal_elem(2); - let data_actual_inplace = tensor_1.equal_elem(2); - - let data_expected = Data::from([[false, false, true], [false, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().greater_elem(4); - let data_actual_inplace = tensor_1.greater_elem(4); - - let data_expected = Data::from([[false, false, false], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater_equal_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0); - let data_actual_inplace = tensor_1.greater_equal_elem(4.0); - - let data_expected = Data::from([[false, false, false], [false, true, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); - let data_actual_inplace = tensor_1.greater(tensor_2); - - let data_expected = Data::from([[false, false, true], [false, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater_equal() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.greater_equal(tensor_2); - - let data_expected = Data::from([[false, true, true], [false, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().lower_elem(4.0); - let data_actual_inplace = tensor_1.lower_elem(4.0); - - let data_expected = Data::from([[true, true, true], [true, false, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower_equal_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0); - let data_actual_inplace = tensor_1.lower_equal_elem(4.0); - - let data_expected = Data::from([[true, true, true], [true, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); - let data_actual_inplace = tensor_1.lower(tensor_2); - - let data_expected = Data::from([[true, false, false], [true, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower_equal() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.lower_equal(tensor_2); - - let data_expected = Data::from([[true, true, false], [true, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - #[test] - fn should_support_bool_equal() { - let data_1 = Data::from([[false, true, true], [true, false, true]]); - let data_2 = Data::from([[false, false, true], [false, true, true]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.equal(tensor_2); - - let data_expected = Data::from([[true, false, true], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - #[test] - fn should_support_bool_not() { - let data_1 = Data::from([[false, true, true], [true, true, false]]); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().bool_not(); - let data_actual_inplace = tensor_1.bool_not(); - - let data_expected = Data::from([[true, false, false], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } + use super::*; + use burn_tensor::{ + backend::Backend, BasicOps, Bool, Data, Element, Float, Int, Numeric, Tensor, TensorKind, + }; + + type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + + #[test] + fn test_equal() { + equal::() + } + + #[test] + fn test_int_equal() { + equal::() + } + + #[test] + fn test_equal_elem() { + equal_elem::() + } + + #[test] + fn test_int_equal_elem() { + equal_elem::() + } + + #[test] + fn test_greater_elem() { + greater_elem::() + } + + #[test] + fn test_int_greater_elem() { + greater_elem::() + } + + #[test] + fn test_greater_equal_elem() { + greater_equal_elem::() + } + + #[test] + fn test_int_greater_equal_elem() { + greater_equal_elem::() + } + + #[test] + fn test_greater() { + greater::() + } + + #[test] + fn test_int_greater() { + greater::() + } + + #[test] + fn test_greater_equal() { + greater_equal::() + } + + #[test] + fn test_int_greater_equal() { + greater_equal::() + } + + #[test] + fn test_lower_elem() { + lower_elem::() + } + + #[test] + fn test_int_lower_elem() { + lower_elem::() + } + + #[test] + fn test_lower_equal_elem() { + lower_equal_elem::() + } + + #[test] + fn test_int_lower_equal_elem() { + lower_equal_elem::() + } + + #[test] + fn test_lower() { + lower::() + } + + #[test] + fn test_int_lower() { + lower::() + } + + #[test] + fn test_lower_equal() { + lower_equal::() + } + + #[test] + fn test_int_lower_equal() { + lower_equal::() + } + + fn equal() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.equal(tensor_2); + + let data_expected = Data::from([[false, true, false], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn equal_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().equal_elem(2); + let data_actual_inplace = tensor_1.equal_elem(2); + + let data_expected = Data::from([[false, false, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().greater_elem(4); + let data_actual_inplace = tensor_1.greater_elem(4); + + let data_expected = Data::from([[false, false, false], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater_equal_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0); + let data_actual_inplace = tensor_1.greater_equal_elem(4.0); + + let data_expected = Data::from([[false, false, false], [false, true, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); + let data_actual_inplace = tensor_1.greater(tensor_2); + + let data_expected = Data::from([[false, false, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater_equal() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.greater_equal(tensor_2); + + let data_expected = Data::from([[false, true, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().lower_elem(4.0); + let data_actual_inplace = tensor_1.lower_elem(4.0); + + let data_expected = Data::from([[true, true, true], [true, false, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower_equal_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0); + let data_actual_inplace = tensor_1.lower_equal_elem(4.0); + + let data_expected = Data::from([[true, true, true], [true, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); + let data_actual_inplace = tensor_1.lower(tensor_2); + + let data_expected = Data::from([[true, false, false], [true, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower_equal() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.lower_equal(tensor_2); + + let data_expected = Data::from([[true, true, false], [true, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn should_support_bool_equal() { + let data_1 = Data::from([[false, true, true], [true, false, true]]); + let data_2 = Data::from([[false, false, true], [false, true, true]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.equal(tensor_2); + + let data_expected = Data::from([[true, false, true], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn should_support_bool_not() { + let data_1 = Data::from([[false, true, true], [true, true, false]]); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().bool_not(); + let data_actual_inplace = tensor_1.bool_not(); + + let data_expected = Data::from([[true, false, false], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } } diff --git a/burn-tensor/src/tests/ops/mask.rs b/burn-tensor/src/tests/ops/mask.rs index 735c31127e..6a815ef2ff 100644 --- a/burn-tensor/src/tests/ops/mask.rs +++ b/burn-tensor/src/tests/ops/mask.rs @@ -1,55 +1,55 @@ #[burn_tensor_testgen::testgen(mask)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; - - #[test] - fn should_support_mask_where_ops() { - let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - let value = Tensor::::from_data(Data::from([[1.8, 2.8], [3.8, 4.8]])); - - let data_actual = tensor.mask_where(mask, value).into_data(); - - let data_expected = Data::from([[1.8, 7.0], [2.0, 4.8]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mask_fill_ops() { - let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - - let data_actual = tensor.mask_fill(mask, 2.0).to_data(); - - let data_expected = Data::from([[2.0, 7.0], [2.0, 2.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_int_mask_where_ops() { - let tensor = Tensor::::from_data([[1, 7], [2, 3]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - let value = Tensor::::from_data(Data::from([[8, 9], [10, 11]])); - - let data_actual = tensor.mask_where(mask, value).into_data(); - - let data_expected = Data::from([[8, 7], [2, 11]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_int_mask_fill_ops() { - let tensor = Tensor::::from_data([[1, 7], [2, 3]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - - let data_actual = tensor.mask_fill(mask, 9).to_data(); - - let data_expected = Data::from([[9, 7], [2, 9]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; + + #[test] + fn should_support_mask_where_ops() { + let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + let value = Tensor::::from_data(Data::from([[1.8, 2.8], [3.8, 4.8]])); + + let data_actual = tensor.mask_where(mask, value).into_data(); + + let data_expected = Data::from([[1.8, 7.0], [2.0, 4.8]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mask_fill_ops() { + let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + + let data_actual = tensor.mask_fill(mask, 2.0).to_data(); + + let data_expected = Data::from([[2.0, 7.0], [2.0, 2.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_int_mask_where_ops() { + let tensor = Tensor::::from_data([[1, 7], [2, 3]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + let value = Tensor::::from_data(Data::from([[8, 9], [10, 11]])); + + let data_actual = tensor.mask_where(mask, value).into_data(); + + let data_expected = Data::from([[8, 7], [2, 11]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_int_mask_fill_ops() { + let tensor = Tensor::::from_data([[1, 7], [2, 3]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + + let data_actual = tensor.mask_fill(mask, 9).to_data(); + + let data_expected = Data::from([[9, 7], [2, 9]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/matmul.rs b/burn-tensor/src/tests/ops/matmul.rs index 6edf59067c..adc0090c5f 100644 --- a/burn-tensor/src/tests/ops/matmul.rs +++ b/burn-tensor/src/tests/ops/matmul.rs @@ -1,108 +1,105 @@ #[burn_tensor_testgen::testgen(matmul)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; - - #[test] - fn test_matmul_d2() { - let tensor_1 = TestTensor::from_floats([[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]]); - let tensor_2 = TestTensor::from_floats([[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]) - ); - } - - #[test] - fn test_matmul_d3() { - let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); - let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[[18.0, 28.0], [14.0, 23.0]]]) - ); - } - - #[test] - fn test_matmul_broadcast_1() { - let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); - let tensor_2 = - TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]) - ); - } - - #[test] - fn test_matmul_simple_1() { - let tensor_1 = TestTensor::from_floats([[5.0, 14.0], [14.0, 50.0]]); - let tensor_2 = TestTensor::from_floats([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]]) - ); - } - - #[test] - fn test_matmul_simple_2() { - let tensor_1 = TestTensor::from_floats([[1.0, 2.0, 3.0, 4.0]]); - let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!(tensor_3.into_data(), Data::from([[50.0]])); - } - - #[test] - fn test_matmul_simple_3() { - let tensor_1 = - TestTensor::from_floats([[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]); - let tensor_2 = - TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([ - [9., 18., 27., 36.], - [12., 24., 36., 48.], - [15., 30., 45., 60.], - [18., 36., 54., 72.] - ]) - ); - } - - #[test] - #[should_panic] - fn should_panic_when_inner_dimensions_are_not_equal() { - let tensor_1 = TestTensor::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]); - let tensor_2 = - TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([ - [9., 18., 27., 36.], - [12., 24., 36., 48.], - [15., 30., 45., 60.], - [18., 36., 54., 72.] - ]) - ); - } + use super::*; + use burn_tensor::{Data, Tensor}; + + #[test] + fn test_matmul_d2() { + let tensor_1 = TestTensor::from_floats([[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]]); + let tensor_2 = TestTensor::from_floats([[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]) + ); + } + + #[test] + fn test_matmul_d3() { + let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); + let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[[18.0, 28.0], [14.0, 23.0]]]) + ); + } + + #[test] + fn test_matmul_broadcast_1() { + let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); + let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]) + ); + } + + #[test] + fn test_matmul_simple_1() { + let tensor_1 = TestTensor::from_floats([[5.0, 14.0], [14.0, 50.0]]); + let tensor_2 = TestTensor::from_floats([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]]) + ); + } + + #[test] + fn test_matmul_simple_2() { + let tensor_1 = TestTensor::from_floats([[1.0, 2.0, 3.0, 4.0]]); + let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!(tensor_3.into_data(), Data::from([[50.0]])); + } + + #[test] + fn test_matmul_simple_3() { + let tensor_1 = + TestTensor::from_floats([[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]); + let tensor_2 = TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([ + [9., 18., 27., 36.], + [12., 24., 36., 48.], + [15., 30., 45., 60.], + [18., 36., 54., 72.] + ]) + ); + } + + #[test] + #[should_panic] + fn should_panic_when_inner_dimensions_are_not_equal() { + let tensor_1 = TestTensor::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]); + let tensor_2 = TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([ + [9., 18., 27., 36.], + [12., 24., 36., 48.], + [15., 30., 45., 60.], + [18., 36., 54., 72.] + ]) + ); + } } diff --git a/burn-tensor/src/tests/ops/maxmin.rs b/burn-tensor/src/tests/ops/maxmin.rs index 0dea58522d..3cbda1fd0b 100644 --- a/burn-tensor/src/tests/ops/maxmin.rs +++ b/burn-tensor/src/tests/ops/maxmin.rs @@ -1,51 +1,51 @@ #[burn_tensor_testgen::testgen(maxmin)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn test_max_dim_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_max_dim_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let output_actual = tensor.max_dim(1); + let output_actual = tensor.max_dim(1); - let output_expected = Data::from([[2.], [5.]]); - assert_eq!(output_expected, output_actual.into_data()); - } + let output_expected = Data::from([[2.], [5.]]); + assert_eq!(output_expected, output_actual.into_data()); + } - #[test] - fn test_max_dim_with_indices_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_max_dim_with_indices_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let (output_actual, index_actual) = tensor.max_dim_with_indices(1); + let (output_actual, index_actual) = tensor.max_dim_with_indices(1); - let output_expected = Data::from([[2.], [5.]]); - let index_expected = Data::from([[2], [2]]); + let output_expected = Data::from([[2.], [5.]]); + let index_expected = Data::from([[2], [2]]); - assert_eq!(output_expected, output_actual.into_data()); - assert_eq!(index_expected, index_actual.into_data()); - } + assert_eq!(output_expected, output_actual.into_data()); + assert_eq!(index_expected, index_actual.into_data()); + } - #[test] - fn test_min_dim_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_min_dim_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let output_actual = tensor.min_dim(1); + let output_actual = tensor.min_dim(1); - let output_expected = Data::from([[0.], [3.]]); - assert_eq!(output_expected, output_actual.into_data()); - } + let output_expected = Data::from([[0.], [3.]]); + assert_eq!(output_expected, output_actual.into_data()); + } - #[test] - fn test_min_dim_with_indices_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_min_dim_with_indices_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let (output_actual, index_actual) = tensor.min_dim_with_indices(1); + let (output_actual, index_actual) = tensor.min_dim_with_indices(1); - let output_expected = Data::from([[0.], [3.]]); - let index_expected = Data::from([[0], [0]]); + let output_expected = Data::from([[0.], [3.]]); + let index_expected = Data::from([[0], [0]]); - assert_eq!(output_expected, output_actual.into_data()); - assert_eq!(index_expected, index_actual.into_data()); - } + assert_eq!(output_expected, output_actual.into_data()); + assert_eq!(index_expected, index_actual.into_data()); + } } diff --git a/burn-tensor/src/tests/ops/mul.rs b/burn-tensor/src/tests/ops/mul.rs index 81337b808f..f3983a4005 100644 --- a/burn-tensor/src/tests/ops/mul.rs +++ b/burn-tensor/src/tests/ops/mul.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(mul)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn should_support_mul_ops() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 * tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_mul_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 * tensor_2).into_data(); - - let data_expected = Data::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mul_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor * scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mul_ops_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 * tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 1, 4], [9, 16, 25]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_mul_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 * tensor_2).into_data(); - - let data_expected = Data::from([[0, 4, 10], [0, 7, 16]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mul_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor * scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 2, 4], [6, 8, 10]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn should_support_mul_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 * tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_mul_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 * tensor_2).into_data(); + + let data_expected = Data::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor * scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 * tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 1, 4], [9, 16, 25]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_mul_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 * tensor_2).into_data(); + + let data_expected = Data::from([[0, 4, 10], [0, 7, 16]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor * scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 2, 4], [6, 8, 10]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/neg.rs b/burn-tensor/src/tests/ops/neg.rs index bcea87b40c..3393418c41 100644 --- a/burn-tensor/src/tests/ops/neg.rs +++ b/burn-tensor/src/tests/ops/neg.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(neg)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_neg_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.neg().into_data(); + let data_actual = tensor.neg().into_data(); - let data_expected = Data::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/one_hot.rs b/burn-tensor/src/tests/ops/one_hot.rs index 1dd09382d9..7dc96e7547 100644 --- a/burn-tensor/src/tests/ops/one_hot.rs +++ b/burn-tensor/src/tests/ops/one_hot.rs @@ -1,32 +1,32 @@ #[burn_tensor_testgen::testgen(one_hot)] mod tests { - use super::*; - use burn_tensor::{Data, Int}; + use super::*; + use burn_tensor::{Data, Int}; - #[test] - fn should_support_one_hot() { - let tensor = TestTensor::<1>::one_hot(0, 5); - assert_eq!(tensor.to_data(), Data::from([1., 0., 0., 0., 0.])); + #[test] + fn should_support_one_hot() { + let tensor = TestTensor::<1>::one_hot(0, 5); + assert_eq!(tensor.to_data(), Data::from([1., 0., 0., 0., 0.])); - let tensor = TestTensor::<1>::one_hot(1, 5); - assert_eq!(tensor.to_data(), Data::from([0., 1., 0., 0., 0.])); + let tensor = TestTensor::<1>::one_hot(1, 5); + assert_eq!(tensor.to_data(), Data::from([0., 1., 0., 0., 0.])); - let tensor = TestTensor::<1>::one_hot(4, 5); - assert_eq!(tensor.to_data(), Data::from([0., 0., 0., 0., 1.])); + let tensor = TestTensor::<1>::one_hot(4, 5); + assert_eq!(tensor.to_data(), Data::from([0., 0., 0., 0., 1.])); - let tensor = TestTensor::<1>::one_hot(1, 2); - assert_eq!(tensor.to_data(), Data::from([0., 1.])); - } + let tensor = TestTensor::<1>::one_hot(1, 2); + assert_eq!(tensor.to_data(), Data::from([0., 1.])); + } - #[test] - #[should_panic] - fn should_panic_when_index_exceeds_number_of_classes() { - let tensor = TestTensor::<1>::one_hot(1, 1); - } + #[test] + #[should_panic] + fn should_panic_when_index_exceeds_number_of_classes() { + let tensor = TestTensor::<1>::one_hot(1, 1); + } - #[test] - #[should_panic] - fn should_panic_when_number_of_classes_is_zero() { - let tensor = TestTensor::<1>::one_hot(0, 0); - } + #[test] + #[should_panic] + fn should_panic_when_number_of_classes_is_zero() { + let tensor = TestTensor::<1>::one_hot(0, 0); + } } diff --git a/burn-tensor/src/tests/ops/powf.rs b/burn-tensor/src/tests/ops/powf.rs index 59f9abca5d..b98c56578f 100644 --- a/burn-tensor/src/tests/ops/powf.rs +++ b/burn-tensor/src/tests/ops/powf.rs @@ -1,49 +1,49 @@ #[burn_tensor_testgen::testgen(powf)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_powf_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_powf_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(0.71).into_data(); + let data_actual = tensor.powf(0.71).into_data(); - let data_expected = Data::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_neg_power() { - let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_power() { + let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(-0.33).into_data(); + let data_actual = tensor.powf(-0.33).into_data(); - let data_expected = Data::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_neg_values_with_even_power() { - let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_values_with_even_power() { + let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(4.0).into_data(); + let data_actual = tensor.powf(4.0).into_data(); - let data_expected = Data::from([[0.0, 1.0, 16.0], [81.0, 256.0, 625.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 1.0, 16.0], [81.0, 256.0, 625.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_neg_values_with_odd_power() { - let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_values_with_odd_power() { + let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(3.0).into_data(); + let data_actual = tensor.powf(3.0).into_data(); - let data_expected = Data::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/random.rs b/burn-tensor/src/tests/ops/random.rs index 5aeccfa7a4..1bc15a1c3a 100644 --- a/burn-tensor/src/tests/ops/random.rs +++ b/burn-tensor/src/tests/ops/random.rs @@ -1,27 +1,27 @@ #[burn_tensor_testgen::testgen(random)] mod tests { - use super::*; - use burn_tensor::{Distribution, Tensor}; + use super::*; + use burn_tensor::{Distribution, Tensor}; - #[test] - fn rand_default() { - let tensor = Tensor::::random([20], Distribution::Default); + #[test] + fn rand_default() { + let tensor = Tensor::::random([20], Distribution::Default); - // check that the tensor is within the range of [0..1) (1 is exclusive) - tensor.into_data().assert_within_range(0.0..1.0); - } + // check that the tensor is within the range of [0..1) (1 is exclusive) + tensor.into_data().assert_within_range(0.0..1.0); + } - #[test] - fn rand_uniform() { - let tensor = Tensor::::random([20], Distribution::Uniform(4., 5.)); + #[test] + fn rand_uniform() { + let tensor = Tensor::::random([20], Distribution::Uniform(4., 5.)); - tensor.into_data().assert_within_range(4.0..5.0); - } + tensor.into_data().assert_within_range(4.0..5.0); + } - #[test] - fn rand_bernoulli() { - let tensor = Tensor::::random([20], Distribution::Bernoulli(1.)); + #[test] + fn rand_bernoulli() { + let tensor = Tensor::::random([20], Distribution::Bernoulli(1.)); - assert_eq!(tensor.into_data(), [1.; 20].into()); - } + assert_eq!(tensor.into_data(), [1.; 20].into()); + } } diff --git a/burn-tensor/src/tests/ops/recip.rs b/burn-tensor/src/tests/ops/recip.rs index 70395fd60b..9e700d67fa 100644 --- a/burn-tensor/src/tests/ops/recip.rs +++ b/burn-tensor/src/tests/ops/recip.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(recip)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_recip_ops() { - let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_recip_ops() { + let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.recip().into_data(); + let data_actual = tensor.recip().into_data(); - let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/repeat.rs b/burn-tensor/src/tests/ops/repeat.rs index 8725decb26..d436b11330 100644 --- a/burn-tensor/src/tests/ops/repeat.rs +++ b/burn-tensor/src/tests/ops/repeat.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(repeat)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_repeat_ops() { - let data = Data::from([[0.0, 1.0, 2.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_repeat_ops() { + let data = Data::from([[0.0, 1.0, 2.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.repeat(0, 4).into_data(); + let data_actual = tensor.repeat(0, 4).into_data(); - let data_expected = Data::from([ - [0.0, 1.0, 2.0], - [0.0, 1.0, 2.0], - [0.0, 1.0, 2.0], - [0.0, 1.0, 2.0], - ]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([ + [0.0, 1.0, 2.0], + [0.0, 1.0, 2.0], + [0.0, 1.0, 2.0], + [0.0, 1.0, 2.0], + ]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/reshape.rs b/burn-tensor/src/tests/ops/reshape.rs index 9c02dd5132..6de805dc7c 100644 --- a/burn-tensor/src/tests/ops/reshape.rs +++ b/burn-tensor/src/tests/ops/reshape.rs @@ -1,88 +1,88 @@ #[burn_tensor_testgen::testgen(reshape)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; - #[test] - fn should_support_reshape_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([1, 3]).into_data(); - let data_expected = Data::from([[0.0, 1.0, 2.0]]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([1, 3]).into_data(); + let data_expected = Data::from([[0.0, 1.0, 2.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_reshape_int() { - let data = Data::from([0, 1, 2]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_int() { + let data = Data::from([0, 1, 2]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([1, 3]).into_data(); - let data_expected = Data::from([[0, 1, 2]]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([1, 3]).into_data(); + let data_expected = Data::from([[0, 1, 2]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_reshape_bool() { - let data = Data::from([false, true, false]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_bool() { + let data = Data::from([false, true, false]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([1, 3]).into_data(); - let data_expected = Data::from([[false, true, false]]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([1, 3]).into_data(); + let data_expected = Data::from([[false, true, false]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_reshape_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([6]).into_data(); - let data_expected = Data::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([6]).into_data(); + let data_expected = Data::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_dim_infererence() { - let data = Data::from([ - [0.0, 1.0, 2.0], - [3.0, 4.0, 5.0], - [6.0, 7.0, 8.0], - [9.0, 10.0, 11.0], - ]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_dim_infererence() { + let data = Data::from([ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + ]); + let tensor = Tensor::::from_data(data); - // Infer the dimension via -1 - let reshaped = tensor.clone().reshape([2, -1]); - assert_eq!(reshaped.shape(), [2, 6].into()); + // Infer the dimension via -1 + let reshaped = tensor.clone().reshape([2, -1]); + assert_eq!(reshaped.shape(), [2, 6].into()); - // Infer the dimension via 0 (keep from the source) and -1 (infer) - let reshaped = reshaped.reshape([0, 2, -1]); - assert_eq!(reshaped.shape(), [2, 2, 3].into()); + // Infer the dimension via 0 (keep from the source) and -1 (infer) + let reshaped = reshaped.reshape([0, 2, -1]); + assert_eq!(reshaped.shape(), [2, 2, 3].into()); - // This is effectively as if we did a flatten - let reshaped = tensor.clone().reshape([-1]); - assert_eq!(reshaped.shape(), [12].into()); + // This is effectively as if we did a flatten + let reshaped = tensor.clone().reshape([-1]); + assert_eq!(reshaped.shape(), [12].into()); - // Keeping the first dimension the same (using 0) - let reshaped = tensor.clone().reshape([0, 3]); - assert_eq!(reshaped.shape(), [4, 3].into()); - } + // Keeping the first dimension the same (using 0) + let reshaped = tensor.clone().reshape([0, 3]); + assert_eq!(reshaped.shape(), [4, 3].into()); + } - #[test] - #[should_panic] - fn multiple_neg_ones() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.reshape([-1, -1]).into_data(); - } + #[test] + #[should_panic] + fn multiple_neg_ones() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.reshape([-1, -1]).into_data(); + } - #[test] - #[should_panic] - fn neg_value() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.reshape([-2, -1]).into_data(); - } + #[test] + #[should_panic] + fn neg_value() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.reshape([-2, -1]).into_data(); + } } diff --git a/burn-tensor/src/tests/ops/select.rs b/burn-tensor/src/tests/ops/select.rs index d62fdaeae7..823168618a 100644 --- a/burn-tensor/src/tests/ops/select.rs +++ b/burn-tensor/src/tests/ops/select.rs @@ -1,128 +1,128 @@ #[burn_tensor_testgen::testgen(select)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_select_1d() { - let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - let output = tensor.select(0, indices); - - assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); - } - - #[test] - fn should_select_1d_int() { - let tensor = TestTensorInt::from_data([5, 6, 7]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - let output = tensor.select(0, indices); - - assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); - } - - #[test] - fn should_select_2d_dim0_same_num_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data(([1, 0])); - - let output = tensor.select(0, indices); - - assert_eq!( - output.into_data(), - Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]) - ); - } - - #[test] - fn should_select_2d_dim0_more_num_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data([1, 0, 1, 1]); - - let output = tensor.select(0, indices); - - assert_eq!( - output.into_data(), - Data::from([ - [3.0, 4.0, 5.0], - [0.0, 1.0, 2.0], - [3.0, 4.0, 5.0], - [3.0, 4.0, 5.0] - ]) - ); - } - - #[test] - fn should_select_2d_dim1() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - let output = tensor.select(1, indices); - - assert_eq!( - output.into_data(), - Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]) - ); - } - - #[test] - fn should_select_assign_1d() { - let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); - let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0]); - let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - - let output = tensor.select_assign(0, indices, values); - - assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0])); - } - - #[test] - fn should_select_assign_1d_int() { - let tensor = TestTensorInt::from_data([7, 8, 9]); - let values = TestTensorInt::from_data([5, 4, 3, 2, 1]); - let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - - let output = tensor.select_assign(0, indices, values); - - assert_eq!(output.into_data(), Data::from([10, 19, 10])); - } - - #[test] - fn should_select_assign_2d_dim0() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_data(Data::from([1, 0])); - - let output = tensor.select_assign(0, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]) - ); - } - - #[test] - fn should_select_assign_2d_dim1() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_data(Data::from([1, 0, 2])); - - let output = tensor.select_assign(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]]) - ); - } - - #[test] - #[should_panic] - fn should_select_panic_invalid_dimension() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - tensor.select(10, indices); - } + #[test] + fn should_select_1d() { + let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + let output = tensor.select(0, indices); + + assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); + } + + #[test] + fn should_select_1d_int() { + let tensor = TestTensorInt::from_data([5, 6, 7]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + let output = tensor.select(0, indices); + + assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); + } + + #[test] + fn should_select_2d_dim0_same_num_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data(([1, 0])); + + let output = tensor.select(0, indices); + + assert_eq!( + output.into_data(), + Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]) + ); + } + + #[test] + fn should_select_2d_dim0_more_num_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data([1, 0, 1, 1]); + + let output = tensor.select(0, indices); + + assert_eq!( + output.into_data(), + Data::from([ + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [3.0, 4.0, 5.0] + ]) + ); + } + + #[test] + fn should_select_2d_dim1() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + let output = tensor.select(1, indices); + + assert_eq!( + output.into_data(), + Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_select_assign_1d() { + let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); + let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0]); + let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); + + let output = tensor.select_assign(0, indices, values); + + assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0])); + } + + #[test] + fn should_select_assign_1d_int() { + let tensor = TestTensorInt::from_data([7, 8, 9]); + let values = TestTensorInt::from_data([5, 4, 3, 2, 1]); + let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); + + let output = tensor.select_assign(0, indices, values); + + assert_eq!(output.into_data(), Data::from([10, 19, 10])); + } + + #[test] + fn should_select_assign_2d_dim0() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_data(Data::from([1, 0])); + + let output = tensor.select_assign(0, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]) + ); + } + + #[test] + fn should_select_assign_2d_dim1() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_data(Data::from([1, 0, 2])); + + let output = tensor.select_assign(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]]) + ); + } + + #[test] + #[should_panic] + fn should_select_panic_invalid_dimension() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + tensor.select(10, indices); + } } diff --git a/burn-tensor/src/tests/ops/sin.rs b/burn-tensor/src/tests/ops/sin.rs index c7f9685947..24518d0f96 100644 --- a/burn-tensor/src/tests/ops/sin.rs +++ b/burn-tensor/src/tests/ops/sin.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(sin)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_sin_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_sin_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.sin().into_data(); + let data_actual = tensor.sin().into_data(); - let data_expected = Data::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/slice.rs b/burn-tensor/src/tests/ops/slice.rs index 7db32a6722..84ccd02d6e 100644 --- a/burn-tensor/src/tests/ops/slice.rs +++ b/burn-tensor/src/tests/ops/slice.rs @@ -1,150 +1,150 @@ #[burn_tensor_testgen::testgen(slice)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_full_sliceing_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + fn should_support_full_sliceing_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([0..3]).into_data(); + let data_actual = tensor.slice([0..3]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - fn should_support_partial_sliceing_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_partial_sliceing_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.slice([1..3]).into_data(); + let data_actual = tensor.slice([1..3]).into_data(); - let data_expected = Data::from([1.0, 2.0]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([1.0, 2.0]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_full_sliceing_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + fn should_support_full_sliceing_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual_1 = tensor.clone().slice([0..2]).into_data(); - let data_actual_2 = tensor.slice([0..2, 0..3]).into_data(); + let data_actual_1 = tensor.clone().slice([0..2]).into_data(); + let data_actual_2 = tensor.slice([0..2, 0..3]).into_data(); - assert_eq!(data, data_actual_1); - assert_eq!(data, data_actual_2); - } + assert_eq!(data, data_actual_1); + assert_eq!(data, data_actual_2); + } - #[test] - fn should_support_partial_sliceing_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_partial_sliceing_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.slice([0..2, 0..2]).into_data(); + let data_actual = tensor.slice([0..2, 0..2]).into_data(); - let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_partial_sliceing_3d() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_partial_sliceing_3d() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.slice([1..2, 1..2, 0..2]).into_data(); + let data_actual = tensor.slice([1..2, 1..2, 0..2]).into_data(); - let data_expected = Data::from([[[9.0, 10.0]]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[[9.0, 10.0]]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_partial_sliceing_3d_non_contiguous() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_partial_sliceing_3d_non_contiguous() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.transpose().slice([1..2, 1..2, 0..2]).into_data(); + let data_actual = tensor.transpose().slice([1..2, 1..2, 0..2]).into_data(); - let data_expected = Data::from([[[7.0, 10.0]]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[[7.0, 10.0]]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_slice_assign_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let data_assigned = Data::from([10.0, 5.0]); + #[test] + fn should_support_slice_assign_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let data_assigned = Data::from([10.0, 5.0]); - let tensor = Tensor::::from_data(data); - let tensor_assigned = Tensor::::from_data(data_assigned); + let tensor = Tensor::::from_data(data); + let tensor_assigned = Tensor::::from_data(data_assigned); - let data_actual = tensor.slice_assign([0..2], tensor_assigned).into_data(); + let data_actual = tensor.slice_assign([0..2], tensor_assigned).into_data(); - let data_expected = Data::from([10.0, 5.0, 2.0]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([10.0, 5.0, 2.0]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_slice_assign_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_assigned = Data::from([[10.0, 5.0]]); + #[test] + fn should_support_slice_assign_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_assigned = Data::from([[10.0, 5.0]]); - let tensor = Tensor::::from_data(data); - let tensor_assigned = Tensor::::from_data(data_assigned); + let tensor = Tensor::::from_data(data); + let tensor_assigned = Tensor::::from_data(data_assigned); - let data_actual = tensor - .slice_assign([1..2, 0..2], tensor_assigned) - .into_data(); + let data_actual = tensor + .slice_assign([1..2, 0..2], tensor_assigned) + .into_data(); - let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_exceeds_dimension() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_exceeds_dimension() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([0..4]).into_data(); + let data_actual = tensor.slice([0..4]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_with_too_many_dimensions() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_with_too_many_dimensions() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([0..1, 0..1]).into_data(); + let data_actual = tensor.slice([0..1, 0..1]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_is_desc() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_is_desc() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - #[allow(clippy::reversed_empty_ranges)] - let data_actual = tensor.slice([2..1]).into_data(); + #[allow(clippy::reversed_empty_ranges)] + let data_actual = tensor.slice([2..1]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_is_equal() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_is_equal() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([1..1]).into_data(); + let data_actual = tensor.slice([1..1]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/sqrt.rs b/burn-tensor/src/tests/ops/sqrt.rs index 3044b1d9bc..593e399c1a 100644 --- a/burn-tensor/src/tests/ops/sqrt.rs +++ b/burn-tensor/src/tests/ops/sqrt.rs @@ -1,17 +1,17 @@ #[burn_tensor_testgen::testgen(sqrt)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; - use core::f32::consts::SQRT_2; + use super::*; + use burn_tensor::{Data, Tensor}; + use core::f32::consts::SQRT_2; - #[test] - fn should_support_sqrt_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_sqrt_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.sqrt().into_data(); + let data_actual = tensor.sqrt().into_data(); - let data_expected = Data::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/squeeze.rs b/burn-tensor/src/tests/ops/squeeze.rs index d8de064bdd..8d1e9fb3cf 100644 --- a/burn-tensor/src/tests/ops/squeeze.rs +++ b/burn-tensor/src/tests/ops/squeeze.rs @@ -1,37 +1,37 @@ #[burn_tensor_testgen::testgen(squeeze)] mod tests { - use super::*; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Shape, Tensor}; - /// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor. - #[test] - fn should_squeeze() { - let tensor = Tensor::::ones(Shape::new([2, 1, 4])); - let squeezed_tensor: Tensor = tensor.squeeze(1); - let expected_shape = Shape::new([2, 4]); - assert_eq!(squeezed_tensor.shape(), expected_shape); - } - /// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor. - #[test] - fn should_squeeze_first() { - let tensor = Tensor::::ones(Shape::new([1, 3, 4, 5])); - let squeezed_tensor: Tensor = tensor.squeeze(0); - let expected_shape = Shape::new([3, 4, 5]); - assert_eq!(squeezed_tensor.shape(), expected_shape); - } - /// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor. - #[test] - fn should_squeeze_last() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 1])); - let squeezed_tensor: Tensor = tensor.squeeze(3); - let expected_shape = Shape::new([2, 3, 4]); - assert_eq!(squeezed_tensor.shape(), expected_shape); - } - /// Test if the function panics when the squeezed dimension is not of size 1. - #[test] - #[should_panic] - fn should_squeeze_panic() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let squeezed_tensor: Tensor = tensor.squeeze(2); - } + /// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor. + #[test] + fn should_squeeze() { + let tensor = Tensor::::ones(Shape::new([2, 1, 4])); + let squeezed_tensor: Tensor = tensor.squeeze(1); + let expected_shape = Shape::new([2, 4]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor. + #[test] + fn should_squeeze_first() { + let tensor = Tensor::::ones(Shape::new([1, 3, 4, 5])); + let squeezed_tensor: Tensor = tensor.squeeze(0); + let expected_shape = Shape::new([3, 4, 5]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor. + #[test] + fn should_squeeze_last() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 1])); + let squeezed_tensor: Tensor = tensor.squeeze(3); + let expected_shape = Shape::new([2, 3, 4]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function panics when the squeezed dimension is not of size 1. + #[test] + #[should_panic] + fn should_squeeze_panic() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let squeezed_tensor: Tensor = tensor.squeeze(2); + } } diff --git a/burn-tensor/src/tests/ops/sub.rs b/burn-tensor/src/tests/ops/sub.rs index 3293379abd..d2d1adba54 100644 --- a/burn-tensor/src/tests/ops/sub.rs +++ b/burn-tensor/src/tests/ops/sub.rs @@ -1,83 +1,83 @@ #[burn_tensor_testgen::testgen(sub)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn should_support_sub_ops() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); - let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_sub_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - let data_expected = Data::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_sub_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor - scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_sub_ops_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); - let data_expected = Data::from([[-6, -6, -6], [-6, -6, -6]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_sub_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - let data_expected = Data::from([[-3, -3, -3], [-6, -6, -6]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_sub_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor - scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[-2, -1, 0], [1, 2, 3]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn should_support_sub_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); + let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_sub_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + let data_expected = Data::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_sub_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor - scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_sub_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); + let data_expected = Data::from([[-6, -6, -6], [-6, -6, -6]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_sub_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + let data_expected = Data::from([[-3, -3, -3], [-6, -6, -6]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_sub_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor - scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[-2, -1, 0], [1, 2, 3]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/tanh.rs b/burn-tensor/src/tests/ops/tanh.rs index b65b1be7ab..5d95b948f7 100644 --- a/burn-tensor/src/tests/ops/tanh.rs +++ b/burn-tensor/src/tests/ops/tanh.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(tanh)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_tanh_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_tanh_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.tanh().into_data(); + let data_actual = tensor.tanh().into_data(); - let data_expected = Data::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/transpose.rs b/burn-tensor/src/tests/ops/transpose.rs index 83ae67b045..fcd5566aff 100644 --- a/burn-tensor/src/tests/ops/transpose.rs +++ b/burn-tensor/src/tests/ops/transpose.rs @@ -1,97 +1,93 @@ #[burn_tensor_testgen::testgen(transpose)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; - - #[test] - fn should_support_transpose_ops() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - - let data_actual = tensor.transpose().into_data(); - - let data_expected = Data::from([ - [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], - [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_swap_dims() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - - let data_actual = tensor.swap_dims(0, 2).into_data(); - - let data_expected = Data::from([ - [[0.0, 6.0], [3.0, 9.0]], - [[1.0, 7.0], [4.0, 10.0]], - [[2.0, 8.0], [5.0, 11.0]], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_transpose_ops_int() { - let tensor = Tensor::::from_data([ - [[0, 1, 2], [3, 4, 5]], - [[6, 7, 8], [9, 10, 11]], - ]); - - let data_actual = tensor.transpose().into_data(); - - let data_expected = Data::from([[[0, 3], [1, 4], [2, 5]], [[6, 9], [7, 10], [8, 11]]]); - assert_eq!(&data_expected, &data_actual); - } - - #[test] - fn should_support_swap_dims_int() { - let tensor = Tensor::::from_data([ - [[0, 1, 2], [3, 4, 5]], - [[6, 7, 8], [9, 10, 11]], - ]); - - let data_actual = tensor.swap_dims(0, 2).into_data(); - - let data_expected = Data::from([[[0, 6], [3, 9]], [[1, 7], [4, 10]], [[2, 8], [5, 11]]]); - assert_eq!(&data_expected, &data_actual); - } - - #[test] - fn should_support_transpose_bool() { - let tensor = Tensor::::from_data([ - [[false, true, false], [false, false, false]], - [[false, false, true], [false, false, true]], - ]); - - let data_actual = tensor.transpose().into_data(); - - let data_expected = Data::from([ - [[false, false], [true, false], [false, false]], - [[false, false], [false, false], [true, true]], - ]); - assert_eq!(&data_expected, &data_actual); - } - - #[test] - fn should_support_swap_dims_bool() { - let tensor = Tensor::::from_data([ - [[false, true, false], [false, false, false]], - [[false, false, true], [false, false, true]], - ]); - - let data_actual = tensor.swap_dims(0, 2).into_data(); - - let data_expected = Data::from([ - [[false, false], [false, false]], - [[true, false], [false, false]], - [[false, true], [false, true]], - ]); - assert_eq!(&data_expected, &data_actual); - } + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; + + #[test] + fn should_support_transpose_ops() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + + let data_actual = tensor.transpose().into_data(); + + let data_expected = Data::from([ + [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], + [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_swap_dims() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + + let data_actual = tensor.swap_dims(0, 2).into_data(); + + let data_expected = Data::from([ + [[0.0, 6.0], [3.0, 9.0]], + [[1.0, 7.0], [4.0, 10.0]], + [[2.0, 8.0], [5.0, 11.0]], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_transpose_ops_int() { + let tensor = + Tensor::::from_data([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]); + + let data_actual = tensor.transpose().into_data(); + + let data_expected = Data::from([[[0, 3], [1, 4], [2, 5]], [[6, 9], [7, 10], [8, 11]]]); + assert_eq!(&data_expected, &data_actual); + } + + #[test] + fn should_support_swap_dims_int() { + let tensor = + Tensor::::from_data([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]); + + let data_actual = tensor.swap_dims(0, 2).into_data(); + + let data_expected = Data::from([[[0, 6], [3, 9]], [[1, 7], [4, 10]], [[2, 8], [5, 11]]]); + assert_eq!(&data_expected, &data_actual); + } + + #[test] + fn should_support_transpose_bool() { + let tensor = Tensor::::from_data([ + [[false, true, false], [false, false, false]], + [[false, false, true], [false, false, true]], + ]); + + let data_actual = tensor.transpose().into_data(); + + let data_expected = Data::from([ + [[false, false], [true, false], [false, false]], + [[false, false], [false, false], [true, true]], + ]); + assert_eq!(&data_expected, &data_actual); + } + + #[test] + fn should_support_swap_dims_bool() { + let tensor = Tensor::::from_data([ + [[false, true, false], [false, false, false]], + [[false, false, true], [false, false, true]], + ]); + + let data_actual = tensor.swap_dims(0, 2).into_data(); + + let data_expected = Data::from([ + [[false, false], [false, false]], + [[true, false], [false, false]], + [[false, true], [false, true]], + ]); + assert_eq!(&data_expected, &data_actual); + } } diff --git a/burn-tensor/src/tests/stats/cov.rs b/burn-tensor/src/tests/stats/cov.rs index f6fe8a91c5..e93ecd5c29 100644 --- a/burn-tensor/src/tests/stats/cov.rs +++ b/burn-tensor/src/tests/stats/cov.rs @@ -1,61 +1,61 @@ #[burn_tensor_testgen::testgen(cov)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Tensor}; - - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; - - #[test] - fn test_cov_1() { - let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.cov(1, 1).into_data(); - - let data_expected = Data::from([[2.4892, -1.7333], [-1.7333, 15.3333]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_cov_4() { - let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.cov(1, 0).into_data(); - - let data_expected = Data::from([[1.8668, -1.2999], [-1.2999, 11.5]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_cov_2() { - let data = Data::from([[0.5, 1.8], [0.2, -2.0], [3.0, -4.0], [5.0, 0.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.cov(1, 1).into_data(); - - let data_expected = Data::from([ - [0.845, -1.43, -4.55, -3.25], - [-1.43, 2.42, 7.7, 5.5], - [-4.55, 7.7, 24.5, 17.5], - [-3.25, 5.5, 17.5, 12.5], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_cov_3() { - let data = Data::from([ - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - ]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.cov(0, 1).into_data(); - let data_expected = Tensor::::zeros([4, 4, 4]).to_data(); - data_expected.assert_approx_eq(&data_actual, 3); - } + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Tensor}; + + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; + + #[test] + fn test_cov_1() { + let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.cov(1, 1).into_data(); + + let data_expected = Data::from([[2.4892, -1.7333], [-1.7333, 15.3333]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_cov_4() { + let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.cov(1, 0).into_data(); + + let data_expected = Data::from([[1.8668, -1.2999], [-1.2999, 11.5]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_cov_2() { + let data = Data::from([[0.5, 1.8], [0.2, -2.0], [3.0, -4.0], [5.0, 0.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.cov(1, 1).into_data(); + + let data_expected = Data::from([ + [0.845, -1.43, -4.55, -3.25], + [-1.43, 2.42, 7.7, 5.5], + [-4.55, 7.7, 24.5, 17.5], + [-3.25, 5.5, 17.5, 12.5], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_cov_3() { + let data = Data::from([ + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + ]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.cov(0, 1).into_data(); + let data_expected = Tensor::::zeros([4, 4, 4]).to_data(); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/stats/diagonal.rs b/burn-tensor/src/tests/stats/diagonal.rs index f7fda216ad..c62f0f45f8 100644 --- a/burn-tensor/src/tests/stats/diagonal.rs +++ b/burn-tensor/src/tests/stats/diagonal.rs @@ -1,18 +1,18 @@ #[burn_tensor_testgen::testgen(diagonal)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Tensor}; - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; - #[test] - fn test_diagonal() { - let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]; - let lhs = Tensor::::from_floats(data); - let rhs = Tensor::::diagonal(3); - lhs.to_data().assert_approx_eq(&rhs.to_data(), 3); - } + #[test] + fn test_diagonal() { + let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]; + let lhs = Tensor::::from_floats(data); + let rhs = Tensor::::diagonal(3); + lhs.to_data().assert_approx_eq(&rhs.to_data(), 3); + } } diff --git a/burn-tensor/src/tests/stats/display.rs b/burn-tensor/src/tests/stats/display.rs index bb4238038f..33c8aa3c44 100644 --- a/burn-tensor/src/tests/stats/display.rs +++ b/burn-tensor/src/tests/stats/display.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(display)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Shape, Tensor}; - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; - #[test] - fn test_display_2d_int_tensor() { - let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]); - let tensor_int: burn_tensor::Tensor = - Tensor::from_data(int_data); + #[test] + fn test_display_2d_int_tensor() { + let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]); + let tensor_int: burn_tensor::Tensor = + Tensor::from_data(int_data); - let output = format!("{}", tensor_int); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor_int); + let expected = format!( + r#"Tensor {{ data: [[1, 2, 3], [4, 5, 6], @@ -26,22 +26,22 @@ mod tests { kind: "Int", dtype: "{dtype}", }}"#, - tensor_int.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor_int.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_2d_float_tensor() { - let float_data = Data::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]); - let tensor_float: burn_tensor::Tensor = - Tensor::from_data(float_data); + #[test] + fn test_display_2d_float_tensor() { + let float_data = Data::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]); + let tensor_float: burn_tensor::Tensor = + Tensor::from_data(float_data); - let output = format!("{}", tensor_float); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor_float); + let expected = format!( + r#"Tensor {{ data: [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], @@ -52,26 +52,26 @@ mod tests { kind: "Float", dtype: "{dtype}", }}"#, - tensor_float.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor_float.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_2d_bool_tensor() { - let bool_data = Data::from([ - [true, false, true], - [false, true, false], - [false, true, true], - ]); - let tensor_bool: burn_tensor::Tensor = - Tensor::from_data(bool_data); + #[test] + fn test_display_2d_bool_tensor() { + let bool_data = Data::from([ + [true, false, true], + [false, true, false], + [false, true, true], + ]); + let tensor_bool: burn_tensor::Tensor = + Tensor::from_data(bool_data); - let output = format!("{}", tensor_bool); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor_bool); + let expected = format!( + r#"Tensor {{ data: [[true, false, true], [false, true, false], @@ -82,23 +82,23 @@ mod tests { kind: "Bool", dtype: "bool", }}"#, - tensor_bool.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor_bool.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_3d_tensor() { - let data = Data::from([ - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], - [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], - ]); - let tensor: burn_tensor::Tensor = Tensor::from_data(data); + #[test] + fn test_display_3d_tensor() { + let data = Data::from([ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], + ]); + let tensor: burn_tensor::Tensor = Tensor::from_data(data); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[1, 2, 3, 4], [5, 6, 7, 8], @@ -112,25 +112,25 @@ mod tests { kind: "Int", dtype: "{dtype}", }}"#, - tensor.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_4d_tensor() { - let data = Data::from([ - [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], - [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]], - ]); + #[test] + fn test_display_4d_tensor() { + let data = Data::from([ + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]], + ]); - let tensor: burn_tensor::Tensor = Tensor::from_data(data); + let tensor: burn_tensor::Tensor = Tensor::from_data(data); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[1, 2, 3], [4, 5, 6]], @@ -146,21 +146,21 @@ mod tests { kind: "Int", dtype: "{dtype}", }}"#, - tensor.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_tensor_summarize_1() { - let tensor: burn_tensor::Tensor = - Tensor::zeros(Shape::new([2, 2, 2, 1000])); + #[test] + fn test_display_tensor_summarize_1() { + let tensor: burn_tensor::Tensor = + Tensor::zeros(Shape::new([2, 2, 2, 1000])); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]], @@ -176,20 +176,20 @@ mod tests { kind: "Float", dtype: "f32", }}"#, - tensor.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_tensor_summarize_2() { - let tensor: burn_tensor::Tensor = - Tensor::zeros(Shape::new([2, 2, 20, 100])); + #[test] + fn test_display_tensor_summarize_2() { + let tensor: burn_tensor::Tensor = + Tensor::zeros(Shape::new([2, 2, 20, 100])); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], @@ -225,20 +225,20 @@ mod tests { kind: "Float", dtype: "f32", }}"#, - tensor.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_tensor_summarize_3() { - let tensor: burn_tensor::Tensor = - Tensor::zeros(Shape::new([2, 2, 200, 6])); + #[test] + fn test_display_tensor_summarize_3() { + let tensor: burn_tensor::Tensor = + Tensor::zeros(Shape::new([2, 2, 200, 6])); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], @@ -274,9 +274,9 @@ mod tests { kind: "Float", dtype: "f32", }}"#, - tensor.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } } diff --git a/burn-tensor/src/tests/stats/var.rs b/burn-tensor/src/tests/stats/var.rs index ac6ccf2f12..dfde1862e8 100644 --- a/burn-tensor/src/tests/stats/var.rs +++ b/burn-tensor/src/tests/stats/var.rs @@ -1,55 +1,55 @@ #[burn_tensor_testgen::testgen(var)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Tensor}; - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; - #[test] - fn test_var() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let data_actual = tensor.var(1).into_data(); + let data_actual = tensor.var(1).into_data(); - let data_expected = Data::from([[2.4892], [15.3333]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[2.4892], [15.3333]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn test_var_mean() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var_mean() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let (var, mean) = tensor.var_mean(1); + let (var, mean) = tensor.var_mean(1); - let var_expected = Data::from([[2.4892], [15.3333]]); - let mean_expected = Data::from([[0.125], [1.]]); + let var_expected = Data::from([[2.4892], [15.3333]]); + let mean_expected = Data::from([[0.125], [1.]]); - var_expected.assert_approx_eq(&(var.into_data()), 3); - mean_expected.assert_approx_eq(&(mean.into_data()), 3); - } + var_expected.assert_approx_eq(&(var.into_data()), 3); + mean_expected.assert_approx_eq(&(mean.into_data()), 3); + } - #[test] - fn test_var_bias() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var_bias() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let data_actual = tensor.var_bias(1).into_data(); + let data_actual = tensor.var_bias(1).into_data(); - let data_expected = Data::from([[1.86688], [11.5]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.86688], [11.5]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn test_var_mean_bias() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var_mean_bias() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let (var, mean) = tensor.var_mean_bias(1); + let (var, mean) = tensor.var_mean_bias(1); - let var_expected = Data::from([[1.86688], [11.5]]); - let mean_expected = Data::from([[0.125], [1.]]); + let var_expected = Data::from([[1.86688], [11.5]]); + let mean_expected = Data::from([[0.125], [1.]]); - var_expected.assert_approx_eq(&(var.into_data()), 3); - mean_expected.assert_approx_eq(&(mean.into_data()), 3); - } + var_expected.assert_approx_eq(&(var.into_data()), 3); + mean_expected.assert_approx_eq(&(mean.into_data()), 3); + } } diff --git a/burn-train/src/checkpoint/async_checkpoint.rs b/burn-train/src/checkpoint/async_checkpoint.rs index 83ec9fcfff..805f1dc014 100644 --- a/burn-train/src/checkpoint/async_checkpoint.rs +++ b/burn-train/src/checkpoint/async_checkpoint.rs @@ -3,118 +3,122 @@ use burn_core::record::Record; use std::sync::mpsc; enum Message { - Restore(usize, mpsc::SyncSender>), - Save(usize, R), - Delete(usize), - End, + Restore(usize, mpsc::SyncSender>), + Save(usize, R), + Delete(usize), + End, } #[derive(new)] struct CheckpointerThread { - checkpointer: C, - receiver: mpsc::Receiver>, + checkpointer: C, + receiver: mpsc::Receiver>, } impl, R: Record> CheckpointerThread { - fn run(self) { - for item in self.receiver.iter() { - match item { - Message::Restore(epoch, callback) => { - let record = self.checkpointer.restore(epoch); - callback - .send(record) - .expect("Can send response through callback channel."); - } - Message::Save(epoch, state) => self - .checkpointer - .save(epoch, state) - .expect("Can save the state."), - Message::Delete(epoch) => self - .checkpointer - .delete(epoch) - .expect("Can delete the state."), - Message::End => { - return; - } - }; + fn run(self) { + for item in self.receiver.iter() { + match item { + Message::Restore(epoch, callback) => { + let record = self.checkpointer.restore(epoch); + callback + .send(record) + .expect("Can send response through callback channel."); } + Message::Save(epoch, state) => self + .checkpointer + .save(epoch, state) + .expect("Can save the state."), + Message::Delete(epoch) => self + .checkpointer + .delete(epoch) + .expect("Can delete the state."), + Message::End => { + return; + } + }; } + } } /// Async checkpointer. pub struct AsyncCheckpointer { - sender: mpsc::SyncSender>, - handler: Option>, + sender: mpsc::SyncSender>, + handler: Option>, } impl AsyncCheckpointer { - /// Create a new async checkpointer. - /// - /// # Arguments - /// - /// * `checkpointer` - The checkpointer. - /// - /// # Returns - /// - /// The async checkpointer. - pub fn new(checkpointer: C) -> Self - where - C: Checkpointer + Send + 'static, - { - // Only on checkpoint can be done in advance. - let (sender, receiver) = mpsc::sync_channel(0); - let thread = CheckpointerThread::new(checkpointer, receiver); - let handler = Some(std::thread::spawn(move || thread.run())); + /// Create a new async checkpointer. + /// + /// # Arguments + /// + /// * `checkpointer` - The checkpointer. + /// + /// # Returns + /// + /// The async checkpointer. + pub fn new(checkpointer: C) -> Self + where + C: Checkpointer + Send + 'static, + { + // Only on checkpoint can be done in advance. + let (sender, receiver) = mpsc::sync_channel(0); + let thread = CheckpointerThread::new(checkpointer, receiver); + let handler = Some(std::thread::spawn(move || thread.run())); - Self { sender, handler } - } + Self { sender, handler } + } } impl Checkpointer for AsyncCheckpointer where - R: Record + 'static, + R: Record + 'static, { - fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { - self.sender - .send(Message::Save(epoch, record)) - .expect("Can send message to checkpointer thread."); + fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { + self + .sender + .send(Message::Save(epoch, record)) + .expect("Can send message to checkpointer thread."); - Ok(()) - } + Ok(()) + } - fn restore(&self, epoch: usize) -> Result { - let (sender, receiver) = mpsc::sync_channel(1); - self.sender - .send(Message::Restore(epoch, sender)) - .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; + fn restore(&self, epoch: usize) -> Result { + let (sender, receiver) = mpsc::sync_channel(1); + self + .sender + .send(Message::Restore(epoch, sender)) + .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; - if let Ok(record) = receiver.recv() { - return record; - }; + if let Ok(record) = receiver.recv() { + return record; + }; - Err(CheckpointerError::Unknown("Channel error.".to_string())) - } + Err(CheckpointerError::Unknown("Channel error.".to_string())) + } - fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { - self.sender - .send(Message::Delete(epoch)) - .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; + fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { + self + .sender + .send(Message::Delete(epoch)) + .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; - Ok(()) - } + Ok(()) + } } impl Drop for AsyncCheckpointer { - fn drop(&mut self) { - self.sender - .send(Message::End) - .expect("Can send the end message to the checkpointer thread."); - let handler = self.handler.take(); + fn drop(&mut self) { + self + .sender + .send(Message::End) + .expect("Can send the end message to the checkpointer thread."); + let handler = self.handler.take(); - if let Some(handler) = handler { - handler - .join() - .expect("The checkpointer thread should stop."); - } + if let Some(handler) = handler { + handler + .join() + .expect("The checkpointer thread should stop."); } + } } diff --git a/burn-train/src/checkpoint/base.rs b/burn-train/src/checkpoint/base.rs index 61a2dca986..2104db82fb 100644 --- a/burn-train/src/checkpoint/base.rs +++ b/burn-train/src/checkpoint/base.rs @@ -3,37 +3,37 @@ use burn_core::record::{Record, RecorderError}; /// The error type for checkpointer. #[derive(Debug)] pub enum CheckpointerError { - /// IO error. - IOError(std::io::Error), + /// IO error. + IOError(std::io::Error), - /// Recorder error. - RecorderError(RecorderError), + /// Recorder error. + RecorderError(RecorderError), - /// Other errors. - Unknown(String), + /// Other errors. + Unknown(String), } /// The trait for checkpointer. pub trait Checkpointer { - /// Save the record. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - /// * `record` - The record. - fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>; + /// Save the record. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + /// * `record` - The record. + fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>; - /// Delete the record at the given epoch if present. - fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>; + /// Delete the record at the given epoch if present. + fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>; - /// Restore the record. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - /// - /// # Returns - /// - /// The record. - fn restore(&self, epoch: usize) -> Result; + /// Restore the record. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + /// + /// # Returns + /// + /// The record. + fn restore(&self, epoch: usize) -> Result; } diff --git a/burn-train/src/checkpoint/file.rs b/burn-train/src/checkpoint/file.rs index 24c9c717f3..afc5e88542 100644 --- a/burn-train/src/checkpoint/file.rs +++ b/burn-train/src/checkpoint/file.rs @@ -3,68 +3,69 @@ use burn_core::record::{FileRecorder, Record}; /// The file checkpointer. pub struct FileCheckpointer { - directory: String, - name: String, - recorder: FR, + directory: String, + name: String, + recorder: FR, } impl FileCheckpointer { - /// Creates a new file checkpointer. - /// - /// # Arguments - /// - /// * `recorder` - The file recorder. - /// * `directory` - The directory to save the checkpoints. - /// * `name` - The name of the checkpoint. - pub fn new(recorder: FR, directory: &str, name: &str) -> Self { - std::fs::create_dir_all(directory).ok(); + /// Creates a new file checkpointer. + /// + /// # Arguments + /// + /// * `recorder` - The file recorder. + /// * `directory` - The directory to save the checkpoints. + /// * `name` - The name of the checkpoint. + pub fn new(recorder: FR, directory: &str, name: &str) -> Self { + std::fs::create_dir_all(directory).ok(); - Self { - directory: directory.to_string(), - name: name.to_string(), - recorder, - } - } - fn path_for_epoch(&self, epoch: usize) -> String { - format!("{}/{}-{}", self.directory, self.name, epoch) + Self { + directory: directory.to_string(), + name: name.to_string(), + recorder, } + } + fn path_for_epoch(&self, epoch: usize) -> String { + format!("{}/{}-{}", self.directory, self.name, epoch) + } } impl Checkpointer for FileCheckpointer where - R: Record, - FR: FileRecorder, + R: Record, + FR: FileRecorder, { - fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { - let file_path = self.path_for_epoch(epoch); - log::info!("Saving checkpoint {} to {}", epoch, file_path); + fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { + let file_path = self.path_for_epoch(epoch); + log::info!("Saving checkpoint {} to {}", epoch, file_path); - self.recorder - .record(record, file_path.into()) - .map_err(CheckpointerError::RecorderError)?; + self + .recorder + .record(record, file_path.into()) + .map_err(CheckpointerError::RecorderError)?; - Ok(()) - } + Ok(()) + } - fn restore(&self, epoch: usize) -> Result { - let file_path = self.path_for_epoch(epoch); - log::info!("Restoring checkpoint {} from {}", epoch, file_path); - let record = self - .recorder - .load(file_path.into()) - .map_err(CheckpointerError::RecorderError)?; + fn restore(&self, epoch: usize) -> Result { + let file_path = self.path_for_epoch(epoch); + log::info!("Restoring checkpoint {} from {}", epoch, file_path); + let record = self + .recorder + .load(file_path.into()) + .map_err(CheckpointerError::RecorderError)?; - Ok(record) - } + Ok(record) + } - fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { - let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),); + fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { + let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),); - if std::path::Path::new(&file_to_remove).exists() { - log::info!("Removing checkpoint {}", file_to_remove); - std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?; - } - - Ok(()) + if std::path::Path::new(&file_to_remove).exists() { + log::info!("Removing checkpoint {}", file_to_remove); + std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?; } + + Ok(()) + } } diff --git a/burn-train/src/checkpoint/strategy/base.rs b/burn-train/src/checkpoint/strategy/base.rs index f16acfeb41..7d4131718d 100644 --- a/burn-train/src/checkpoint/strategy/base.rs +++ b/burn-train/src/checkpoint/strategy/base.rs @@ -5,30 +5,30 @@ use crate::metric::store::EventStoreClient; /// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer). #[derive(Clone, PartialEq, Debug)] pub enum CheckpointingAction { - /// Delete the given epoch. - Delete(usize), - /// Save the current record. - Save, + /// Delete the given epoch. + Delete(usize), + /// Save the current record. + Save, } /// Define when checkpoint should be saved and deleted. pub trait CheckpointingStrategy { - /// Based on the epoch, determine if the checkpoint should be saved. - fn checkpointing( - &mut self, - epoch: usize, - collector: &EventStoreClient, - ) -> Vec; + /// Based on the epoch, determine if the checkpoint should be saved. + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec; } // We make dyn box implement the checkpointing strategy so that it can be used with generic, but // still be dynamic. impl CheckpointingStrategy for Box { - fn checkpointing( - &mut self, - epoch: usize, - collector: &EventStoreClient, - ) -> Vec { - self.deref_mut().checkpointing(epoch, collector) - } + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec { + self.deref_mut().checkpointing(epoch, collector) + } } diff --git a/burn-train/src/checkpoint/strategy/composed.rs b/burn-train/src/checkpoint/strategy/composed.rs index 8029c9ed78..38d8354aad 100644 --- a/burn-train/src/checkpoint/strategy/composed.rs +++ b/burn-train/src/checkpoint/strategy/composed.rs @@ -6,141 +6,144 @@ use std::collections::HashSet; /// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an /// epoch to be deleted. pub struct ComposedCheckpointingStrategy { - strategies: Vec>, - deleted: Vec>, + strategies: Vec>, + deleted: Vec>, } /// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones. #[derive(Default)] pub struct ComposedCheckpointingStrategyBuilder { - strategies: Vec>, + strategies: Vec>, } impl ComposedCheckpointingStrategyBuilder { - /// Add a new [checkpointing strategy](CheckpointingStrategy). - #[allow(clippy::should_implement_trait)] - pub fn add(mut self, strategy: S) -> Self - where - S: CheckpointingStrategy + 'static, - { - self.strategies.push(Box::new(strategy)); - self - } - - /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy). - pub fn build(self) -> ComposedCheckpointingStrategy { - ComposedCheckpointingStrategy::new(self.strategies) - } + /// Add a new [checkpointing strategy](CheckpointingStrategy). + #[allow(clippy::should_implement_trait)] + pub fn add(mut self, strategy: S) -> Self + where + S: CheckpointingStrategy + 'static, + { + self.strategies.push(Box::new(strategy)); + self + } + + /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy). + pub fn build(self) -> ComposedCheckpointingStrategy { + ComposedCheckpointingStrategy::new(self.strategies) + } } impl ComposedCheckpointingStrategy { - fn new(strategies: Vec>) -> Self { - Self { - deleted: strategies.iter().map(|_| HashSet::new()).collect(), - strategies, - } - } - /// Create a new builder which help compose multiple - /// [checkpointing strategies](CheckpointingStrategy). - pub fn builder() -> ComposedCheckpointingStrategyBuilder { - ComposedCheckpointingStrategyBuilder::default() + fn new(strategies: Vec>) -> Self { + Self { + deleted: strategies.iter().map(|_| HashSet::new()).collect(), + strategies, } + } + /// Create a new builder which help compose multiple + /// [checkpointing strategies](CheckpointingStrategy). + pub fn builder() -> ComposedCheckpointingStrategyBuilder { + ComposedCheckpointingStrategyBuilder::default() + } } impl CheckpointingStrategy for ComposedCheckpointingStrategy { - fn checkpointing( - &mut self, - epoch: usize, - collector: &EventStoreClient, - ) -> Vec { - let mut saved = false; - let mut actions = Vec::new(); - let mut epochs_to_check = Vec::new(); - - for (i, strategy) in self.strategies.iter_mut().enumerate() { - let actions = strategy.checkpointing(epoch, collector); - // We assume that the strategy would not want the current epoch to be saved. - // So we flag it as deleted. - if actions.is_empty() { - self.deleted - .get_mut(i) - .expect("As many 'deleted' as 'strategies'.") - .insert(epoch); - } - - for action in actions { - match action { - CheckpointingAction::Delete(epoch) => { - self.deleted - .get_mut(i) - .expect("As many 'deleted' as 'strategies'.") - .insert(epoch); - epochs_to_check.push(epoch); - } - CheckpointingAction::Save => saved = true, - } - } + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec { + let mut saved = false; + let mut actions = Vec::new(); + let mut epochs_to_check = Vec::new(); + + for (i, strategy) in self.strategies.iter_mut().enumerate() { + let actions = strategy.checkpointing(epoch, collector); + // We assume that the strategy would not want the current epoch to be saved. + // So we flag it as deleted. + if actions.is_empty() { + self + .deleted + .get_mut(i) + .expect("As many 'deleted' as 'strategies'.") + .insert(epoch); + } + + for action in actions { + match action { + CheckpointingAction::Delete(epoch) => { + self + .deleted + .get_mut(i) + .expect("As many 'deleted' as 'strategies'.") + .insert(epoch); + epochs_to_check.push(epoch); + } + CheckpointingAction::Save => saved = true, } + } + } - if saved { - actions.push(CheckpointingAction::Save); - } + if saved { + actions.push(CheckpointingAction::Save); + } - for epoch in epochs_to_check.into_iter() { - let mut num_true = 0; - for i in 0..self.strategies.len() { - if self - .deleted - .get(i) - .expect("Ad many 'deleted' as 'strategies'.") - .contains(&epoch) - { - num_true += 1; - } - } - - if num_true == self.strategies.len() { - actions.push(CheckpointingAction::Delete(epoch)); - - for i in 0..self.strategies.len() { - self.deleted - .get_mut(i) - .expect("As many 'deleted' as 'strategies'.") - .remove(&epoch); - } - } + for epoch in epochs_to_check.into_iter() { + let mut num_true = 0; + for i in 0..self.strategies.len() { + if self + .deleted + .get(i) + .expect("Ad many 'deleted' as 'strategies'.") + .contains(&epoch) + { + num_true += 1; } + } + + if num_true == self.strategies.len() { + actions.push(CheckpointingAction::Delete(epoch)); - actions + for i in 0..self.strategies.len() { + self + .deleted + .get_mut(i) + .expect("As many 'deleted' as 'strategies'.") + .remove(&epoch); + } + } } + + actions + } } #[cfg(test)] mod tests { - use super::*; - use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore}; - - #[test] - fn should_delete_when_both_deletes() { - let store = EventStoreClient::new(LogEventStore::default()); - let mut strategy = ComposedCheckpointingStrategy::builder() - .add(KeepLastNCheckpoints::new(1)) - .add(KeepLastNCheckpoints::new(2)) - .build(); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(1, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(2, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], - strategy.checkpointing(3, &store) - ); - } + use super::*; + use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore}; + + #[test] + fn should_delete_when_both_deletes() { + let store = EventStoreClient::new(LogEventStore::default()); + let mut strategy = ComposedCheckpointingStrategy::builder() + .add(KeepLastNCheckpoints::new(1)) + .add(KeepLastNCheckpoints::new(2)) + .build(); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(1, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(2, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], + strategy.checkpointing(3, &store) + ); + } } diff --git a/burn-train/src/checkpoint/strategy/lastn.rs b/burn-train/src/checkpoint/strategy/lastn.rs index 66f5df91bf..f810f0dc8d 100644 --- a/burn-train/src/checkpoint/strategy/lastn.rs +++ b/burn-train/src/checkpoint/strategy/lastn.rs @@ -7,50 +7,46 @@ use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient}; /// resumed even if something goes wrong. #[derive(new)] pub struct KeepLastNCheckpoints { - num_keep: usize, + num_keep: usize, } impl CheckpointingStrategy for KeepLastNCheckpoints { - fn checkpointing( - &mut self, - epoch: usize, - _store: &EventStoreClient, - ) -> Vec { - let mut actions = vec![CheckpointingAction::Save]; - - if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) { - if epoch > 0 { - actions.push(CheckpointingAction::Delete(epoch)); - } - } - - actions + fn checkpointing(&mut self, epoch: usize, _store: &EventStoreClient) -> Vec { + let mut actions = vec![CheckpointingAction::Save]; + + if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) { + if epoch > 0 { + actions.push(CheckpointingAction::Delete(epoch)); + } } + + actions + } } #[cfg(test)] mod tests { - use super::*; - use crate::metric::store::LogEventStore; - - #[test] - fn should_always_delete_lastn_epoch_if_higher_than_one() { - let mut strategy = KeepLastNCheckpoints::new(2); - let store = EventStoreClient::new(LogEventStore::default()); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(1, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(2, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], - strategy.checkpointing(3, &store) - ); - } + use super::*; + use crate::metric::store::LogEventStore; + + #[test] + fn should_always_delete_lastn_epoch_if_higher_than_one() { + let mut strategy = KeepLastNCheckpoints::new(2); + let store = EventStoreClient::new(LogEventStore::default()); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(1, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(2, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], + strategy.checkpointing(3, &store) + ); + } } diff --git a/burn-train/src/checkpoint/strategy/metric.rs b/burn-train/src/checkpoint/strategy/metric.rs index f2aa58efeb..a746241066 100644 --- a/burn-train/src/checkpoint/strategy/metric.rs +++ b/burn-train/src/checkpoint/strategy/metric.rs @@ -1,133 +1,129 @@ use super::CheckpointingStrategy; use crate::{ - checkpoint::CheckpointingAction, - metric::{ - store::{Aggregate, Direction, EventStoreClient, Split}, - Metric, - }, + checkpoint::CheckpointingAction, + metric::{ + store::{Aggregate, Direction, EventStoreClient, Split}, + Metric, + }, }; /// Keep the best checkpoint based on a metric. pub struct MetricCheckpointingStrategy { - current: Option, - aggregate: Aggregate, - direction: Direction, - split: Split, - name: String, + current: Option, + aggregate: Aggregate, + direction: Direction, + split: Split, + name: String, } impl MetricCheckpointingStrategy { - /// Create a new metric strategy. - pub fn new(aggregate: Aggregate, direction: Direction, split: Split) -> Self - where - M: Metric, - { - Self { - current: None, - name: M::NAME.to_string(), - aggregate, - direction, - split, - } + /// Create a new metric strategy. + pub fn new(aggregate: Aggregate, direction: Direction, split: Split) -> Self + where + M: Metric, + { + Self { + current: None, + name: M::NAME.to_string(), + aggregate, + direction, + split, } + } } impl CheckpointingStrategy for MetricCheckpointingStrategy { - fn checkpointing( - &mut self, - epoch: usize, - store: &EventStoreClient, - ) -> Vec { - let best_epoch = - match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) { - Some(epoch_best) => epoch_best, - None => epoch, - }; - - let mut actions = Vec::new(); - - if let Some(current) = self.current { - if current != best_epoch { - actions.push(CheckpointingAction::Delete(current)); - } - } - - if best_epoch == epoch { - actions.push(CheckpointingAction::Save); - } - - self.current = Some(best_epoch); - - actions + fn checkpointing(&mut self, epoch: usize, store: &EventStoreClient) -> Vec { + let best_epoch = match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) + { + Some(epoch_best) => epoch_best, + None => epoch, + }; + + let mut actions = Vec::new(); + + if let Some(current) = self.current { + if current != best_epoch { + actions.push(CheckpointingAction::Delete(current)); + } } + + if best_epoch == epoch { + actions.push(CheckpointingAction::Save); + } + + self.current = Some(best_epoch); + + actions + } } #[cfg(test)] mod tests { - use crate::{ - logger::InMemoryMetricLogger, - metric::{ - processor::{ - test_utils::{end_epoch, process_train}, - Metrics, MinimalEventProcessor, - }, - store::LogEventStore, - LossMetric, - }, - TestBackend, - }; - use std::sync::Arc; - - use super::*; - - #[test] - fn always_keep_the_best_epoch() { - let mut store = LogEventStore::default(); - let mut strategy = MetricCheckpointingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Train, - ); - let mut metrics = Metrics::::default(); - // Register an in memory logger. - store.register_logger_train(InMemoryMetricLogger::default()); - // Register the loss metric. - metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Arc::new(EventStoreClient::new(store)); - let mut processor = MinimalEventProcessor::new(metrics, store.clone()); - - // Two points for the first epoch. Mean 0.75 - let mut epoch = 1; - process_train(&mut processor, 1.0, epoch); - process_train(&mut processor, 0.5, epoch); - end_epoch(&mut processor, epoch); - - // Should save the current record. - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(epoch, &store) - ); - - // Two points for the second epoch. Mean 0.4 - epoch += 1; - process_train(&mut processor, 0.5, epoch); - process_train(&mut processor, 0.3, epoch); - end_epoch(&mut processor, epoch); - - // Should save the current record and delete the pervious one. - assert_eq!( - vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], - strategy.checkpointing(epoch, &store) - ); - - // Two points for the last epoch. Mean 2.0 - epoch += 1; - process_train(&mut processor, 1.0, epoch); - process_train(&mut processor, 3.0, epoch); - end_epoch(&mut processor, epoch); - - // Should not delete the previous record, since it's the best one, and should not save a - // new one. - assert!(strategy.checkpointing(epoch, &store).is_empty()); - } + use crate::{ + logger::InMemoryMetricLogger, + metric::{ + processor::{ + test_utils::{end_epoch, process_train}, + Metrics, MinimalEventProcessor, + }, + store::LogEventStore, + LossMetric, + }, + TestBackend, + }; + use std::sync::Arc; + + use super::*; + + #[test] + fn always_keep_the_best_epoch() { + let mut store = LogEventStore::default(); + let mut strategy = MetricCheckpointingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Train, + ); + let mut metrics = Metrics::::default(); + // Register an in memory logger. + store.register_logger_train(InMemoryMetricLogger::default()); + // Register the loss metric. + metrics.register_train_metric_numeric(LossMetric::::new()); + let store = Arc::new(EventStoreClient::new(store)); + let mut processor = MinimalEventProcessor::new(metrics, store.clone()); + + // Two points for the first epoch. Mean 0.75 + let mut epoch = 1; + process_train(&mut processor, 1.0, epoch); + process_train(&mut processor, 0.5, epoch); + end_epoch(&mut processor, epoch); + + // Should save the current record. + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(epoch, &store) + ); + + // Two points for the second epoch. Mean 0.4 + epoch += 1; + process_train(&mut processor, 0.5, epoch); + process_train(&mut processor, 0.3, epoch); + end_epoch(&mut processor, epoch); + + // Should save the current record and delete the pervious one. + assert_eq!( + vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], + strategy.checkpointing(epoch, &store) + ); + + // Two points for the last epoch. Mean 2.0 + epoch += 1; + process_train(&mut processor, 1.0, epoch); + process_train(&mut processor, 3.0, epoch); + end_epoch(&mut processor, epoch); + + // Should not delete the previous record, since it's the best one, and should not save a + // new one. + assert!(strategy.checkpointing(epoch, &store).is_empty()); + } } diff --git a/burn-train/src/components.rs b/burn-train/src/components.rs index 3eddb93b0f..16cbe9ad80 100644 --- a/burn-train/src/components.rs +++ b/burn-train/src/components.rs @@ -1,71 +1,71 @@ use crate::{ - checkpoint::{Checkpointer, CheckpointingStrategy}, - metric::processor::EventProcessor, + checkpoint::{Checkpointer, CheckpointingStrategy}, + metric::processor::EventProcessor, }; use burn_core::{ - lr_scheduler::LrScheduler, - module::{AutodiffModule, Module}, - optim::Optimizer, - tensor::backend::AutodiffBackend, + lr_scheduler::LrScheduler, + module::{AutodiffModule, Module}, + optim::Optimizer, + tensor::backend::AutodiffBackend, }; use std::marker::PhantomData; /// All components necessary to train a model grouped in one trait. pub trait LearnerComponents { - /// The backend in used for the training. - type Backend: AutodiffBackend; - /// The learning rate scheduler used for the training. - type LrScheduler: LrScheduler; - /// The model to train. - type Model: AutodiffModule + core::fmt::Display + 'static; - /// The optimizer used for the training. - type Optimizer: Optimizer; - /// The checkpointer used for the model. - type CheckpointerModel: Checkpointer<>::Record>; - /// The checkpointer used for the optimizer. - type CheckpointerOptimizer: Checkpointer< - >::Record, - >; - /// The checkpointer used for the scheduler. - type CheckpointerLrScheduler: Checkpointer<::Record>; - type EventProcessor: EventProcessor + 'static; - /// The strategy to save and delete checkpoints. - type CheckpointerStrategy: CheckpointingStrategy; + /// The backend in used for the training. + type Backend: AutodiffBackend; + /// The learning rate scheduler used for the training. + type LrScheduler: LrScheduler; + /// The model to train. + type Model: AutodiffModule + core::fmt::Display + 'static; + /// The optimizer used for the training. + type Optimizer: Optimizer; + /// The checkpointer used for the model. + type CheckpointerModel: Checkpointer<>::Record>; + /// The checkpointer used for the optimizer. + type CheckpointerOptimizer: Checkpointer< + >::Record, + >; + /// The checkpointer used for the scheduler. + type CheckpointerLrScheduler: Checkpointer<::Record>; + type EventProcessor: EventProcessor + 'static; + /// The strategy to save and delete checkpoints. + type CheckpointerStrategy: CheckpointingStrategy; } /// Concrete type that implements [training components trait](TrainingComponents). pub struct LearnerComponentsMarker { - _backend: PhantomData, - _lr_scheduler: PhantomData, - _model: PhantomData, - _optimizer: PhantomData, - _checkpointer_model: PhantomData, - _checkpointer_optim: PhantomData, - _checkpointer_scheduler: PhantomData, - _event_processor: PhantomData, - _strategy: S, + _backend: PhantomData, + _lr_scheduler: PhantomData, + _model: PhantomData, + _optimizer: PhantomData, + _checkpointer_model: PhantomData, + _checkpointer_optim: PhantomData, + _checkpointer_scheduler: PhantomData, + _event_processor: PhantomData, + _strategy: S, } impl LearnerComponents - for LearnerComponentsMarker + for LearnerComponentsMarker where - B: AutodiffBackend, - LR: LrScheduler, - M: AutodiffModule + core::fmt::Display + 'static, - O: Optimizer, - CM: Checkpointer, - CO: Checkpointer, - CS: Checkpointer, - EP: EventProcessor + 'static, - S: CheckpointingStrategy, + B: AutodiffBackend, + LR: LrScheduler, + M: AutodiffModule + core::fmt::Display + 'static, + O: Optimizer, + CM: Checkpointer, + CO: Checkpointer, + CS: Checkpointer, + EP: EventProcessor + 'static, + S: CheckpointingStrategy, { - type Backend = B; - type LrScheduler = LR; - type Model = M; - type Optimizer = O; - type CheckpointerModel = CM; - type CheckpointerOptimizer = CO; - type CheckpointerLrScheduler = CS; - type EventProcessor = EP; - type CheckpointerStrategy = S; + type Backend = B; + type LrScheduler = LR; + type Model = M; + type Optimizer = O; + type CheckpointerModel = CM; + type CheckpointerOptimizer = CO; + type CheckpointerLrScheduler = CS; + type EventProcessor = EP; + type CheckpointerStrategy = S; } diff --git a/burn-train/src/learner/base.rs b/burn-train/src/learner/base.rs index 55a5515ef7..ea6f0ff245 100644 --- a/burn-train/src/learner/base.rs +++ b/burn-train/src/learner/base.rs @@ -13,115 +13,121 @@ use std::sync::Arc; /// /// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct. pub struct Learner { - pub(crate) model: LC::Model, - pub(crate) optim: LC::Optimizer, - pub(crate) lr_scheduler: LC::LrScheduler, - pub(crate) num_epochs: usize, - pub(crate) checkpoint: Option, - pub(crate) grad_accumulation: Option, - pub(crate) checkpointer: Option>, - pub(crate) devices: Vec<::Device>, - pub(crate) interrupter: TrainingInterrupter, - pub(crate) early_stopping: Option>, - pub(crate) event_processor: LC::EventProcessor, - pub(crate) event_store: Arc, + pub(crate) model: LC::Model, + pub(crate) optim: LC::Optimizer, + pub(crate) lr_scheduler: LC::LrScheduler, + pub(crate) num_epochs: usize, + pub(crate) checkpoint: Option, + pub(crate) grad_accumulation: Option, + pub(crate) checkpointer: Option>, + pub(crate) devices: Vec<::Device>, + pub(crate) interrupter: TrainingInterrupter, + pub(crate) early_stopping: Option>, + pub(crate) event_processor: LC::EventProcessor, + pub(crate) event_store: Arc, } #[derive(new)] pub(crate) struct LearnerCheckpointer { - model: LC::CheckpointerModel, - optim: LC::CheckpointerOptimizer, - lr_scheduler: LC::CheckpointerLrScheduler, - strategy: LC::CheckpointerStrategy, + model: LC::CheckpointerModel, + optim: LC::CheckpointerOptimizer, + lr_scheduler: LC::CheckpointerLrScheduler, + strategy: LC::CheckpointerStrategy, } impl LearnerCheckpointer { - pub(crate) fn checkpoint( - &mut self, - model: &LC::Model, - optim: &LC::Optimizer, - scheduler: &LC::LrScheduler, - epoch: usize, - store: &EventStoreClient, - ) { - let actions = self.strategy.checkpointing(epoch, store); + pub(crate) fn checkpoint( + &mut self, + model: &LC::Model, + optim: &LC::Optimizer, + scheduler: &LC::LrScheduler, + epoch: usize, + store: &EventStoreClient, + ) { + let actions = self.strategy.checkpointing(epoch, store); - for action in actions { - match action { - CheckpointingAction::Delete(epoch) => { - self.model - .delete(epoch) - .expect("Can delete model checkpoint."); - self.optim - .delete(epoch) - .expect("Can delete optimizer checkpoint."); - self.lr_scheduler - .delete(epoch) - .expect("Can delete learning rate scheduler checkpoint."); - } - CheckpointingAction::Save => { - self.model - .save(epoch, model.clone().into_record()) - .expect("Can save model checkpoint."); - self.optim - .save(epoch, optim.to_record()) - .expect("Can save optimizer checkpoint."); - self.lr_scheduler - .save(epoch, scheduler.to_record()) - .expect("Can save learning rate scheduler checkpoint."); - } - } + for action in actions { + match action { + CheckpointingAction::Delete(epoch) => { + self + .model + .delete(epoch) + .expect("Can delete model checkpoint."); + self + .optim + .delete(epoch) + .expect("Can delete optimizer checkpoint."); + self + .lr_scheduler + .delete(epoch) + .expect("Can delete learning rate scheduler checkpoint."); + } + CheckpointingAction::Save => { + self + .model + .save(epoch, model.clone().into_record()) + .expect("Can save model checkpoint."); + self + .optim + .save(epoch, optim.to_record()) + .expect("Can save optimizer checkpoint."); + self + .lr_scheduler + .save(epoch, scheduler.to_record()) + .expect("Can save learning rate scheduler checkpoint."); } + } } + } - pub(crate) fn load_checkpoint( - &self, - model: LC::Model, - optim: LC::Optimizer, - scheduler: LC::LrScheduler, - epoch: usize, - ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) { - let record = self - .model - .restore(epoch) - .expect("Can load model checkpoint."); - let model = model.load_record(record); + pub(crate) fn load_checkpoint( + &self, + model: LC::Model, + optim: LC::Optimizer, + scheduler: LC::LrScheduler, + epoch: usize, + ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) { + let record = self + .model + .restore(epoch) + .expect("Can load model checkpoint."); + let model = model.load_record(record); - let record = self - .optim - .restore(epoch) - .expect("Can load optimizer checkpoint."); - let optim = optim.load_record(record); + let record = self + .optim + .restore(epoch) + .expect("Can load optimizer checkpoint."); + let optim = optim.load_record(record); - let record = self - .lr_scheduler - .restore(epoch) - .expect("Can load learning rate scheduler checkpoint."); - let scheduler = scheduler.load_record(record); + let record = self + .lr_scheduler + .restore(epoch) + .expect("Can load learning rate scheduler checkpoint."); + let scheduler = scheduler.load_record(record); - (model, optim, scheduler) - } + (model, optim, scheduler) + } } #[derive(Clone, Default)] /// A handle that allows aborting the training process early. pub struct TrainingInterrupter { - state: Arc, + state: Arc, } impl TrainingInterrupter { - /// Create a new instance. - pub fn new() -> Self { - Self::default() - } + /// Create a new instance. + pub fn new() -> Self { + Self::default() + } - /// Notify the learner that it should stop. - pub fn stop(&self) { - self.state.store(true, Ordering::Relaxed); - } + /// Notify the learner that it should stop. + pub fn stop(&self) { + self.state.store(true, Ordering::Relaxed); + } - /// True if .stop() has been called. - pub fn should_stop(&self) -> bool { - self.state.load(Ordering::Relaxed) - } + /// True if .stop() has been called. + pub fn should_stop(&self) -> bool { + self.state.load(Ordering::Relaxed) + } } diff --git a/burn-train/src/learner/builder.rs b/burn-train/src/learner/builder.rs index 99e6d90bea..b78760eae7 100644 --- a/burn-train/src/learner/builder.rs +++ b/burn-train/src/learner/builder.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use super::log::install_file_logger; use super::Learner; use crate::checkpoint::{ - AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, - KeepLastNCheckpoints, MetricCheckpointingStrategy, + AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, + KeepLastNCheckpoints, MetricCheckpointingStrategy, }; use crate::components::LearnerComponentsMarker; use crate::learner::base::TrainingInterrupter; @@ -24,319 +24,317 @@ use burn_core::tensor::backend::AutodiffBackend; /// Struct to configure and create a [learner](Learner). pub struct LearnerBuilder where - T: Send + Sync + 'static, - V: Send + Sync + 'static, - B: AutodiffBackend, - M: AutodiffModule, - O: Optimizer, - S: LrScheduler, + T: Send + Sync + 'static, + V: Send + Sync + 'static, + B: AutodiffBackend, + M: AutodiffModule, + O: Optimizer, + S: LrScheduler, { - // Not that complex and very convenient when the traits are - // already constrained correctly. Extracting in another type - // would be more complex. - #[allow(clippy::type_complexity)] - checkpointers: Option<( - AsyncCheckpointer, - AsyncCheckpointer, - AsyncCheckpointer, - )>, - num_epochs: usize, - checkpoint: Option, - directory: String, - grad_accumulation: Option, - devices: Vec, - renderer: Option>, - metrics: Metrics, - event_store: LogEventStore, - interrupter: TrainingInterrupter, - log_to_file: bool, - num_loggers: usize, - checkpointer_strategy: Box, - early_stopping: Option>, + // Not that complex and very convenient when the traits are + // already constrained correctly. Extracting in another type + // would be more complex. + #[allow(clippy::type_complexity)] + checkpointers: Option<( + AsyncCheckpointer, + AsyncCheckpointer, + AsyncCheckpointer, + )>, + num_epochs: usize, + checkpoint: Option, + directory: String, + grad_accumulation: Option, + devices: Vec, + renderer: Option>, + metrics: Metrics, + event_store: LogEventStore, + interrupter: TrainingInterrupter, + log_to_file: bool, + num_loggers: usize, + checkpointer_strategy: Box, + early_stopping: Option>, } impl LearnerBuilder where - B: AutodiffBackend, - T: Send + Sync + 'static, - V: Send + Sync + 'static, - M: AutodiffModule + core::fmt::Display + 'static, - O: Optimizer, - S: LrScheduler, + B: AutodiffBackend, + T: Send + Sync + 'static, + V: Send + Sync + 'static, + M: AutodiffModule + core::fmt::Display + 'static, + O: Optimizer, + S: LrScheduler, { - /// Creates a new learner builder. - /// - /// # Arguments - /// - /// * `directory` - The directory to save the checkpoints. - pub fn new(directory: &str) -> Self { - Self { - num_epochs: 1, - checkpoint: None, - checkpointers: None, - directory: directory.to_string(), - grad_accumulation: None, - devices: vec![B::Device::default()], - metrics: Metrics::default(), - event_store: LogEventStore::default(), - renderer: None, - interrupter: TrainingInterrupter::new(), - log_to_file: true, - num_loggers: 0, - checkpointer_strategy: Box::new( - ComposedCheckpointingStrategy::builder() - .add(KeepLastNCheckpoints::new(2)) - .add(MetricCheckpointingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Valid, - )) - .build(), - ), - early_stopping: None, - } + /// Creates a new learner builder. + /// + /// # Arguments + /// + /// * `directory` - The directory to save the checkpoints. + pub fn new(directory: &str) -> Self { + Self { + num_epochs: 1, + checkpoint: None, + checkpointers: None, + directory: directory.to_string(), + grad_accumulation: None, + devices: vec![B::Device::default()], + metrics: Metrics::default(), + event_store: LogEventStore::default(), + renderer: None, + interrupter: TrainingInterrupter::new(), + log_to_file: true, + num_loggers: 0, + checkpointer_strategy: Box::new( + ComposedCheckpointingStrategy::builder() + .add(KeepLastNCheckpoints::new(2)) + .add(MetricCheckpointingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + )) + .build(), + ), + early_stopping: None, } + } - /// Replace the default metric loggers with the provided ones. - /// - /// # Arguments - /// - /// * `logger_train` - The training logger. - /// * `logger_valid` - The validation logger. - pub fn metric_loggers(mut self, logger_train: MT, logger_valid: MV) -> Self - where - MT: MetricLogger + 'static, - MV: MetricLogger + 'static, - { - self.event_store.register_logger_train(logger_train); - self.event_store.register_logger_valid(logger_valid); - self.num_loggers += 1; - self - } + /// Replace the default metric loggers with the provided ones. + /// + /// # Arguments + /// + /// * `logger_train` - The training logger. + /// * `logger_valid` - The validation logger. + pub fn metric_loggers(mut self, logger_train: MT, logger_valid: MV) -> Self + where + MT: MetricLogger + 'static, + MV: MetricLogger + 'static, + { + self.event_store.register_logger_train(logger_train); + self.event_store.register_logger_valid(logger_valid); + self.num_loggers += 1; + self + } - /// Update the checkpointing_strategy. - pub fn with_checkpointing_strategy(&mut self, strategy: CS) - where - CS: CheckpointingStrategy + 'static, - { - self.checkpointer_strategy = Box::new(strategy); - } + /// Update the checkpointing_strategy. + pub fn with_checkpointing_strategy(&mut self, strategy: CS) + where + CS: CheckpointingStrategy + 'static, + { + self.checkpointer_strategy = Box::new(strategy); + } - /// Replace the default CLI renderer with a custom one. - /// - /// # Arguments - /// - /// * `renderer` - The custom renderer. - pub fn renderer(mut self, renderer: MR) -> Self - where - MR: MetricsRenderer + 'static, - { - self.renderer = Some(Box::new(renderer)); - self - } + /// Replace the default CLI renderer with a custom one. + /// + /// # Arguments + /// + /// * `renderer` - The custom renderer. + pub fn renderer(mut self, renderer: MR) -> Self + where + MR: MetricsRenderer + 'static, + { + self.renderer = Some(Box::new(renderer)); + self + } - /// Register a training metric. - pub fn metric_train(mut self, metric: Me) -> Self - where - T: Adaptor, - { - self.metrics.register_metric_train(metric); - self - } + /// Register a training metric. + pub fn metric_train(mut self, metric: Me) -> Self + where + T: Adaptor, + { + self.metrics.register_metric_train(metric); + self + } - /// Register a validation metric. - pub fn metric_valid(mut self, metric: Me) -> Self - where - V: Adaptor, - { - self.metrics.register_valid_metric(metric); - self - } + /// Register a validation metric. + pub fn metric_valid(mut self, metric: Me) -> Self + where + V: Adaptor, + { + self.metrics.register_valid_metric(metric); + self + } - /// Enable gradients accumulation. - /// - /// # Notes - /// - /// When you enable gradients accumulation, the gradients object used by the optimizer will be - /// the sum of all gradients generated by each backward pass. It might be a good idea to - /// reduce the learning to compensate. - /// - /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation` - /// amount. - pub fn grads_accumulation(mut self, accumulation: usize) -> Self { - self.grad_accumulation = Some(accumulation); - self - } + /// Enable gradients accumulation. + /// + /// # Notes + /// + /// When you enable gradients accumulation, the gradients object used by the optimizer will be + /// the sum of all gradients generated by each backward pass. It might be a good idea to + /// reduce the learning to compensate. + /// + /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation` + /// amount. + pub fn grads_accumulation(mut self, accumulation: usize) -> Self { + self.grad_accumulation = Some(accumulation); + self + } - /// Register a [numeric](crate::metric::Numeric) training [metric](Metric). - pub fn metric_train_numeric(mut self, metric: Me) -> Self - where - Me: Metric + crate::metric::Numeric + 'static, - T: Adaptor, - { - self.metrics.register_train_metric_numeric(metric); - self - } + /// Register a [numeric](crate::metric::Numeric) training [metric](Metric). + pub fn metric_train_numeric(mut self, metric: Me) -> Self + where + Me: Metric + crate::metric::Numeric + 'static, + T: Adaptor, + { + self.metrics.register_train_metric_numeric(metric); + self + } - /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric). - pub fn metric_valid_numeric( - mut self, - metric: Me, - ) -> Self - where - V: Adaptor, - { - self.metrics.register_valid_metric_numeric(metric); - self - } + /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric). + pub fn metric_valid_numeric( + mut self, + metric: Me, + ) -> Self + where + V: Adaptor, + { + self.metrics.register_valid_metric_numeric(metric); + self + } - /// The number of epochs the training should last. - pub fn num_epochs(mut self, num_epochs: usize) -> Self { - self.num_epochs = num_epochs; - self - } + /// The number of epochs the training should last. + pub fn num_epochs(mut self, num_epochs: usize) -> Self { + self.num_epochs = num_epochs; + self + } - /// Run the training loop on multiple devices. - pub fn devices(mut self, devices: Vec) -> Self { - self.devices = devices; - self - } + /// Run the training loop on multiple devices. + pub fn devices(mut self, devices: Vec) -> Self { + self.devices = devices; + self + } - /// The epoch from which the training must resume. - pub fn checkpoint(mut self, checkpoint: usize) -> Self { - self.checkpoint = Some(checkpoint); - self - } + /// The epoch from which the training must resume. + pub fn checkpoint(mut self, checkpoint: usize) -> Self { + self.checkpoint = Some(checkpoint); + self + } - /// Provides a handle that can be used to interrupt training. - pub fn interrupter(&self) -> TrainingInterrupter { - self.interrupter.clone() - } + /// Provides a handle that can be used to interrupt training. + pub fn interrupter(&self) -> TrainingInterrupter { + self.interrupter.clone() + } - /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the - /// conditions are meet. - pub fn early_stopping(mut self, strategy: Strategy) -> Self - where - Strategy: EarlyStoppingStrategy + 'static, - { - self.early_stopping = Some(Box::new(strategy)); - self - } + /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the + /// conditions are meet. + pub fn early_stopping(mut self, strategy: Strategy) -> Self + where + Strategy: EarlyStoppingStrategy + 'static, + { + self.early_stopping = Some(Box::new(strategy)); + self + } - /// By default, Rust logs are captured and written into - /// `experiment.log`. If disabled, standard Rust log handling - /// will apply. - pub fn log_to_file(mut self, enabled: bool) -> Self { - self.log_to_file = enabled; - self - } + /// By default, Rust logs are captured and written into + /// `experiment.log`. If disabled, standard Rust log handling + /// will apply. + pub fn log_to_file(mut self, enabled: bool) -> Self { + self.log_to_file = enabled; + self + } - /// Register a checkpointer that will save the [optimizer](Optimizer), the - /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files. - pub fn with_file_checkpointer(mut self, recorder: FR) -> Self - where - FR: FileRecorder + 'static, - O::Record: 'static, - M::Record: 'static, - S::Record: 'static, - { - let checkpointer_model = FileCheckpointer::new( - recorder.clone(), - format!("{}/checkpoint", self.directory).as_str(), - "model", - ); - let checkpointer_optimizer = FileCheckpointer::new( - recorder.clone(), - format!("{}/checkpoint", self.directory).as_str(), - "optim", - ); - let checkpointer_scheduler = FileCheckpointer::new( - recorder, - format!("{}/checkpoint", self.directory).as_str(), - "scheduler", - ); + /// Register a checkpointer that will save the [optimizer](Optimizer), the + /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files. + pub fn with_file_checkpointer(mut self, recorder: FR) -> Self + where + FR: FileRecorder + 'static, + O::Record: 'static, + M::Record: 'static, + S::Record: 'static, + { + let checkpointer_model = FileCheckpointer::new( + recorder.clone(), + format!("{}/checkpoint", self.directory).as_str(), + "model", + ); + let checkpointer_optimizer = FileCheckpointer::new( + recorder.clone(), + format!("{}/checkpoint", self.directory).as_str(), + "optim", + ); + let checkpointer_scheduler = FileCheckpointer::new( + recorder, + format!("{}/checkpoint", self.directory).as_str(), + "scheduler", + ); - self.checkpointers = Some(( - AsyncCheckpointer::new(checkpointer_model), - AsyncCheckpointer::new(checkpointer_optimizer), - AsyncCheckpointer::new(checkpointer_scheduler), - )); + self.checkpointers = Some(( + AsyncCheckpointer::new(checkpointer_model), + AsyncCheckpointer::new(checkpointer_optimizer), + AsyncCheckpointer::new(checkpointer_scheduler), + )); - self - } + self + } - /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer). - /// The [learning rate scheduler](LrScheduler) can also be a simple - /// [learning rate](burn_core::LearningRate). - #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and - // creates a clean learner. - pub fn build( - mut self, - model: M, - optim: O, - lr_scheduler: S, - ) -> Learner< - LearnerComponentsMarker< - B, - S, - M, - O, - AsyncCheckpointer, - AsyncCheckpointer, - AsyncCheckpointer, - FullEventProcessor, - Box, - >, - > - where - M::Record: 'static, - O::Record: 'static, - S::Record: 'static, - { - if self.log_to_file { - self.init_logger(); - } - let renderer = self.renderer.unwrap_or_else(|| { - Box::new(default_renderer(self.interrupter.clone(), self.checkpoint)) - }); - let directory = &self.directory; + /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer). + /// The [learning rate scheduler](LrScheduler) can also be a simple + /// [learning rate](burn_core::LearningRate). + #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and + // creates a clean learner. + pub fn build( + mut self, + model: M, + optim: O, + lr_scheduler: S, + ) -> Learner< + LearnerComponentsMarker< + B, + S, + M, + O, + AsyncCheckpointer, + AsyncCheckpointer, + AsyncCheckpointer, + FullEventProcessor, + Box, + >, + > + where + M::Record: 'static, + O::Record: 'static, + S::Record: 'static, + { + if self.log_to_file { + self.init_logger(); + } + let renderer = self + .renderer + .unwrap_or_else(|| Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))); + let directory = &self.directory; - if self.num_loggers == 0 { - self.event_store - .register_logger_train(FileMetricLogger::new( - format!("{directory}/train").as_str(), - )); - self.event_store - .register_logger_valid(FileMetricLogger::new( - format!("{directory}/valid").as_str(), - )); - } + if self.num_loggers == 0 { + self + .event_store + .register_logger_train(FileMetricLogger::new(format!("{directory}/train").as_str())); + self + .event_store + .register_logger_valid(FileMetricLogger::new(format!("{directory}/valid").as_str())); + } - let event_store = Arc::new(EventStoreClient::new(self.event_store)); - let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); + let event_store = Arc::new(EventStoreClient::new(self.event_store)); + let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); - let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { - LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy) - }); + let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { + LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy) + }); - Learner { - model, - optim, - lr_scheduler, - checkpointer, - num_epochs: self.num_epochs, - event_processor, - event_store, - checkpoint: self.checkpoint, - grad_accumulation: self.grad_accumulation, - devices: self.devices, - interrupter: self.interrupter, - early_stopping: self.early_stopping, - } + Learner { + model, + optim, + lr_scheduler, + checkpointer, + num_epochs: self.num_epochs, + event_processor, + event_store, + checkpoint: self.checkpoint, + grad_accumulation: self.grad_accumulation, + devices: self.devices, + interrupter: self.interrupter, + early_stopping: self.early_stopping, } + } - fn init_logger(&self) { - let file_path = format!("{}/experiment.log", self.directory); - install_file_logger(file_path.as_str()); - } + fn init_logger(&self) { + let file_path = format!("{}/experiment.log", self.directory); + install_file_logger(file_path.as_str()); + } } diff --git a/burn-train/src/learner/classification.rs b/burn-train/src/learner/classification.rs index f6b415fa29..ecba5bed3e 100644 --- a/burn-train/src/learner/classification.rs +++ b/burn-train/src/learner/classification.rs @@ -5,24 +5,24 @@ use burn_core::tensor::{Int, Tensor}; /// Simple classification output adapted for multiple metrics. #[derive(new)] pub struct ClassificationOutput { - /// The loss. - pub loss: Tensor, + /// The loss. + pub loss: Tensor, - /// The output. - pub output: Tensor, + /// The output. + pub output: Tensor, - /// The targets. - pub targets: Tensor, + /// The targets. + pub targets: Tensor, } impl Adaptor> for ClassificationOutput { - fn adapt(&self) -> AccuracyInput { - AccuracyInput::new(self.output.clone(), self.targets.clone()) - } + fn adapt(&self) -> AccuracyInput { + AccuracyInput::new(self.output.clone(), self.targets.clone()) + } } impl Adaptor> for ClassificationOutput { - fn adapt(&self) -> LossInput { - LossInput::new(self.loss.clone()) - } + fn adapt(&self) -> LossInput { + LossInput::new(self.loss.clone()) + } } diff --git a/burn-train/src/learner/early_stopping.rs b/burn-train/src/learner/early_stopping.rs index 641d49551b..f8c0f0c5a2 100644 --- a/burn-train/src/learner/early_stopping.rs +++ b/burn-train/src/learner/early_stopping.rs @@ -1,209 +1,209 @@ use crate::metric::{ - store::{Aggregate, Direction, EventStoreClient, Split}, - Metric, + store::{Aggregate, Direction, EventStoreClient, Split}, + Metric, }; /// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow. pub enum StoppingCondition { - /// When no improvement has happened since the given number of epochs. - NoImprovementSince { - /// The number of epochs allowed to worsen before it gets better. - n_epochs: usize, - }, + /// When no improvement has happened since the given number of epochs. + NoImprovementSince { + /// The number of epochs allowed to worsen before it gets better. + n_epochs: usize, + }, } /// A strategy that checks if the training should be stopped. pub trait EarlyStoppingStrategy { - /// Update its current state and returns if the training should be stopped. - fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool; + /// Update its current state and returns if the training should be stopped. + fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool; } /// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected /// during training or validation. pub struct MetricEarlyStoppingStrategy { - condition: StoppingCondition, - metric_name: String, - aggregate: Aggregate, - direction: Direction, - split: Split, - best_epoch: usize, - best_value: f64, + condition: StoppingCondition, + metric_name: String, + aggregate: Aggregate, + direction: Direction, + split: Split, + best_epoch: usize, + best_value: f64, } impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy { - fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool { - let current_value = - match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) { - Some(value) => value, - None => { - log::warn!("Can't find metric for early stopping."); - return false; - } - }; - - let is_best = match self.direction { - Direction::Lowest => current_value < self.best_value, - Direction::Highest => current_value > self.best_value, - }; - - if is_best { - log::info!( - "New best epoch found {} {}: {}", - epoch, - self.metric_name, - current_value - ); - self.best_value = current_value; - self.best_epoch = epoch; - return false; + fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool { + let current_value = + match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) { + Some(value) => value, + None => { + log::warn!("Can't find metric for early stopping."); + return false; } + }; - match self.condition { - StoppingCondition::NoImprovementSince { n_epochs } => { - let should_stop = epoch - self.best_epoch >= n_epochs; + let is_best = match self.direction { + Direction::Lowest => current_value < self.best_value, + Direction::Highest => current_value > self.best_value, + }; - if should_stop { - log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value); - } + if is_best { + log::info!( + "New best epoch found {} {}: {}", + epoch, + self.metric_name, + current_value + ); + self.best_value = current_value; + self.best_epoch = epoch; + return false; + } - should_stop - } + match self.condition { + StoppingCondition::NoImprovementSince { n_epochs } => { + let should_stop = epoch - self.best_epoch >= n_epochs; + + if should_stop { + log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value); } + + should_stop + } } + } } impl MetricEarlyStoppingStrategy { - /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected - /// during training or validation. - /// - /// # Notes - /// - /// The metric should be registered for early stopping to work, otherwise no data is collected. - pub fn new( - aggregate: Aggregate, - direction: Direction, - split: Split, - condition: StoppingCondition, - ) -> Self { - let init_value = match direction { - Direction::Lowest => f64::MAX, - Direction::Highest => f64::MIN, - }; - - Self { - metric_name: Me::NAME.to_string(), - condition, - aggregate, - direction, - split, - best_epoch: 1, - best_value: init_value, - } + /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected + /// during training or validation. + /// + /// # Notes + /// + /// The metric should be registered for early stopping to work, otherwise no data is collected. + pub fn new( + aggregate: Aggregate, + direction: Direction, + split: Split, + condition: StoppingCondition, + ) -> Self { + let init_value = match direction { + Direction::Lowest => f64::MAX, + Direction::Highest => f64::MIN, + }; + + Self { + metric_name: Me::NAME.to_string(), + condition, + aggregate, + direction, + split, + best_epoch: 1, + best_value: init_value, } + } } #[cfg(test)] mod tests { - use std::sync::Arc; - - use crate::{ - logger::InMemoryMetricLogger, - metric::{ - processor::{ - test_utils::{end_epoch, process_train}, - Metrics, MinimalEventProcessor, - }, - store::LogEventStore, - LossMetric, - }, - TestBackend, - }; - - use super::*; - - #[test] - fn never_early_stop_while_it_is_improving() { - test_early_stopping( - 1, - &[ - (&[0.5, 0.3], false, "Should not stop first epoch"), - (&[0.4, 0.3], false, "Should not stop when improving"), - (&[0.3, 0.3], false, "Should not stop when improving"), - (&[0.2, 0.3], false, "Should not stop when improving"), - ], - ); - } - - #[test] - fn early_stop_when_no_improvement_since_two_epochs() { - test_early_stopping( - 2, - &[ - (&[1.0, 0.5], false, "Should not stop first epoch"), - (&[0.5, 0.3], false, "Should not stop when improving"), - ( - &[1.0, 3.0], - false, - "Should not stop first time it gets worse", - ), - ( - &[1.0, 2.0], - true, - "Should stop since two following epochs didn't improve", - ), - ], - ); - } - - #[test] - fn early_stop_when_stays_equal() { - test_early_stopping( - 2, - &[ - (&[0.5, 0.3], false, "Should not stop first epoch"), - ( - &[0.5, 0.3], - false, - "Should not stop first time it stars the same", - ), - ( - &[0.5, 0.3], - true, - "Should stop since two following epochs didn't improve", - ), - ], - ); - } - - fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) { - let mut early_stopping = MetricEarlyStoppingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Train, - StoppingCondition::NoImprovementSince { n_epochs }, - ); - let mut store = LogEventStore::default(); - let mut metrics = Metrics::::default(); - - store.register_logger_train(InMemoryMetricLogger::default()); - metrics.register_train_metric_numeric(LossMetric::::new()); - - let store = Arc::new(EventStoreClient::new(store)); - let mut processor = MinimalEventProcessor::new(metrics, store.clone()); - - let mut epoch = 1; - for (points, should_start, comment) in data { - for point in points.iter() { - process_train(&mut processor, *point, epoch); - } - end_epoch(&mut processor, epoch); - - assert_eq!( - *should_start, - early_stopping.should_stop(epoch, &store), - "{comment}" - ); - epoch += 1; - } + use std::sync::Arc; + + use crate::{ + logger::InMemoryMetricLogger, + metric::{ + processor::{ + test_utils::{end_epoch, process_train}, + Metrics, MinimalEventProcessor, + }, + store::LogEventStore, + LossMetric, + }, + TestBackend, + }; + + use super::*; + + #[test] + fn never_early_stop_while_it_is_improving() { + test_early_stopping( + 1, + &[ + (&[0.5, 0.3], false, "Should not stop first epoch"), + (&[0.4, 0.3], false, "Should not stop when improving"), + (&[0.3, 0.3], false, "Should not stop when improving"), + (&[0.2, 0.3], false, "Should not stop when improving"), + ], + ); + } + + #[test] + fn early_stop_when_no_improvement_since_two_epochs() { + test_early_stopping( + 2, + &[ + (&[1.0, 0.5], false, "Should not stop first epoch"), + (&[0.5, 0.3], false, "Should not stop when improving"), + ( + &[1.0, 3.0], + false, + "Should not stop first time it gets worse", + ), + ( + &[1.0, 2.0], + true, + "Should stop since two following epochs didn't improve", + ), + ], + ); + } + + #[test] + fn early_stop_when_stays_equal() { + test_early_stopping( + 2, + &[ + (&[0.5, 0.3], false, "Should not stop first epoch"), + ( + &[0.5, 0.3], + false, + "Should not stop first time it stars the same", + ), + ( + &[0.5, 0.3], + true, + "Should stop since two following epochs didn't improve", + ), + ], + ); + } + + fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) { + let mut early_stopping = MetricEarlyStoppingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Train, + StoppingCondition::NoImprovementSince { n_epochs }, + ); + let mut store = LogEventStore::default(); + let mut metrics = Metrics::::default(); + + store.register_logger_train(InMemoryMetricLogger::default()); + metrics.register_train_metric_numeric(LossMetric::::new()); + + let store = Arc::new(EventStoreClient::new(store)); + let mut processor = MinimalEventProcessor::new(metrics, store.clone()); + + let mut epoch = 1; + for (points, should_start, comment) in data { + for point in points.iter() { + process_train(&mut processor, *point, epoch); + } + end_epoch(&mut processor, epoch); + + assert_eq!( + *should_start, + early_stopping.should_stop(epoch, &store), + "{comment}" + ); + epoch += 1; } + } } diff --git a/burn-train/src/learner/epoch.rs b/burn-train/src/learner/epoch.rs index 06eba7b7d9..eca2ff0b49 100644 --- a/burn-train/src/learner/epoch.rs +++ b/burn-train/src/learner/epoch.rs @@ -1,6 +1,6 @@ use burn_core::{ - data::dataloader::DataLoader, lr_scheduler::LrScheduler, module::AutodiffModule, - optim::GradientsAccumulator, tensor::backend::Backend, + data::dataloader::DataLoader, lr_scheduler::LrScheduler, module::AutodiffModule, + optim::GradientsAccumulator, tensor::backend::Backend, }; use std::sync::Arc; @@ -11,237 +11,237 @@ use crate::{MultiDevicesTrainStep, TrainStep, ValidStep}; /// A validation epoch. #[derive(new)] pub struct ValidEpoch { - dataloader: Arc>, - epoch: usize, - epoch_total: usize, + dataloader: Arc>, + epoch: usize, + epoch_total: usize, } /// A training epoch. #[derive(new)] pub struct TrainEpoch { - dataloader: Arc>, - epoch: usize, - epoch_total: usize, - grad_accumulation: Option, + dataloader: Arc>, + epoch: usize, + epoch_total: usize, + grad_accumulation: Option, } impl ValidEpoch { - /// Runs the validation epoch. - /// - /// # Arguments - /// - /// * `model` - The model to validate. - /// * `processor` - The event processor to use. - pub fn run( - &self, - model: &LC::Model, - processor: &mut LC::EventProcessor, - interrupter: &TrainingInterrupter, - ) where - LC::EventProcessor: EventProcessor, - >::InnerModule: ValidStep, - { - log::info!("Executing validation step for epoch {}", self.epoch); - let model = model.valid(); - - let mut iterator = self.dataloader.iter(); - let mut iteration = 0; - - while let Some(item) = iterator.next() { - let progress = iterator.progress(); - iteration += 1; - - let item = model.step(item); - let item = LearnerItem::new( - item, - progress, - self.epoch, - self.epoch_total, - iteration, - None, - ); - - processor.process_valid(Event::ProcessedItem(item)); - - if interrupter.should_stop() { - log::info!("Training interrupted."); - break; - } - } - processor.process_valid(Event::EndEpoch(self.epoch)); + /// Runs the validation epoch. + /// + /// # Arguments + /// + /// * `model` - The model to validate. + /// * `processor` - The event processor to use. + pub fn run( + &self, + model: &LC::Model, + processor: &mut LC::EventProcessor, + interrupter: &TrainingInterrupter, + ) where + LC::EventProcessor: EventProcessor, + >::InnerModule: ValidStep, + { + log::info!("Executing validation step for epoch {}", self.epoch); + let model = model.valid(); + + let mut iterator = self.dataloader.iter(); + let mut iteration = 0; + + while let Some(item) = iterator.next() { + let progress = iterator.progress(); + iteration += 1; + + let item = model.step(item); + let item = LearnerItem::new( + item, + progress, + self.epoch, + self.epoch_total, + iteration, + None, + ); + + processor.process_valid(Event::ProcessedItem(item)); + + if interrupter.should_stop() { + log::info!("Training interrupted."); + break; + } } + processor.process_valid(Event::EndEpoch(self.epoch)); + } } impl TrainEpoch { - /// Runs the training epoch. - /// - /// # Arguments - /// - /// * `model` - The model to train. - /// * `optim` - The optimizer to use. - /// * `scheduler` - The learning rate scheduler to use. - /// * `processor` - The event processor to use. - /// - /// # Returns - /// - /// The trained model and the optimizer. - pub fn run( - &self, - mut model: LC::Model, - mut optim: LC::Optimizer, - scheduler: &mut LC::LrScheduler, - processor: &mut LC::EventProcessor, - interrupter: &TrainingInterrupter, - ) -> (LC::Model, LC::Optimizer) - where - LC::EventProcessor: EventProcessor, - LC::Model: TrainStep, - { - log::info!("Executing training step for epoch {}", self.epoch,); - - let mut iterator = self.dataloader.iter(); - let mut iteration = 0; - let mut accumulator = GradientsAccumulator::new(); - let mut accumulation_current = 0; - - while let Some(item) = iterator.next() { - iteration += 1; - let lr = scheduler.step(); - log::info!("Iteration {}", iteration); - - let progress = iterator.progress(); - let item = model.step(item); - - match self.grad_accumulation { - Some(accumulation) => { - accumulator.accumulate(&model, item.grads); - accumulation_current += 1; - - if accumulation <= accumulation_current { - let grads = accumulator.grads(); - model = model.optimize(&mut optim, lr, grads); - accumulation_current = 0; - } - } - None => model = model.optimize(&mut optim, lr, item.grads), - } - - let item = LearnerItem::new( - item.item, - progress, - self.epoch, - self.epoch_total, - iteration, - Some(lr), - ); - - processor.process_train(Event::ProcessedItem(item)); - - if interrupter.should_stop() { - log::info!("Training interrupted."); - break; - } + /// Runs the training epoch. + /// + /// # Arguments + /// + /// * `model` - The model to train. + /// * `optim` - The optimizer to use. + /// * `scheduler` - The learning rate scheduler to use. + /// * `processor` - The event processor to use. + /// + /// # Returns + /// + /// The trained model and the optimizer. + pub fn run( + &self, + mut model: LC::Model, + mut optim: LC::Optimizer, + scheduler: &mut LC::LrScheduler, + processor: &mut LC::EventProcessor, + interrupter: &TrainingInterrupter, + ) -> (LC::Model, LC::Optimizer) + where + LC::EventProcessor: EventProcessor, + LC::Model: TrainStep, + { + log::info!("Executing training step for epoch {}", self.epoch,); + + let mut iterator = self.dataloader.iter(); + let mut iteration = 0; + let mut accumulator = GradientsAccumulator::new(); + let mut accumulation_current = 0; + + while let Some(item) = iterator.next() { + iteration += 1; + let lr = scheduler.step(); + log::info!("Iteration {}", iteration); + + let progress = iterator.progress(); + let item = model.step(item); + + match self.grad_accumulation { + Some(accumulation) => { + accumulator.accumulate(&model, item.grads); + accumulation_current += 1; + + if accumulation <= accumulation_current { + let grads = accumulator.grads(); + model = model.optimize(&mut optim, lr, grads); + accumulation_current = 0; + } } - processor.process_train(Event::EndEpoch(self.epoch)); - - (model, optim) + None => model = model.optimize(&mut optim, lr, item.grads), + } + + let item = LearnerItem::new( + item.item, + progress, + self.epoch, + self.epoch_total, + iteration, + Some(lr), + ); + + processor.process_train(Event::ProcessedItem(item)); + + if interrupter.should_stop() { + log::info!("Training interrupted."); + break; + } } + processor.process_train(Event::EndEpoch(self.epoch)); + + (model, optim) + } } impl TrainEpoch { - /// Runs the training epoch on multiple devices. - /// - /// # Arguments - /// - /// * `model` - The model to train. - /// * `optim` - The optimizer to use. - /// * `lr_scheduler` - The learning rate scheduler to use. - /// * `processor` - The event processor to use. - /// * `devices` - The devices to use. - /// - /// # Returns - /// - /// The trained model and the optimizer. - pub fn run_multi_device( - &self, - mut model: LC::Model, - mut optim: LC::Optimizer, - lr_scheduler: &mut LC::LrScheduler, - processor: &mut LC::EventProcessor, - devices: Vec<::Device>, - interrupter: &TrainingInterrupter, - ) -> (LC::Model, LC::Optimizer) - where - LC::EventProcessor: EventProcessor, - LC::Model: TrainStep, - TO: Send + 'static, - TI: Send + 'static, - { - log::info!( - "Executing training step for epoch {} on devices {:?}", - self.epoch, - devices + /// Runs the training epoch on multiple devices. + /// + /// # Arguments + /// + /// * `model` - The model to train. + /// * `optim` - The optimizer to use. + /// * `lr_scheduler` - The learning rate scheduler to use. + /// * `processor` - The event processor to use. + /// * `devices` - The devices to use. + /// + /// # Returns + /// + /// The trained model and the optimizer. + pub fn run_multi_device( + &self, + mut model: LC::Model, + mut optim: LC::Optimizer, + lr_scheduler: &mut LC::LrScheduler, + processor: &mut LC::EventProcessor, + devices: Vec<::Device>, + interrupter: &TrainingInterrupter, + ) -> (LC::Model, LC::Optimizer) + where + LC::EventProcessor: EventProcessor, + LC::Model: TrainStep, + TO: Send + 'static, + TI: Send + 'static, + { + log::info!( + "Executing training step for epoch {} on devices {:?}", + self.epoch, + devices + ); + + let mut iterator = self.dataloader.iter(); + let mut iteration = 0; + let mut accumulator = GradientsAccumulator::new(); + let mut accumulation_current = 0; + + let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len(); + let step = MultiDevicesTrainStep::new(&devices); + + // The main device is always the first in the list. + let device_main = devices.get(0).expect("A minimum of one device.").clone(); + let mut interrupted = false; + + loop { + let items = step.step(&mut iterator, &model); + if items.is_empty() { + break; + } + + for item in items { + iteration += 1; + let lr = lr_scheduler.step(); + let progress = iterator.progress(); + + let grads = item.grads.to_device(&device_main, &model); + + accumulator.accumulate(&model, grads); + accumulation_current += 1; + + if accumulation <= accumulation_current { + let grads = accumulator.grads(); + model = model.optimize(&mut optim, lr, grads); + accumulation_current = 0; + } + + let item = LearnerItem::new( + item.item, + progress, + self.epoch, + self.epoch_total, + iteration, + Some(lr), ); - let mut iterator = self.dataloader.iter(); - let mut iteration = 0; - let mut accumulator = GradientsAccumulator::new(); - let mut accumulation_current = 0; - - let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len(); - let step = MultiDevicesTrainStep::new(&devices); - - // The main device is always the first in the list. - let device_main = devices.get(0).expect("A minimum of one device.").clone(); - let mut interrupted = false; - - loop { - let items = step.step(&mut iterator, &model); - if items.is_empty() { - break; - } - - for item in items { - iteration += 1; - let lr = lr_scheduler.step(); - let progress = iterator.progress(); - - let grads = item.grads.to_device(&device_main, &model); - - accumulator.accumulate(&model, grads); - accumulation_current += 1; - - if accumulation <= accumulation_current { - let grads = accumulator.grads(); - model = model.optimize(&mut optim, lr, grads); - accumulation_current = 0; - } - - let item = LearnerItem::new( - item.item, - progress, - self.epoch, - self.epoch_total, - iteration, - Some(lr), - ); - - processor.process_train(Event::ProcessedItem(item)); - - if interrupter.should_stop() { - log::info!("Training interrupted."); - interrupted = true; - break; - } - } - - if interrupted { - break; - } - } + processor.process_train(Event::ProcessedItem(item)); - processor.process_train(Event::EndEpoch(self.epoch)); + if interrupter.should_stop() { + log::info!("Training interrupted."); + interrupted = true; + break; + } + } - (model, optim) + if interrupted { + break; + } } + + processor.process_train(Event::EndEpoch(self.epoch)); + + (model, optim) + } } diff --git a/burn-train/src/learner/log.rs b/burn-train/src/learner/log.rs index 35162cfc73..e8e1025a02 100644 --- a/burn-train/src/learner/log.rs +++ b/burn-train/src/learner/log.rs @@ -7,40 +7,41 @@ use tracing_subscriber::{registry, Layer}; /// If a global tracing subscriber is not already configured, set up logging to a file, /// and add our custom panic hook. pub(crate) fn install_file_logger(file_path: &str) { - let path = Path::new(file_path); - let writer = tracing_appender::rolling::never( - path.parent().unwrap_or_else(|| Path::new(".")), - path.file_name() - .unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")), - ); - let layer = tracing_subscriber::fmt::layer() - .with_ansi(false) - .with_writer(writer) - .with_filter(LevelFilter::INFO) - .with_filter(filter_fn(|m| { - if let Some(path) = m.module_path() { - // The wgpu crate is logging too much, so we skip `info` level. - if path.starts_with("wgpu") && *m.level() >= Level::INFO { - return false; - } - } - true - })); + let path = Path::new(file_path); + let writer = tracing_appender::rolling::never( + path.parent().unwrap_or_else(|| Path::new(".")), + path + .file_name() + .unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")), + ); + let layer = tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_writer(writer) + .with_filter(LevelFilter::INFO) + .with_filter(filter_fn(|m| { + if let Some(path) = m.module_path() { + // The wgpu crate is logging too much, so we skip `info` level. + if path.starts_with("wgpu") && *m.level() >= Level::INFO { + return false; + } + } + true + })); - if registry().with(layer).try_init().is_ok() { - update_panic_hook(file_path); - } + if registry().with(layer).try_init().is_ok() { + update_panic_hook(file_path); + } } fn update_panic_hook(file_path: &str) { - let hook = std::panic::take_hook(); - let file_path = file_path.to_owned(); + let hook = std::panic::take_hook(); + let file_path = file_path.to_owned(); - std::panic::set_hook(Box::new(move |info| { - log::error!("PANIC => {}", info.to_string()); - eprintln!( + std::panic::set_hook(Box::new(move |info| { + log::error!("PANIC => {}", info.to_string()); + eprintln!( "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => '{file_path}'\n=============" ); - hook(info); - })); + hook(info); + })); } diff --git a/burn-train/src/learner/regression.rs b/burn-train/src/learner/regression.rs index 9aa5db2e94..d6d647db61 100644 --- a/burn-train/src/learner/regression.rs +++ b/burn-train/src/learner/regression.rs @@ -5,18 +5,18 @@ use burn_core::tensor::Tensor; /// Simple regression output adapted for multiple metrics. #[derive(new)] pub struct RegressionOutput { - /// The loss. - pub loss: Tensor, + /// The loss. + pub loss: Tensor, - /// The output. - pub output: Tensor, + /// The output. + pub output: Tensor, - /// The targets. - pub targets: Tensor, + /// The targets. + pub targets: Tensor, } impl Adaptor> for RegressionOutput { - fn adapt(&self) -> LossInput { - LossInput::new(self.loss.clone()) - } + fn adapt(&self) -> LossInput { + LossInput::new(self.loss.clone()) + } } diff --git a/burn-train/src/learner/step/train.rs b/burn-train/src/learner/step/train.rs index c000e8b661..c8c46ae6d8 100644 --- a/burn-train/src/learner/step/train.rs +++ b/burn-train/src/learner/step/train.rs @@ -1,139 +1,139 @@ use crate::{TrainOutput, TrainStep}; use burn_core::{ - data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend, + data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend, }; use std::sync::mpsc::{Receiver, Sender}; use std::thread::spawn; /// Multi devices train step. pub struct MultiDevicesTrainStep { - workers: Vec>, - receiver: Receiver>, + workers: Vec>, + receiver: Receiver>, } struct Message { - item: TI, - model: M, + item: TI, + model: M, } struct Worker { - sender_input: Sender>, - device: B::Device, + sender_input: Sender>, + device: B::Device, } impl Worker where - B: AutodiffBackend, - M: AutodiffModule, + B: AutodiffBackend, + M: AutodiffModule, { - fn register(&self, item: TI, model: &M) { - let message = Message { - item, - model: model.clone(), - }; - self.sender_input.send(message).unwrap(); - } + fn register(&self, item: TI, model: &M) { + let message = Message { + item, + model: model.clone(), + }; + self.sender_input.send(message).unwrap(); + } - fn start( - &self, - sender_output: Sender>, - receiver_input: Receiver>, - ) where - TI: Send + 'static, - TO: Send + 'static, - M: TrainStep + Send + 'static, - { - let device = self.device.clone(); + fn start( + &self, + sender_output: Sender>, + receiver_input: Receiver>, + ) where + TI: Send + 'static, + TO: Send + 'static, + M: TrainStep + Send + 'static, + { + let device = self.device.clone(); - spawn(move || loop { - match receiver_input.recv() { - Ok(item) => { - let step = item.model.fork(&device); - let output = step.step(item.item); + spawn(move || loop { + match receiver_input.recv() { + Ok(item) => { + let step = item.model.fork(&device); + let output = step.step(item.item); - sender_output.send(output).unwrap(); - } - Err(_err) => { - log::info!("Closing thread on device {:?}", device); - break; - } - } - }); - } + sender_output.send(output).unwrap(); + } + Err(_err) => { + log::info!("Closing thread on device {:?}", device); + break; + } + } + }); + } } impl MultiDevicesTrainStep where - B: AutodiffBackend, - M: AutodiffModule + TrainStep + Send + Clone + 'static, - TI: Send + 'static, - TO: Send + 'static, + B: AutodiffBackend, + M: AutodiffModule + TrainStep + Send + Clone + 'static, + TI: Send + 'static, + TO: Send + 'static, { - /// Create a new multi devices train step. - /// - /// # Arguments - /// - /// * `devices` - Devices. - /// - /// # Returns - /// - /// MultiDevicesTrainStep instance. - pub fn new(devices: &[B::Device]) -> Self - where - TI: Send + 'static, - { - let (sender_output, receiver_output) = std::sync::mpsc::channel(); - let workers = devices - .iter() - .map(|device| { - let (sender_input, receiver_input) = std::sync::mpsc::channel(); - let worker = Worker { - sender_input, - device: device.clone(), - }; + /// Create a new multi devices train step. + /// + /// # Arguments + /// + /// * `devices` - Devices. + /// + /// # Returns + /// + /// MultiDevicesTrainStep instance. + pub fn new(devices: &[B::Device]) -> Self + where + TI: Send + 'static, + { + let (sender_output, receiver_output) = std::sync::mpsc::channel(); + let workers = devices + .iter() + .map(|device| { + let (sender_input, receiver_input) = std::sync::mpsc::channel(); + let worker = Worker { + sender_input, + device: device.clone(), + }; - worker.start(sender_output.clone(), receiver_input); - worker - }) - .collect(); + worker.start(sender_output.clone(), receiver_input); + worker + }) + .collect(); - Self { - workers, - receiver: receiver_output, - } + Self { + workers, + receiver: receiver_output, } + } - /// Collect outputs from workers for one step. - /// - /// # Arguments - /// - /// * `dataloader` - Dataloader. - /// * `model` - Model. - /// - /// # Returns - /// - /// Outputs. - pub fn step<'a>( - &self, - dataloader: &mut Box + 'a>, - model: &M, - ) -> Vec> { - let mut num_send = 0; - - for worker in self.workers.iter() { - if let Some(item) = dataloader.next() { - worker.register(item, model); - num_send += 1; - } - } + /// Collect outputs from workers for one step. + /// + /// # Arguments + /// + /// * `dataloader` - Dataloader. + /// * `model` - Model. + /// + /// # Returns + /// + /// Outputs. + pub fn step<'a>( + &self, + dataloader: &mut Box + 'a>, + model: &M, + ) -> Vec> { + let mut num_send = 0; - let mut outputs = Vec::with_capacity(num_send); + for worker in self.workers.iter() { + if let Some(item) = dataloader.next() { + worker.register(item, model); + num_send += 1; + } + } - for _ in 0..num_send { - let output = self.receiver.recv().unwrap(); - outputs.push(output); - } + let mut outputs = Vec::with_capacity(num_send); - outputs + for _ in 0..num_send { + let output = self.receiver.recv().unwrap(); + outputs.push(output); } + + outputs + } } diff --git a/burn-train/src/learner/train_val.rs b/burn-train/src/learner/train_val.rs index b8b16dddf9..6a9f902697 100644 --- a/burn-train/src/learner/train_val.rs +++ b/burn-train/src/learner/train_val.rs @@ -9,33 +9,33 @@ use std::sync::Arc; /// A training output. pub struct TrainOutput { - /// The gradients. - pub grads: GradientsParams, + /// The gradients. + pub grads: GradientsParams, - /// The item. - pub item: TO, + /// The item. + pub item: TO, } impl TrainOutput { - /// Creates a new training output. - /// - /// # Arguments - /// - /// * `module` - The module. - /// * `grads` - The gradients. - /// * `item` - The item. - /// - /// # Returns - /// - /// A new training output. - pub fn new>( - module: &M, - grads: B::Gradients, - item: TO, - ) -> Self { - let grads = GradientsParams::from_grads(grads, module); - Self { grads, item } - } + /// Creates a new training output. + /// + /// # Arguments + /// + /// * `module` - The module. + /// * `grads` - The gradients. + /// * `item` - The item. + /// + /// # Returns + /// + /// A new training output. + pub fn new>( + module: &M, + grads: B::Gradients, + item: TO, + ) -> Self { + let grads = GradientsParams::from_grads(grads, module); + Self { grads, item } + } } /// Trait to be implemented for training models. @@ -52,152 +52,144 @@ impl TrainOutput { /// also implement the [AutodiffModule] trait, which is done automatically with the /// [Module](burn_core::module::Module) derive. pub trait TrainStep { - /// Runs the training step, which executes the forward and backward passes. - /// - /// # Arguments - /// - /// * `item` - The training input for the model. - /// - /// # Returns - /// - /// The training output containing the model output and the gradients. - fn step(&self, item: TI) -> TrainOutput; - /// Optimize the current module with the provided gradients and learning rate. - /// - /// # Arguments - /// - /// * `optim`: Optimizer used for training this model. - /// * `lr`: The learning rate used for this step. - /// * `grads`: The gradients of each parameter in the current model. - /// - /// # Returns - /// - /// The updated model. - fn optimize(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self - where - B: AutodiffBackend, - O: Optimizer, - Self: AutodiffModule, - { - optim.step(lr, self, grads) - } + /// Runs the training step, which executes the forward and backward passes. + /// + /// # Arguments + /// + /// * `item` - The training input for the model. + /// + /// # Returns + /// + /// The training output containing the model output and the gradients. + fn step(&self, item: TI) -> TrainOutput; + /// Optimize the current module with the provided gradients and learning rate. + /// + /// # Arguments + /// + /// * `optim`: Optimizer used for training this model. + /// * `lr`: The learning rate used for this step. + /// * `grads`: The gradients of each parameter in the current model. + /// + /// # Returns + /// + /// The updated model. + fn optimize(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self + where + B: AutodiffBackend, + O: Optimizer, + Self: AutodiffModule, + { + optim.step(lr, self, grads) + } } /// Trait to be implemented for validating models. pub trait ValidStep { - /// Runs a validation step. - /// - /// # Arguments - /// - /// * `item` - The item to validate on. - /// - /// # Returns - /// - /// The validation output. - fn step(&self, item: VI) -> VO; + /// Runs a validation step. + /// + /// # Arguments + /// + /// * `item` - The item to validate on. + /// + /// # Returns + /// + /// The validation output. + fn step(&self, item: VI) -> VO; } impl Learner { - /// Fits the model. - /// - /// # Arguments - /// - /// * `dataloader_train` - The training dataloader. - /// * `dataloader_valid` - The validation dataloader. - /// - /// # Returns - /// - /// The fitted model. - pub fn fit( - mut self, - dataloader_train: Arc>, - dataloader_valid: Arc>, - ) -> LC::Model - where - InputTrain: Send + 'static, - InputValid: Send, - OutputTrain: Send + 'static, - OutputValid: Send, - LC::Model: TrainStep, - >::InnerModule: ValidStep, - LC::EventProcessor: EventProcessor, - { - log::info!("Fitting {}", self.model.to_string()); - // The reference model is always on the first device provided. - if let Some(device) = self.devices.get(0) { - self.model = self.model.fork(device); - } + /// Fits the model. + /// + /// # Arguments + /// + /// * `dataloader_train` - The training dataloader. + /// * `dataloader_valid` - The validation dataloader. + /// + /// # Returns + /// + /// The fitted model. + pub fn fit( + mut self, + dataloader_train: Arc>, + dataloader_valid: Arc>, + ) -> LC::Model + where + InputTrain: Send + 'static, + InputValid: Send, + OutputTrain: Send + 'static, + OutputValid: Send, + LC::Model: TrainStep, + >::InnerModule: ValidStep, + LC::EventProcessor: EventProcessor, + { + log::info!("Fitting {}", self.model.to_string()); + // The reference model is always on the first device provided. + if let Some(device) = self.devices.get(0) { + self.model = self.model.fork(device); + } - let starting_epoch = match self.checkpoint { - Some(checkpoint) => { - if let Some(checkpointer) = &mut self.checkpointer { - (self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint( - self.model, - self.optim, - self.lr_scheduler, - checkpoint, - ); - } - checkpoint + 1 - } - None => 1, - }; + let starting_epoch = match self.checkpoint { + Some(checkpoint) => { + if let Some(checkpointer) = &mut self.checkpointer { + (self.model, self.optim, self.lr_scheduler) = + checkpointer.load_checkpoint(self.model, self.optim, self.lr_scheduler, checkpoint); + } + checkpoint + 1 + } + None => 1, + }; - for epoch in starting_epoch..self.num_epochs + 1 { - let epoch_train = TrainEpoch::new( - dataloader_train.clone(), - epoch, - self.num_epochs, - self.grad_accumulation, - ); + for epoch in starting_epoch..self.num_epochs + 1 { + let epoch_train = TrainEpoch::new( + dataloader_train.clone(), + epoch, + self.num_epochs, + self.grad_accumulation, + ); - if self.devices.len() > 1 { - (self.model, self.optim) = epoch_train.run_multi_device::( - self.model, - self.optim, - &mut self.lr_scheduler, - &mut self.event_processor, - self.devices.clone(), - &self.interrupter, - ) - } else { - (self.model, self.optim) = epoch_train.run::( - self.model, - self.optim, - &mut self.lr_scheduler, - &mut self.event_processor, - &self.interrupter, - ); - } + if self.devices.len() > 1 { + (self.model, self.optim) = epoch_train.run_multi_device::( + self.model, + self.optim, + &mut self.lr_scheduler, + &mut self.event_processor, + self.devices.clone(), + &self.interrupter, + ) + } else { + (self.model, self.optim) = epoch_train.run::( + self.model, + self.optim, + &mut self.lr_scheduler, + &mut self.event_processor, + &self.interrupter, + ); + } - if self.interrupter.should_stop() { - break; - } + if self.interrupter.should_stop() { + break; + } - let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs); - epoch_valid.run::( - &self.model, - &mut self.event_processor, - &self.interrupter, - ); + let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs); + epoch_valid.run::(&self.model, &mut self.event_processor, &self.interrupter); - if let Some(checkpointer) = &mut self.checkpointer { - checkpointer.checkpoint( - &self.model, - &self.optim, - &self.lr_scheduler, - epoch, - &self.event_store, - ); - } + if let Some(checkpointer) = &mut self.checkpointer { + checkpointer.checkpoint( + &self.model, + &self.optim, + &self.lr_scheduler, + epoch, + &self.event_store, + ); + } - if let Some(early_stopping) = &mut self.early_stopping { - if early_stopping.should_stop(epoch, &self.event_store) { - break; - } - } + if let Some(early_stopping) = &mut self.early_stopping { + if early_stopping.should_stop(epoch, &self.event_store) { + break; } - - self.model + } } + + self.model + } } diff --git a/burn-train/src/logger/async_logger.rs b/burn-train/src/logger/async_logger.rs index c659098b1e..79308e21f0 100644 --- a/burn-train/src/logger/async_logger.rs +++ b/burn-train/src/logger/async_logger.rs @@ -2,90 +2,93 @@ use super::Logger; use std::sync::mpsc; enum Message { - Log(T), - End, - Sync(mpsc::Sender<()>), + Log(T), + End, + Sync(mpsc::Sender<()>), } /// Async logger. pub struct AsyncLogger { - sender: mpsc::Sender>, - handler: Option>, + sender: mpsc::Sender>, + handler: Option>, } #[derive(new)] struct LoggerThread> { - logger: L, - receiver: mpsc::Receiver>, + logger: L, + receiver: mpsc::Receiver>, } impl LoggerThread where - L: Logger, + L: Logger, { - fn run(mut self) { - for item in self.receiver.iter() { - match item { - Message::Log(item) => { - self.logger.log(item); - } - Message::End => { - return; - } - Message::Sync(callback) => { - callback - .send(()) - .expect("Can return result with the callback channel."); - } - } + fn run(mut self) { + for item in self.receiver.iter() { + match item { + Message::Log(item) => { + self.logger.log(item); } + Message::End => { + return; + } + Message::Sync(callback) => { + callback + .send(()) + .expect("Can return result with the callback channel."); + } + } } + } } impl AsyncLogger { - /// Create a new async logger. - pub fn new(logger: L) -> Self - where - L: Logger + 'static, - { - let (sender, receiver) = mpsc::channel(); - let thread = LoggerThread::new(logger, receiver); + /// Create a new async logger. + pub fn new(logger: L) -> Self + where + L: Logger + 'static, + { + let (sender, receiver) = mpsc::channel(); + let thread = LoggerThread::new(logger, receiver); - let handler = Some(std::thread::spawn(move || thread.run())); + let handler = Some(std::thread::spawn(move || thread.run())); - Self { sender, handler } - } + Self { sender, handler } + } - /// Sync the async logger. - pub(crate) fn sync(&self) { - let (sender, receiver) = mpsc::channel(); + /// Sync the async logger. + pub(crate) fn sync(&self) { + let (sender, receiver) = mpsc::channel(); - self.sender - .send(Message::Sync(sender)) - .expect("Can send message to logger thread."); + self + .sender + .send(Message::Sync(sender)) + .expect("Can send message to logger thread."); - receiver - .recv() - .expect("Should sync, otherwise the thread is dead."); - } + receiver + .recv() + .expect("Should sync, otherwise the thread is dead."); + } } impl Logger for AsyncLogger { - fn log(&mut self, item: T) { - self.sender - .send(Message::Log(item)) - .expect("Can log using the logger thread."); - } + fn log(&mut self, item: T) { + self + .sender + .send(Message::Log(item)) + .expect("Can log using the logger thread."); + } } impl Drop for AsyncLogger { - fn drop(&mut self) { - self.sender - .send(Message::End) - .expect("Can send the end message to the logger thread."); - let handler = self.handler.take(); + fn drop(&mut self) { + self + .sender + .send(Message::End) + .expect("Can send the end message to the logger thread."); + let handler = self.handler.take(); - if let Some(handler) = handler { - handler.join().expect("The logger thread should stop."); - } + if let Some(handler) = handler { + handler.join().expect("The logger thread should stop."); } + } } diff --git a/burn-train/src/logger/base.rs b/burn-train/src/logger/base.rs index 3b37c55e61..5e3fcd677b 100644 --- a/burn-train/src/logger/base.rs +++ b/burn-train/src/logger/base.rs @@ -1,26 +1,26 @@ /// The logger trait. pub trait Logger: Send { - /// Logs an item. - /// - /// # Arguments - /// - /// * `item` - The item. - fn log(&mut self, item: T); + /// Logs an item. + /// + /// # Arguments + /// + /// * `item` - The item. + fn log(&mut self, item: T); } /// The logger backend trait. pub trait LoggerBackend { - /// The logger type. - type Logger: Logger; + /// The logger type. + type Logger: Logger; - /// Create a new logger. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - /// - /// # Returns - /// - /// The logger. - fn create(&self, epoch: usize) -> Self::Logger; + /// Create a new logger. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + /// + /// # Returns + /// + /// The logger. + fn create(&self, epoch: usize) -> Self::Logger; } diff --git a/burn-train/src/logger/file.rs b/burn-train/src/logger/file.rs index 79c23b462d..7b13089af9 100644 --- a/burn-train/src/logger/file.rs +++ b/burn-train/src/logger/file.rs @@ -3,37 +3,37 @@ use std::{fs::File, io::Write}; /// File logger. pub struct FileLogger { - file: File, + file: File, } impl FileLogger { - /// Create a new file logger. - /// - /// # Arguments - /// - /// * `path` - The path. - /// - /// # Returns - /// - /// The file logger. - pub fn new(path: &str) -> Self { - let mut options = std::fs::File::options(); - let file = options - .write(true) - .truncate(true) - .create(true) - .open(path) - .unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}")); + /// Create a new file logger. + /// + /// # Arguments + /// + /// * `path` - The path. + /// + /// # Returns + /// + /// The file logger. + pub fn new(path: &str) -> Self { + let mut options = std::fs::File::options(); + let file = options + .write(true) + .truncate(true) + .create(true) + .open(path) + .unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}")); - Self { file } - } + Self { file } + } } impl Logger for FileLogger where - T: std::fmt::Display, + T: std::fmt::Display, { - fn log(&mut self, item: T) { - writeln!(&mut self.file, "{item}").expect("Can log an item."); - } + fn log(&mut self, item: T) { + writeln!(&mut self.file, "{item}").expect("Can log an item."); + } } diff --git a/burn-train/src/logger/in_memory.rs b/burn-train/src/logger/in_memory.rs index 31cf3f165c..425c8dc76e 100644 --- a/burn-train/src/logger/in_memory.rs +++ b/burn-train/src/logger/in_memory.rs @@ -3,14 +3,14 @@ use super::Logger; /// In memory logger. #[derive(Default)] pub struct InMemoryLogger { - pub(crate) values: Vec, + pub(crate) values: Vec, } impl Logger for InMemoryLogger where - T: std::fmt::Display, + T: std::fmt::Display, { - fn log(&mut self, item: T) { - self.values.push(item.to_string()); - } + fn log(&mut self, item: T) { + self.values.push(item.to_string()); + } } diff --git a/burn-train/src/logger/metric.rs b/burn-train/src/logger/metric.rs index 5751eff925..42519e7c48 100644 --- a/burn-train/src/logger/metric.rs +++ b/burn-train/src/logger/metric.rs @@ -4,169 +4,173 @@ use std::collections::HashMap; /// Metric logger. pub trait MetricLogger: Send { - /// Logs an item. - /// - /// # Arguments - /// - /// * `item` - The item. - fn log(&mut self, item: &MetricEntry); - - /// Logs an epoch. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - fn end_epoch(&mut self, epoch: usize); - - /// Read the logs for an epoch. - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String>; + /// Logs an item. + /// + /// # Arguments + /// + /// * `item` - The item. + fn log(&mut self, item: &MetricEntry); + + /// Logs an epoch. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + fn end_epoch(&mut self, epoch: usize); + + /// Read the logs for an epoch. + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String>; } /// The file metric logger. pub struct FileMetricLogger { - loggers: HashMap>, - directory: String, - epoch: usize, + loggers: HashMap>, + directory: String, + epoch: usize, } impl FileMetricLogger { - /// Create a new file metric logger. - /// - /// # Arguments - /// - /// * `directory` - The directory. - /// - /// # Returns - /// - /// The file metric logger. - pub fn new(directory: &str) -> Self { - Self { - loggers: HashMap::new(), - directory: directory.to_string(), - epoch: 1, - } - } - - fn file_path(&self, name: &str, epoch: usize) -> String { - let directory = format!("{}/epoch-{}", self.directory, epoch); - let name = name.replace(' ', "_"); - - format!("{directory}/{name}.log") - } - fn create_directory(&self, epoch: usize) { - let directory = format!("{}/epoch-{}", self.directory, epoch); - std::fs::create_dir_all(directory).ok(); + /// Create a new file metric logger. + /// + /// # Arguments + /// + /// * `directory` - The directory. + /// + /// # Returns + /// + /// The file metric logger. + pub fn new(directory: &str) -> Self { + Self { + loggers: HashMap::new(), + directory: directory.to_string(), + epoch: 1, } + } + + fn file_path(&self, name: &str, epoch: usize) -> String { + let directory = format!("{}/epoch-{}", self.directory, epoch); + let name = name.replace(' ', "_"); + + format!("{directory}/{name}.log") + } + fn create_directory(&self, epoch: usize) { + let directory = format!("{}/epoch-{}", self.directory, epoch); + std::fs::create_dir_all(directory).ok(); + } } impl MetricLogger for FileMetricLogger { - fn log(&mut self, item: &MetricEntry) { - let key = &item.name; - let value = &item.serialize; - - let logger = match self.loggers.get_mut(key) { - Some(val) => val, - None => { - self.create_directory(self.epoch); - - let file_path = self.file_path(key, self.epoch); - let logger = FileLogger::new(&file_path); - let logger = AsyncLogger::new(logger); - - self.loggers.insert(key.clone(), logger); - self.loggers - .get_mut(key) - .expect("Can get the previously saved logger.") - } - }; - - logger.log(value.clone()); + fn log(&mut self, item: &MetricEntry) { + let key = &item.name; + let value = &item.serialize; + + let logger = match self.loggers.get_mut(key) { + Some(val) => val, + None => { + self.create_directory(self.epoch); + + let file_path = self.file_path(key, self.epoch); + let logger = FileLogger::new(&file_path); + let logger = AsyncLogger::new(logger); + + self.loggers.insert(key.clone(), logger); + self + .loggers + .get_mut(key) + .expect("Can get the previously saved logger.") + } + }; + + logger.log(value.clone()); + } + + fn end_epoch(&mut self, epoch: usize) { + self.loggers.clear(); + self.epoch = epoch + 1; + } + + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { + if let Some(value) = self.loggers.get(name) { + value.sync() } - fn end_epoch(&mut self, epoch: usize) { - self.loggers.clear(); - self.epoch = epoch + 1; - } + let file_path = self.file_path(name, epoch); - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { - if let Some(value) = self.loggers.get(name) { - value.sync() - } + let mut errors = false; - let file_path = self.file_path(name, epoch); - - let mut errors = false; - - let data = std::fs::read_to_string(file_path) - .unwrap_or_default() - .split('\n') - .filter_map(|value| { - if value.is_empty() { - None - } else { - match value.parse::() { - Ok(value) => Some(value), - Err(err) => { - log::error!("{err}"); - errors = true; - None - } - } - } - }) - .collect(); - - if errors { - Err("Parsing float errors".to_string()) + let data = std::fs::read_to_string(file_path) + .unwrap_or_default() + .split('\n') + .filter_map(|value| { + if value.is_empty() { + None } else { - Ok(data) + match value.parse::() { + Ok(value) => Some(value), + Err(err) => { + log::error!("{err}"); + errors = true; + None + } + } } + }) + .collect(); + + if errors { + Err("Parsing float errors".to_string()) + } else { + Ok(data) } + } } /// In memory metric logger, useful when testing and debugging. #[derive(Default)] pub struct InMemoryMetricLogger { - values: HashMap>, + values: HashMap>, } impl InMemoryMetricLogger { - /// Create a new in-memory metric logger. - pub fn new() -> Self { - Self::default() - } + /// Create a new in-memory metric logger. + pub fn new() -> Self { + Self::default() + } } impl MetricLogger for InMemoryMetricLogger { - fn log(&mut self, item: &MetricEntry) { - if !self.values.contains_key(&item.name) { - self.values - .insert(item.name.clone(), vec![InMemoryLogger::default()]); - } + fn log(&mut self, item: &MetricEntry) { + if !self.values.contains_key(&item.name) { + self + .values + .insert(item.name.clone(), vec![InMemoryLogger::default()]); + } - let values = self.values.get_mut(&item.name).unwrap(); + let values = self.values.get_mut(&item.name).unwrap(); - values.last_mut().unwrap().log(item.serialize.clone()); - } + values.last_mut().unwrap().log(item.serialize.clone()); + } - fn end_epoch(&mut self, _epoch: usize) { - for (_, values) in self.values.iter_mut() { - values.push(InMemoryLogger::default()); - } + fn end_epoch(&mut self, _epoch: usize) { + for (_, values) in self.values.iter_mut() { + values.push(InMemoryLogger::default()); } - - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { - let values = match self.values.get(name) { - Some(values) => values, - None => return Ok(Vec::new()), - }; - - match values.get(epoch - 1) { - Some(logger) => Ok(logger - .values - .iter() - .filter_map(|value| value.parse::().ok()) - .collect()), - None => Ok(Vec::new()), - } + } + + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { + let values = match self.values.get(name) { + Some(values) => values, + None => return Ok(Vec::new()), + }; + + match values.get(epoch - 1) { + Some(logger) => Ok( + logger + .values + .iter() + .filter_map(|value| value.parse::().ok()) + .collect(), + ), + None => Ok(Vec::new()), } + } } diff --git a/burn-train/src/metric/acc.rs b/burn-train/src/metric/acc.rs index a1a0f3ff2f..7157739231 100644 --- a/burn-train/src/metric/acc.rs +++ b/burn-train/src/metric/acc.rs @@ -7,123 +7,123 @@ use burn_core::tensor::{ElementConversion, Int, Tensor}; /// The accuracy metric. #[derive(Default)] pub struct AccuracyMetric { - state: NumericMetricState, - pad_token: Option, - _b: B, + state: NumericMetricState, + pad_token: Option, + _b: B, } /// The [accuracy metric](AccuracyMetric) input type. #[derive(new)] pub struct AccuracyInput { - outputs: Tensor, - targets: Tensor, + outputs: Tensor, + targets: Tensor, } impl AccuracyMetric { - /// Creates the metric. - pub fn new() -> Self { - Self::default() - } - - /// Sets the pad token. - pub fn with_pad_token(mut self, index: usize) -> Self { - self.pad_token = Some(index); - self - } + /// Creates the metric. + pub fn new() -> Self { + Self::default() + } + + /// Sets the pad token. + pub fn with_pad_token(mut self, index: usize) -> Self { + self.pad_token = Some(index); + self + } } impl Metric for AccuracyMetric { - const NAME: &'static str = "Accuracy"; - - type Input = AccuracyInput; - - fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { - let [batch_size, _n_classes] = input.outputs.dims(); - - let targets = input.targets.clone().to_device(&B::Device::default()); - let outputs = input - .outputs - .clone() - .argmax(1) - .to_device(&B::Device::default()) - .reshape([batch_size]); - - let accuracy = match self.pad_token { - Some(pad_token) => { - let mask = targets.clone().equal_elem(pad_token as i64); - let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0); - let num_pad = mask.int().sum().into_scalar().elem::(); - - matches.sum().into_scalar().elem::() / (batch_size as f64 - num_pad) - } - None => { - outputs - .equal(targets) - .int() - .sum() - .into_scalar() - .elem::() - / batch_size as f64 - } - }; - - self.state.update( - 100.0 * accuracy, - batch_size, - FormatOptions::new(Self::NAME).unit("%").precision(2), - ) - } - - fn clear(&mut self) { - self.state.reset() - } + const NAME: &'static str = "Accuracy"; + + type Input = AccuracyInput; + + fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { + let [batch_size, _n_classes] = input.outputs.dims(); + + let targets = input.targets.clone().to_device(&B::Device::default()); + let outputs = input + .outputs + .clone() + .argmax(1) + .to_device(&B::Device::default()) + .reshape([batch_size]); + + let accuracy = match self.pad_token { + Some(pad_token) => { + let mask = targets.clone().equal_elem(pad_token as i64); + let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0); + let num_pad = mask.int().sum().into_scalar().elem::(); + + matches.sum().into_scalar().elem::() / (batch_size as f64 - num_pad) + } + None => { + outputs + .equal(targets) + .int() + .sum() + .into_scalar() + .elem::() + / batch_size as f64 + } + }; + + self.state.update( + 100.0 * accuracy, + batch_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } } impl Numeric for AccuracyMetric { - fn value(&self) -> f64 { - self.state.value() - } + fn value(&self) -> f64 { + self.state.value() + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - - #[test] - fn test_accuracy_without_padding() { - let mut metric = AccuracyMetric::::new(); - let input = AccuracyInput::new( - Tensor::from_data([ - [0.0, 0.2, 0.8], // 2 - [1.0, 2.0, 0.5], // 1 - [0.4, 0.1, 0.2], // 0 - [0.6, 0.7, 0.2], // 1 - ]), - Tensor::from_data([2, 2, 1, 1]), - ); - - let _entry = metric.update(&input, &MetricMetadata::fake()); - assert_eq!(50.0, metric.value()); - } - - #[test] - fn test_accuracy_with_padding() { - let mut metric = AccuracyMetric::::new().with_pad_token(3); - let input = AccuracyInput::new( - Tensor::from_data([ - [0.0, 0.2, 0.8, 0.0], // 2 - [1.0, 2.0, 0.5, 0.0], // 1 - [0.4, 0.1, 0.2, 0.0], // 0 - [0.6, 0.7, 0.2, 0.0], // 1 - [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count - [0.0, 0.1, 0.2, 0.0], // Error on padding should not count - [0.6, 0.0, 0.2, 0.0], // Error on padding should not count - ]), - Tensor::from_data([2, 2, 1, 1, 3, 3, 3]), - ); - - let _entry = metric.update(&input, &MetricMetadata::fake()); - assert_eq!(50.0, metric.value()); - } + use super::*; + use crate::TestBackend; + + #[test] + fn test_accuracy_without_padding() { + let mut metric = AccuracyMetric::::new(); + let input = AccuracyInput::new( + Tensor::from_data([ + [0.0, 0.2, 0.8], // 2 + [1.0, 2.0, 0.5], // 1 + [0.4, 0.1, 0.2], // 0 + [0.6, 0.7, 0.2], // 1 + ]), + Tensor::from_data([2, 2, 1, 1]), + ); + + let _entry = metric.update(&input, &MetricMetadata::fake()); + assert_eq!(50.0, metric.value()); + } + + #[test] + fn test_accuracy_with_padding() { + let mut metric = AccuracyMetric::::new().with_pad_token(3); + let input = AccuracyInput::new( + Tensor::from_data([ + [0.0, 0.2, 0.8, 0.0], // 2 + [1.0, 2.0, 0.5, 0.0], // 1 + [0.4, 0.1, 0.2, 0.0], // 0 + [0.6, 0.7, 0.2, 0.0], // 1 + [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count + [0.0, 0.1, 0.2, 0.0], // Error on padding should not count + [0.6, 0.0, 0.2, 0.0], // Error on padding should not count + ]), + Tensor::from_data([2, 2, 1, 1, 3, 3, 3]), + ); + + let _entry = metric.update(&input, &MetricMetadata::fake()); + assert_eq!(50.0, metric.value()); + } } diff --git a/burn-train/src/metric/base.rs b/burn-train/src/metric/base.rs index 1d0f2ca49b..215c137f0a 100644 --- a/burn-train/src/metric/base.rs +++ b/burn-train/src/metric/base.rs @@ -2,36 +2,36 @@ use burn_core::{data::dataloader::Progress, LearningRate}; /// Metric metadata that can be used when computing metrics. pub struct MetricMetadata { - /// The current progress. - pub progress: Progress, + /// The current progress. + pub progress: Progress, - /// The current epoch. - pub epoch: usize, + /// The current epoch. + pub epoch: usize, - /// The total number of epochs. - pub epoch_total: usize, + /// The total number of epochs. + pub epoch_total: usize, - /// The current iteration. - pub iteration: usize, + /// The current iteration. + pub iteration: usize, - /// The current learning rate. - pub lr: Option, + /// The current learning rate. + pub lr: Option, } impl MetricMetadata { - #[cfg(test)] - pub fn fake() -> Self { - Self { - progress: Progress { - items_processed: 1, - items_total: 1, - }, - epoch: 0, - epoch_total: 1, - iteration: 0, - lr: None, - } + #[cfg(test)] + pub fn fake() -> Self { + Self { + progress: Progress { + items_processed: 1, + items_total: 1, + }, + epoch: 0, + epoch_total: 1, + iteration: 0, + lr: None, } + } } /// Metric trait. @@ -42,18 +42,18 @@ impl MetricMetadata { /// This is important since some conflict may happen when the model output is adapted for each /// metric's input type. pub trait Metric: Send + Sync { - /// The name of the metric. - /// - /// This should be unique, so avoid using short generic names, prefer using the long name. - const NAME: &'static str; + /// The name of the metric. + /// + /// This should be unique, so avoid using short generic names, prefer using the long name. + const NAME: &'static str; - /// The input type of the metric. - type Input; + /// The input type of the metric. + type Input; - /// Update the metric state and returns the current metric entry. - fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry; - /// Clear the metric state. - fn clear(&mut self); + /// Update the metric state and returns the current metric entry. + fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry; + /// Clear the metric state. + fn clear(&mut self); } /// Adaptor are used to transform types so that they can be used by metrics. @@ -61,35 +61,35 @@ pub trait Metric: Send + Sync { /// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are /// registered with the [leaner buidler](crate::learner::LearnerBuilder) . pub trait Adaptor { - /// Adapt the type to be passed to a [metric](Metric). - fn adapt(&self) -> T; + /// Adapt the type to be passed to a [metric](Metric). + fn adapt(&self) -> T; } /// Declare a metric to be numeric. /// /// This is useful to plot the values of a metric during training. pub trait Numeric { - /// Returns the numeric value of the metric. - fn value(&self) -> f64; + /// Returns the numeric value of the metric. + fn value(&self) -> f64; } /// Data type that contains the current state of a metric at a given time. #[derive(new, Debug, Clone)] pub struct MetricEntry { - /// The name of the metric. - pub name: String, - /// The string to be displayed. - pub formatted: String, - /// The string to be saved. - pub serialize: String, + /// The name of the metric. + pub name: String, + /// The string to be displayed. + pub formatted: String, + /// The string to be saved. + pub serialize: String, } /// Format a float with the given precision. Will use scientific notation if necessary. pub fn format_float(float: f64, precision: usize) -> String { - let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); + let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); - match scientific_notation_threshold >= float { - true => format!("{float:.precision$e}"), - false => format!("{float:.precision$}"), - } + match scientific_notation_threshold >= float { + true => format!("{float:.precision$e}"), + false => format!("{float:.precision$}"), + } } diff --git a/burn-train/src/metric/cpu_temp.rs b/burn-train/src/metric/cpu_temp.rs index a44aba8f05..ea96ec9f6b 100644 --- a/burn-train/src/metric/cpu_temp.rs +++ b/burn-train/src/metric/cpu_temp.rs @@ -5,51 +5,51 @@ use systemstat::{Platform, System}; /// CPU Temperature in celsius degrees pub struct CpuTemperature { - temp_celsius: f32, - sys: System, + temp_celsius: f32, + sys: System, } impl CpuTemperature { - /// Creates a new CPU temp metric - pub fn new() -> Self { - Self { - temp_celsius: 0., - sys: System::new(), - } + /// Creates a new CPU temp metric + pub fn new() -> Self { + Self { + temp_celsius: 0., + sys: System::new(), } + } } impl Default for CpuTemperature { - fn default() -> Self { - CpuTemperature::new() - } + fn default() -> Self { + CpuTemperature::new() + } } impl Metric for CpuTemperature { - const NAME: &'static str = "CPU Temperature"; + const NAME: &'static str = "CPU Temperature"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - match self.sys.cpu_temp() { - Ok(temp) => self.temp_celsius = temp, - Err(_) => self.temp_celsius = f32::NAN, - } + fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + match self.sys.cpu_temp() { + Ok(temp) => self.temp_celsius = temp, + Err(_) => self.temp_celsius = f32::NAN, + } - let formatted = match self.temp_celsius.is_nan() { - true => format!("{}: NaN °C", Self::NAME), - false => format!("{}: {:.2} °C", Self::NAME, self.temp_celsius), - }; - let raw = format!("{:.2}", self.temp_celsius); + let formatted = match self.temp_celsius.is_nan() { + true => format!("{}: NaN °C", Self::NAME), + false => format!("{}: {:.2} °C", Self::NAME, self.temp_celsius), + }; + let raw = format!("{:.2}", self.temp_celsius); - MetricEntry::new(Self::NAME.to_string(), formatted, raw) - } + MetricEntry::new(Self::NAME.to_string(), formatted, raw) + } - fn clear(&mut self) {} + fn clear(&mut self) {} } impl Numeric for CpuTemperature { - fn value(&self) -> f64 { - self.temp_celsius as f64 - } + fn value(&self) -> f64 { + self.temp_celsius as f64 + } } diff --git a/burn-train/src/metric/cpu_use.rs b/burn-train/src/metric/cpu_use.rs index 353165d289..41849cb916 100644 --- a/burn-train/src/metric/cpu_use.rs +++ b/burn-train/src/metric/cpu_use.rs @@ -5,65 +5,65 @@ use sysinfo::{CpuExt, CpuRefreshKind, RefreshKind, System, SystemExt}; /// General CPU Usage metric pub struct CpuUse { - last_refresh: Instant, - refresh_frequency: Duration, - sys: System, - current: f64, + last_refresh: Instant, + refresh_frequency: Duration, + sys: System, + current: f64, } impl CpuUse { - /// Creates a new CPU metric - pub fn new() -> Self { - let mut sys = System::new(); - let current = Self::refresh(&mut sys); + /// Creates a new CPU metric + pub fn new() -> Self { + let mut sys = System::new(); + let current = Self::refresh(&mut sys); - Self { - last_refresh: Instant::now(), - refresh_frequency: Duration::from_millis(200), - sys, - current, - } + Self { + last_refresh: Instant::now(), + refresh_frequency: Duration::from_millis(200), + sys, + current, } + } - fn refresh(sys: &mut System) -> f64 { - sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); + fn refresh(sys: &mut System) -> f64 { + sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); - let cpus = sys.cpus(); - let num_cpus = cpus.len(); - let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64; + let cpus = sys.cpus(); + let num_cpus = cpus.len(); + let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64; - use_percentage / num_cpus as f64 - } + use_percentage / num_cpus as f64 + } } impl Default for CpuUse { - fn default() -> Self { - CpuUse::new() - } + fn default() -> Self { + CpuUse::new() + } } impl Metric for CpuUse { - const NAME: &'static str = "CPU Usage"; + const NAME: &'static str = "CPU Usage"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - if self.last_refresh.elapsed() >= self.refresh_frequency { - self.current = Self::refresh(&mut self.sys); - self.last_refresh = Instant::now(); - } + fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + if self.last_refresh.elapsed() >= self.refresh_frequency { + self.current = Self::refresh(&mut self.sys); + self.last_refresh = Instant::now(); + } - let formatted = format!("{}: {:.2} %", Self::NAME, self.current); - let raw = format!("{:.2}", self.current); + let formatted = format!("{}: {:.2} %", Self::NAME, self.current); + let raw = format!("{:.2}", self.current); - MetricEntry::new(Self::NAME.to_string(), formatted, raw) - } + MetricEntry::new(Self::NAME.to_string(), formatted, raw) + } - fn clear(&mut self) {} + fn clear(&mut self) {} } impl Numeric for CpuUse { - fn value(&self) -> f64 { - self.current - } + fn value(&self) -> f64 { + self.current + } } diff --git a/burn-train/src/metric/cuda.rs b/burn-train/src/metric/cuda.rs index e69e11ffc1..c7f15ef1df 100644 --- a/burn-train/src/metric/cuda.rs +++ b/burn-train/src/metric/cuda.rs @@ -4,101 +4,101 @@ use nvml_wrapper::Nvml; /// Track basic cuda infos. pub struct CUDAMetric { - nvml: Option, + nvml: Option, } impl CUDAMetric { - /// Creates a new metric for CUDA. - pub fn new() -> Self { - Self { - nvml: Nvml::init().map(Some).unwrap_or_else(|err| { - log::warn!("Unable to initialize CUDA Metric: {err}"); - None - }), - } + /// Creates a new metric for CUDA. + pub fn new() -> Self { + Self { + nvml: Nvml::init().map(Some).unwrap_or_else(|err| { + log::warn!("Unable to initialize CUDA Metric: {err}"); + None + }), } + } } impl Default for CUDAMetric { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl Adaptor<()> for T { - fn adapt(&self) {} + fn adapt(&self) {} } impl Metric for CUDAMetric { - const NAME: &'static str = "CUDA Stats"; + const NAME: &'static str = "CUDA Stats"; + + type Input = (); + + fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry { + let not_available = || { + MetricEntry::new( + Self::NAME.to_string(), + "Unavailable".to_string(), + "Unavailable".to_string(), + ) + }; + + let available = |nvml: &Nvml| { + let mut formatted = String::new(); + let mut raw_running = String::new(); + + let device_count = match nvml.device_count() { + Ok(val) => val, + Err(err) => { + log::warn!("Unable to get the number of cuda devices: {err}"); + return not_available(); + } + }; + + for index in 0..device_count { + let device = match nvml.device_by_index(index) { + Ok(val) => val, + Err(err) => { + log::warn!("Unable to get device {index}: {err}"); + return not_available(); + } + }; + let memory_info = match device.memory_info() { + Ok(info) => info, + Err(err) => { + log::warn!("Unable to get memory info from device {index}: {err}"); + return not_available(); + } + }; - type Input = (); + let used_gb = memory_info.used as f64 * 1e-9; + let total_gb = memory_info.total as f64 * 1e-9; - fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry { - let not_available = || { - MetricEntry::new( - Self::NAME.to_string(), - "Unavailable".to_string(), - "Unavailable".to_string(), - ) - }; + let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb"); + let memory_info_raw = format!("{used_gb}/{total_gb}"); + + formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}"); + raw_running = format!("{memory_info_raw} "); - let available = |nvml: &Nvml| { - let mut formatted = String::new(); - let mut raw_running = String::new(); - - let device_count = match nvml.device_count() { - Ok(val) => val, - Err(err) => { - log::warn!("Unable to get the number of cuda devices: {err}"); - return not_available(); - } - }; - - for index in 0..device_count { - let device = match nvml.device_by_index(index) { - Ok(val) => val, - Err(err) => { - log::warn!("Unable to get device {index}: {err}"); - return not_available(); - } - }; - let memory_info = match device.memory_info() { - Ok(info) => info, - Err(err) => { - log::warn!("Unable to get memory info from device {index}: {err}"); - return not_available(); - } - }; - - let used_gb = memory_info.used as f64 * 1e-9; - let total_gb = memory_info.total as f64 * 1e-9; - - let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb"); - let memory_info_raw = format!("{used_gb}/{total_gb}"); - - formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}"); - raw_running = format!("{memory_info_raw} "); - - let utilization_rates = match device.utilization_rates() { - Ok(rate) => rate, - Err(err) => { - log::warn!("Unable to get utilization rates from device {index}: {err}"); - return not_available(); - } - }; - let utilization_rate_formatted = format!("{}%", utilization_rates.gpu); - formatted = format!("{formatted} - Usage {utilization_rate_formatted}"); - } - - MetricEntry::new(Self::NAME.to_string(), formatted, raw_running) + let utilization_rates = match device.utilization_rates() { + Ok(rate) => rate, + Err(err) => { + log::warn!("Unable to get utilization rates from device {index}: {err}"); + return not_available(); + } }; + let utilization_rate_formatted = format!("{}%", utilization_rates.gpu); + formatted = format!("{formatted} - Usage {utilization_rate_formatted}"); + } - match &self.nvml { - Some(nvml) => available(nvml), - None => not_available(), - } + MetricEntry::new(Self::NAME.to_string(), formatted, raw_running) + }; + + match &self.nvml { + Some(nvml) => available(nvml), + None => not_available(), } + } - fn clear(&mut self) {} + fn clear(&mut self) {} } diff --git a/burn-train/src/metric/learning_rate.rs b/burn-train/src/metric/learning_rate.rs index c6ca58c018..c7a542dbbc 100644 --- a/burn-train/src/metric/learning_rate.rs +++ b/burn-train/src/metric/learning_rate.rs @@ -1,48 +1,49 @@ use super::{ - state::{FormatOptions, NumericMetricState}, - MetricMetadata, Numeric, + state::{FormatOptions, NumericMetricState}, + MetricMetadata, Numeric, }; use crate::metric::{Metric, MetricEntry}; /// Track the learning rate across iterations. pub struct LearningRateMetric { - state: NumericMetricState, + state: NumericMetricState, } impl LearningRateMetric { - /// Creates a new learning rate metric. - pub fn new() -> Self { - Self { - state: NumericMetricState::new(), - } + /// Creates a new learning rate metric. + pub fn new() -> Self { + Self { + state: NumericMetricState::new(), } + } } impl Default for LearningRateMetric { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl Metric for LearningRateMetric { - const NAME: &'static str = "Learning Rate"; + const NAME: &'static str = "Learning Rate"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry { - let lr = metadata.lr.unwrap_or(0.0); + fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry { + let lr = metadata.lr.unwrap_or(0.0); - self.state - .update(lr, 1, FormatOptions::new("Learning Rate").precision(2)) - } + self + .state + .update(lr, 1, FormatOptions::new("Learning Rate").precision(2)) + } - fn clear(&mut self) { - self.state.reset() - } + fn clear(&mut self) { + self.state.reset() + } } impl Numeric for LearningRateMetric { - fn value(&self) -> f64 { - self.state.value() - } + fn value(&self) -> f64 { + self.state.value() + } } diff --git a/burn-train/src/metric/loss.rs b/burn-train/src/metric/loss.rs index 62ed71d816..877cc7ed74 100644 --- a/burn-train/src/metric/loss.rs +++ b/burn-train/src/metric/loss.rs @@ -10,42 +10,43 @@ use burn_core::tensor::Tensor; /// The loss metric. #[derive(Default)] pub struct LossMetric { - state: NumericMetricState, - _b: B, + state: NumericMetricState, + _b: B, } /// The [loss metric](LossMetric) input type. #[derive(new)] pub struct LossInput { - tensor: Tensor, + tensor: Tensor, } impl LossMetric { - /// Create the metric. - pub fn new() -> Self { - Self::default() - } + /// Create the metric. + pub fn new() -> Self { + Self::default() + } } impl Metric for LossMetric { - const NAME: &'static str = "Loss"; + const NAME: &'static str = "Loss"; - type Input = LossInput; + type Input = LossInput; - fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]); + fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]); - self.state - .update(loss, 1, FormatOptions::new(Self::NAME).precision(2)) - } + self + .state + .update(loss, 1, FormatOptions::new(Self::NAME).precision(2)) + } - fn clear(&mut self) { - self.state.reset() - } + fn clear(&mut self) { + self.state.reset() + } } impl Numeric for LossMetric { - fn value(&self) -> f64 { - self.state.value() - } + fn value(&self) -> f64 { + self.state.value() + } } diff --git a/burn-train/src/metric/memory_use.rs b/burn-train/src/metric/memory_use.rs index 832c910f69..72c85285ea 100644 --- a/burn-train/src/metric/memory_use.rs +++ b/burn-train/src/metric/memory_use.rs @@ -6,74 +6,74 @@ use sysinfo::{System, SystemExt}; /// Memory information pub struct CpuMemory { - last_refresh: Instant, - refresh_frequency: Duration, - sys: System, - ram_bytes_total: u64, - ram_bytes_used: u64, + last_refresh: Instant, + refresh_frequency: Duration, + sys: System, + ram_bytes_total: u64, + ram_bytes_used: u64, } impl CpuMemory { - /// Creates a new memory metric - pub fn new() -> Self { - let mut metric = Self { - last_refresh: Instant::now(), - refresh_frequency: Duration::from_millis(200), - sys: System::new(), - ram_bytes_total: 0, - ram_bytes_used: 0, - }; - metric.refresh(); - metric - } + /// Creates a new memory metric + pub fn new() -> Self { + let mut metric = Self { + last_refresh: Instant::now(), + refresh_frequency: Duration::from_millis(200), + sys: System::new(), + ram_bytes_total: 0, + ram_bytes_used: 0, + }; + metric.refresh(); + metric + } - fn refresh(&mut self) { - self.sys.refresh_memory(); - self.last_refresh = Instant::now(); + fn refresh(&mut self) { + self.sys.refresh_memory(); + self.last_refresh = Instant::now(); - // bytes of RAM available - self.ram_bytes_total = self.sys.total_memory(); + // bytes of RAM available + self.ram_bytes_total = self.sys.total_memory(); - // bytes of RAM in use - self.ram_bytes_used = self.sys.used_memory(); - } + // bytes of RAM in use + self.ram_bytes_used = self.sys.used_memory(); + } } impl Default for CpuMemory { - fn default() -> Self { - CpuMemory::new() - } + fn default() -> Self { + CpuMemory::new() + } } impl Metric for CpuMemory { - const NAME: &'static str = "CPU Memory"; + const NAME: &'static str = "CPU Memory"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - if self.last_refresh.elapsed() >= self.refresh_frequency { - self.refresh(); - } + fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + if self.last_refresh.elapsed() >= self.refresh_frequency { + self.refresh(); + } - let raw = bytes2gb(self.ram_bytes_used); - let formatted = format!( - "RAM Used: {:.2} / {:.2} Gb", - raw, - bytes2gb(self.ram_bytes_total), - ); + let raw = bytes2gb(self.ram_bytes_used); + let formatted = format!( + "RAM Used: {:.2} / {:.2} Gb", + raw, + bytes2gb(self.ram_bytes_total), + ); - MetricEntry::new(Self::NAME.to_string(), formatted, raw.to_string()) - } + MetricEntry::new(Self::NAME.to_string(), formatted, raw.to_string()) + } - fn clear(&mut self) {} + fn clear(&mut self) {} } impl Numeric for CpuMemory { - fn value(&self) -> f64 { - bytes2gb(self.ram_bytes_used) - } + fn value(&self) -> f64 { + bytes2gb(self.ram_bytes_used) + } } fn bytes2gb(bytes: u64) -> f64 { - bytes as f64 / 1e9 + bytes as f64 / 1e9 } diff --git a/burn-train/src/metric/processor/base.rs b/burn-train/src/metric/processor/base.rs index 9093d26457..d82d246c50 100644 --- a/burn-train/src/metric/processor/base.rs +++ b/burn-train/src/metric/processor/base.rs @@ -3,43 +3,43 @@ use burn_core::LearningRate; /// Event happening during the training/validation process. pub enum Event { - /// Signal that an item have been processed. - ProcessedItem(LearnerItem), - /// Signal the end of an epoch. - EndEpoch(usize), + /// Signal that an item have been processed. + ProcessedItem(LearnerItem), + /// Signal the end of an epoch. + EndEpoch(usize), } /// Process events happening during training and validation. pub trait EventProcessor { - /// The training item. - type ItemTrain; - /// The validation item. - type ItemValid; - - /// Collect a training event. - fn process_train(&mut self, event: Event); - /// Collect a validation event. - fn process_valid(&mut self, event: Event); + /// The training item. + type ItemTrain; + /// The validation item. + type ItemValid; + + /// Collect a training event. + fn process_train(&mut self, event: Event); + /// Collect a validation event. + fn process_valid(&mut self, event: Event); } /// A learner item. #[derive(new)] pub struct LearnerItem { - /// The item. - pub item: T, + /// The item. + pub item: T, - /// The progress. - pub progress: Progress, + /// The progress. + pub progress: Progress, - /// The epoch. - pub epoch: usize, + /// The epoch. + pub epoch: usize, - /// The total number of epochs. - pub epoch_total: usize, + /// The total number of epochs. + pub epoch_total: usize, - /// The iteration. - pub iteration: usize, + /// The iteration. + pub iteration: usize, - /// The learning rate. - pub lr: Option, + /// The learning rate. + pub lr: Option, } diff --git a/burn-train/src/metric/processor/full.rs b/burn-train/src/metric/processor/full.rs index b25870dfb4..d392932692 100644 --- a/burn-train/src/metric/processor/full.rs +++ b/burn-train/src/metric/processor/full.rs @@ -7,94 +7,100 @@ use std::sync::Arc; /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). /// - Render metrics using a [metrics renderer](MetricsRenderer). pub struct FullEventProcessor { - metrics: Metrics, - renderer: Box, - store: Arc, + metrics: Metrics, + renderer: Box, + store: Arc, } impl FullEventProcessor { - pub(crate) fn new( - metrics: Metrics, - renderer: Box, - store: Arc, - ) -> Self { - Self { - metrics, - renderer, - store, - } + pub(crate) fn new( + metrics: Metrics, + renderer: Box, + store: Arc, + ) -> Self { + Self { + metrics, + renderer, + store, } + } } impl EventProcessor for FullEventProcessor { - type ItemTrain = T; - type ItemValid = V; + type ItemTrain = T; + type ItemValid = V; - fn process_train(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let progress = (&item).into(); - let metadata = (&item).into(); + fn process_train(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let progress = (&item).into(); + let metadata = (&item).into(); - let update = self.metrics.update_train(&item, &metadata); + let update = self.metrics.update_train(&item, &metadata); - self.store - .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); + self + .store + .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); - update - .entries - .into_iter() - .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); - update - .entries_numeric - .into_iter() - .for_each(|(entry, value)| { - self.renderer - .update_train(MetricState::Numeric(entry, value)) - }); + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self + .renderer + .update_train(MetricState::Numeric(entry, value)) + }); - self.renderer.render_train(progress); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_train(); - self.store - .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); - } - } + self.renderer.render_train(progress); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_train(); + self + .store + .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); + } } + } - fn process_valid(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let progress = (&item).into(); - let metadata = (&item).into(); + fn process_valid(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let progress = (&item).into(); + let metadata = (&item).into(); - let update = self.metrics.update_valid(&item, &metadata); + let update = self.metrics.update_valid(&item, &metadata); - self.store - .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); + self + .store + .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); - update - .entries - .into_iter() - .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); - update - .entries_numeric - .into_iter() - .for_each(|(entry, value)| { - self.renderer - .update_valid(MetricState::Numeric(entry, value)) - }); + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self + .renderer + .update_valid(MetricState::Numeric(entry, value)) + }); - self.renderer.render_valid(progress); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_valid(); - self.store - .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); - } - } + self.renderer.render_valid(progress); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_valid(); + self + .store + .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); + } } + } } diff --git a/burn-train/src/metric/processor/metrics.rs b/burn-train/src/metric/processor/metrics.rs index e2992f12b0..b5c50d27be 100644 --- a/burn-train/src/metric/processor/metrics.rs +++ b/burn-train/src/metric/processor/metrics.rs @@ -1,200 +1,196 @@ use super::LearnerItem; use crate::{ - metric::{store::MetricsUpdate, Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, - renderer::TrainingProgress, + metric::{store::MetricsUpdate, Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, + renderer::TrainingProgress, }; pub(crate) struct Metrics { - train: Vec>>, - valid: Vec>>, - train_numeric: Vec>>, - valid_numeric: Vec>>, + train: Vec>>, + valid: Vec>>, + train_numeric: Vec>>, + valid_numeric: Vec>>, } impl Default for Metrics { - fn default() -> Self { - Self { - train: Vec::default(), - valid: Vec::default(), - train_numeric: Vec::default(), - valid_numeric: Vec::default(), - } + fn default() -> Self { + Self { + train: Vec::default(), + valid: Vec::default(), + train_numeric: Vec::default(), + valid_numeric: Vec::default(), } + } } impl Metrics { - /// Register a training metric. - pub(crate) fn register_metric_train(&mut self, metric: Me) - where - T: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.train.push(Box::new(metric)) + /// Register a training metric. + pub(crate) fn register_metric_train(&mut self, metric: Me) + where + T: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.train.push(Box::new(metric)) + } + + /// Register a validation metric. + pub(crate) fn register_valid_metric(&mut self, metric: Me) + where + V: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.valid.push(Box::new(metric)) + } + + /// Register a numeric training metric. + pub(crate) fn register_train_metric_numeric(&mut self, metric: Me) + where + T: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.train_numeric.push(Box::new(metric)) + } + + /// Register a numeric validation metric. + pub(crate) fn register_valid_metric_numeric(&mut self, metric: Me) + where + V: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.valid_numeric.push(Box::new(metric)) + } + + /// Update the training information from the training item. + pub(crate) fn update_train( + &mut self, + item: &LearnerItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.train.len()); + let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); + + for metric in self.train.iter_mut() { + let state = metric.update(item, metadata); + entries.push(state); } - /// Register a validation metric. - pub(crate) fn register_valid_metric(&mut self, metric: Me) - where - V: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.valid.push(Box::new(metric)) + for metric in self.train_numeric.iter_mut() { + let (state, value) = metric.update(item, metadata); + entries_numeric.push((state, value)); } - /// Register a numeric training metric. - pub(crate) fn register_train_metric_numeric( - &mut self, - metric: Me, - ) where - T: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.train_numeric.push(Box::new(metric)) + MetricsUpdate::new(entries, entries_numeric) + } + + /// Update the training information from the validation item. + pub(crate) fn update_valid( + &mut self, + item: &LearnerItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.valid.len()); + let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); + + for metric in self.valid.iter_mut() { + let state = metric.update(item, metadata); + entries.push(state); } - /// Register a numeric validation metric. - pub(crate) fn register_valid_metric_numeric( - &mut self, - metric: Me, - ) where - V: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.valid_numeric.push(Box::new(metric)) + for metric in self.valid_numeric.iter_mut() { + let (state, value) = metric.update(item, metadata); + entries_numeric.push((state, value)); } - /// Update the training information from the training item. - pub(crate) fn update_train( - &mut self, - item: &LearnerItem, - metadata: &MetricMetadata, - ) -> MetricsUpdate { - let mut entries = Vec::with_capacity(self.train.len()); - let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); - - for metric in self.train.iter_mut() { - let state = metric.update(item, metadata); - entries.push(state); - } - - for metric in self.train_numeric.iter_mut() { - let (state, value) = metric.update(item, metadata); - entries_numeric.push((state, value)); - } - - MetricsUpdate::new(entries, entries_numeric) - } + MetricsUpdate::new(entries, entries_numeric) + } - /// Update the training information from the validation item. - pub(crate) fn update_valid( - &mut self, - item: &LearnerItem, - metadata: &MetricMetadata, - ) -> MetricsUpdate { - let mut entries = Vec::with_capacity(self.valid.len()); - let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); - - for metric in self.valid.iter_mut() { - let state = metric.update(item, metadata); - entries.push(state); - } - - for metric in self.valid_numeric.iter_mut() { - let (state, value) = metric.update(item, metadata); - entries_numeric.push((state, value)); - } - - MetricsUpdate::new(entries, entries_numeric) + /// Signal the end of a training epoch. + pub(crate) fn end_epoch_train(&mut self) { + for metric in self.train.iter_mut() { + metric.clear(); } - - /// Signal the end of a training epoch. - pub(crate) fn end_epoch_train(&mut self) { - for metric in self.train.iter_mut() { - metric.clear(); - } - for metric in self.train_numeric.iter_mut() { - metric.clear(); - } + for metric in self.train_numeric.iter_mut() { + metric.clear(); } + } - /// Signal the end of a validation epoch. - pub(crate) fn end_epoch_valid(&mut self) { - for metric in self.valid.iter_mut() { - metric.clear(); - } - for metric in self.valid_numeric.iter_mut() { - metric.clear(); - } + /// Signal the end of a validation epoch. + pub(crate) fn end_epoch_valid(&mut self) { + for metric in self.valid.iter_mut() { + metric.clear(); } + for metric in self.valid_numeric.iter_mut() { + metric.clear(); + } + } } impl From<&LearnerItem> for TrainingProgress { - fn from(item: &LearnerItem) -> Self { - Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, - } + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, } + } } impl From<&LearnerItem> for MetricMetadata { - fn from(item: &LearnerItem) -> Self { - Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, - lr: item.lr, - } + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, + lr: item.lr, } + } } trait NumericMetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64); - fn clear(&mut self); + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64); + fn clear(&mut self); } trait MetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry; - fn clear(&mut self); + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry; + fn clear(&mut self); } #[derive(new)] struct MetricWrapper { - metric: M, + metric: M, } impl NumericMetricUpdater for MetricWrapper where - T: 'static, - M: Metric + Numeric + 'static, - T: Adaptor, + T: 'static, + M: Metric + Numeric + 'static, + T: Adaptor, { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64) { - let update = self.metric.update(&item.item.adapt(), metadata); - let numeric = self.metric.value(); + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64) { + let update = self.metric.update(&item.item.adapt(), metadata); + let numeric = self.metric.value(); - (update, numeric) - } + (update, numeric) + } - fn clear(&mut self) { - self.metric.clear() - } + fn clear(&mut self) { + self.metric.clear() + } } impl MetricUpdater for MetricWrapper where - T: 'static, - M: Metric + 'static, - T: Adaptor, + T: 'static, + M: Metric + 'static, + T: Adaptor, { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry { - self.metric.update(&item.item.adapt(), metadata) - } + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry { + self.metric.update(&item.item.adapt(), metadata) + } - fn clear(&mut self) { - self.metric.clear() - } + fn clear(&mut self) { + self.metric.clear() + } } diff --git a/burn-train/src/metric/processor/minimal.rs b/burn-train/src/metric/processor/minimal.rs index bb60713e45..7350ca6e18 100644 --- a/burn-train/src/metric/processor/minimal.rs +++ b/burn-train/src/metric/processor/minimal.rs @@ -6,47 +6,51 @@ use std::sync::Arc; /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). #[derive(new)] pub(crate) struct MinimalEventProcessor { - metrics: Metrics, - store: Arc, + metrics: Metrics, + store: Arc, } impl EventProcessor for MinimalEventProcessor { - type ItemTrain = T; - type ItemValid = V; - - fn process_train(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let metadata = (&item).into(); - - let update = self.metrics.update_train(&item, &metadata); - - self.store - .add_event_train(crate::metric::store::Event::MetricsUpdate(update)); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_train(); - self.store - .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); - } - } + type ItemTrain = T; + type ItemValid = V; + + fn process_train(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let metadata = (&item).into(); + + let update = self.metrics.update_train(&item, &metadata); + + self + .store + .add_event_train(crate::metric::store::Event::MetricsUpdate(update)); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_train(); + self + .store + .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); + } } - - fn process_valid(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let metadata = (&item).into(); - - let update = self.metrics.update_valid(&item, &metadata); - - self.store - .add_event_valid(crate::metric::store::Event::MetricsUpdate(update)); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_valid(); - self.store - .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); - } - } + } + + fn process_valid(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let metadata = (&item).into(); + + let update = self.metrics.update_valid(&item, &metadata); + + self + .store + .add_event_valid(crate::metric::store::Event::MetricsUpdate(update)); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_valid(); + self + .store + .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); + } } + } } diff --git a/burn-train/src/metric/processor/mod.rs b/burn-train/src/metric/processor/mod.rs index f889894098..532ad59620 100644 --- a/burn-train/src/metric/processor/mod.rs +++ b/burn-train/src/metric/processor/mod.rs @@ -12,42 +12,42 @@ pub(crate) use minimal::*; #[cfg(test)] pub(crate) mod test_utils { - use crate::metric::{ - processor::{Event, EventProcessor, LearnerItem, MinimalEventProcessor}, - Adaptor, LossInput, - }; - use burn_core::tensor::{backend::Backend, ElementConversion, Tensor}; + use crate::metric::{ + processor::{Event, EventProcessor, LearnerItem, MinimalEventProcessor}, + Adaptor, LossInput, + }; + use burn_core::tensor::{backend::Backend, ElementConversion, Tensor}; - impl Adaptor> for f64 { - fn adapt(&self) -> LossInput { - LossInput::new(Tensor::from_data([self.elem()])) - } + impl Adaptor> for f64 { + fn adapt(&self) -> LossInput { + LossInput::new(Tensor::from_data([self.elem()])) } + } - pub(crate) fn process_train( - processor: &mut MinimalEventProcessor, - value: f64, - epoch: usize, - ) { - let dummy_progress = burn_core::data::dataloader::Progress { - items_processed: 1, - items_total: 10, - }; - let num_epochs = 3; - let dummy_iteration = 1; + pub(crate) fn process_train( + processor: &mut MinimalEventProcessor, + value: f64, + epoch: usize, + ) { + let dummy_progress = burn_core::data::dataloader::Progress { + items_processed: 1, + items_total: 10, + }; + let num_epochs = 3; + let dummy_iteration = 1; - processor.process_train(Event::ProcessedItem(LearnerItem::new( - value, - dummy_progress, - epoch, - num_epochs, - dummy_iteration, - None, - ))); - } + processor.process_train(Event::ProcessedItem(LearnerItem::new( + value, + dummy_progress, + epoch, + num_epochs, + dummy_iteration, + None, + ))); + } - pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor, epoch: usize) { - processor.process_train(Event::EndEpoch(epoch)); - processor.process_valid(Event::EndEpoch(epoch)); - } + pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor, epoch: usize) { + processor.process_train(Event::EndEpoch(epoch)); + processor.process_valid(Event::EndEpoch(epoch)); + } } diff --git a/burn-train/src/metric/state.rs b/burn-train/src/metric/state.rs index 9a188198dc..db8b887dfc 100644 --- a/burn-train/src/metric/state.rs +++ b/burn-train/src/metric/state.rs @@ -7,95 +7,95 @@ use crate::metric::{format_float, MetricEntry, Numeric}; /// The numeric metric store values inside floats. /// Even if some metric are integers, their mean are floats. pub struct NumericMetricState { - sum: f64, - count: usize, - current: f64, + sum: f64, + count: usize, + current: f64, } /// Formatting options for the [numeric metric state](NumericMetricState). pub struct FormatOptions { - name: String, - unit: Option, - precision: Option, + name: String, + unit: Option, + precision: Option, } impl FormatOptions { - /// Create the [formatting options](FormatOptions) with a name. - pub fn new(name: &str) -> Self { - Self { - name: name.to_string(), - unit: None, - precision: None, - } + /// Create the [formatting options](FormatOptions) with a name. + pub fn new(name: &str) -> Self { + Self { + name: name.to_string(), + unit: None, + precision: None, } + } - /// Specify the metric unit. - pub fn unit(mut self, unit: &str) -> Self { - self.unit = Some(unit.to_string()); - self - } + /// Specify the metric unit. + pub fn unit(mut self, unit: &str) -> Self { + self.unit = Some(unit.to_string()); + self + } - /// Specify the floating point precision. - pub fn precision(mut self, precision: usize) -> Self { - self.precision = Some(precision); - self - } + /// Specify the floating point precision. + pub fn precision(mut self, precision: usize) -> Self { + self.precision = Some(precision); + self + } } impl NumericMetricState { - /// Create a new [numeric metric state](NumericMetricState). - pub fn new() -> Self { - Self { - sum: 0.0, - count: 0, - current: f64::NAN, - } + /// Create a new [numeric metric state](NumericMetricState). + pub fn new() -> Self { + Self { + sum: 0.0, + count: 0, + current: f64::NAN, } + } - /// Reset the state. - pub fn reset(&mut self) { - self.sum = 0.0; - self.count = 0; - self.current = f64::NAN; - } + /// Reset the state. + pub fn reset(&mut self) { + self.sum = 0.0; + self.count = 0; + self.current = f64::NAN; + } - /// Update the state. - pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry { - self.sum += value * batch_size as f64; - self.count += batch_size; - self.current = value; + /// Update the state. + pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry { + self.sum += value * batch_size as f64; + self.count += batch_size; + self.current = value; - let value_current = value; - let value_running = self.sum / self.count as f64; - let serialized = value_current.to_string(); + let value_current = value; + let value_running = self.sum / self.count as f64; + let serialized = value_current.to_string(); - let (formatted_current, formatted_running) = match format.precision { - Some(precision) => ( - format_float(value_current, precision), - format_float(value_running, precision), - ), - None => (format!("{value_current}"), format!("{value_running}")), - }; + let (formatted_current, formatted_running) = match format.precision { + Some(precision) => ( + format_float(value_current, precision), + format_float(value_running, precision), + ), + None => (format!("{value_current}"), format!("{value_running}")), + }; - let formatted = match format.unit { - Some(unit) => { - format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}") - } - None => format!("epoch {formatted_running} - batch {formatted_current}"), - }; + let formatted = match format.unit { + Some(unit) => { + format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}") + } + None => format!("epoch {formatted_running} - batch {formatted_current}"), + }; - MetricEntry::new(format.name, formatted, serialized) - } + MetricEntry::new(format.name, formatted, serialized) + } } impl Numeric for NumericMetricState { - fn value(&self) -> f64 { - self.current - } + fn value(&self) -> f64 { + self.current + } } impl Default for NumericMetricState { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } diff --git a/burn-train/src/metric/store/aggregate.rs b/burn-train/src/metric/store/aggregate.rs index 679f6fa22e..2579f23bda 100644 --- a/burn-train/src/metric/store/aggregate.rs +++ b/burn-train/src/metric/store/aggregate.rs @@ -6,157 +6,157 @@ use super::{Aggregate, Direction}; /// Type that can be used to fetch and use numeric metric aggregates. #[derive(Default, Debug)] pub(crate) struct NumericMetricsAggregate { - value_for_each_epoch: HashMap, + value_for_each_epoch: HashMap, } #[derive(new, Hash, PartialEq, Eq, Debug)] struct Key { - name: String, - epoch: usize, - aggregate: Aggregate, + name: String, + epoch: usize, + aggregate: Aggregate, } impl NumericMetricsAggregate { - pub(crate) fn aggregate( - &mut self, - name: &str, - epoch: usize, - aggregate: Aggregate, - loggers: &mut [Box], - ) -> Option { - let key = Key::new(name.to_string(), epoch, aggregate); - - if let Some(value) = self.value_for_each_epoch.get(&key) { - return Some(*value); - } + pub(crate) fn aggregate( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + loggers: &mut [Box], + ) -> Option { + let key = Key::new(name.to_string(), epoch, aggregate); - let points = || { - let mut errors = Vec::new(); - for logger in loggers { - match logger.read_numeric(name, epoch) { - Ok(points) => return Ok(points), - Err(err) => errors.push(err), - }; - } + if let Some(value) = self.value_for_each_epoch.get(&key) { + return Some(*value); + } - Err(errors.join(" ")) + let points = || { + let mut errors = Vec::new(); + for logger in loggers { + match logger.read_numeric(name, epoch) { + Ok(points) => return Ok(points), + Err(err) => errors.push(err), }; + } - let points = points().expect("Can read values"); + Err(errors.join(" ")) + }; - if points.is_empty() { - return None; - } - - let num_points = points.len(); - let sum = points.into_iter().sum::(); - let value = match aggregate { - Aggregate::Mean => sum / num_points as f64, - }; + let points = points().expect("Can read values"); - self.value_for_each_epoch.insert(key, value); - Some(value) + if points.is_empty() { + return None; } - pub(crate) fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - loggers: &mut [Box], - ) -> Option { - let mut data = Vec::new(); - let mut current_epoch = 1; - - while let Some(value) = self.aggregate(name, current_epoch, aggregate, loggers) { - data.push(value); - current_epoch += 1; - } + let num_points = points.len(); + let sum = points.into_iter().sum::(); + let value = match aggregate { + Aggregate::Mean => sum / num_points as f64, + }; - if data.is_empty() { - return None; - } + self.value_for_each_epoch.insert(key, value); + Some(value) + } - let mut current_value = match &direction { - Direction::Lowest => f64::MAX, - Direction::Highest => f64::MIN, - }; + pub(crate) fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + loggers: &mut [Box], + ) -> Option { + let mut data = Vec::new(); + let mut current_epoch = 1; + + while let Some(value) = self.aggregate(name, current_epoch, aggregate, loggers) { + data.push(value); + current_epoch += 1; + } - for (i, value) in data.into_iter().enumerate() { - match &direction { - Direction::Lowest => { - if value < current_value { - current_value = value; - current_epoch = i + 1; - } - } - Direction::Highest => { - if value > current_value { - current_value = value; - current_epoch = i + 1; - } - } - } - } + if data.is_empty() { + return None; + } - Some(current_epoch) + let mut current_value = match &direction { + Direction::Lowest => f64::MAX, + Direction::Highest => f64::MIN, + }; + + for (i, value) in data.into_iter().enumerate() { + match &direction { + Direction::Lowest => { + if value < current_value { + current_value = value; + current_epoch = i + 1; + } + } + Direction::Highest => { + if value > current_value { + current_value = value; + current_epoch = i + 1; + } + } + } } + + Some(current_epoch) + } } #[cfg(test)] mod tests { - use crate::{logger::FileMetricLogger, metric::MetricEntry}; + use crate::{logger::FileMetricLogger, metric::MetricEntry}; - use super::*; + use super::*; - struct TestLogger { - logger: FileMetricLogger, - epoch: usize, + struct TestLogger { + logger: FileMetricLogger, + epoch: usize, + } + const NAME: &str = "test-logger"; + + impl TestLogger { + fn new() -> Self { + Self { + logger: FileMetricLogger::new("/tmp"), + epoch: 1, + } } - const NAME: &str = "test-logger"; - - impl TestLogger { - fn new() -> Self { - Self { - logger: FileMetricLogger::new("/tmp"), - epoch: 1, - } - } - fn log(&mut self, num: f64) { - self.logger.log(&MetricEntry::new( - NAME.into(), - num.to_string(), - num.to_string(), - )); - } - fn new_epoch(&mut self) { - self.logger.end_epoch(self.epoch); - self.epoch += 1; - } + fn log(&mut self, num: f64) { + self.logger.log(&MetricEntry::new( + NAME.into(), + num.to_string(), + num.to_string(), + )); } - - #[test] - fn should_find_epoch() { - let mut logger = TestLogger::new(); - let mut aggregate = NumericMetricsAggregate::default(); - - logger.log(500.); // Epoch 1 - logger.log(1000.); // Epoch 1 - logger.new_epoch(); - logger.log(200.); // Epoch 2 - logger.log(1000.); // Epoch 2 - logger.new_epoch(); - logger.log(10000.); // Epoch 3 - - let value = aggregate - .find_epoch( - NAME, - Aggregate::Mean, - Direction::Lowest, - &mut [Box::new(logger.logger)], - ) - .unwrap(); - - assert_eq!(value, 2); + fn new_epoch(&mut self) { + self.logger.end_epoch(self.epoch); + self.epoch += 1; } + } + + #[test] + fn should_find_epoch() { + let mut logger = TestLogger::new(); + let mut aggregate = NumericMetricsAggregate::default(); + + logger.log(500.); // Epoch 1 + logger.log(1000.); // Epoch 1 + logger.new_epoch(); + logger.log(200.); // Epoch 2 + logger.log(1000.); // Epoch 2 + logger.new_epoch(); + logger.log(10000.); // Epoch 3 + + let value = aggregate + .find_epoch( + NAME, + Aggregate::Mean, + Direction::Lowest, + &mut [Box::new(logger.logger)], + ) + .unwrap(); + + assert_eq!(value, 2); + } } diff --git a/burn-train/src/metric/store/base.rs b/burn-train/src/metric/store/base.rs index 51592a683c..6039d2dcc6 100644 --- a/burn-train/src/metric/store/base.rs +++ b/burn-train/src/metric/store/base.rs @@ -2,68 +2,68 @@ use crate::metric::MetricEntry; /// Event happening during the training/validation process. pub enum Event { - /// Signal that metrics have been updated. - MetricsUpdate(MetricsUpdate), - /// Signal the end of an epoch. - EndEpoch(usize), + /// Signal that metrics have been updated. + MetricsUpdate(MetricsUpdate), + /// Signal the end of an epoch. + EndEpoch(usize), } /// Contains all metric information. #[derive(new, Clone)] pub struct MetricsUpdate { - /// Metrics information related to non-numeric metrics. - pub entries: Vec, - /// Metrics information related to numeric metrics. - pub entries_numeric: Vec<(MetricEntry, f64)>, + /// Metrics information related to non-numeric metrics. + pub entries: Vec, + /// Metrics information related to numeric metrics. + pub entries_numeric: Vec<(MetricEntry, f64)>, } /// Defines how training and validation events are collected and searched. /// /// This trait also exposes methods that uses the collected data to compute useful information. pub trait EventStore: Send { - /// Collect a training/validation event. - fn add_event(&mut self, event: Event, split: Split); + /// Collect a training/validation event. + fn add_event(&mut self, event: Event, split: Split); - /// Find the epoch following the given criteria from the collected data. - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option; + /// Find the epoch following the given criteria from the collected data. + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option; - /// Find the metric value for the current epoch following the given criteria. - fn find_metric( - &mut self, - name: &str, - epoch: usize, - aggregate: Aggregate, - split: Split, - ) -> Option; + /// Find the metric value for the current epoch following the given criteria. + fn find_metric( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option; } #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)] /// How to aggregate the metric. pub enum Aggregate { - /// Compute the average. - Mean, + /// Compute the average. + Mean, } #[derive(Copy, Clone)] /// The split to use. pub enum Split { - /// The training split. - Train, - /// The validation split. - Valid, + /// The training split. + Train, + /// The validation split. + Valid, } #[derive(Copy, Clone)] /// The direction of the query. pub enum Direction { - /// Lower is better. - Lowest, - /// Higher is better. - Highest, + /// Lower is better. + Lowest, + /// Higher is better. + Highest, } diff --git a/burn-train/src/metric/store/client.rs b/burn-train/src/metric/store/client.rs index 74ba83ab74..192f4f2abd 100644 --- a/burn-train/src/metric/store/client.rs +++ b/burn-train/src/metric/store/client.rs @@ -4,156 +4,161 @@ use std::{sync::mpsc, thread::JoinHandle}; /// Type that allows to communicate with an [event store](EventStore). pub struct EventStoreClient { - sender: mpsc::Sender, - handler: Option>, + sender: mpsc::Sender, + handler: Option>, } impl EventStoreClient { - /// Create a new [event store](EventStore) client. - pub(crate) fn new(store: C) -> Self - where - C: EventStore + 'static, - { - let (sender, receiver) = mpsc::channel(); - let thread = WorkerThread::new(store, receiver); + /// Create a new [event store](EventStore) client. + pub(crate) fn new(store: C) -> Self + where + C: EventStore + 'static, + { + let (sender, receiver) = mpsc::channel(); + let thread = WorkerThread::new(store, receiver); - let handler = std::thread::spawn(move || thread.run()); - let handler = Some(handler); + let handler = std::thread::spawn(move || thread.run()); + let handler = Some(handler); - Self { sender, handler } - } + Self { sender, handler } + } } impl EventStoreClient { - /// Add a training event to the [event store](EventStore). - pub(crate) fn add_event_train(&self, event: Event) { - self.sender - .send(Message::OnEventTrain(event)) - .expect("Can send event to event store thread."); - } + /// Add a training event to the [event store](EventStore). + pub(crate) fn add_event_train(&self, event: Event) { + self + .sender + .send(Message::OnEventTrain(event)) + .expect("Can send event to event store thread."); + } - /// Add a validation event to the [event store](EventStore). - pub(crate) fn add_event_valid(&self, event: Event) { - self.sender - .send(Message::OnEventValid(event)) - .expect("Can send event to event store thread."); - } + /// Add a validation event to the [event store](EventStore). + pub(crate) fn add_event_valid(&self, event: Event) { + self + .sender + .send(Message::OnEventValid(event)) + .expect("Can send event to event store thread."); + } - /// Find the epoch following the given criteria from the collected data. - pub fn find_epoch( - &self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - let (sender, receiver) = mpsc::sync_channel(1); - self.sender - .send(Message::FindEpoch( - name.to_string(), - aggregate, - direction, - split, - sender, - )) - .expect("Can send event to event store thread."); + /// Find the epoch following the given criteria from the collected data. + pub fn find_epoch( + &self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + let (sender, receiver) = mpsc::sync_channel(1); + self + .sender + .send(Message::FindEpoch( + name.to_string(), + aggregate, + direction, + split, + sender, + )) + .expect("Can send event to event store thread."); - match receiver.recv() { - Ok(value) => value, - Err(err) => panic!("Event store thread crashed: {:?}", err), - } + match receiver.recv() { + Ok(value) => value, + Err(err) => panic!("Event store thread crashed: {:?}", err), } + } - /// Find the metric value for the current epoch following the given criteria. - pub fn find_metric( - &self, - name: &str, - epoch: usize, - aggregate: Aggregate, - split: Split, - ) -> Option { - let (sender, receiver) = mpsc::sync_channel(1); - self.sender - .send(Message::FindMetric( - name.to_string(), - epoch, - aggregate, - split, - sender, - )) - .expect("Can send event to event store thread."); + /// Find the metric value for the current epoch following the given criteria. + pub fn find_metric( + &self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option { + let (sender, receiver) = mpsc::sync_channel(1); + self + .sender + .send(Message::FindMetric( + name.to_string(), + epoch, + aggregate, + split, + sender, + )) + .expect("Can send event to event store thread."); - match receiver.recv() { - Ok(value) => value, - Err(err) => panic!("Event store thread crashed: {:?}", err), - } + match receiver.recv() { + Ok(value) => value, + Err(err) => panic!("Event store thread crashed: {:?}", err), } + } } #[derive(new)] struct WorkerThread { - store: S, - receiver: mpsc::Receiver, + store: S, + receiver: mpsc::Receiver, } impl WorkerThread where - C: EventStore, + C: EventStore, { - fn run(mut self) { - for item in self.receiver.iter() { - match item { - Message::End => { - return; - } - Message::FindEpoch(name, aggregate, direction, split, callback) => { - let response = self.store.find_epoch(&name, aggregate, direction, split); - callback - .send(response) - .expect("Can send response using callback channel."); - } - Message::FindMetric(name, epoch, aggregate, split, callback) => { - let response = self.store.find_metric(&name, epoch, aggregate, split); - callback - .send(response) - .expect("Can send response using callback channel."); - } - Message::OnEventTrain(event) => self.store.add_event(event, Split::Train), - Message::OnEventValid(event) => self.store.add_event(event, Split::Valid), - } + fn run(mut self) { + for item in self.receiver.iter() { + match item { + Message::End => { + return; + } + Message::FindEpoch(name, aggregate, direction, split, callback) => { + let response = self.store.find_epoch(&name, aggregate, direction, split); + callback + .send(response) + .expect("Can send response using callback channel."); } + Message::FindMetric(name, epoch, aggregate, split, callback) => { + let response = self.store.find_metric(&name, epoch, aggregate, split); + callback + .send(response) + .expect("Can send response using callback channel."); + } + Message::OnEventTrain(event) => self.store.add_event(event, Split::Train), + Message::OnEventValid(event) => self.store.add_event(event, Split::Valid), + } } + } } enum Message { - OnEventTrain(Event), - OnEventValid(Event), - End, - FindEpoch( - String, - Aggregate, - Direction, - Split, - mpsc::SyncSender>, - ), - FindMetric( - String, - usize, - Aggregate, - Split, - mpsc::SyncSender>, - ), + OnEventTrain(Event), + OnEventValid(Event), + End, + FindEpoch( + String, + Aggregate, + Direction, + Split, + mpsc::SyncSender>, + ), + FindMetric( + String, + usize, + Aggregate, + Split, + mpsc::SyncSender>, + ), } impl Drop for EventStoreClient { - fn drop(&mut self) { - self.sender - .send(Message::End) - .expect("Can send the end message to the event store thread."); - let handler = self.handler.take(); + fn drop(&mut self) { + self + .sender + .send(Message::End) + .expect("Can send the end message to the event store thread."); + let handler = self.handler.take(); - if let Some(handler) = handler { - handler.join().expect("The event store thread should stop."); - } + if let Some(handler) = handler { + handler.join().expect("The event store thread should stop."); } + } } diff --git a/burn-train/src/metric/store/log.rs b/burn-train/src/metric/store/log.rs index 9272e32330..c8b88c6b72 100644 --- a/burn-train/src/metric/store/log.rs +++ b/burn-train/src/metric/store/log.rs @@ -3,99 +3,105 @@ use crate::logger::MetricLogger; #[derive(Default)] pub(crate) struct LogEventStore { - loggers_train: Vec>, - loggers_valid: Vec>, - aggregate_train: NumericMetricsAggregate, - aggregate_valid: NumericMetricsAggregate, + loggers_train: Vec>, + loggers_valid: Vec>, + aggregate_train: NumericMetricsAggregate, + aggregate_valid: NumericMetricsAggregate, } impl EventStore for LogEventStore { - fn add_event(&mut self, event: Event, split: Split) { - match event { - Event::MetricsUpdate(update) => match split { - Split::Train => { - update - .entries - .iter() - .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) - .for_each(|entry| { - self.loggers_train - .iter_mut() - .for_each(|logger| logger.log(entry)); - }); - } - Split::Valid => { - update - .entries - .iter() - .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) - .for_each(|entry| { - self.loggers_valid - .iter_mut() - .for_each(|logger| logger.log(entry)); - }); - } - }, - Event::EndEpoch(epoch) => match split { - Split::Train => self - .loggers_train - .iter_mut() - .for_each(|logger| logger.end_epoch(epoch)), - Split::Valid => self - .loggers_valid - .iter_mut() - .for_each(|logger| logger.end_epoch(epoch + 1)), - }, + fn add_event(&mut self, event: Event, split: Split) { + match event { + Event::MetricsUpdate(update) => match split { + Split::Train => { + update + .entries + .iter() + .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) + .for_each(|entry| { + self + .loggers_train + .iter_mut() + .for_each(|logger| logger.log(entry)); + }); } + Split::Valid => { + update + .entries + .iter() + .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) + .for_each(|entry| { + self + .loggers_valid + .iter_mut() + .for_each(|logger| logger.log(entry)); + }); + } + }, + Event::EndEpoch(epoch) => match split { + Split::Train => self + .loggers_train + .iter_mut() + .for_each(|logger| logger.end_epoch(epoch)), + Split::Valid => self + .loggers_valid + .iter_mut() + .for_each(|logger| logger.end_epoch(epoch + 1)), + }, } + } - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - match split { - Split::Train => { - self.aggregate_train - .find_epoch(name, aggregate, direction, &mut self.loggers_train) - } - Split::Valid => { - self.aggregate_valid - .find_epoch(name, aggregate, direction, &mut self.loggers_valid) - } - } + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + match split { + Split::Train => { + self + .aggregate_train + .find_epoch(name, aggregate, direction, &mut self.loggers_train) + } + Split::Valid => { + self + .aggregate_valid + .find_epoch(name, aggregate, direction, &mut self.loggers_valid) + } } + } - fn find_metric( - &mut self, - name: &str, - epoch: usize, - aggregate: Aggregate, - split: Split, - ) -> Option { - match split { - Split::Train => { - self.aggregate_train - .aggregate(name, epoch, aggregate, &mut self.loggers_train) - } - Split::Valid => { - self.aggregate_valid - .aggregate(name, epoch, aggregate, &mut self.loggers_valid) - } - } + fn find_metric( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option { + match split { + Split::Train => { + self + .aggregate_train + .aggregate(name, epoch, aggregate, &mut self.loggers_train) + } + Split::Valid => { + self + .aggregate_valid + .aggregate(name, epoch, aggregate, &mut self.loggers_valid) + } } + } } impl LogEventStore { - /// Register a logger for training metrics. - pub(crate) fn register_logger_train(&mut self, logger: ML) { - self.loggers_train.push(Box::new(logger)); - } + /// Register a logger for training metrics. + pub(crate) fn register_logger_train(&mut self, logger: ML) { + self.loggers_train.push(Box::new(logger)); + } - /// Register a logger for validation metrics. - pub(crate) fn register_logger_valid(&mut self, logger: ML) { - self.loggers_valid.push(Box::new(logger)); - } + /// Register a logger for validation metrics. + pub(crate) fn register_logger_valid(&mut self, logger: ML) { + self.loggers_valid.push(Box::new(logger)); + } } diff --git a/burn-train/src/renderer/base.rs b/burn-train/src/renderer/base.rs index 6cfc2a5eb0..2258c32e46 100644 --- a/burn-train/src/renderer/base.rs +++ b/burn-train/src/renderer/base.rs @@ -4,72 +4,72 @@ use crate::metric::MetricEntry; /// Trait for rendering metrics. pub trait MetricsRenderer: Send + Sync { - /// Updates the training metric state. - /// - /// # Arguments - /// - /// * `state` - The metric state. - fn update_train(&mut self, state: MetricState); + /// Updates the training metric state. + /// + /// # Arguments + /// + /// * `state` - The metric state. + fn update_train(&mut self, state: MetricState); - /// Updates the validation metric state. - /// - /// # Arguments - /// - /// * `state` - The metric state. - fn update_valid(&mut self, state: MetricState); + /// Updates the validation metric state. + /// + /// # Arguments + /// + /// * `state` - The metric state. + fn update_valid(&mut self, state: MetricState); - /// Renders the training progress. - /// - /// # Arguments - /// - /// * `item` - The training progress. - fn render_train(&mut self, item: TrainingProgress); + /// Renders the training progress. + /// + /// # Arguments + /// + /// * `item` - The training progress. + fn render_train(&mut self, item: TrainingProgress); - /// Renders the validation progress. - /// - /// # Arguments - /// - /// * `item` - The validation progress. - fn render_valid(&mut self, item: TrainingProgress); + /// Renders the validation progress. + /// + /// # Arguments + /// + /// * `item` - The validation progress. + fn render_valid(&mut self, item: TrainingProgress); } /// The state of a metric. #[derive(Debug)] pub enum MetricState { - /// A generic metric. - Generic(MetricEntry), + /// A generic metric. + Generic(MetricEntry), - /// A numeric metric. - Numeric(MetricEntry, f64), + /// A numeric metric. + Numeric(MetricEntry, f64), } /// Training progress. #[derive(Debug)] pub struct TrainingProgress { - /// The progress. - pub progress: Progress, + /// The progress. + pub progress: Progress, - /// The epoch. - pub epoch: usize, + /// The epoch. + pub epoch: usize, - /// The total number of epochs. - pub epoch_total: usize, + /// The total number of epochs. + pub epoch_total: usize, - /// The iteration. - pub iteration: usize, + /// The iteration. + pub iteration: usize, } impl TrainingProgress { - /// Creates a new empty training progress. - pub fn none() -> Self { - Self { - progress: Progress { - items_processed: 0, - items_total: 0, - }, - epoch: 0, - epoch_total: 0, - iteration: 0, - } + /// Creates a new empty training progress. + pub fn none() -> Self { + Self { + progress: Progress { + items_processed: 0, + items_total: 0, + }, + epoch: 0, + epoch_total: 0, + iteration: 0, } + } } diff --git a/burn-train/src/renderer/cli.rs b/burn-train/src/renderer/cli.rs index d5a974a51e..1ed3cf3acb 100644 --- a/burn-train/src/renderer/cli.rs +++ b/burn-train/src/renderer/cli.rs @@ -4,22 +4,22 @@ use crate::renderer::{MetricState, MetricsRenderer, TrainingProgress}; pub struct CliMetricsRenderer; impl CliMetricsRenderer { - /// Create a new instance. - pub fn new() -> Self { - Self {} - } + /// Create a new instance. + pub fn new() -> Self { + Self {} + } } impl MetricsRenderer for CliMetricsRenderer { - fn update_train(&mut self, _state: MetricState) {} + fn update_train(&mut self, _state: MetricState) {} - fn update_valid(&mut self, _state: MetricState) {} + fn update_valid(&mut self, _state: MetricState) {} - fn render_train(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_train(&mut self, item: TrainingProgress) { + dbg!(item); + } - fn render_valid(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_valid(&mut self, item: TrainingProgress) { + dbg!(item); + } } diff --git a/burn-train/src/renderer/mod.rs b/burn-train/src/renderer/mod.rs index 9002184326..7e2132bc74 100644 --- a/burn-train/src/renderer/mod.rs +++ b/burn-train/src/renderer/mod.rs @@ -15,12 +15,12 @@ pub use tui::TuiMetricsRenderer as SelectedMetricsRenderer; /// The TUI renderer, or a simple stub if the tui feature is not enabled. #[allow(unused_variables)] pub(crate) fn default_renderer( - interuptor: TrainingInterrupter, - checkpoint: Option, + interuptor: TrainingInterrupter, + checkpoint: Option, ) -> SelectedMetricsRenderer { - #[cfg(feature = "tui")] - return SelectedMetricsRenderer::new(interuptor, checkpoint); + #[cfg(feature = "tui")] + return SelectedMetricsRenderer::new(interuptor, checkpoint); - #[cfg(not(feature = "tui"))] - return SelectedMetricsRenderer::new(); + #[cfg(not(feature = "tui"))] + return SelectedMetricsRenderer::new(); } diff --git a/burn-train/src/renderer/tui/base.rs b/burn-train/src/renderer/tui/base.rs index 38d6c25b31..04cce37144 100644 --- a/burn-train/src/renderer/tui/base.rs +++ b/burn-train/src/renderer/tui/base.rs @@ -1,45 +1,45 @@ use super::{ - ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView, + ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView, }; use ratatui::prelude::{Constraint, Direction, Layout, Rect}; #[derive(new)] pub(crate) struct MetricsView<'a> { - metric_numeric: NumericMetricView<'a>, - metric_text: TextMetricView, - progress: ProgressBarView, - controls: ControlsView, - status: StatusView, + metric_numeric: NumericMetricView<'a>, + metric_text: TextMetricView, + progress: ProgressBarView, + controls: ControlsView, + status: StatusView, } impl<'a> MetricsView<'a> { - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Min(16), Constraint::Max(3)].as_ref()) - .split(size); - let size_other = chunks[0]; - let size_progress = chunks[1]; + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(16), Constraint::Max(3)].as_ref()) + .split(size); + let size_other = chunks[0]; + let size_progress = chunks[1]; - let chunks = Layout::default() - .direction(Direction::Horizontal) - .constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref()) - .split(size_other); - let size_other = chunks[0]; - let size_metric_numeric = chunks[1]; + let chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref()) + .split(size_other); + let size_other = chunks[0]; + let size_metric_numeric = chunks[1]; - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref()) - .split(size_other); - let size_controls = chunks[0]; - let size_metric_text = chunks[1]; - let size_status = chunks[2]; + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref()) + .split(size_other); + let size_controls = chunks[0]; + let size_metric_text = chunks[1]; + let size_status = chunks[2]; - self.metric_numeric.render(frame, size_metric_numeric); - self.metric_text.render(frame, size_metric_text); - self.controls.render(frame, size_controls); - self.progress.render(frame, size_progress); - self.status.render(frame, size_status); - } + self.metric_numeric.render(frame, size_metric_numeric); + self.metric_text.render(frame, size_metric_text); + self.controls.render(frame, size_controls); + self.progress.render(frame, size_progress); + self.status.render(frame, size_status); + } } diff --git a/burn-train/src/renderer/tui/controls.rs b/burn-train/src/renderer/tui/controls.rs index e48778034b..50ed7491dc 100644 --- a/burn-train/src/renderer/tui/controls.rs +++ b/burn-train/src/renderer/tui/controls.rs @@ -1,46 +1,46 @@ use super::TerminalFrame; use ratatui::{ - prelude::{Alignment, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; /// Controls view. pub(crate) struct ControlsView; impl ControlsView { - /// Render the view. - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let lines = vec![ - vec![ - Span::from(" Quit : ").yellow().bold(), - Span::from("q ").bold(), - Span::from(" Stop the training.").italic(), - ], - vec![ - Span::from(" Plots Metrics : ").yellow().bold(), - Span::from("⬅ ➡").bold(), - Span::from(" Switch between metrics.").italic(), - ], - vec![ - Span::from(" Plots Type : ").yellow().bold(), - Span::from("⬆ ⬇").bold(), - Span::from(" Switch between types.").italic(), - ], - ]; - let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::>()) - .alignment(Alignment::Left) - .wrap(Wrap { trim: false }) - .style(Style::default().fg(Color::Gray)) - .block( - Block::default() - .borders(Borders::ALL) - .style(Style::default().fg(Color::Gray)) - .title_alignment(Alignment::Left) - .title("Controls"), - ); + /// Render the view. + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let lines = vec![ + vec![ + Span::from(" Quit : ").yellow().bold(), + Span::from("q ").bold(), + Span::from(" Stop the training.").italic(), + ], + vec![ + Span::from(" Plots Metrics : ").yellow().bold(), + Span::from("⬅ ➡").bold(), + Span::from(" Switch between metrics.").italic(), + ], + vec![ + Span::from(" Plots Type : ").yellow().bold(), + Span::from("⬆ ⬇").bold(), + Span::from(" Switch between types.").italic(), + ], + ]; + let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::>()) + .alignment(Alignment::Left) + .wrap(Wrap { trim: false }) + .style(Style::default().fg(Color::Gray)) + .block( + Block::default() + .borders(Borders::ALL) + .style(Style::default().fg(Color::Gray)) + .title_alignment(Alignment::Left) + .title("Controls"), + ); - frame.render_widget(paragraph, size); - } + frame.render_widget(paragraph, size); + } } diff --git a/burn-train/src/renderer/tui/full_history.rs b/burn-train/src/renderer/tui/full_history.rs index 3c2e4e90e7..d33641ad19 100644 --- a/burn-train/src/renderer/tui/full_history.rs +++ b/burn-train/src/renderer/tui/full_history.rs @@ -1,216 +1,216 @@ use super::PlotAxes; use ratatui::{ - style::{Color, Style, Stylize}, - symbols, - widgets::{Dataset, GraphType}, + style::{Color, Style, Stylize}, + symbols, + widgets::{Dataset, GraphType}, }; /// A plot that shows the full history at a reduced resolution. pub(crate) struct FullHistoryPlot { - pub(crate) axes: PlotAxes, - train: FullHistoryPoints, - valid: FullHistoryPoints, - next_x_state: usize, + pub(crate) axes: PlotAxes, + train: FullHistoryPoints, + valid: FullHistoryPoints, + next_x_state: usize, } struct FullHistoryPoints { - min_x: f64, - max_x: f64, - min_y: f64, - max_y: f64, - points: Vec<(f64, f64)>, - max_samples: usize, - step_size: usize, + min_x: f64, + max_x: f64, + min_y: f64, + max_y: f64, + points: Vec<(f64, f64)>, + max_samples: usize, + step_size: usize, } impl FullHistoryPlot { - /// Create a new history plot. - pub(crate) fn new(max_samples: usize) -> Self { - Self { - axes: PlotAxes::default(), - train: FullHistoryPoints::new(max_samples), - valid: FullHistoryPoints::new(max_samples), - next_x_state: 0, - } + /// Create a new history plot. + pub(crate) fn new(max_samples: usize) -> Self { + Self { + axes: PlotAxes::default(), + train: FullHistoryPoints::new(max_samples), + valid: FullHistoryPoints::new(max_samples), + next_x_state: 0, } - - /// Update the maximum amount of sample to display for the validation points. - /// - /// This is necessary if we want the validation line to have the same point density as the - /// training line. - pub(crate) fn update_max_sample_valid(&mut self, ratio_train: f64) { - if self.valid.step_size == 1 { - self.valid.max_samples = (ratio_train * self.train.max_samples as f64) as usize; - } - } - - /// Register a training data point. - pub(crate) fn push_train(&mut self, data: f64) { - let x_current = self.next_x(); - self.train.push((x_current, data)); - - self.update_bounds(); + } + + /// Update the maximum amount of sample to display for the validation points. + /// + /// This is necessary if we want the validation line to have the same point density as the + /// training line. + pub(crate) fn update_max_sample_valid(&mut self, ratio_train: f64) { + if self.valid.step_size == 1 { + self.valid.max_samples = (ratio_train * self.train.max_samples as f64) as usize; } + } - /// Register a validation data point. - pub(crate) fn push_valid(&mut self, data: f64) { - let x_current = self.next_x(); + /// Register a training data point. + pub(crate) fn push_train(&mut self, data: f64) { + let x_current = self.next_x(); + self.train.push((x_current, data)); - self.valid.push((x_current, data)); + self.update_bounds(); + } - self.update_bounds(); - } + /// Register a validation data point. + pub(crate) fn push_valid(&mut self, data: f64) { + let x_current = self.next_x(); - /// Create the training and validation datasets from the data points. - pub(crate) fn datasets(&self) -> Vec> { - let mut datasets = Vec::with_capacity(2); + self.valid.push((x_current, data)); - if !self.train.is_empty() { - datasets.push(self.train.dataset("Train", Color::LightRed)); - } + self.update_bounds(); + } - if !self.valid.is_empty() { - datasets.push(self.valid.dataset("Valid", Color::LightBlue)); - } + /// Create the training and validation datasets from the data points. + pub(crate) fn datasets(&self) -> Vec> { + let mut datasets = Vec::with_capacity(2); - datasets + if !self.train.is_empty() { + datasets.push(self.train.dataset("Train", Color::LightRed)); } - fn next_x(&mut self) -> f64 { - let value = self.next_x_state; - self.next_x_state += 1; - value as f64 + if !self.valid.is_empty() { + datasets.push(self.valid.dataset("Valid", Color::LightBlue)); } - fn update_bounds(&mut self) { - self.axes.update_bounds( - (self.train.min_x, self.train.max_x), - (self.valid.min_x, self.valid.max_x), - (self.train.min_y, self.train.max_y), - (self.valid.min_y, self.valid.max_y), - ); - } + datasets + } + + fn next_x(&mut self) -> f64 { + let value = self.next_x_state; + self.next_x_state += 1; + value as f64 + } + + fn update_bounds(&mut self) { + self.axes.update_bounds( + (self.train.min_x, self.train.max_x), + (self.valid.min_x, self.valid.max_x), + (self.train.min_y, self.train.max_y), + (self.valid.min_y, self.valid.max_y), + ); + } } impl FullHistoryPoints { - fn new(max_samples: usize) -> Self { - Self { - min_x: 0., - max_x: 0., - min_y: f64::MAX, - max_y: f64::MIN, - points: Vec::with_capacity(max_samples), - max_samples, - step_size: 1, - } + fn new(max_samples: usize) -> Self { + Self { + min_x: 0., + max_x: 0., + min_y: f64::MAX, + max_y: f64::MIN, + points: Vec::with_capacity(max_samples), + max_samples, + step_size: 1, } + } - fn push(&mut self, (x, y): (f64, f64)) { - if x as usize % self.step_size != 0 { - return; - } - - if x > self.max_x { - self.max_x = x; - } - if x < self.min_x { - self.min_x = x; - } - if y > self.max_y { - self.max_y = y; - } - if y < self.min_y { - self.min_y = y - } - - self.points.push((x, y)); - - if self.points.len() > self.max_samples { - self.resize(); - } + fn push(&mut self, (x, y): (f64, f64)) { + if x as usize % self.step_size != 0 { + return; } - /// We keep only half the points and we double the step size. - /// - /// This ensure that we have the same amount of points across the X axis. - fn resize(&mut self) { - let mut points = Vec::with_capacity(self.max_samples / 2); - let mut max_x = f64::MIN; - let mut max_y = f64::MIN; - let mut min_x = f64::MAX; - let mut min_y = f64::MAX; - - for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() { - if i % 2 == 0 { - if x > max_x { - max_x = x; - } - if x < min_x { - min_x = x; - } - if y > max_y { - max_y = y; - } - if y < min_y { - min_y = y; - } - - points.push((x, y)); - } - } + if x > self.max_x { + self.max_x = x; + } + if x < self.min_x { + self.min_x = x; + } + if y > self.max_y { + self.max_y = y; + } + if y < self.min_y { + self.min_y = y + } - self.points = points; - self.step_size *= 2; + self.points.push((x, y)); - self.min_x = min_x; - self.max_x = max_x; - self.min_y = min_y; - self.max_y = max_y; + if self.points.len() > self.max_samples { + self.resize(); } + } + + /// We keep only half the points and we double the step size. + /// + /// This ensure that we have the same amount of points across the X axis. + fn resize(&mut self) { + let mut points = Vec::with_capacity(self.max_samples / 2); + let mut max_x = f64::MIN; + let mut max_y = f64::MIN; + let mut min_x = f64::MAX; + let mut min_y = f64::MAX; + + for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() { + if i % 2 == 0 { + if x > max_x { + max_x = x; + } + if x < min_x { + min_x = x; + } + if y > max_y { + max_y = y; + } + if y < min_y { + min_y = y; + } - fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { - Dataset::default() - .name(name) - .marker(symbols::Marker::Braille) - .style(Style::default().fg(color).bold()) - .graph_type(GraphType::Line) - .data(&self.points) + points.push((x, y)); + } } - fn is_empty(&self) -> bool { - self.points.is_empty() - } + self.points = points; + self.step_size *= 2; + + self.min_x = min_x; + self.max_x = max_x; + self.min_y = min_y; + self.max_y = max_y; + } + + fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { + Dataset::default() + .name(name) + .marker(symbols::Marker::Braille) + .style(Style::default().fg(color).bold()) + .graph_type(GraphType::Line) + .data(&self.points) + } + + fn is_empty(&self) -> bool { + self.points.is_empty() + } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_points() { - let mut chart = FullHistoryPlot::new(10); - chart.update_max_sample_valid(0.6); + use super::*; - for i in 0..100 { - chart.push_train(i as f64); - } - for i in 0..60 { - chart.push_valid(i as f64); - } + #[test] + fn test_points() { + let mut chart = FullHistoryPlot::new(10); + chart.update_max_sample_valid(0.6); - let expected_train = vec![ - (0.0, 0.0), - (16.0, 16.0), - (32.0, 32.0), - (48.0, 48.0), - (64.0, 64.0), - (80.0, 80.0), - (96.0, 96.0), - ]; - - let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)]; - - assert_eq!(chart.train.points, expected_train); - assert_eq!(chart.valid.points, expected_valid); + for i in 0..100 { + chart.push_train(i as f64); } + for i in 0..60 { + chart.push_valid(i as f64); + } + + let expected_train = vec![ + (0.0, 0.0), + (16.0, 16.0), + (32.0, 32.0), + (48.0, 48.0), + (64.0, 64.0), + (80.0, 80.0), + (96.0, 96.0), + ]; + + let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)]; + + assert_eq!(chart.train.points, expected_train); + assert_eq!(chart.valid.points, expected_valid); + } } diff --git a/burn-train/src/renderer/tui/metric_numeric.rs b/burn-train/src/renderer/tui/metric_numeric.rs index ccae8e295c..5a27182147 100644 --- a/burn-train/src/renderer/tui/metric_numeric.rs +++ b/burn-train/src/renderer/tui/metric_numeric.rs @@ -3,10 +3,10 @@ use crate::renderer::TrainingProgress; use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame}; use crossterm::event::{Event, KeyCode}; use ratatui::{ - prelude::{Alignment, Constraint, Direction, Layout, Rect}, - style::{Color, Modifier, Style, Stylize}, - text::Line, - widgets::{Axis, Block, Borders, Chart, Paragraph, Tabs}, + prelude::{Alignment, Constraint, Direction, Layout, Rect}, + style::{Color, Modifier, Style, Stylize}, + text::Line, + widgets::{Axis, Block, Borders, Chart, Paragraph, Tabs}, }; use std::collections::HashMap; @@ -19,214 +19,213 @@ const MAX_NUM_SAMPLES_FULL: usize = 250; /// Numeric metrics state that handles creating plots. #[derive(Default)] pub(crate) struct NumericMetricsState { - data: HashMap, - names: Vec, - selected: usize, - kind: PlotKind, - num_samples_train: Option, - num_samples_valid: Option, + data: HashMap, + names: Vec, + selected: usize, + kind: PlotKind, + num_samples_train: Option, + num_samples_valid: Option, } /// The kind of plot to display. #[derive(Default, Clone, Copy)] pub(crate) enum PlotKind { - /// Display the full history of the metric with reduced resolution. - #[default] - Full, - /// Display only the recent history of the metric, but with more resolution. - Recent, + /// Display the full history of the metric with reduced resolution. + #[default] + Full, + /// Display only the recent history of the metric, but with more resolution. + Recent, } impl NumericMetricsState { - /// Register a new training value for the metric with the given name. - pub(crate) fn push_train(&mut self, name: String, data: f64) { - if let Some((recent, full)) = self.data.get_mut(&name) { - recent.push_train(data); - full.push_train(data); - } else { - let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); - let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); - - recent.push_train(data); - full.push_train(data); - - self.names.push(name.clone()); - self.data.insert(name, (recent, full)); - } + /// Register a new training value for the metric with the given name. + pub(crate) fn push_train(&mut self, name: String, data: f64) { + if let Some((recent, full)) = self.data.get_mut(&name) { + recent.push_train(data); + full.push_train(data); + } else { + let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); + let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); + + recent.push_train(data); + full.push_train(data); + + self.names.push(name.clone()); + self.data.insert(name, (recent, full)); } + } - /// Register a new validation value for the metric with the given name. - pub(crate) fn push_valid(&mut self, key: String, data: f64) { - if let Some((recent, full)) = self.data.get_mut(&key) { - recent.push_valid(data); - full.push_valid(data); - } else { - let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); - let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); + /// Register a new validation value for the metric with the given name. + pub(crate) fn push_valid(&mut self, key: String, data: f64) { + if let Some((recent, full)) = self.data.get_mut(&key) { + recent.push_valid(data); + full.push_valid(data); + } else { + let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); + let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); - recent.push_valid(data); - full.push_valid(data); + recent.push_valid(data); + full.push_valid(data); - self.data.insert(key, (recent, full)); - } + self.data.insert(key, (recent, full)); } + } - /// Update the state with the training progress. - pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) { - if self.num_samples_train.is_some() { - return; - } - - self.num_samples_train = Some(progress.progress.items_total); + /// Update the state with the training progress. + pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) { + if self.num_samples_train.is_some() { + return; } - /// Update the state with the validation progress. - pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) { - if self.num_samples_valid.is_some() { - return; - } - - if let Some(num_sample_train) = self.num_samples_train { - for (_, (_recent, full)) in self.data.iter_mut() { - let ratio = progress.progress.items_total as f64 / num_sample_train as f64; - full.update_max_sample_valid(ratio); - } - } + self.num_samples_train = Some(progress.progress.items_total); + } - self.num_samples_valid = Some(progress.progress.items_total); + /// Update the state with the validation progress. + pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) { + if self.num_samples_valid.is_some() { + return; } - /// Create a view to display the numeric metrics. - pub(crate) fn view(&self) -> NumericMetricView<'_> { - match self.names.is_empty() { - true => NumericMetricView::None, - false => NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind), - } + if let Some(num_sample_train) = self.num_samples_train { + for (_, (_recent, full)) in self.data.iter_mut() { + let ratio = progress.progress.items_total as f64 / num_sample_train as f64; + full.update_max_sample_valid(ratio); + } } - /// Handle the current event. - pub(crate) fn on_event(&mut self, event: &Event) { - if let Event::Key(key) = event { - match key.code { - KeyCode::Right => self.next_metric(), - KeyCode::Left => self.previous_metric(), - KeyCode::Up => self.switch_kind(), - KeyCode::Down => self.switch_kind(), - _ => {} - } - } - } + self.num_samples_valid = Some(progress.progress.items_total); + } - fn switch_kind(&mut self) { - self.kind = match self.kind { - PlotKind::Full => PlotKind::Recent, - PlotKind::Recent => PlotKind::Full, - }; + /// Create a view to display the numeric metrics. + pub(crate) fn view(&self) -> NumericMetricView<'_> { + match self.names.is_empty() { + true => NumericMetricView::None, + false => NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind), } - - fn next_metric(&mut self) { - self.selected = (self.selected + 1) % { - let this = &self; - this.data.len() - }; + } + + /// Handle the current event. + pub(crate) fn on_event(&mut self, event: &Event) { + if let Event::Key(key) = event { + match key.code { + KeyCode::Right => self.next_metric(), + KeyCode::Left => self.previous_metric(), + KeyCode::Up => self.switch_kind(), + KeyCode::Down => self.switch_kind(), + _ => {} + } } - - fn previous_metric(&mut self) { - if self.selected > 0 { - self.selected -= 1; - } else { - self.selected = ({ - let this = &self; - this.data.len() - }) - 1; - } - } - - fn chart<'a>(&'a self) -> Chart<'a> { - let name = self.names.get(self.selected).unwrap(); - let (recent, full) = self.data.get(name).unwrap(); - - let (datasets, axes) = match self.kind { - PlotKind::Full => (full.datasets(), &full.axes), - PlotKind::Recent => (recent.datasets(), &recent.axes), - }; - - Chart::<'a>::new(datasets) - .block(Block::default()) - .x_axis( - Axis::default() - .style(Style::default().fg(Color::DarkGray)) - .title("Iteration") - .labels(axes.labels_x.iter().map(|s| s.bold()).collect()) - .bounds(axes.bounds_x), - ) - .y_axis( - Axis::default() - .style(Style::default().fg(Color::DarkGray)) - .labels(axes.labels_y.iter().map(|s| s.bold()).collect()) - .bounds(axes.bounds_y), - ) + } + + fn switch_kind(&mut self) { + self.kind = match self.kind { + PlotKind::Full => PlotKind::Recent, + PlotKind::Recent => PlotKind::Full, + }; + } + + fn next_metric(&mut self) { + self.selected = (self.selected + 1) % { + let this = &self; + this.data.len() + }; + } + + fn previous_metric(&mut self) { + if self.selected > 0 { + self.selected -= 1; + } else { + self.selected = ({ + let this = &self; + this.data.len() + }) - 1; } + } + + fn chart<'a>(&'a self) -> Chart<'a> { + let name = self.names.get(self.selected).unwrap(); + let (recent, full) = self.data.get(name).unwrap(); + + let (datasets, axes) = match self.kind { + PlotKind::Full => (full.datasets(), &full.axes), + PlotKind::Recent => (recent.datasets(), &recent.axes), + }; + + Chart::<'a>::new(datasets) + .block(Block::default()) + .x_axis( + Axis::default() + .style(Style::default().fg(Color::DarkGray)) + .title("Iteration") + .labels(axes.labels_x.iter().map(|s| s.bold()).collect()) + .bounds(axes.bounds_x), + ) + .y_axis( + Axis::default() + .style(Style::default().fg(Color::DarkGray)) + .labels(axes.labels_y.iter().map(|s| s.bold()).collect()) + .bounds(axes.bounds_y), + ) + } } #[derive(new)] pub(crate) enum NumericMetricView<'a> { - Plots(&'a [String], usize, Chart<'a>, PlotKind), - None, + Plots(&'a [String], usize, Chart<'a>, PlotKind), + None, } impl<'a> NumericMetricView<'a> { - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - match self { - Self::Plots(titles, selected, chart, kind) => { - let block = Block::default() - .borders(Borders::ALL) - .title("Plots") - .title_alignment(Alignment::Left); - let size_new = block.inner(size); - frame.render_widget(block, size); - - let size = size_new; - - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints( - [ - Constraint::Length(2), - Constraint::Length(1), - Constraint::Min(0), - ] - .as_ref(), - ) - .split(size); - - let titles = titles - .iter() - .map(|i| Line::from(vec![i.yellow()])) - .collect(); - - let tabs = Tabs::new(titles) - .select(selected) - .style(Style::default()) - .highlight_style( - Style::default() - .add_modifier(Modifier::BOLD) - .add_modifier(Modifier::UNDERLINED) - .fg(Color::LightYellow), - ); - let title = match kind { - PlotKind::Full => "Full History", - PlotKind::Recent => "Recent History", - }; - - let plot_type = - Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center); - - frame.render_widget(tabs, chunks[0]); - frame.render_widget(plot_type, chunks[1]); - frame.render_widget(chart, chunks[2]); - } - Self::None => {} + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + match self { + Self::Plots(titles, selected, chart, kind) => { + let block = Block::default() + .borders(Borders::ALL) + .title("Plots") + .title_alignment(Alignment::Left); + let size_new = block.inner(size); + frame.render_widget(block, size); + + let size = size_new; + + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints( + [ + Constraint::Length(2), + Constraint::Length(1), + Constraint::Min(0), + ] + .as_ref(), + ) + .split(size); + + let titles = titles + .iter() + .map(|i| Line::from(vec![i.yellow()])) + .collect(); + + let tabs = Tabs::new(titles) + .select(selected) + .style(Style::default()) + .highlight_style( + Style::default() + .add_modifier(Modifier::BOLD) + .add_modifier(Modifier::UNDERLINED) + .fg(Color::LightYellow), + ); + let title = match kind { + PlotKind::Full => "Full History", + PlotKind::Recent => "Recent History", }; - } + + let plot_type = Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center); + + frame.render_widget(tabs, chunks[0]); + frame.render_widget(plot_type, chunks[1]); + frame.render_widget(chart, chunks[2]); + } + Self::None => {} + }; + } } diff --git a/burn-train/src/renderer/tui/metric_text.rs b/burn-train/src/renderer/tui/metric_text.rs index c6d6b021c1..338877b808 100644 --- a/burn-train/src/renderer/tui/metric_text.rs +++ b/burn-train/src/renderer/tui/metric_text.rs @@ -1,101 +1,101 @@ use super::TerminalFrame; use crate::metric::MetricEntry; use ratatui::{ - prelude::{Alignment, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; use std::collections::HashMap; #[derive(Default)] pub(crate) struct TextMetricsState { - data: HashMap, - names: Vec, + data: HashMap, + names: Vec, } #[derive(new)] pub(crate) struct MetricData { - train: Option, - valid: Option, + train: Option, + valid: Option, } impl TextMetricsState { - pub(crate) fn update_train(&mut self, metric: MetricEntry) { - if let Some(existing) = self.data.get_mut(&metric.name) { - existing.train = Some(metric); - } else { - let key = metric.name.clone(); - let value = MetricData::new(Some(metric), None); - - self.names.push(key.clone()); - self.data.insert(key, value); - } + pub(crate) fn update_train(&mut self, metric: MetricEntry) { + if let Some(existing) = self.data.get_mut(&metric.name) { + existing.train = Some(metric); + } else { + let key = metric.name.clone(); + let value = MetricData::new(Some(metric), None); + + self.names.push(key.clone()); + self.data.insert(key, value); } - pub(crate) fn update_valid(&mut self, metric: MetricEntry) { - if let Some(existing) = self.data.get_mut(&metric.name) { - existing.valid = Some(metric); - } else { - let key = metric.name.clone(); - let value = MetricData::new(None, Some(metric)); - - self.names.push(key.clone()); - self.data.insert(key, value); - } - } - pub(crate) fn view(&self) -> TextMetricView { - TextMetricView::new(&self.names, &self.data) + } + pub(crate) fn update_valid(&mut self, metric: MetricEntry) { + if let Some(existing) = self.data.get_mut(&metric.name) { + existing.valid = Some(metric); + } else { + let key = metric.name.clone(); + let value = MetricData::new(None, Some(metric)); + + self.names.push(key.clone()); + self.data.insert(key, value); } + } + pub(crate) fn view(&self) -> TextMetricView { + TextMetricView::new(&self.names, &self.data) + } } pub(crate) struct TextMetricView { - lines: Vec>>, + lines: Vec>>, } impl TextMetricView { - fn new(names: &[String], data: &HashMap) -> Self { - let mut lines = Vec::with_capacity(names.len() * 4); - - let start_line = |title: &str| vec![Span::from(format!(" {title} ")).bold().yellow()]; - let train_line = |formatted: &str| { - vec![ - Span::from(" Train ").bold(), - Span::from(formatted.to_string()).italic(), - ] - }; - let valid_line = |formatted: &str| { - vec![ - Span::from(" Valid ").bold(), - Span::from(formatted.to_string()).italic(), - ] - }; - - for name in names { - lines.push(start_line(name)); - - let entry = data.get(name).unwrap(); - - if let Some(entry) = &entry.train { - lines.push(train_line(&entry.formatted)); - } - - if let Some(entry) = &entry.valid { - lines.push(valid_line(&entry.formatted)); - } - - lines.push(vec![Span::from("")]); - } - - Self { lines } + fn new(names: &[String], data: &HashMap) -> Self { + let mut lines = Vec::with_capacity(names.len() * 4); + + let start_line = |title: &str| vec![Span::from(format!(" {title} ")).bold().yellow()]; + let train_line = |formatted: &str| { + vec![ + Span::from(" Train ").bold(), + Span::from(formatted.to_string()).italic(), + ] + }; + let valid_line = |formatted: &str| { + vec![ + Span::from(" Valid ").bold(), + Span::from(formatted.to_string()).italic(), + ] + }; + + for name in names { + lines.push(start_line(name)); + + let entry = data.get(name).unwrap(); + + if let Some(entry) = &entry.train { + lines.push(train_line(&entry.formatted)); + } + + if let Some(entry) = &entry.valid { + lines.push(valid_line(&entry.formatted)); + } + + lines.push(vec![Span::from("")]); } - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) - .alignment(Alignment::Left) - .wrap(Wrap { trim: false }) - .block(Block::default().borders(Borders::ALL).title("Metrics")) - .style(Style::default().fg(Color::Gray)); + Self { lines } + } - frame.render_widget(paragraph, size); - } + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) + .alignment(Alignment::Left) + .wrap(Wrap { trim: false }) + .block(Block::default().borders(Borders::ALL).title("Metrics")) + .style(Style::default().fg(Color::Gray)); + + frame.render_widget(paragraph, size); + } } diff --git a/burn-train/src/renderer/tui/plot_utils.rs b/burn-train/src/renderer/tui/plot_utils.rs index 615cad8184..30207ced9a 100644 --- a/burn-train/src/renderer/tui/plot_utils.rs +++ b/burn-train/src/renderer/tui/plot_utils.rs @@ -4,45 +4,45 @@ const AXIS_TITLE_PRECISION: usize = 2; /// The data describing both X and Y axes. pub(crate) struct PlotAxes { - pub(crate) labels_x: Vec, - pub(crate) labels_y: Vec, - pub(crate) bounds_x: [f64; 2], - pub(crate) bounds_y: [f64; 2], + pub(crate) labels_x: Vec, + pub(crate) labels_y: Vec, + pub(crate) bounds_x: [f64; 2], + pub(crate) bounds_y: [f64; 2], } impl Default for PlotAxes { - fn default() -> Self { - Self { - bounds_x: [f64::MAX, f64::MIN], - bounds_y: [f64::MAX, f64::MIN], - labels_x: Vec::new(), - labels_y: Vec::new(), - } + fn default() -> Self { + Self { + bounds_x: [f64::MAX, f64::MIN], + bounds_y: [f64::MAX, f64::MIN], + labels_x: Vec::new(), + labels_y: Vec::new(), } + } } impl PlotAxes { - /// Update the bounds based on the min max of each X and Y axes with both train and valid data. - pub(crate) fn update_bounds( - &mut self, - (x_train_min, x_train_max): (f64, f64), - (x_valid_min, x_valid_max): (f64, f64), - (y_train_min, y_train_max): (f64, f64), - (y_valid_min, y_valid_max): (f64, f64), - ) { - let x_min = f64::min(x_train_min, x_valid_min); - let x_max = f64::max(x_train_max, x_valid_max); - let y_min = f64::min(y_train_min, y_valid_min); - let y_max = f64::max(y_train_max, y_valid_max); + /// Update the bounds based on the min max of each X and Y axes with both train and valid data. + pub(crate) fn update_bounds( + &mut self, + (x_train_min, x_train_max): (f64, f64), + (x_valid_min, x_valid_max): (f64, f64), + (y_train_min, y_train_max): (f64, f64), + (y_valid_min, y_valid_max): (f64, f64), + ) { + let x_min = f64::min(x_train_min, x_valid_min); + let x_max = f64::max(x_train_max, x_valid_max); + let y_min = f64::min(y_train_min, y_valid_min); + let y_max = f64::max(y_train_max, y_valid_max); - self.bounds_x = [x_min, x_max]; - self.bounds_y = [y_min, y_max]; + self.bounds_x = [x_min, x_max]; + self.bounds_y = [y_min, y_max]; - // We know x are integers. - self.labels_x = vec![format!("{x_min}"), format!("{x_max}")]; - self.labels_y = vec![ - format_float(y_min, AXIS_TITLE_PRECISION), - format_float(y_max, AXIS_TITLE_PRECISION), - ]; - } + // We know x are integers. + self.labels_x = vec![format!("{x_min}"), format!("{x_max}")]; + self.labels_y = vec![ + format_float(y_min, AXIS_TITLE_PRECISION), + format_float(y_max, AXIS_TITLE_PRECISION), + ]; + } } diff --git a/burn-train/src/renderer/tui/popup.rs b/burn-train/src/renderer/tui/popup.rs index b39a9c8f66..b75773d0e7 100644 --- a/burn-train/src/renderer/tui/popup.rs +++ b/burn-train/src/renderer/tui/popup.rs @@ -1,144 +1,144 @@ use crossterm::event::{Event, KeyCode}; use ratatui::{ - prelude::{Alignment, Constraint, Direction, Layout, Rect}, - style::{Color, Modifier, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Constraint, Direction, Layout, Rect}, + style::{Color, Modifier, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; use super::TerminalFrame; /// Popup callback function. pub(crate) trait CallbackFn: Send + Sync { - /// Call the function and return if the popup state should be reset. - fn call(&self) -> bool; + /// Call the function and return if the popup state should be reset. + fn call(&self) -> bool; } /// Popup callback. pub(crate) struct Callback { - title: String, - description: String, - trigger: char, - callback: Box, + title: String, + description: String, + trigger: char, + callback: Box, } impl Callback { - /// Create a new popup. - pub(crate) fn new(title: T, description: D, trigger: char, callback: C) -> Self - where - T: Into, - D: Into, - C: CallbackFn + 'static, - { - Self { - title: title.into(), - description: description.into(), - trigger, - callback: Box::new(callback), - } + /// Create a new popup. + pub(crate) fn new(title: T, description: D, trigger: char, callback: C) -> Self + where + T: Into, + D: Into, + C: CallbackFn + 'static, + { + Self { + title: title.into(), + description: description.into(), + trigger, + callback: Box::new(callback), } + } } /// Popup state. pub(crate) enum PopupState { - Empty, - Full(String, Vec), + Empty, + Full(String, Vec), } impl PopupState { - /// If the popup is empty. - pub(crate) fn is_empty(&self) -> bool { - matches!(&self, PopupState::Empty) - } - /// Handle popup events. - pub(crate) fn on_event(&mut self, event: &Event) { - let mut reset = false; + /// If the popup is empty. + pub(crate) fn is_empty(&self) -> bool { + matches!(&self, PopupState::Empty) + } + /// Handle popup events. + pub(crate) fn on_event(&mut self, event: &Event) { + let mut reset = false; - match self { - PopupState::Empty => {} - PopupState::Full(_, callbacks) => { - for callback in callbacks.iter() { - if let Event::Key(key) = event { - if let KeyCode::Char(key) = &key.code { - if &callback.trigger == key && callback.callback.call() { - reset = true; - } - } - } - } + match self { + PopupState::Empty => {} + PopupState::Full(_, callbacks) => { + for callback in callbacks.iter() { + if let Event::Key(key) = event { + if let KeyCode::Char(key) = &key.code { + if &callback.trigger == key && callback.callback.call() { + reset = true; + } } - }; - - if reset { - *self = Self::Empty; + } } + } + }; + + if reset { + *self = Self::Empty; } - /// Create the popup view. - pub(crate) fn view(&self) -> Option> { - match self { - PopupState::Empty => None, - PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)), - } + } + /// Create the popup view. + pub(crate) fn view(&self) -> Option> { + match self { + PopupState::Empty => None, + PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)), } + } } #[derive(new)] pub(crate) struct PopupView<'a> { - title: &'a String, - callbacks: &'a [Callback], + title: &'a String, + callbacks: &'a [Callback], } impl<'a> PopupView<'a> { - /// Render the view. - pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) { - let lines = self - .callbacks - .iter() - .flat_map(|callback| { - vec![ - Line::from(vec![ - Span::from(format!("[{}] ", callback.trigger)).bold(), - Span::from(format!("{} ", callback.title)).yellow().bold(), - ]), - Line::from(Span::from("")), - Line::from(Span::from(callback.description.to_string()).italic()), - Line::from(Span::from("")), - ] - }) - .collect::>(); + /// Render the view. + pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) { + let lines = self + .callbacks + .iter() + .flat_map(|callback| { + vec![ + Line::from(vec![ + Span::from(format!("[{}] ", callback.trigger)).bold(), + Span::from(format!("{} ", callback.title)).yellow().bold(), + ]), + Line::from(Span::from("")), + Line::from(Span::from(callback.description.to_string()).italic()), + Line::from(Span::from("")), + ] + }) + .collect::>(); - let paragraph = Paragraph::new(lines) - .alignment(Alignment::Left) - .wrap(Wrap { trim: false }) - .style(Style::default().fg(Color::Gray)) - .block( - Block::default() - .borders(Borders::ALL) - .title_alignment(Alignment::Center) - .style(Style::default().fg(Color::Gray)) - .title(Span::styled( - self.title, - Style::default().add_modifier(Modifier::BOLD), - )), - ); + let paragraph = Paragraph::new(lines) + .alignment(Alignment::Left) + .wrap(Wrap { trim: false }) + .style(Style::default().fg(Color::Gray)) + .block( + Block::default() + .borders(Borders::ALL) + .title_alignment(Alignment::Center) + .style(Style::default().fg(Color::Gray)) + .title(Span::styled( + self.title, + Style::default().add_modifier(Modifier::BOLD), + )), + ); - let area = centered_percent(20, size, Direction::Horizontal); - let area = centered_percent(20, area, Direction::Vertical); + let area = centered_percent(20, size, Direction::Horizontal); + let area = centered_percent(20, area, Direction::Vertical); - frame.render_widget(paragraph, area); - } + frame.render_widget(paragraph, area); + } } /// The percent represents the amount of space that will be taken by each side. fn centered_percent(percent: u16, size: Rect, direction: Direction) -> Rect { - let center = 100 - (percent * 2); + let center = 100 - (percent * 2); - Layout::default() - .direction(direction) - .constraints([ - Constraint::Percentage(percent), - Constraint::Percentage(center), - Constraint::Percentage(percent), - ]) - .split(size)[1] + Layout::default() + .direction(direction) + .constraints([ + Constraint::Percentage(percent), + Constraint::Percentage(center), + Constraint::Percentage(percent), + ]) + .split(size)[1] } diff --git a/burn-train/src/renderer/tui/progress.rs b/burn-train/src/renderer/tui/progress.rs index 5921db2438..647eeab38e 100644 --- a/burn-train/src/renderer/tui/progress.rs +++ b/burn-train/src/renderer/tui/progress.rs @@ -1,10 +1,10 @@ use super::TerminalFrame; use crate::renderer::TrainingProgress; use ratatui::{ - prelude::{Alignment, Constraint, Direction, Layout, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Gauge, Paragraph}, + prelude::{Alignment, Constraint, Direction, Layout, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Gauge, Paragraph}, }; use std::time::{Duration, Instant}; @@ -12,9 +12,9 @@ use std::time::{Duration, Instant}; /// /// We currently ignore the time taken for the validation part. pub(crate) struct ProgressBarState { - progress_train: f64, // Progress for total training. - starting_epoch: usize, - estimate: ProgressEstimate, + progress_train: f64, // Progress for total training. + starting_epoch: usize, + estimate: ProgressEstimate, } const MINUTE: u64 = 60; @@ -22,243 +22,243 @@ const HOUR: u64 = 60 * 60; const DAY: u64 = 24 * 60 * 60; impl ProgressBarState { - pub fn new(checkpoint: Option) -> Self { - Self { - progress_train: 0.0, - estimate: ProgressEstimate::new(), - starting_epoch: checkpoint.unwrap_or(0), - } - } - /// Update the training progress. - pub(crate) fn update_train(&mut self, progress: &TrainingProgress) { - self.progress_train = calculate_progress(progress, 0, 0); - self.estimate.update(progress, self.starting_epoch); - } - - /// Update the validation progress. - pub(crate) fn update_valid(&mut self, _progress: &TrainingProgress) { - // We don't use the validation for the progress yet. - } - - /// Create a view for the current progress. - pub(crate) fn view(&self) -> ProgressBarView { - const NO_ETA: &str = "---"; - - let eta = match self.estimate.secs() { - Some(eta) => format_eta(eta), - None => NO_ETA.to_string(), - }; - ProgressBarView::new(self.progress_train, eta) + pub fn new(checkpoint: Option) -> Self { + Self { + progress_train: 0.0, + estimate: ProgressEstimate::new(), + starting_epoch: checkpoint.unwrap_or(0), } + } + /// Update the training progress. + pub(crate) fn update_train(&mut self, progress: &TrainingProgress) { + self.progress_train = calculate_progress(progress, 0, 0); + self.estimate.update(progress, self.starting_epoch); + } + + /// Update the validation progress. + pub(crate) fn update_valid(&mut self, _progress: &TrainingProgress) { + // We don't use the validation for the progress yet. + } + + /// Create a view for the current progress. + pub(crate) fn view(&self) -> ProgressBarView { + const NO_ETA: &str = "---"; + + let eta = match self.estimate.secs() { + Some(eta) => format_eta(eta), + None => NO_ETA.to_string(), + }; + ProgressBarView::new(self.progress_train, eta) + } } #[derive(new)] pub(crate) struct ProgressBarView { - progress: f64, - eta: String, + progress: f64, + eta: String, } impl ProgressBarView { - /// Render the view. - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let block = Block::default() - .borders(Borders::ALL) - .title("Progress") - .title_alignment(Alignment::Left); - let size_new = block.inner(size); - frame.render_widget(block, size); - let size = size_new; - - let chunks = Layout::default() - .direction(Direction::Horizontal) - .constraints( - [ - Constraint::Length(1), // Empty space - Constraint::Min(0), - Constraint::Length(self.eta.len() as u16 + 4), - ] - .as_ref(), - ) - .split(size); - - let size_gauge = chunks[1]; - let size_eta = chunks[2]; - - let iteration = Gauge::default() - .gauge_style(Style::default().fg(Color::Yellow)) - .ratio(self.progress); - let eta = Paragraph::new(Line::from(vec![ - Span::from(" ("), - Span::from(self.eta).italic(), - Span::from(") "), - ])); - - frame.render_widget(iteration, size_gauge); - frame.render_widget(eta, size_eta); - } + /// Render the view. + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let block = Block::default() + .borders(Borders::ALL) + .title("Progress") + .title_alignment(Alignment::Left); + let size_new = block.inner(size); + frame.render_widget(block, size); + let size = size_new; + + let chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints( + [ + Constraint::Length(1), // Empty space + Constraint::Min(0), + Constraint::Length(self.eta.len() as u16 + 4), + ] + .as_ref(), + ) + .split(size); + + let size_gauge = chunks[1]; + let size_eta = chunks[2]; + + let iteration = Gauge::default() + .gauge_style(Style::default().fg(Color::Yellow)) + .ratio(self.progress); + let eta = Paragraph::new(Line::from(vec![ + Span::from(" ("), + Span::from(self.eta).italic(), + Span::from(") "), + ])); + + frame.render_widget(iteration, size_gauge); + frame.render_widget(eta, size_eta); + } } struct ProgressEstimate { - started: Instant, - started_after_warmup: Option, - warmup_num_items: usize, - progress: f64, + started: Instant, + started_after_warmup: Option, + warmup_num_items: usize, + progress: f64, } impl ProgressEstimate { - fn new() -> Self { - Self { - started: Instant::now(), - started_after_warmup: None, - warmup_num_items: 0, - progress: 0.0, - } + fn new() -> Self { + Self { + started: Instant::now(), + started_after_warmup: None, + warmup_num_items: 0, + progress: 0.0, } + } - fn secs(&self) -> Option { - let eta = match self.started_after_warmup { - Some(started) => started.elapsed(), - None => return None, - }; - - let total_estimated = (eta.as_secs() as f64) / self.progress; - - if total_estimated.is_normal() { - let remaining = 1.0 - self.progress; - let eta = (total_estimated * remaining) as u64; - Some(eta) - } else { - None - } + fn secs(&self) -> Option { + let eta = match self.started_after_warmup { + Some(started) => started.elapsed(), + None => return None, + }; + + let total_estimated = (eta.as_secs() as f64) / self.progress; + + if total_estimated.is_normal() { + let remaining = 1.0 - self.progress; + let eta = (total_estimated * remaining) as u64; + Some(eta) + } else { + None } + } - fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) { - if self.started_after_warmup.is_some() { - self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); - return; - } - - const WARMUP_NUM_ITERATION: usize = 10; - - // When the training has started since 30 seconds. - if self.started.elapsed() > Duration::from_secs(30) { - self.init(progress, starting_epoch); - return; - } - - // When the training has started since at least 10 seconds and completed 10 iterations. - if progress.iteration >= WARMUP_NUM_ITERATION - && self.started.elapsed() > Duration::from_secs(10) - { - self.init(progress, starting_epoch); - } + fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) { + if self.started_after_warmup.is_some() { + self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); + return; } - fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) { - let epoch = progress.epoch - starting_epoch; - let epoch_items = (epoch - 1) * progress.progress.items_total; - let iteration_items = progress.progress.items_processed; + const WARMUP_NUM_ITERATION: usize = 10; - self.warmup_num_items = epoch_items + iteration_items; - self.started_after_warmup = Some(Instant::now()); - self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); + // When the training has started since 30 seconds. + if self.started.elapsed() > Duration::from_secs(30) { + self.init(progress, starting_epoch); + return; } + + // When the training has started since at least 10 seconds and completed 10 iterations. + if progress.iteration >= WARMUP_NUM_ITERATION + && self.started.elapsed() > Duration::from_secs(10) + { + self.init(progress, starting_epoch); + } + } + + fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) { + let epoch = progress.epoch - starting_epoch; + let epoch_items = (epoch - 1) * progress.progress.items_total; + let iteration_items = progress.progress.items_processed; + + self.warmup_num_items = epoch_items + iteration_items; + self.started_after_warmup = Some(Instant::now()); + self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); + } } fn calculate_progress( - progress: &TrainingProgress, - starting_epoch: usize, - ignore_num_items: usize, + progress: &TrainingProgress, + starting_epoch: usize, + ignore_num_items: usize, ) -> f64 { - let epoch_total = progress.epoch_total - starting_epoch; - let epoch = progress.epoch - starting_epoch; + let epoch_total = progress.epoch_total - starting_epoch; + let epoch = progress.epoch - starting_epoch; - let total_items = progress.progress.items_total * epoch_total; - let epoch_items = (epoch - 1) * progress.progress.items_total; - let iteration_items = progress.progress.items_processed; - let num_items = epoch_items + iteration_items - ignore_num_items; + let total_items = progress.progress.items_total * epoch_total; + let epoch_items = (epoch - 1) * progress.progress.items_total; + let iteration_items = progress.progress.items_processed; + let num_items = epoch_items + iteration_items - ignore_num_items; - num_items as f64 / total_items as f64 + num_items as f64 / total_items as f64 } fn format_eta(eta_secs: u64) -> String { - let seconds = eta_secs % 60; - let minutes = eta_secs / MINUTE % 60; - let hours = eta_secs / HOUR % 24; - let days = eta_secs / DAY; - - if days > 1 { - format!("{days} days") - } else if days == 1 { - "1 day".to_string() - } else if hours > 1 { - format!("{hours} hours") - } else if hours == 1 { - "1 hour".to_string() - } else if minutes > 1 { - format!("{minutes} mins") - } else if minutes == 1 { - "1 min".to_string() - } else if seconds > 1 { - format!("{seconds} secs") - } else { - "1 sec".to_string() - } + let seconds = eta_secs % 60; + let minutes = eta_secs / MINUTE % 60; + let hours = eta_secs / HOUR % 24; + let days = eta_secs / DAY; + + if days > 1 { + format!("{days} days") + } else if days == 1 { + "1 day".to_string() + } else if hours > 1 { + format!("{hours} hours") + } else if hours == 1 { + "1 hour".to_string() + } else if minutes > 1 { + format!("{minutes} mins") + } else if minutes == 1 { + "1 min".to_string() + } else if seconds > 1 { + format!("{seconds} secs") + } else { + "1 sec".to_string() + } } #[cfg(test)] mod tests { - use super::*; - use burn_core::data::dataloader::Progress; - - #[test] - fn test_format_eta() { - assert_eq!("55 secs", format_eta(55), "Less than 1 minutes"); - assert_eq!("1 min", format_eta(61), "More than 1 minutes"); - assert_eq!("2 mins", format_eta(2 * 61), "More than 2 minutes"); - assert_eq!("1 hour", format_eta(3601), "More than 1 hour"); - assert_eq!("2 hours", format_eta(2 * 3601), "More than 2 hour"); - assert_eq!("1 day", format_eta(24 * 3601), "More than 1 day"); - assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day"); - } - - #[test] - fn calculate_progress_for_eta() { - let half = Progress { - items_processed: 5, - items_total: 10, - }; - let progress = TrainingProgress { - progress: half, - epoch: 9, - epoch_total: 10, - iteration: 500, - }; - - let starting_epoch = 8; - let progress = calculate_progress(&progress, starting_epoch, 0); - - // Two epochs remaining while the first is half done. - assert_eq!(0.25, progress); - } - - #[test] - fn calculate_progress_for_eta_with_warmup() { - let half = Progress { - items_processed: 110, - items_total: 1000, - }; - let progress = TrainingProgress { - progress: half, - epoch: 9, - epoch_total: 10, - iteration: 500, - }; - - let starting_epoch = 8; - let progress = calculate_progress(&progress, starting_epoch, 10); - - // Two epochs remaining while the first is half done. - assert_eq!(0.05, progress); - } + use super::*; + use burn_core::data::dataloader::Progress; + + #[test] + fn test_format_eta() { + assert_eq!("55 secs", format_eta(55), "Less than 1 minutes"); + assert_eq!("1 min", format_eta(61), "More than 1 minutes"); + assert_eq!("2 mins", format_eta(2 * 61), "More than 2 minutes"); + assert_eq!("1 hour", format_eta(3601), "More than 1 hour"); + assert_eq!("2 hours", format_eta(2 * 3601), "More than 2 hour"); + assert_eq!("1 day", format_eta(24 * 3601), "More than 1 day"); + assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day"); + } + + #[test] + fn calculate_progress_for_eta() { + let half = Progress { + items_processed: 5, + items_total: 10, + }; + let progress = TrainingProgress { + progress: half, + epoch: 9, + epoch_total: 10, + iteration: 500, + }; + + let starting_epoch = 8; + let progress = calculate_progress(&progress, starting_epoch, 0); + + // Two epochs remaining while the first is half done. + assert_eq!(0.25, progress); + } + + #[test] + fn calculate_progress_for_eta_with_warmup() { + let half = Progress { + items_processed: 110, + items_total: 1000, + }; + let progress = TrainingProgress { + progress: half, + epoch: 9, + epoch_total: 10, + iteration: 500, + }; + + let starting_epoch = 8; + let progress = calculate_progress(&progress, starting_epoch, 10); + + // Two epochs remaining while the first is half done. + assert_eq!(0.05, progress); + } } diff --git a/burn-train/src/renderer/tui/recent_history.rs b/burn-train/src/renderer/tui/recent_history.rs index ac91d60888..60b1ce77c3 100644 --- a/burn-train/src/renderer/tui/recent_history.rs +++ b/burn-train/src/renderer/tui/recent_history.rs @@ -1,244 +1,244 @@ use super::PlotAxes; use ratatui::{ - style::{Color, Style, Stylize}, - symbols, - widgets::{Dataset, GraphType}, + style::{Color, Style, Stylize}, + symbols, + widgets::{Dataset, GraphType}, }; const FACTOR_BEFORE_RESIZE: usize = 2; /// A plot that shows the recent history at full resolution. pub(crate) struct RecentHistoryPlot { - pub(crate) axes: PlotAxes, - train: RecentHistoryPoints, - valid: RecentHistoryPoints, - max_samples: usize, + pub(crate) axes: PlotAxes, + train: RecentHistoryPoints, + valid: RecentHistoryPoints, + max_samples: usize, } struct RecentHistoryPoints { - min_x: f64, - max_x: f64, - min_y: f64, - max_y: f64, - cursor: usize, - points: Vec<(f64, f64)>, - max_samples: usize, - factor_before_resize: usize, + min_x: f64, + max_x: f64, + min_y: f64, + max_y: f64, + cursor: usize, + points: Vec<(f64, f64)>, + max_samples: usize, + factor_before_resize: usize, } impl RecentHistoryPlot { - pub(crate) fn new(max_samples: usize) -> Self { - Self { - axes: PlotAxes::default(), - train: RecentHistoryPoints::new(max_samples), - valid: RecentHistoryPoints::new(max_samples), - max_samples, - } + pub(crate) fn new(max_samples: usize) -> Self { + Self { + axes: PlotAxes::default(), + train: RecentHistoryPoints::new(max_samples), + valid: RecentHistoryPoints::new(max_samples), + max_samples, } + } - pub(crate) fn push_train(&mut self, data: f64) { - let (x_min, x_current) = self.x(); + pub(crate) fn push_train(&mut self, data: f64) { + let (x_min, x_current) = self.x(); - self.train.push((x_current, data)); - self.train.update_cursor(x_min); - self.valid.update_cursor(x_min); + self.train.push((x_current, data)); + self.train.update_cursor(x_min); + self.valid.update_cursor(x_min); - self.update_bounds(); - } - - pub(crate) fn push_valid(&mut self, data: f64) { - let (x_min, x_current) = self.x(); + self.update_bounds(); + } - self.valid.push((x_current, data)); - self.valid.update_cursor(x_min); - self.train.update_cursor(x_min); + pub(crate) fn push_valid(&mut self, data: f64) { + let (x_min, x_current) = self.x(); - self.update_bounds(); - } + self.valid.push((x_current, data)); + self.valid.update_cursor(x_min); + self.train.update_cursor(x_min); - pub(crate) fn datasets(&self) -> Vec> { - let mut datasets = Vec::with_capacity(2); + self.update_bounds(); + } - if self.train.num_visible_points() > 0 { - datasets.push(self.train.dataset("Train", Color::LightRed)); - } + pub(crate) fn datasets(&self) -> Vec> { + let mut datasets = Vec::with_capacity(2); - if self.valid.num_visible_points() > 0 { - datasets.push(self.valid.dataset("Valid", Color::LightBlue)); - } + if self.train.num_visible_points() > 0 { + datasets.push(self.train.dataset("Train", Color::LightRed)); + } - datasets + if self.valid.num_visible_points() > 0 { + datasets.push(self.valid.dataset("Valid", Color::LightBlue)); } - fn x(&mut self) -> (f64, f64) { - let x_current = f64::max(self.train.max_x, self.valid.max_x) + 1.0; - let mut x_min = f64::min(self.train.min_x, self.valid.min_x); - if x_current - x_min >= self.max_samples as f64 { - x_min += 1.0; - } + datasets + } - (x_min, x_current) + fn x(&mut self) -> (f64, f64) { + let x_current = f64::max(self.train.max_x, self.valid.max_x) + 1.0; + let mut x_min = f64::min(self.train.min_x, self.valid.min_x); + if x_current - x_min >= self.max_samples as f64 { + x_min += 1.0; } - fn update_bounds(&mut self) { - self.axes.update_bounds( - (self.train.min_x, self.train.max_x), - (self.valid.min_x, self.valid.max_x), - (self.train.min_y, self.train.max_y), - (self.valid.min_y, self.valid.max_y), - ); - } + (x_min, x_current) + } + + fn update_bounds(&mut self) { + self.axes.update_bounds( + (self.train.min_x, self.train.max_x), + (self.valid.min_x, self.valid.max_x), + (self.train.min_y, self.train.max_y), + (self.valid.min_y, self.valid.max_y), + ); + } } impl RecentHistoryPoints { - fn new(max_samples: usize) -> Self { - let factor_before_resize = FACTOR_BEFORE_RESIZE; + fn new(max_samples: usize) -> Self { + let factor_before_resize = FACTOR_BEFORE_RESIZE; - Self { - min_x: 0., - max_x: 0., - min_y: f64::MAX, - max_y: f64::MIN, - points: Vec::with_capacity(factor_before_resize * max_samples), - cursor: 0, - max_samples, - factor_before_resize, - } + Self { + min_x: 0., + max_x: 0., + min_y: f64::MAX, + max_y: f64::MIN, + points: Vec::with_capacity(factor_before_resize * max_samples), + cursor: 0, + max_samples, + factor_before_resize, } + } - fn num_visible_points(&self) -> usize { - self.points.len() - } + fn num_visible_points(&self) -> usize { + self.points.len() + } - fn push(&mut self, (x, y): (f64, f64)) { - if x > self.max_x { - self.max_x = x; - } - if x < self.min_x { - self.min_x = x; - } - if y > self.max_y { - self.max_y = y; - } - if y < self.min_y { - self.min_y = y - } - self.points.push((x, y)); + fn push(&mut self, (x, y): (f64, f64)) { + if x > self.max_x { + self.max_x = x; } + if x < self.min_x { + self.min_x = x; + } + if y > self.max_y { + self.max_y = y; + } + if y < self.min_y { + self.min_y = y + } + self.points.push((x, y)); + } - fn update_cursor(&mut self, min_x: f64) { - if self.min_x >= min_x { - return; - } - self.min_x = min_x; - - let mut update_y_max = false; - let mut update_y_min = false; + fn update_cursor(&mut self, min_x: f64) { + if self.min_x >= min_x { + return; + } + self.min_x = min_x; - while let Some((x, y)) = self.points.get(self.cursor) { - if *x >= self.min_x { - break; - } + let mut update_y_max = false; + let mut update_y_min = false; - if *y == self.max_y { - update_y_max = true - } - if *y == self.min_y { - update_y_min = true; - } + while let Some((x, y)) = self.points.get(self.cursor) { + if *x >= self.min_x { + break; + } - self.cursor += 1; - } + if *y == self.max_y { + update_y_max = true + } + if *y == self.min_y { + update_y_min = true; + } - if update_y_max { - self.max_y = self.calculate_max_y(); - } + self.cursor += 1; + } - if update_y_min { - self.min_y = self.calculate_min_y(); - } + if update_y_max { + self.max_y = self.calculate_max_y(); + } - if self.points.len() >= self.max_samples * self.factor_before_resize { - self.resize(); - } + if update_y_min { + self.min_y = self.calculate_min_y(); } - fn slice(&self) -> &[(f64, f64)] { - &self.points[self.cursor..self.points.len()] + if self.points.len() >= self.max_samples * self.factor_before_resize { + self.resize(); } + } - fn calculate_max_y(&self) -> f64 { - let mut max_y = f64::MIN; + fn slice(&self) -> &[(f64, f64)] { + &self.points[self.cursor..self.points.len()] + } - for (_x, y) in self.slice() { - if *y > max_y { - max_y = *y; - } - } + fn calculate_max_y(&self) -> f64 { + let mut max_y = f64::MIN; - max_y + for (_x, y) in self.slice() { + if *y > max_y { + max_y = *y; + } } - fn calculate_min_y(&self) -> f64 { - let mut min_y = f64::MAX; + max_y + } - for (_x, y) in self.slice() { - if *y < min_y { - min_y = *y; - } - } + fn calculate_min_y(&self) -> f64 { + let mut min_y = f64::MAX; - min_y + for (_x, y) in self.slice() { + if *y < min_y { + min_y = *y; + } } - fn resize(&mut self) { - let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize); + min_y + } - for i in self.cursor..self.points.len() { - points.push(self.points[i]); - } + fn resize(&mut self) { + let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize); - self.points = points; - self.cursor = 0; + for i in self.cursor..self.points.len() { + points.push(self.points[i]); } - fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { - let data = &self.points[self.cursor..self.points.len()]; + self.points = points; + self.cursor = 0; + } - Dataset::default() - .name(name) - .marker(symbols::Marker::Braille) - .style(Style::default().fg(color).bold()) - .graph_type(GraphType::Scatter) - .data(data) - } + fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { + let data = &self.points[self.cursor..self.points.len()]; + + Dataset::default() + .name(name) + .marker(symbols::Marker::Braille) + .style(Style::default().fg(color).bold()) + .graph_type(GraphType::Scatter) + .data(data) + } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_push_update_bounds_max_y() { - let mut chart = RecentHistoryPlot::new(3); - chart.push_train(15.0); - chart.push_train(10.0); - chart.push_train(14.0); - - assert_eq!(chart.axes.bounds_y[1], 15.); - chart.push_train(10.0); - assert_eq!(chart.axes.bounds_y[1], 14.); - } - - #[test] - fn test_push_update_bounds_min_y() { - let mut chart = RecentHistoryPlot::new(3); - chart.push_train(5.0); - chart.push_train(10.0); - chart.push_train(14.0); - - assert_eq!(chart.axes.bounds_y[0], 5.); - chart.push_train(10.0); - assert_eq!(chart.axes.bounds_y[0], 10.); - } + use super::*; + + #[test] + fn test_push_update_bounds_max_y() { + let mut chart = RecentHistoryPlot::new(3); + chart.push_train(15.0); + chart.push_train(10.0); + chart.push_train(14.0); + + assert_eq!(chart.axes.bounds_y[1], 15.); + chart.push_train(10.0); + assert_eq!(chart.axes.bounds_y[1], 14.); + } + + #[test] + fn test_push_update_bounds_min_y() { + let mut chart = RecentHistoryPlot::new(3); + chart.push_train(5.0); + chart.push_train(10.0); + chart.push_train(14.0); + + assert_eq!(chart.axes.bounds_y[0], 5.); + chart.push_train(10.0); + assert_eq!(chart.axes.bounds_y[0], 10.); + } } diff --git a/burn-train/src/renderer/tui/renderer.rs b/burn-train/src/renderer/tui/renderer.rs index 015e3e88ca..ae09c9885c 100644 --- a/burn-train/src/renderer/tui/renderer.rs +++ b/burn-train/src/renderer/tui/renderer.rs @@ -2,20 +2,20 @@ use crate::renderer::{tui::NumericMetricsState, MetricsRenderer}; use crate::renderer::{MetricState, TrainingProgress}; use crate::TrainingInterrupter; use crossterm::{ - event::{self, Event, KeyCode}, - execute, - terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, + event::{self, Event, KeyCode}, + execute, + terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; use ratatui::{prelude::*, Terminal}; use std::{ - error::Error, - io::{self, Stdout}, - time::{Duration, Instant}, + error::Error, + io::{self, Stdout}, + time::{Duration, Instant}, }; use super::{ - Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState, - TextMetricsState, + Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState, + TextMetricsState, }; /// The current terminal backend. @@ -27,124 +27,124 @@ const MAX_REFRESH_RATE_MILLIS: u64 = 100; /// The terminal UI metrics renderer. pub struct TuiMetricsRenderer { - terminal: Terminal, - last_update: std::time::Instant, - progress: ProgressBarState, - metrics_numeric: NumericMetricsState, - metrics_text: TextMetricsState, - status: StatusState, - interuptor: TrainingInterrupter, - popup: PopupState, + terminal: Terminal, + last_update: std::time::Instant, + progress: ProgressBarState, + metrics_numeric: NumericMetricsState, + metrics_text: TextMetricsState, + status: StatusState, + interuptor: TrainingInterrupter, + popup: PopupState, } impl MetricsRenderer for TuiMetricsRenderer { - fn update_train(&mut self, state: MetricState) { - match state { - MetricState::Generic(entry) => { - self.metrics_text.update_train(entry); - } - MetricState::Numeric(entry, value) => { - self.metrics_numeric.push_train(entry.name.clone(), value); - self.metrics_text.update_train(entry); - } - }; - } - - fn update_valid(&mut self, state: MetricState) { - match state { - MetricState::Generic(entry) => { - self.metrics_text.update_valid(entry); - } - MetricState::Numeric(entry, value) => { - self.metrics_numeric.push_valid(entry.name.clone(), value); - self.metrics_text.update_valid(entry); - } - }; - } - - fn render_train(&mut self, item: TrainingProgress) { - self.progress.update_train(&item); - self.metrics_numeric.update_progress_train(&item); - self.status.update_train(item); - self.render().unwrap(); - } - - fn render_valid(&mut self, item: TrainingProgress) { - self.progress.update_valid(&item); - self.metrics_numeric.update_progress_valid(&item); - self.status.update_valid(item); - self.render().unwrap(); - } + fn update_train(&mut self, state: MetricState) { + match state { + MetricState::Generic(entry) => { + self.metrics_text.update_train(entry); + } + MetricState::Numeric(entry, value) => { + self.metrics_numeric.push_train(entry.name.clone(), value); + self.metrics_text.update_train(entry); + } + }; + } + + fn update_valid(&mut self, state: MetricState) { + match state { + MetricState::Generic(entry) => { + self.metrics_text.update_valid(entry); + } + MetricState::Numeric(entry, value) => { + self.metrics_numeric.push_valid(entry.name.clone(), value); + self.metrics_text.update_valid(entry); + } + }; + } + + fn render_train(&mut self, item: TrainingProgress) { + self.progress.update_train(&item); + self.metrics_numeric.update_progress_train(&item); + self.status.update_train(item); + self.render().unwrap(); + } + + fn render_valid(&mut self, item: TrainingProgress) { + self.progress.update_valid(&item); + self.metrics_numeric.update_progress_valid(&item); + self.status.update_valid(item); + self.render().unwrap(); + } } impl TuiMetricsRenderer { - /// Create a new terminal UI renderer. - pub fn new(interuptor: TrainingInterrupter, checkpoint: Option) -> Self { - let mut stdout = io::stdout(); - execute!(stdout, EnterAlternateScreen).unwrap(); - enable_raw_mode().unwrap(); - let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap(); - - Self { - terminal, - last_update: Instant::now(), - progress: ProgressBarState::new(checkpoint), - metrics_numeric: NumericMetricsState::default(), - metrics_text: TextMetricsState::default(), - status: StatusState::default(), - interuptor, - popup: PopupState::Empty, - } + /// Create a new terminal UI renderer. + pub fn new(interuptor: TrainingInterrupter, checkpoint: Option) -> Self { + let mut stdout = io::stdout(); + execute!(stdout, EnterAlternateScreen).unwrap(); + enable_raw_mode().unwrap(); + let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap(); + + Self { + terminal, + last_update: Instant::now(), + progress: ProgressBarState::new(checkpoint), + metrics_numeric: NumericMetricsState::default(), + metrics_text: TextMetricsState::default(), + status: StatusState::default(), + interuptor, + popup: PopupState::Empty, } + } - fn render(&mut self) -> Result<(), Box> { - let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); - if self.last_update.elapsed() < tick_rate { - return Ok(()); - } + fn render(&mut self) -> Result<(), Box> { + let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); + if self.last_update.elapsed() < tick_rate { + return Ok(()); + } - self.draw()?; - self.handle_events()?; + self.draw()?; + self.handle_events()?; - self.last_update = Instant::now(); + self.last_update = Instant::now(); - Ok(()) - } + Ok(()) + } - fn draw(&mut self) -> Result<(), Box> { - self.terminal.draw(|frame| { - let size = frame.size(); - - match self.popup.view() { - Some(view) => view.render(frame, size), - None => { - let view = MetricsView::new( - self.metrics_numeric.view(), - self.metrics_text.view(), - self.progress.view(), - ControlsView, - self.status.view(), - ); + fn draw(&mut self) -> Result<(), Box> { + self.terminal.draw(|frame| { + let size = frame.size(); - view.render(frame, size); - } - }; - })?; + match self.popup.view() { + Some(view) => view.render(frame, size), + None => { + let view = MetricsView::new( + self.metrics_numeric.view(), + self.metrics_text.view(), + self.progress.view(), + ControlsView, + self.status.view(), + ); - Ok(()) - } + view.render(frame, size); + } + }; + })?; - fn handle_events(&mut self) -> Result<(), Box> { - while event::poll(Duration::from_secs(0))? { - let event = event::read()?; - self.popup.on_event(&event); + Ok(()) + } - if self.popup.is_empty() { - self.metrics_numeric.on_event(&event); + fn handle_events(&mut self) -> Result<(), Box> { + while event::poll(Duration::from_secs(0))? { + let event = event::read()?; + self.popup.on_event(&event); - if let Event::Key(key) = event { - if let KeyCode::Char('q') = key.code { - self.popup = PopupState::Full( + if self.popup.is_empty() { + self.metrics_numeric.on_event(&event); + + if let Event::Key(key) = event { + if let KeyCode::Char('q') = key.code { + self.popup = PopupState::Full( "Quit".to_string(), vec![ Callback::new( @@ -162,13 +162,13 @@ impl TuiMetricsRenderer { Callback::new("Cancel", "Cancel the action, continue the training.", 'c', PopupCancel), ], ); - } - } - } + } } - - Ok(()) + } } + + Ok(()) + } } struct QuitPopupAccept(TrainingInterrupter); @@ -176,28 +176,28 @@ struct KillPopupAccept; struct PopupCancel; impl CallbackFn for KillPopupAccept { - fn call(&self) -> bool { - panic!("Killing training from user input."); - } + fn call(&self) -> bool { + panic!("Killing training from user input."); + } } impl CallbackFn for QuitPopupAccept { - fn call(&self) -> bool { - self.0.stop(); - true - } + fn call(&self) -> bool { + self.0.stop(); + true + } } impl CallbackFn for PopupCancel { - fn call(&self) -> bool { - true - } + fn call(&self) -> bool { + true + } } impl Drop for TuiMetricsRenderer { - fn drop(&mut self) { - disable_raw_mode().ok(); - execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap(); - self.terminal.show_cursor().ok(); - } + fn drop(&mut self) { + disable_raw_mode().ok(); + execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap(); + self.terminal.show_cursor().ok(); + } } diff --git a/burn-train/src/renderer/tui/status.rs b/burn-train/src/renderer/tui/status.rs index 3519d217cf..c067168498 100644 --- a/burn-train/src/renderer/tui/status.rs +++ b/burn-train/src/renderer/tui/status.rs @@ -1,91 +1,91 @@ use super::TerminalFrame; use crate::renderer::TrainingProgress; use ratatui::{ - prelude::{Alignment, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; /// Show the training status with various information. pub(crate) struct StatusState { - progress: TrainingProgress, - mode: Mode, + progress: TrainingProgress, + mode: Mode, } enum Mode { - Valid, - Train, + Valid, + Train, } impl Default for StatusState { - fn default() -> Self { - Self { - progress: TrainingProgress::none(), - mode: Mode::Train, - } + fn default() -> Self { + Self { + progress: TrainingProgress::none(), + mode: Mode::Train, } + } } impl StatusState { - /// Update the training information. - pub(crate) fn update_train(&mut self, progress: TrainingProgress) { - self.progress = progress; - self.mode = Mode::Train; - } - /// Update the validation information. - pub(crate) fn update_valid(&mut self, progress: TrainingProgress) { - self.progress = progress; - self.mode = Mode::Valid; - } - /// Create a view. - pub(crate) fn view(&self) -> StatusView { - StatusView::new(&self.progress, &self.mode) - } + /// Update the training information. + pub(crate) fn update_train(&mut self, progress: TrainingProgress) { + self.progress = progress; + self.mode = Mode::Train; + } + /// Update the validation information. + pub(crate) fn update_valid(&mut self, progress: TrainingProgress) { + self.progress = progress; + self.mode = Mode::Valid; + } + /// Create a view. + pub(crate) fn view(&self) -> StatusView { + StatusView::new(&self.progress, &self.mode) + } } pub(crate) struct StatusView { - lines: Vec>>, + lines: Vec>>, } impl StatusView { - fn new(progress: &TrainingProgress, mode: &Mode) -> Self { - let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow(); - let value = |value: String| Span::from(value).italic(); - let mode = match mode { - Mode::Valid => "Validating", - Mode::Train => "Training", - }; + fn new(progress: &TrainingProgress, mode: &Mode) -> Self { + let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow(); + let value = |value: String| Span::from(value).italic(); + let mode = match mode { + Mode::Valid => "Validating", + Mode::Train => "Training", + }; - Self { - lines: vec![ - vec![title("Mode :"), value(mode.to_string())], - vec![ - title("Epoch :"), - value(format!("{}/{}", progress.epoch, progress.epoch_total)), - ], - vec![ - title("Iteration :"), - value(format!("{}", progress.iteration)), - ], - vec![ - title("Items :"), - value(format!( - "{}/{}", - progress.progress.items_processed, progress.progress.items_total - )), - ], - ], - } + Self { + lines: vec![ + vec![title("Mode :"), value(mode.to_string())], + vec![ + title("Epoch :"), + value(format!("{}/{}", progress.epoch, progress.epoch_total)), + ], + vec![ + title("Iteration :"), + value(format!("{}", progress.iteration)), + ], + vec![ + title("Items :"), + value(format!( + "{}/{}", + progress.progress.items_processed, progress.progress.items_total + )), + ], + ], } + } - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) - .alignment(Alignment::Left) - .block(Block::default().borders(Borders::ALL).title("Status")) - .wrap(Wrap { trim: false }) - .style(Style::default().fg(Color::Gray)); + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) + .alignment(Alignment::Left) + .block(Block::default().borders(Borders::ALL).title("Status")) + .wrap(Wrap { trim: false }) + .style(Style::default().fg(Color::Gray)); - frame.render_widget(paragraph, size); - } + frame.render_widget(paragraph, size); + } } diff --git a/burn-wgpu/benches/fused_elemwise.rs b/burn-wgpu/benches/fused_elemwise.rs index ff8443bd2a..417c91433e 100644 --- a/burn-wgpu/benches/fused_elemwise.rs +++ b/burn-wgpu/benches/fused_elemwise.rs @@ -9,66 +9,66 @@ use std::marker::PhantomData; #[derive(new)] struct ElemWiseBenchmark { - shape: Shape<3>, - device: B::Device, - repeat: usize, - _b: PhantomData, + shape: Shape<3>, + device: B::Device, + repeat: usize, + _b: PhantomData, } impl Benchmark for ElemWiseBenchmark { - type Args = (Tensor, Tensor); + type Args = (Tensor, Tensor); - fn name(&self) -> String { - format!( - "Backend {} Shape {:?} Repeat {}", - B::name(), - self.shape.dims, - self.repeat - ) - } + fn name(&self) -> String { + format!( + "Backend {} Shape {:?} Repeat {}", + B::name(), + self.shape.dims, + self.repeat + ) + } - fn num_samples(&self) -> usize { - 10 - } + fn num_samples(&self) -> usize { + 10 + } - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.repeat { - let tmp_0 = lhs.clone() + rhs.clone(); - let tmp_1 = rhs.clone() * tmp_0.clone(); - let tmp_2 = rhs.clone().exp(); - let tmp_3 = tmp_0 * tmp_1; - let _tmp_4 = tmp_2 / tmp_3; - } + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.repeat { + let tmp_0 = lhs.clone() + rhs.clone(); + let tmp_1 = rhs.clone() * tmp_0.clone(); + let tmp_2 = rhs.clone().exp(); + let tmp_3 = tmp_0 * tmp_1; + let _tmp_4 = tmp_2 / tmp_3; } + } - fn prepare(&self) -> Self::Args { - B::seed(10); - let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + B::seed(10); + let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] /// Runs the benchmarks for wgpu matmul implementations pub fn bench(device: &WgpuDevice) { - run_benchmark(ElemWiseBenchmark::::new( - Shape::new([256, 256, 1024]), - device.clone(), - 10, - )); - run_benchmark(ElemWiseBenchmark::>::new( - Shape::new([256, 256, 1024]), - device.clone(), - 10, - )); + run_benchmark(ElemWiseBenchmark::::new( + Shape::new([256, 256, 1024]), + device.clone(), + 10, + )); + run_benchmark(ElemWiseBenchmark::>::new( + Shape::new([256, 256, 1024]), + device.clone(), + 10, + )); } fn main() { - bench(&WgpuDevice::BestAvailable) + bench(&WgpuDevice::BestAvailable) } diff --git a/burn-wgpu/benches/matmul.rs b/burn-wgpu/benches/matmul.rs index 586faf0e9c..f0a5cfa531 100644 --- a/burn-wgpu/benches/matmul.rs +++ b/burn-wgpu/benches/matmul.rs @@ -11,131 +11,129 @@ use derive_new::new; use std::marker::PhantomData; use burn_wgpu::{ - kernel::matmul::{matmul_mem_coalescing_default, matmul_naive_default}, - GraphicsApi, + kernel::matmul::{matmul_mem_coalescing_default, matmul_naive_default}, + GraphicsApi, }; type WTensor = Tensor, D>; #[derive(new)] struct MatmulBenchmark { - shape_lhs: Shape, - shape_rhs: Shape, - num_repeats: usize, - device: B::Device, - matmul: PhantomData, + shape_lhs: Shape, + shape_rhs: Shape, + num_repeats: usize, + device: B::Device, + matmul: PhantomData, } trait MatmulFunction { - fn run(lhs: WTensor, rhs: WTensor) -> WTensor; + fn run(lhs: WTensor, rhs: WTensor) -> WTensor; } impl Benchmark for MatmulBenchmark, F, D> where - F: MatmulFunction, - G: GraphicsApi, + F: MatmulFunction, + G: GraphicsApi, { - type Args = (WTensor, WTensor); + type Args = (WTensor, WTensor); - fn name(&self) -> String { - format!( - "{:?} {:?} x {:?}", - std::any::type_name::(), - self.shape_lhs.dims, - self.shape_rhs.dims - ) - } + fn name(&self) -> String { + format!( + "{:?} {:?} x {:?}", + std::any::type_name::(), + self.shape_lhs.dims, + self.shape_rhs.dims + ) + } - fn num_samples(&self) -> usize { - 10 - } + fn num_samples(&self) -> usize { + 10 + } - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.num_repeats { - F::run(lhs.clone(), rhs.clone()); - } + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.num_repeats { + F::run(lhs.clone(), rhs.clone()); } + } - fn prepare(&self) -> Self::Args { - let lhs = - WTensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); - let rhs = - WTensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + let lhs = WTensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); + let rhs = WTensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - Wgpu::::sync(&self.device) - } + fn sync(&self) { + Wgpu::::sync(&self.device) + } } macro_rules! bench_matmul { - ($benchmark:ident, $matmul_name:ident, $func:expr) => { - struct $matmul_name {} - impl MatmulFunction for $matmul_name { - fn run(lhs: WTensor, rhs: WTensor) -> WTensor { - let lhs = lhs.into_primitive(); - let rhs = rhs.into_primitive(); - let output = init_matmul_output(&lhs, &rhs); - Tensor::from_primitive($func(lhs, rhs, output)) - } - } - type $benchmark = - MatmulBenchmark, $matmul_name, D>; - }; + ($benchmark:ident, $matmul_name:ident, $func:expr) => { + struct $matmul_name {} + impl MatmulFunction for $matmul_name { + fn run(lhs: WTensor, rhs: WTensor) -> WTensor { + let lhs = lhs.into_primitive(); + let rhs = rhs.into_primitive(); + let output = init_matmul_output(&lhs, &rhs); + Tensor::from_primitive($func(lhs, rhs, output)) + } + } + type $benchmark = + MatmulBenchmark, $matmul_name, D>; + }; } bench_matmul!(NaiveMatmulBenchmark, NaiveMatmul, matmul_naive_default); bench_matmul!( - MemCoalescingMatmulBenchmark, - MemCoalescingMatmul, - matmul_mem_coalescing_default + MemCoalescingMatmulBenchmark, + MemCoalescingMatmul, + matmul_mem_coalescing_default ); bench_matmul!( - Tiling2DMatmulVec4LHSBenchmark, - Tiling2DMatmulVec4LHS, - matmul_tiling_2d_vec4_lhs + Tiling2DMatmulVec4LHSBenchmark, + Tiling2DMatmulVec4LHS, + matmul_tiling_2d_vec4_lhs ); bench_matmul!( - Tiling2DMatmulVec4Benchmark, - Tiling2DMatmulVec4, - matmul_tiling_2d_vec4 + Tiling2DMatmulVec4Benchmark, + Tiling2DMatmulVec4, + matmul_tiling_2d_vec4 ); bench_matmul!( - Tiling2DMatmulUnpaddedBenchmark, - Tiling2DMatmulUnpadded, - matmul_tiling_2d_unpadded + Tiling2DMatmulUnpaddedBenchmark, + Tiling2DMatmulUnpadded, + matmul_tiling_2d_unpadded ); #[allow(dead_code)] /// Runs the benchmarks for wgpu matmul implementations pub fn bench(device: &WgpuDevice) { - const D: usize = 3; - let num_repeats = 3; - let batch_size = 3; - let m = 1007; - let k = 1023; - let n = 1005; - let shape_lhs = Shape::new([batch_size, m, k]); - let shape_rhs = Shape::new([batch_size, k, n]); + const D: usize = 3; + let num_repeats = 3; + let batch_size = 3; + let m = 1007; + let k = 1023; + let n = 1005; + let shape_lhs = Shape::new([batch_size, m, k]); + let shape_rhs = Shape::new([batch_size, k, n]); - macro_rules! run_matmul_benchmark { - ($benchmark:ident) => { - run_benchmark($benchmark::new( - shape_lhs.clone(), - shape_rhs.clone(), - num_repeats, - device.clone(), - )); - }; - } - run_matmul_benchmark!(NaiveMatmulBenchmark); - run_matmul_benchmark!(MemCoalescingMatmulBenchmark); - run_matmul_benchmark!(Tiling2DMatmulUnpaddedBenchmark); - run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark); - run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark); + macro_rules! run_matmul_benchmark { + ($benchmark:ident) => { + run_benchmark($benchmark::new( + shape_lhs.clone(), + shape_rhs.clone(), + num_repeats, + device.clone(), + )); + }; + } + run_matmul_benchmark!(NaiveMatmulBenchmark); + run_matmul_benchmark!(MemCoalescingMatmulBenchmark); + run_matmul_benchmark!(Tiling2DMatmulUnpaddedBenchmark); + run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark); + run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark); } fn main() { - bench(&WgpuDevice::BestAvailable) + bench(&WgpuDevice::BestAvailable) } diff --git a/burn-wgpu/benches/reduction.rs b/burn-wgpu/benches/reduction.rs index 7eac3440ae..a642192ab8 100644 --- a/burn-wgpu/benches/reduction.rs +++ b/burn-wgpu/benches/reduction.rs @@ -13,96 +13,96 @@ type WTensor = Tensor, D>; #[derive(new)] struct ReduceBenchmark { - shape: Shape, - dim: usize, - num_repeats: usize, - device: B::Device, - reduce: PhantomData, + shape: Shape, + dim: usize, + num_repeats: usize, + device: B::Device, + reduce: PhantomData, } trait ReduceFunction { - fn run(input: WTensor, dim: usize) -> WTensor; + fn run(input: WTensor, dim: usize) -> WTensor; } impl Benchmark for ReduceBenchmark, F, D> where - F: ReduceFunction, - G: GraphicsApi, + F: ReduceFunction, + G: GraphicsApi, { - type Args = WTensor; - - fn name(&self) -> String { - format!( - "{:?} {:?} dim={:?}", - std::any::type_name::(), - self.shape.dims, - self.dim - ) + type Args = WTensor; + + fn name(&self) -> String { + format!( + "{:?} {:?} dim={:?}", + std::any::type_name::(), + self.shape.dims, + self.dim + ) + } + + fn num_samples(&self) -> usize { + 10 + } + + fn execute(&self, input: Self::Args) { + for _ in 0..self.num_repeats { + F::run(input.clone(), self.dim); } + } - fn num_samples(&self) -> usize { - 10 - } - - fn execute(&self, input: Self::Args) { - for _ in 0..self.num_repeats { - F::run(input.clone(), self.dim); - } - } - - fn prepare(&self) -> Self::Args { - WTensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } + fn prepare(&self) -> Self::Args { + WTensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } - fn sync(&self) { - Wgpu::::sync(&self.device) - } + fn sync(&self) { + Wgpu::::sync(&self.device) + } } macro_rules! bench_reduce { - ($benchmark:ident, $reduce_name:ident, $func:expr) => { - struct $reduce_name {} - impl ReduceFunction for $reduce_name { - fn run(input: WTensor, dim: usize) -> WTensor { - let input = input.into_primitive(); - let output = init_reduce_output(&input, dim); - Tensor::from_primitive($func(input, output, dim)) - } - } - type $benchmark = - ReduceBenchmark, $reduce_name, D>; - }; + ($benchmark:ident, $reduce_name:ident, $func:expr) => { + struct $reduce_name {} + impl ReduceFunction for $reduce_name { + fn run(input: WTensor, dim: usize) -> WTensor { + let input = input.into_primitive(); + let output = init_reduce_output(&input, dim); + Tensor::from_primitive($func(input, output, dim)) + } + } + type $benchmark = + ReduceBenchmark, $reduce_name, D>; + }; } bench_reduce!(SumDimBenchmark, SumDim, sum_dim); bench_reduce!( - SumDimSharedMemoryBenchmark, - SumDimSharedMemory, - sum_dim_shared_memory + SumDimSharedMemoryBenchmark, + SumDimSharedMemory, + sum_dim_shared_memory ); #[allow(dead_code)] /// Runs the benchmarks for wgpu matmul implementations pub fn bench(device: &WgpuDevice) { - let num_repeats = 3; - let shape = Shape::new([50, 8000, 50]); - let dim = 1; - - macro_rules! run_reduce_benchmark { - ($benchmark:ident) => { - run_benchmark($benchmark::new( - shape.clone(), - dim, - num_repeats, - device.clone(), - )); - }; - } + let num_repeats = 3; + let shape = Shape::new([50, 8000, 50]); + let dim = 1; + + macro_rules! run_reduce_benchmark { + ($benchmark:ident) => { + run_benchmark($benchmark::new( + shape.clone(), + dim, + num_repeats, + device.clone(), + )); + }; + } - run_reduce_benchmark!(SumDimSharedMemoryBenchmark); - run_reduce_benchmark!(SumDimBenchmark); + run_reduce_benchmark!(SumDimSharedMemoryBenchmark); + run_reduce_benchmark!(SumDimBenchmark); } fn main() { - bench(&WgpuDevice::BestAvailable) + bench(&WgpuDevice::BestAvailable) } diff --git a/burn-wgpu/src/backend.rs b/burn-wgpu/src/backend.rs index 39180b9a0f..d81bf2504e 100644 --- a/burn-wgpu/src/backend.rs +++ b/burn-wgpu/src/backend.rs @@ -1,8 +1,8 @@ use crate::{ - compute::compute_client, - element::{FloatElement, IntElement}, - tensor::WgpuTensor, - AutoGraphicsApi, GraphicsApi, WgpuDevice, + compute::compute_client, + element::{FloatElement, IntElement}, + tensor::WgpuTensor, + AutoGraphicsApi, GraphicsApi, WgpuDevice, }; use burn_tensor::backend::Backend; use rand::{rngs::StdRng, SeedableRng}; @@ -22,43 +22,43 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); #[derive(Debug, Default, Clone)] pub struct Wgpu where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - _g: PhantomData, - _f: PhantomData, - _i: PhantomData, + _g: PhantomData, + _f: PhantomData, + _i: PhantomData, } impl Backend for Wgpu { - type Device = WgpuDevice; - type FullPrecisionBackend = Wgpu; + type Device = WgpuDevice; + type FullPrecisionBackend = Wgpu; - type FullPrecisionElem = f32; - type FloatElem = F; - type IntElem = I; + type FullPrecisionElem = f32; + type FloatElem = F; + type IntElem = I; - type TensorPrimitive = WgpuTensor; - type IntTensorPrimitive = WgpuTensor; - type BoolTensorPrimitive = WgpuTensor; + type TensorPrimitive = WgpuTensor; + type IntTensorPrimitive = WgpuTensor; + type BoolTensorPrimitive = WgpuTensor; - fn name() -> String { - String::from("wgpu") - } + fn name() -> String { + String::from("wgpu") + } - fn seed(seed: u64) { - let rng = StdRng::seed_from_u64(seed); - let mut seed = SEED.lock().unwrap(); - *seed = Some(rng); - } + fn seed(seed: u64) { + let rng = StdRng::seed_from_u64(seed); + let mut seed = SEED.lock().unwrap(); + *seed = Some(rng); + } - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn sync(device: &Self::Device) { - let client = compute_client::(device); - client.sync(); - } + fn sync(device: &Self::Device) { + let client = compute_client::(device); + client.sync(); + } } diff --git a/burn-wgpu/src/compute/base.rs b/burn-wgpu/src/compute/base.rs index 864b40151b..b978478791 100644 --- a/burn-wgpu/src/compute/base.rs +++ b/burn-wgpu/src/compute/base.rs @@ -2,11 +2,11 @@ use super::WgpuServer; use crate::{compute::WgpuStorage, GraphicsApi, WgpuDevice}; use alloc::sync::Arc; use burn_compute::{ - channel::MutexComputeChannel, - client::ComputeClient, - memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, - tune::Tuner, - Compute, + channel::MutexComputeChannel, + client::ComputeClient, + memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, + tune::Tuner, + Compute, }; use spin::Mutex; use wgpu::DeviceDescriptor; @@ -26,207 +26,207 @@ static COMPUTE: Compute, Channel> = Com /// Get the [compute client](ComputeClient) for the given [device](WgpuDevice). pub fn compute_client(device: &WgpuDevice) -> ComputeClient { - let device = Arc::new(device); + let device = Arc::new(device); - COMPUTE.client(&device, move || { - pollster::block_on(create_client::(&device)) - }) + COMPUTE.client(&device, move || { + pollster::block_on(create_client::(&device)) + }) } /// Init the client async, necessary for wasm. pub async fn init_async(device: &WgpuDevice) { - let device = Arc::new(device); - let client = create_client::(&device).await; + let device = Arc::new(device); + let client = create_client::(&device).await; - COMPUTE.register(&device, client) + COMPUTE.register(&device, client) } async fn create_client(device: &WgpuDevice) -> ComputeClient { - let (device_wgpu, queue, info) = select_device::(device).await; - - log::info!( - "Created wgpu compute server on device {:?} => {:?}", - device, - info - ); - - // TODO: Support a way to modify max_tasks without std. - let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { - Ok(value) => value - .parse::() - .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), - Err(_) => 64, // 64 tasks by default - }; - - let device = Arc::new(device_wgpu); - let storage = WgpuStorage::new(device.clone()); - let memory_management = SimpleMemoryManagement::new( - storage, - DeallocStrategy::new_period_tick(max_tasks * 2), - SliceStrategy::Ratio(0.8), - ); - let server = WgpuServer::new(memory_management, device, queue, max_tasks); - let channel = Channel::new(server); - - ComputeClient::new(channel, Arc::new(Mutex::new(Tuner::new()))) + let (device_wgpu, queue, info) = select_device::(device).await; + + log::info!( + "Created wgpu compute server on device {:?} => {:?}", + device, + info + ); + + // TODO: Support a way to modify max_tasks without std. + let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { + Ok(value) => value + .parse::() + .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), + Err(_) => 64, // 64 tasks by default + }; + + let device = Arc::new(device_wgpu); + let storage = WgpuStorage::new(device.clone()); + let memory_management = SimpleMemoryManagement::new( + storage, + DeallocStrategy::new_period_tick(max_tasks * 2), + SliceStrategy::Ratio(0.8), + ); + let server = WgpuServer::new(memory_management, device, queue, max_tasks); + let channel = Channel::new(server); + + ComputeClient::new(channel, Arc::new(Mutex::new(Tuner::new()))) } /// Select the wgpu device and queue based on the provided [device](WgpuDevice). pub async fn select_device( - device: &WgpuDevice, + device: &WgpuDevice, ) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { - #[cfg(target_family = "wasm")] - let adapter = select_adapter::(device).await; - - #[cfg(not(target_family = "wasm"))] - let adapter = select_adapter::(device); - - let limits = adapter.limits(); - - let (device, queue) = adapter - .request_device( - &DeviceDescriptor { - label: None, - features: wgpu::Features::empty(), - limits, - }, - None, - ) - .await - .map_err(|err| { - format!( - "Unable to request the device with the adapter {:?}, err {:?}", - adapter.get_info(), - err - ) - }) - .unwrap(); - - (device, queue, adapter.get_info()) + #[cfg(target_family = "wasm")] + let adapter = select_adapter::(device).await; + + #[cfg(not(target_family = "wasm"))] + let adapter = select_adapter::(device); + + let limits = adapter.limits(); + + let (device, queue) = adapter + .request_device( + &DeviceDescriptor { + label: None, + features: wgpu::Features::empty(), + limits, + }, + None, + ) + .await + .map_err(|err| { + format!( + "Unable to request the device with the adapter {:?}, err {:?}", + adapter.get_info(), + err + ) + }) + .unwrap(); + + (device, queue, adapter.get_info()) } #[cfg(target_family = "wasm")] async fn select_adapter(_device: &WgpuDevice) -> wgpu::Adapter { - let instance = wgpu::Instance::default(); + let instance = wgpu::Instance::default(); - instance - .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) - .await - .unwrap() + instance + .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) + .await + .unwrap() } #[cfg(not(target_family = "wasm"))] fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { - use wgpu::DeviceType; - - let instance = wgpu::Instance::default(); - let mut adapters_other = Vec::new(); - let mut adapters = Vec::new(); + use wgpu::DeviceType; + + let instance = wgpu::Instance::default(); + let mut adapters_other = Vec::new(); + let mut adapters = Vec::new(); + + instance + .enumerate_adapters(G::backend().into()) + .for_each(|adapter| { + let device_type = adapter.get_info().device_type; + + if let DeviceType::Other = device_type { + adapters_other.push(adapter); + return; + } + + let is_same_type = match device { + WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, + WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, + WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, + WgpuDevice::Cpu => device_type == DeviceType::Cpu, + WgpuDevice::BestAvailable => true, + }; + + if is_same_type { + adapters.push(adapter); + } + }); + + fn select( + num: usize, + error: &str, + mut adapters: Vec, + mut adapters_other: Vec, + ) -> wgpu::Adapter { + if adapters.len() <= num { + if adapters_other.len() <= num { + panic!( + "{}, adapters {:?}, other adapters {:?}", + error, + adapters + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + adapters_other + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + ); + } else { + return adapters_other.remove(num); + } + } - instance - .enumerate_adapters(G::backend().into()) + adapters.remove(num) + } + + let adapter = match device { + WgpuDevice::DiscreteGpu(num) => select( + *num, + "No Discrete GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::IntegratedGpu(num) => select( + *num, + "No Integrated GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::VirtualGpu(num) => select( + *num, + "No Virtual GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), + WgpuDevice::BestAvailable => { + let mut most_performant_adapter = None; + let mut current_score = -1; + + adapters + .into_iter() + .chain(adapters_other) .for_each(|adapter| { - let device_type = adapter.get_info().device_type; - - if let DeviceType::Other = device_type { - adapters_other.push(adapter); - return; - } - - let is_same_type = match device { - WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, - WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, - WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, - WgpuDevice::Cpu => device_type == DeviceType::Cpu, - WgpuDevice::BestAvailable => true, - }; - - if is_same_type { - adapters.push(adapter); - } + let info = adapter.get_info(); + let score = match info.device_type { + DeviceType::DiscreteGpu => 5, + DeviceType::Other => 4, // Let's be optimistic with the Other device, it's + // often a Discrete Gpu. + DeviceType::IntegratedGpu => 3, + DeviceType::VirtualGpu => 2, + DeviceType::Cpu => 1, + }; + + if score > current_score { + most_performant_adapter = Some(adapter); + current_score = score; + } }); - fn select( - num: usize, - error: &str, - mut adapters: Vec, - mut adapters_other: Vec, - ) -> wgpu::Adapter { - if adapters.len() <= num { - if adapters_other.len() <= num { - panic!( - "{}, adapters {:?}, other adapters {:?}", - error, - adapters - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - adapters_other - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - ); - } else { - return adapters_other.remove(num); - } - } - - adapters.remove(num) + if let Some(adapter) = most_performant_adapter { + adapter + } else { + panic!("No adapter found for graphics API {:?}", G::default()); + } } + }; + + log::info!("Using adapter {:?}", adapter.get_info()); - let adapter = match device { - WgpuDevice::DiscreteGpu(num) => select( - *num, - "No Discrete GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::IntegratedGpu(num) => select( - *num, - "No Integrated GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::VirtualGpu(num) => select( - *num, - "No Virtual GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), - WgpuDevice::BestAvailable => { - let mut most_performant_adapter = None; - let mut current_score = -1; - - adapters - .into_iter() - .chain(adapters_other) - .for_each(|adapter| { - let info = adapter.get_info(); - let score = match info.device_type { - DeviceType::DiscreteGpu => 5, - DeviceType::Other => 4, // Let's be optimistic with the Other device, it's - // often a Discrete Gpu. - DeviceType::IntegratedGpu => 3, - DeviceType::VirtualGpu => 2, - DeviceType::Cpu => 1, - }; - - if score > current_score { - most_performant_adapter = Some(adapter); - current_score = score; - } - }); - - if let Some(adapter) = most_performant_adapter { - adapter - } else { - panic!("No adapter found for graphics API {:?}", G::default()); - } - } - }; - - log::info!("Using adapter {:?}", adapter.get_info()); - - adapter + adapter } diff --git a/burn-wgpu/src/compute/kernel.rs b/burn-wgpu/src/compute/kernel.rs index 332e725073..8254908ba2 100644 --- a/burn-wgpu/src/compute/kernel.rs +++ b/burn-wgpu/src/compute/kernel.rs @@ -5,102 +5,101 @@ use core::marker::PhantomData; /// Provides launch information specifying the number of work groups to be used by a compute shader. #[derive(new, Clone, Debug)] pub struct WorkGroup { - /// Work groups for the x axis. - pub x: u32, - /// Work groups for the y axis. - pub y: u32, - /// Work groups for the z axis. - pub z: u32, + /// Work groups for the x axis. + pub x: u32, + /// Work groups for the y axis. + pub y: u32, + /// Work groups for the z axis. + pub z: u32, } impl WorkGroup { - /// Calculate the number of invocations of a compute shader. - pub fn num_invocations(&self) -> usize { - (self.x * self.y * self.z) as usize - } + /// Calculate the number of invocations of a compute shader. + pub fn num_invocations(&self) -> usize { + (self.x * self.y * self.z) as usize + } } /// Wraps a [dynamic kernel source](DynamicKernelSource) into a [kernel](Kernel) with launch /// information such as [workgroup](WorkGroup). #[derive(new)] pub struct DynamicKernel { - kernel: K, - workgroup: WorkGroup, + kernel: K, + workgroup: WorkGroup, } /// Wraps a [static kernel source](StaticKernelSource) into a [kernel](Kernel) with launch /// information such as [workgroup](WorkGroup). #[derive(new)] pub struct StaticKernel { - workgroup: WorkGroup, - _kernel: PhantomData, + workgroup: WorkGroup, + _kernel: PhantomData, } impl Kernel for DynamicKernel where - K: DynamicKernelSource + 'static, + K: DynamicKernelSource + 'static, { - fn source(&self) -> SourceTemplate { - self.kernel.source() - } + fn source(&self) -> SourceTemplate { + self.kernel.source() + } - fn id(&self) -> String { - self.kernel.id() - } + fn id(&self) -> String { + self.kernel.id() + } - fn workgroup(&self) -> WorkGroup { - self.workgroup.clone() - } + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } } impl Kernel for StaticKernel where - K: StaticKernelSource + 'static, + K: StaticKernelSource + 'static, { - fn source(&self) -> SourceTemplate { - K::source() - } + fn source(&self) -> SourceTemplate { + K::source() + } - fn id(&self) -> String { - format!("{:?}", core::any::TypeId::of::()) - } + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } - fn workgroup(&self) -> WorkGroup { - self.workgroup.clone() - } + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi, - WgpuDevice, - }; + use super::*; + use crate::{ + binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi, WgpuDevice, + }; - #[test] - fn can_run_kernel() { - binary_elemwise!(Add, "+"); + #[test] + fn can_run_kernel() { + binary_elemwise!(Add, "+"); - let client = compute_client::(&WgpuDevice::default()); + let client = compute_client::(&WgpuDevice::default()); - let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; - let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; - let info: Vec = vec![1, 1, 1, 1, 8, 8, 8]; + let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; + let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; + let info: Vec = vec![1, 1, 1, 1, 8, 8, 8]; - let lhs = client.create(bytemuck::cast_slice(&lhs)); - let rhs = client.create(bytemuck::cast_slice(&rhs)); - let out = client.empty(core::mem::size_of::() * 8); - let info = client.create(bytemuck::cast_slice(&info)); + let lhs = client.create(bytemuck::cast_slice(&lhs)); + let rhs = client.create(bytemuck::cast_slice(&rhs)); + let out = client.empty(core::mem::size_of::() * 8); + let info = client.create(bytemuck::cast_slice(&info)); - type Kernel = KernelSettings; - let kernel = Box::new(StaticKernel::::new(WorkGroup::new(1, 1, 1))); + type Kernel = KernelSettings; + let kernel = Box::new(StaticKernel::::new(WorkGroup::new(1, 1, 1))); - client.execute(kernel, &[&lhs, &rhs, &out, &info]); + client.execute(kernel, &[&lhs, &rhs, &out, &info]); - let data = client.read(&out).read_sync().unwrap(); - let output: &[f32] = bytemuck::cast_slice(&data); + let data = client.read(&out).read_sync().unwrap(); + let output: &[f32] = bytemuck::cast_slice(&data); - assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); - } + assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); + } } diff --git a/burn-wgpu/src/compute/server.rs b/burn-wgpu/src/compute/server.rs index 1789f2f00c..ca47e7b0ab 100644 --- a/burn-wgpu/src/compute/server.rs +++ b/burn-wgpu/src/compute/server.rs @@ -2,35 +2,35 @@ use super::{WgpuAutotuneKey, WgpuStorage, WorkGroup}; use crate::kernel::SourceTemplate; use alloc::{borrow::Cow, sync::Arc}; use burn_compute::{ - memory_management::MemoryManagement, - server::{self, ComputeServer}, + memory_management::MemoryManagement, + server::{self, ComputeServer}, }; use burn_tensor::Reader; use hashbrown::HashMap; use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, + util::{BufferInitDescriptor, DeviceExt}, + BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, }; /// Wgpu compute server. #[derive(Debug)] pub struct WgpuServer> { - memory_management: MM, - device: Arc, - queue: wgpu::Queue, - encoder: CommandEncoder, - pipelines: HashMap>, - tasks: Vec, - max_tasks: usize, - manual_available: HashMap>>, - manual_taken: Vec<(usize, server::Handle)>, + memory_management: MM, + device: Arc, + queue: wgpu::Queue, + encoder: CommandEncoder, + pipelines: HashMap>, + tasks: Vec, + max_tasks: usize, + manual_available: HashMap>>, + manual_taken: Vec<(usize, server::Handle)>, } #[derive(new, Debug)] struct ComputeTask { - pipeline: Arc, - bind_group: BindGroup, - work_group: WorkGroup, + pipeline: Arc, + bind_group: BindGroup, + work_group: WorkGroup, } /// Kernel trait with the [source](SourceTemplate) that will be compiled and cached based on the @@ -38,308 +38,306 @@ struct ComputeTask { /// /// The kernel will be launched with the given [workgroup](WorkGroup). pub trait Kernel: 'static + Send + Sync { - /// Source template for the kernel. - fn source(&self) -> SourceTemplate; - /// Identifier for the kernel, used for caching kernel compilation. - fn id(&self) -> String; - /// Launch information. - fn workgroup(&self) -> WorkGroup; + /// Source template for the kernel. + fn source(&self) -> SourceTemplate; + /// Identifier for the kernel, used for caching kernel compilation. + fn id(&self) -> String; + /// Launch information. + fn workgroup(&self) -> WorkGroup; } impl WgpuServer where - MM: MemoryManagement, + MM: MemoryManagement, { - /// Create a new server. - pub fn new( - memory_management: MM, - device: Arc, - queue: wgpu::Queue, - max_tasks: usize, - ) -> Self { - let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Command Encoder"), - }); - - Self { - memory_management, - device, - queue, - encoder, - pipelines: HashMap::new(), - tasks: Vec::new(), - max_tasks, - manual_available: HashMap::new(), - manual_taken: Vec::new(), - } - } - - fn submit(&mut self) { - assert!( - self.tasks.is_empty(), - "Tasks should be completed before submitting the current encoder." - ); - let mut new_encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - core::mem::swap(&mut new_encoder, &mut self.encoder); - - self.queue.submit(Some(new_encoder.finish())); - - // Cleanup allocations and deallocations. - self.free_manual_allocations(); - self.memory_management.storage().perform_deallocations(); + /// Create a new server. + pub fn new( + memory_management: MM, + device: Arc, + queue: wgpu::Queue, + max_tasks: usize, + ) -> Self { + let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Command Encoder"), + }); + + Self { + memory_management, + device, + queue, + encoder, + pipelines: HashMap::new(), + tasks: Vec::new(), + max_tasks, + manual_available: HashMap::new(), + manual_taken: Vec::new(), } - - fn free_manual_allocations(&mut self) { - let mut manual_taken_tmp = Vec::new(); - core::mem::swap(&mut manual_taken_tmp, &mut self.manual_taken); - - for (size, handle) in manual_taken_tmp.drain(..) { - if handle.can_mut() { - self.register_manual(size, handle); - } else { - self.manual_taken.push((size, handle)); - } - } + } + + fn submit(&mut self) { + assert!( + self.tasks.is_empty(), + "Tasks should be completed before submitting the current encoder." + ); + let mut new_encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + core::mem::swap(&mut new_encoder, &mut self.encoder); + + self.queue.submit(Some(new_encoder.finish())); + + // Cleanup allocations and deallocations. + self.free_manual_allocations(); + self.memory_management.storage().perform_deallocations(); + } + + fn free_manual_allocations(&mut self) { + let mut manual_taken_tmp = Vec::new(); + core::mem::swap(&mut manual_taken_tmp, &mut self.manual_taken); + + for (size, handle) in manual_taken_tmp.drain(..) { + if handle.can_mut() { + self.register_manual(size, handle); + } else { + self.manual_taken.push((size, handle)); + } } - - // Finds a free, manually-added handle of specified size, or creates it if none is found - fn manual_reserve(&mut self, size: usize) -> server::Handle { - let handle = self - .manual_available - .get_mut(&size) - .and_then(|h| h.pop()) - .unwrap_or_else(|| { - let memory = self.memory_management.alloc(size); - server::Handle::new(memory) - }); - - self.manual_taken.push((size, handle.clone())); - - handle + } + + // Finds a free, manually-added handle of specified size, or creates it if none is found + fn manual_reserve(&mut self, size: usize) -> server::Handle { + let handle = self + .manual_available + .get_mut(&size) + .and_then(|h| h.pop()) + .unwrap_or_else(|| { + let memory = self.memory_management.alloc(size); + server::Handle::new(memory) + }); + + self.manual_taken.push((size, handle.clone())); + + handle + } + + // Manually adds a handle of given size + fn register_manual(&mut self, size: usize, handle: server::Handle) { + if let Some(handles) = self.manual_available.get_mut(&size) { + handles.push(handle); + } else { + self.manual_available.insert(size, [handle].into()); } + } - // Manually adds a handle of given size - fn register_manual(&mut self, size: usize, handle: server::Handle) { - if let Some(handles) = self.manual_available.get_mut(&size) { - handles.push(handle); - } else { - self.manual_available.insert(size, [handle].into()); - } + fn register_tasks(&mut self) { + if self.tasks.is_empty() { + return; } - fn register_tasks(&mut self) { - if self.tasks.is_empty() { - return; - } + let mut compute = self + .encoder + .begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); - let mut compute = self - .encoder - .begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); - - for task in self.tasks.iter() { - compute.set_pipeline(&task.pipeline); - compute.set_bind_group(0, &task.bind_group, &[]); - compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z); - } - - std::mem::drop(compute); - self.tasks.clear(); + for task in self.tasks.iter() { + compute.set_pipeline(&task.pipeline); + compute.set_bind_group(0, &task.bind_group, &[]); + compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z); } - fn pipeline(&mut self, kernel: Box) -> Arc { - let kernel_id = kernel.id(); - if let Some(pipeline) = self.pipelines.get(&kernel_id) { - return pipeline.clone(); - } - - let source = kernel.source().complete(); - log::trace!("Compiling kernel {kernel_id}:\n {source}"); - let pipeline = self.compile_source(&source); - self.pipelines.insert(kernel_id.clone(), pipeline.clone()); - - pipeline - } + std::mem::drop(compute); + self.tasks.clear(); + } - fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); - - Arc::new( - self.device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: "main", - }), - ) + fn pipeline(&mut self, kernel: Box) -> Arc { + let kernel_id = kernel.id(); + if let Some(pipeline) = self.pipelines.get(&kernel_id) { + return pipeline.clone(); } - fn buffer_reader(&mut self, handle: &server::Handle) -> BufferReader { - // Register previous tasks before reading the buffer so that it is up to date. - self.register_tasks(); - - let resource = self.memory_management.get(&handle.memory); - - let size = resource.size(); - let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, - mapped_at_creation: false, - }); - - self.encoder.copy_buffer_to_buffer( - &resource.buffer, - resource.offset(), - &buffer_dest, - 0, - size, - ); - - self.submit(); - - BufferReader::new(buffer_dest) - } + let source = kernel.source().complete(); + log::trace!("Compiling kernel {kernel_id}:\n {source}"); + let pipeline = self.compile_source(&source); + self.pipelines.insert(kernel_id.clone(), pipeline.clone()); + + pipeline + } + + fn compile_source(&self, source: &str) -> Arc { + let module = self.device.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }); + + Arc::new( + self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + }), + ) + } + + fn buffer_reader(&mut self, handle: &server::Handle) -> BufferReader { + // Register previous tasks before reading the buffer so that it is up to date. + self.register_tasks(); + + let resource = self.memory_management.get(&handle.memory); + + let size = resource.size(); + let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + self + .encoder + .copy_buffer_to_buffer(&resource.buffer, resource.offset(), &buffer_dest, 0, size); + + self.submit(); + + BufferReader::new(buffer_dest) + } } #[derive(new)] struct BufferReader { - buffer: wgpu::Buffer, + buffer: wgpu::Buffer, } impl BufferReader { - #[cfg(target_family = "wasm")] - async fn read(self, device: alloc::sync::Arc) -> Vec { - self.read_async(&device).await - } - - #[cfg(not(target_family = "wasm"))] - fn read(self, device: &wgpu::Device) -> Vec { - pollster::block_on(self.read_async(device)) - } - - async fn read_async(&self, device: &wgpu::Device) -> Vec { - let buffer_slice = self.buffer.slice(..); - let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - - device.poll(wgpu::Maintain::Wait); - - let result = receiver.receive().await; - - if let Some(Ok(())) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); - - drop(data); - self.buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) - } + #[cfg(target_family = "wasm")] + async fn read(self, device: alloc::sync::Arc) -> Vec { + self.read_async(&device).await + } + + #[cfg(not(target_family = "wasm"))] + fn read(self, device: &wgpu::Device) -> Vec { + pollster::block_on(self.read_async(device)) + } + + async fn read_async(&self, device: &wgpu::Device) -> Vec { + let buffer_slice = self.buffer.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .send(v) + .expect("Unable to send buffer slice result to async channel.") + }); + + device.poll(wgpu::Maintain::Wait); + + let result = receiver.receive().await; + + if let Some(Ok(())) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + self.buffer.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) } + } } impl ComputeServer for WgpuServer where - MM: MemoryManagement, + MM: MemoryManagement, { - type Kernel = Box; - type Storage = WgpuStorage; - type MemoryManagement = MM; - type AutotuneKey = WgpuAutotuneKey; - - fn read(&mut self, handle: &server::Handle) -> Reader> { - #[cfg(target_family = "wasm")] - { - let future = self.buffer_reader(handle).read(self.device.clone()); - return Reader::Future(Box::pin(future)); - } - - #[cfg(not(target_family = "wasm"))] - Reader::Concrete(self.buffer_reader(handle).read(&self.device)) - } + type Kernel = Box; + type Storage = WgpuStorage; + type MemoryManagement = MM; + type AutotuneKey = WgpuAutotuneKey; - /// When we create a new handle from existing data, we use custom allocations so that we don't - /// have to execute the current pending tasks. - /// - /// This is important, otherwise the compute passes are going to be too small and we won't be able to - /// fully utilize the GPU. - fn create(&mut self, data: &[u8]) -> server::Handle { - let handle = self.manual_reserve(data.len()); - - let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { - label: Some("Buffer Src"), - contents: data, - usage: wgpu::BufferUsages::COPY_SRC, - })); - - let resource = self.memory_management.get(&handle.memory); - - self.encoder.copy_buffer_to_buffer( - &buffer_src, - 0, - &resource.buffer, - resource.offset(), - buffer_src.size(), - ); - - handle + fn read(&mut self, handle: &server::Handle) -> Reader> { + #[cfg(target_family = "wasm")] + { + let future = self.buffer_reader(handle).read(self.device.clone()); + return Reader::Future(Box::pin(future)); } - fn empty(&mut self, size: usize) -> server::Handle { - server::Handle::new(self.memory_management.reserve(size)) + #[cfg(not(target_family = "wasm"))] + Reader::Concrete(self.buffer_reader(handle).read(&self.device)) + } + + /// When we create a new handle from existing data, we use custom allocations so that we don't + /// have to execute the current pending tasks. + /// + /// This is important, otherwise the compute passes are going to be too small and we won't be able to + /// fully utilize the GPU. + fn create(&mut self, data: &[u8]) -> server::Handle { + let handle = self.manual_reserve(data.len()); + + let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { + label: Some("Buffer Src"), + contents: data, + usage: wgpu::BufferUsages::COPY_SRC, + })); + + let resource = self.memory_management.get(&handle.memory); + + self.encoder.copy_buffer_to_buffer( + &buffer_src, + 0, + &resource.buffer, + resource.offset(), + buffer_src.size(), + ); + + handle + } + + fn empty(&mut self, size: usize) -> server::Handle { + server::Handle::new(self.memory_management.reserve(size)) + } + + fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle]) { + let work_group = kernel.workgroup(); + let pipeline = self.pipeline(kernel); + let group_layout = pipeline.get_bind_group_layout(0); + + let handles = handles + .iter() + .map(|handle| self.memory_management.get(&handle.memory)) + .collect::>(); + + let entries = handles + .iter() + .enumerate() + .map(|(i, buffer)| wgpu::BindGroupEntry { + binding: i as u32, + resource: buffer.as_binding(), + }) + .collect::>(); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &group_layout, + entries: &entries, + }); + + self + .tasks + .push(ComputeTask::new(pipeline, bind_group, work_group)); + + if self.tasks.len() >= self.max_tasks { + self.register_tasks(); + self.submit(); } + } - fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle]) { - let work_group = kernel.workgroup(); - let pipeline = self.pipeline(kernel); - let group_layout = pipeline.get_bind_group_layout(0); - - let handles = handles - .iter() - .map(|handle| self.memory_management.get(&handle.memory)) - .collect::>(); - - let entries = handles - .iter() - .enumerate() - .map(|(i, buffer)| wgpu::BindGroupEntry { - binding: i as u32, - resource: buffer.as_binding(), - }) - .collect::>(); - - let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { - label: None, - layout: &group_layout, - entries: &entries, - }); - - self.tasks - .push(ComputeTask::new(pipeline, bind_group, work_group)); - - if self.tasks.len() >= self.max_tasks { - self.register_tasks(); - self.submit(); - } + fn sync(&mut self) { + if !self.tasks.is_empty() { + self.register_tasks(); + self.submit(); } - fn sync(&mut self) { - if !self.tasks.is_empty() { - self.register_tasks(); - self.submit(); - } - - self.device.poll(wgpu::Maintain::Wait); - } + self.device.poll(wgpu::Maintain::Wait); + } } diff --git a/burn-wgpu/src/compute/storage.rs b/burn-wgpu/src/compute/storage.rs index ef74a927a3..11314cc14b 100644 --- a/burn-wgpu/src/compute/storage.rs +++ b/burn-wgpu/src/compute/storage.rs @@ -4,121 +4,119 @@ use std::{num::NonZeroU64, sync::Arc}; /// Buffer storage for wgpu. pub struct WgpuStorage { - memory: HashMap>, - deallocations: Vec, - device: Arc, + memory: HashMap>, + deallocations: Vec, + device: Arc, } impl core::fmt::Debug for WgpuStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) + } } /// The memory resource that can be allocated for wgpu. #[derive(new, Debug)] pub struct WgpuResource { - /// The wgpu buffer. - pub buffer: Arc, - /// How the resource is used. - pub kind: WgpuResourceKind, + /// The wgpu buffer. + pub buffer: Arc, + /// How the resource is used. + pub kind: WgpuResourceKind, } impl WgpuResource { - /// Return the binding view of the buffer. - pub fn as_binding(&self) -> wgpu::BindingResource { - let binding = match &self.kind { - WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), - WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { - buffer: &self.buffer, - offset: *offs, - size: Some(*size), - }, - }; - wgpu::BindingResource::Buffer(binding) - } - - /// Return the buffer size. - pub fn size(&self) -> u64 { - match self.kind { - WgpuResourceKind::Full => self.buffer.size(), - WgpuResourceKind::Slice(_, size) => size.get(), - } + /// Return the binding view of the buffer. + pub fn as_binding(&self) -> wgpu::BindingResource { + let binding = match &self.kind { + WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), + WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { + buffer: &self.buffer, + offset: *offs, + size: Some(*size), + }, + }; + wgpu::BindingResource::Buffer(binding) + } + + /// Return the buffer size. + pub fn size(&self) -> u64 { + match self.kind { + WgpuResourceKind::Full => self.buffer.size(), + WgpuResourceKind::Slice(_, size) => size.get(), } + } - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - match self.kind { - WgpuResourceKind::Full => 0, - WgpuResourceKind::Slice(offset, _) => offset, - } + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + match self.kind { + WgpuResourceKind::Full => 0, + WgpuResourceKind::Slice(offset, _) => offset, } + } } /// How the resource is used, either as a slice or fully. #[derive(Debug)] pub enum WgpuResourceKind { - /// Represents an entire buffer. - Full, - /// A slice over a buffer. - Slice(wgpu::BufferAddress, wgpu::BufferSize), + /// Represents an entire buffer. + Full, + /// A slice over a buffer. + Slice(wgpu::BufferAddress, wgpu::BufferSize), } /// Keeps actual wgpu buffer references in a hashmap with ids as key. impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc) -> Self { - Self { - memory: HashMap::new(), - deallocations: Vec::new(), - device, - } + /// Create a new storage on the given [device](wgpu::Device). + pub fn new(device: Arc) -> Self { + Self { + memory: HashMap::new(), + deallocations: Vec::new(), + device, } - - /// Actually deallocates buffers tagged to be deallocated. - pub fn perform_deallocations(&mut self) { - for id in self.deallocations.drain(..) { - if let Some(buffer) = self.memory.remove(&id) { - buffer.destroy() - } - } + } + + /// Actually deallocates buffers tagged to be deallocated. + pub fn perform_deallocations(&mut self) { + for id in self.deallocations.drain(..) { + if let Some(buffer) = self.memory.remove(&id) { + buffer.destroy() + } } + } } impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; - - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let buffer = self.memory.get(&handle.id).unwrap(); - - match handle.utilization { - StorageUtilization::Full(_) => { - WgpuResource::new(buffer.clone(), WgpuResourceKind::Full) - } - StorageUtilization::Slice(offset, size) => WgpuResource::new( - buffer.clone(), - WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), - ), - } - } - - fn alloc(&mut self, size: usize) -> StorageHandle { - let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size: size as u64, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC, - mapped_at_creation: false, - })); + type Resource = WgpuResource; - self.memory.insert(id.clone(), buffer); - - StorageHandle::new(id, StorageUtilization::Full(size)) - } + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let buffer = self.memory.get(&handle.id).unwrap(); - fn dealloc(&mut self, id: StorageId) { - self.deallocations.push(id); + match handle.utilization { + StorageUtilization::Full(_) => WgpuResource::new(buffer.clone(), WgpuResourceKind::Full), + StorageUtilization::Slice(offset, size) => WgpuResource::new( + buffer.clone(), + WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), + ), } + } + + fn alloc(&mut self, size: usize) -> StorageHandle { + let id = StorageId::new(); + let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: size as u64, + usage: wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + })); + + self.memory.insert(id.clone(), buffer); + + StorageHandle::new(id, StorageUtilization::Full(size)) + } + + fn dealloc(&mut self, id: StorageId) { + self.deallocations.push(id); + } } diff --git a/burn-wgpu/src/compute/tune_key.rs b/burn-wgpu/src/compute/tune_key.rs index 2b2ce25018..3015b19697 100644 --- a/burn-wgpu/src/compute/tune_key.rs +++ b/burn-wgpu/src/compute/tune_key.rs @@ -7,22 +7,22 @@ use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey}; #[derive(Hash, Eq, PartialEq, Debug, Clone)] /// Key for all autotune-enabled operations pub enum WgpuAutotuneKey { - /// Key for matmul operation - Matmul(MatmulAutotuneKey), - /// Key for sum_dim operations - SumDim(ReduceAutotuneKey), - /// Key for mean_dim operations - MeanDim(ReduceAutotuneKey), + /// Key for matmul operation + Matmul(MatmulAutotuneKey), + /// Key for sum_dim operations + SumDim(ReduceAutotuneKey), + /// Key for mean_dim operations + MeanDim(ReduceAutotuneKey), } impl Display for WgpuAutotuneKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), - WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), - WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), + WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), } + } } impl AutotuneKey for WgpuAutotuneKey {} diff --git a/burn-wgpu/src/device.rs b/burn-wgpu/src/device.rs index c9d6657fb8..1e9eacc5eb 100644 --- a/burn-wgpu/src/device.rs +++ b/burn-wgpu/src/device.rs @@ -12,39 +12,39 @@ /// ``` #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum WgpuDevice { - /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list - /// of all discrete GPUs found on the system. - DiscreteGpu(usize), + /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list + /// of all discrete GPUs found on the system. + DiscreteGpu(usize), - /// Integrated GPU with the given index. The index is the index of the integrated GPU in the - /// list of all integrated GPUs found on the system. - IntegratedGpu(usize), + /// Integrated GPU with the given index. The index is the index of the integrated GPU in the + /// list of all integrated GPUs found on the system. + IntegratedGpu(usize), - /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of - /// all virtual GPUs found on the system. - VirtualGpu(usize), + /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of + /// all virtual GPUs found on the system. + VirtualGpu(usize), - /// CPU. - Cpu, + /// CPU. + Cpu, - /// The best available device found with the current [graphics API](crate::GraphicsApi). - /// - /// Priority - /// - /// 1. DiscreteGpu - /// 2. IntegratedGpu - /// 3. VirtualGpu - /// 4. Cpu - /// - /// # Notes - /// - /// A device might be identified as [Other](wgpu::DeviceType::Other) by [wgpu](wgpu), in this case, we chose this device over - /// `IntegratedGpu` since it's often a discrete GPU. - BestAvailable, + /// The best available device found with the current [graphics API](crate::GraphicsApi). + /// + /// Priority + /// + /// 1. DiscreteGpu + /// 2. IntegratedGpu + /// 3. VirtualGpu + /// 4. Cpu + /// + /// # Notes + /// + /// A device might be identified as [Other](wgpu::DeviceType::Other) by [wgpu](wgpu), in this case, we chose this device over + /// `IntegratedGpu` since it's often a discrete GPU. + BestAvailable, } impl Default for WgpuDevice { - fn default() -> Self { - Self::BestAvailable - } + fn default() -> Self { + Self::BestAvailable + } } diff --git a/burn-wgpu/src/element.rs b/burn-wgpu/src/element.rs index b14ddfe8a8..baf8613cb4 100644 --- a/burn-wgpu/src/element.rs +++ b/burn-wgpu/src/element.rs @@ -2,13 +2,13 @@ use burn_tensor::Element; /// The base element trait for the wgou backend. pub trait WgpuElement: - burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod + burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod where - Self: Sized, + Self: Sized, { - fn type_name() -> &'static str; - fn as_bytes(slice: &[Self]) -> &[u8]; - fn from_bytes(bytes: &[u8]) -> &[Self]; + fn type_name() -> &'static str; + fn as_bytes(slice: &[Self]) -> &[u8]; + fn from_bytes(bytes: &[u8]) -> &[Self]; } /// The float element type for the wgpu backend. @@ -18,39 +18,39 @@ pub trait FloatElement: WgpuElement + Element {} pub trait IntElement: WgpuElement + Element {} impl WgpuElement for u32 { - fn type_name() -> &'static str { - "u32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } + fn type_name() -> &'static str { + "u32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } } impl WgpuElement for i32 { - fn type_name() -> &'static str { - "i32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } + fn type_name() -> &'static str { + "i32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } } impl WgpuElement for f32 { - fn type_name() -> &'static str { - "f32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } + fn type_name() -> &'static str { + "f32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } } impl FloatElement for f32 {} diff --git a/burn-wgpu/src/fusion/base.rs b/burn-wgpu/src/fusion/base.rs index b1e1f864c5..1c1274e26a 100644 --- a/burn-wgpu/src/fusion/base.rs +++ b/burn-wgpu/src/fusion/base.rs @@ -1,144 +1,144 @@ use crate::{ - compute::{WgpuComputeClient, WgpuHandle}, - element::WgpuElement, - fusion::FloatElementWiseFusionOps, - tensor::WgpuTensor, - FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice, + compute::{WgpuComputeClient, WgpuHandle}, + element::WgpuElement, + fusion::FloatElementWiseFusionOps, + tensor::WgpuTensor, + FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice, }; use burn_fusion::{ - client::MutexFusionClient, graph::GreedyGraphExecution, DeviceId, FusionBackend, FusionDevice, + client::MutexFusionClient, graph::GreedyGraphExecution, DeviceId, FusionBackend, FusionDevice, }; use burn_tensor::Shape; use core::marker::PhantomData; impl FusionDevice for WgpuDevice { - fn id(&self) -> DeviceId { - match self { - WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), - WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), - WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), - WgpuDevice::Cpu => DeviceId::new(3, 0), - WgpuDevice::BestAvailable => DeviceId::new(4, 0), - } + fn id(&self) -> DeviceId { + match self { + WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), + WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), + WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), + WgpuDevice::Cpu => DeviceId::new(3, 0), + WgpuDevice::BestAvailable => DeviceId::new(4, 0), } + } } impl FusionBackend for Wgpu where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - type FusionDevice = WgpuDevice; - type Handle = WgpuFusionHandle; - type FusionClient = MutexFusionClient; - - fn operations(device: &WgpuDevice) -> Vec>> { - vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))] - } - - fn float_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::TensorPrimitive { - handle.into_tensor(shape) - } - - fn int_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::IntTensorPrimitive { - handle.into_tensor(shape) - } - - fn bool_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::BoolTensorPrimitive { - handle.into_tensor(shape) - } - - fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle { - tensor.into() - } - - fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle { - tensor.into() - } - - fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle { - tensor.into() - } + type FusionDevice = WgpuDevice; + type Handle = WgpuFusionHandle; + type FusionClient = MutexFusionClient; + + fn operations(device: &WgpuDevice) -> Vec>> { + vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))] + } + + fn float_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::TensorPrimitive { + handle.into_tensor(shape) + } + + fn int_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::IntTensorPrimitive { + handle.into_tensor(shape) + } + + fn bool_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::BoolTensorPrimitive { + handle.into_tensor(shape) + } + + fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle { + tensor.into() + } + + fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle { + tensor.into() + } + + fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle { + tensor.into() + } } pub fn strides_dyn_rank(shape: &[usize]) -> Vec { - let mut strides = vec![0; shape.len()]; + let mut strides = vec![0; shape.len()]; - let mut current = 1; - shape.iter().enumerate().rev().for_each(|(index, val)| { - strides[index] = current; - current *= val; - }); + let mut current = 1; + shape.iter().enumerate().rev().for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); - strides + strides } pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize { - let mut num_elems = 1; - for i in shape.iter() { - num_elems *= i; - } - num_elems + let mut num_elems = 1; + for i in shape.iter() { + num_elems *= i; + } + num_elems } #[derive(new, Debug, Clone)] /// Handle to be used when fusing operations. pub struct WgpuFusionHandle { - /// Compute client for wgpu. - pub client: WgpuComputeClient, - /// The buffer where the data are stored. - pub handle: WgpuHandle, - /// The device of the current tensor. - pub device: WgpuDevice, - pub(crate) strides: Vec, + /// Compute client for wgpu. + pub client: WgpuComputeClient, + /// The buffer where the data are stored. + pub handle: WgpuHandle, + /// The device of the current tensor. + pub device: WgpuDevice, + pub(crate) strides: Vec, } impl WgpuFusionHandle { - pub(crate) fn into_tensor( - self, - shape: Shape, - ) -> WgpuTensor { - WgpuTensor { - client: self.client, - handle: self.handle, - device: self.device, - shape, - strides: self.strides.try_into().expect("Wrong dimension"), - elem: PhantomData, - } + pub(crate) fn into_tensor( + self, + shape: Shape, + ) -> WgpuTensor { + WgpuTensor { + client: self.client, + handle: self.handle, + device: self.device, + shape, + strides: self.strides.try_into().expect("Wrong dimension"), + elem: PhantomData, } + } } impl From> for WgpuFusionHandle { - fn from(value: WgpuTensor) -> Self { - Self { - client: value.client, - handle: value.handle, - device: value.device, - strides: value.strides.into(), - } + fn from(value: WgpuTensor) -> Self { + Self { + client: value.client, + handle: value.handle, + device: value.device, + strides: value.strides.into(), } + } } #[cfg(test)] mod tests { - use super::*; - use burn_fusion::Fusion; + use super::*; + use burn_fusion::Fusion; - pub type TestBackend = Fusion; - pub type TestTensor = burn_tensor::Tensor; - pub type TestTensorInt = burn_tensor::Tensor; + pub type TestBackend = Fusion; + pub type TestTensor = burn_tensor::Tensor; + pub type TestTensorInt = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!(); + burn_autodiff::testgen_all!(); } diff --git a/burn-wgpu/src/fusion/codegen/body.rs b/burn-wgpu/src/fusion/codegen/body.rs index cab35bf75d..a08cf100ad 100644 --- a/burn-wgpu/src/fusion/codegen/body.rs +++ b/burn-wgpu/src/fusion/codegen/body.rs @@ -7,21 +7,19 @@ use std::fmt::Display; /// X and Y, but with Z=1. #[derive(Hash, new)] pub struct Body { - operators: Vec, + operators: Vec, } impl Display for Body { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str( - "let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n", - )?; - f.write_str("let rank: u32 = info[0];\n\n")?; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n")?; + f.write_str("let rank: u32 = info[0];\n\n")?; - for ops in self.operators.iter() { - f.write_fmt(format_args!("{ops}"))?; - f.write_str("\n")?; - } - - Ok(()) + for ops in self.operators.iter() { + f.write_fmt(format_args!("{ops}"))?; + f.write_str("\n")?; } + + Ok(()) + } } diff --git a/burn-wgpu/src/fusion/codegen/function.rs b/burn-wgpu/src/fusion/codegen/function.rs index fceae4e399..fa3c27d1f6 100644 --- a/burn-wgpu/src/fusion/codegen/function.rs +++ b/burn-wgpu/src/fusion/codegen/function.rs @@ -4,22 +4,22 @@ use std::fmt::Display; /// Not all functions are native to WGSL, so this struct allows to support more functions. #[derive(Hash, PartialEq, Eq, Clone)] pub enum Function { - Powf(Elem), - Erf(Elem), + Powf(Elem), + Erf(Elem), } impl Display for Function { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Function::Powf(elem) => format_powf(f, elem), - Function::Erf(elem) => format_erf(f, elem), - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Function::Powf(elem) => format_powf(f, elem), + Function::Erf(elem) => format_erf(f, elem), } + } } fn format_powf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result { - f.write_fmt(format_args!( - " + f.write_fmt(format_args!( + " fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{ let modulo = rhs % 2.0; @@ -35,11 +35,11 @@ fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{ }} }} " - )) + )) } fn format_erf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result { - f.write_fmt(format_args!( + f.write_fmt(format_args!( " /// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations /// diff --git a/burn-wgpu/src/fusion/codegen/operator.rs b/burn-wgpu/src/fusion/codegen/operator.rs index 922f7cf81d..474cefeedf 100644 --- a/burn-wgpu/src/fusion/codegen/operator.rs +++ b/burn-wgpu/src/fusion/codegen/operator.rs @@ -4,133 +4,119 @@ use std::fmt::Display; /// All operators that can be fused in a WGSL compute shader. #[derive(Debug, Hash, Clone)] pub enum Operator { - Add { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Sub { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Mul { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Div { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Abs { - input: Variable, - out: Variable, - }, - Exp { - input: Variable, - out: Variable, - }, - Log { - input: Variable, - out: Variable, - }, - Log1p { - input: Variable, - out: Variable, - }, - Cos { - input: Variable, - out: Variable, - }, - Sin { - input: Variable, - out: Variable, - }, - Tanh { - input: Variable, - out: Variable, - }, - Powf { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Erf { - input: Variable, - out: Variable, - }, - Recip { - input: Variable, - out: Variable, - }, - AssignGlobal { - input: Variable, - out: Variable, - }, - ReadGlobal { - variable: Variable, - position: usize, - position_out: usize, - }, + Add { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Sub { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Mul { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Div { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Abs { + input: Variable, + out: Variable, + }, + Exp { + input: Variable, + out: Variable, + }, + Log { + input: Variable, + out: Variable, + }, + Log1p { + input: Variable, + out: Variable, + }, + Cos { + input: Variable, + out: Variable, + }, + Sin { + input: Variable, + out: Variable, + }, + Tanh { + input: Variable, + out: Variable, + }, + Powf { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Erf { + input: Variable, + out: Variable, + }, + Recip { + input: Variable, + out: Variable, + }, + AssignGlobal { + input: Variable, + out: Variable, + }, + ReadGlobal { + variable: Variable, + position: usize, + position_out: usize, + }, } impl Display for Operator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Operator::Add { lhs, rhs, out } => { - f.write_fmt(format_args!("let {out} = {lhs} + {rhs};")) - } - Operator::Sub { lhs, rhs, out } => { - f.write_fmt(format_args!("let {out} = {lhs} - {rhs};")) - } - Operator::Mul { lhs, rhs, out } => { - f.write_fmt(format_args!("let {out} = {lhs} * {rhs};")) - } - Operator::Div { lhs, rhs, out } => { - f.write_fmt(format_args!("let {out} = {lhs} / {rhs};")) - } - Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")), - Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")), - Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")), - Operator::Powf { lhs, rhs, out } => { - f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});")) - } - Operator::Log1p { input, out } => { - f.write_fmt(format_args!("let {out} = log({input} + 1.0);")) - } - Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")), - Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")), - Operator::Tanh { input, out } => { - f.write_fmt(format_args!("let {out} = tanh({input});")) - } - Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")), - Operator::Recip { input, out } => { - f.write_fmt(format_args!("let {out} = 1.0 / {input};")) - } - Operator::AssignGlobal { input, out } => { - f.write_fmt(format_args!("{out}_global[id] = {input};")) - } - Operator::ReadGlobal { - variable, - position, - position_out, - } => { - let (global, local) = match variable { - Variable::Input(number) => { - (format!("input_{number}_global"), format!("input_{number}")) - } - Variable::Local(_) => panic!("can't read globala local variable."), - Variable::Output(number) => ( - format!("output_{number}_global"), - format!("output_{number}"), - ), - Variable::Scalar(_, _) => panic!("Can't read global scalar variable."), - }; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Operator::Add { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} + {rhs};")), + Operator::Sub { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} - {rhs};")), + Operator::Mul { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} * {rhs};")), + Operator::Div { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} / {rhs};")), + Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")), + Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")), + Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")), + Operator::Powf { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});")) + } + Operator::Log1p { input, out } => { + f.write_fmt(format_args!("let {out} = log({input} + 1.0);")) + } + Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")), + Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")), + Operator::Tanh { input, out } => f.write_fmt(format_args!("let {out} = tanh({input});")), + Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")), + Operator::Recip { input, out } => f.write_fmt(format_args!("let {out} = 1.0 / {input};")), + Operator::AssignGlobal { input, out } => { + f.write_fmt(format_args!("{out}_global[id] = {input};")) + } + Operator::ReadGlobal { + variable, + position, + position_out, + } => { + let (global, local) = match variable { + Variable::Input(number) => (format!("input_{number}_global"), format!("input_{number}")), + Variable::Local(_) => panic!("can't read globala local variable."), + Variable::Output(number) => ( + format!("output_{number}_global"), + format!("output_{number}"), + ), + Variable::Scalar(_, _) => panic!("Can't read global scalar variable."), + }; - f.write_fmt(format_args!( - " + f.write_fmt(format_args!( + " var index_{local}: u32 = 0u; for (var i: u32 = 1u; i <= rank; i++) {{ @@ -146,8 +132,8 @@ for (var i: u32 = 1u; i <= rank; i++) {{ let {local} = {global}[index_{local}]; " - )) - } - } + )) + } } + } } diff --git a/burn-wgpu/src/fusion/codegen/shader.rs b/burn-wgpu/src/fusion/codegen/shader.rs index 8ce3999784..ee7f8f330b 100644 --- a/burn-wgpu/src/fusion/codegen/shader.rs +++ b/burn-wgpu/src/fusion/codegen/shader.rs @@ -1,201 +1,201 @@ use super::{Body, Function}; use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT}; use std::{ - collections::hash_map::DefaultHasher, - fmt::Display, - hash::{Hash, Hasher}, + collections::hash_map::DefaultHasher, + fmt::Display, + hash::{Hash, Hasher}, }; #[derive(Hash, PartialEq, Eq)] pub enum Location { - Storage, - #[allow(dead_code)] - Workgroup, + Storage, + #[allow(dead_code)] + Workgroup, } #[derive(Hash, PartialEq, Eq)] pub enum Visibility { - Read, - ReadWrite, + Read, + ReadWrite, } #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Elem { - F32, - #[allow(dead_code)] - I32, - U32, + F32, + #[allow(dead_code)] + I32, + U32, } #[derive(Hash, PartialEq, Eq)] pub struct Binding { - pub location: Location, - pub visibility: Visibility, - pub elem: Elem, - pub size: Option, + pub location: Location, + pub visibility: Visibility, + pub elem: Elem, + pub size: Option, } #[derive(Hash, PartialEq, Eq)] pub struct WorkgroupSize { - pub x: usize, - pub y: usize, - pub z: usize, + pub x: usize, + pub y: usize, + pub z: usize, } impl Default for WorkgroupSize { - fn default() -> Self { - Self { - x: WORKGROUP_DEFAULT, - y: WORKGROUP_DEFAULT, - z: 1, - } + fn default() -> Self { + Self { + x: WORKGROUP_DEFAULT, + y: WORKGROUP_DEFAULT, + z: 1, } + } } #[derive(Hash)] pub struct ComputeShader { - pub inputs: Vec, - pub outputs: Vec, - pub named: Vec<(String, Binding)>, - pub workgroup_size: WorkgroupSize, - pub global_invocation_id: bool, - pub num_workgroups: bool, - pub body: Body, - pub functions: Vec, + pub inputs: Vec, + pub outputs: Vec, + pub named: Vec<(String, Binding)>, + pub workgroup_size: WorkgroupSize, + pub global_invocation_id: bool, + pub num_workgroups: bool, + pub body: Body, + pub functions: Vec, } impl DynamicKernelSource for ComputeShader { - fn source(&self) -> SourceTemplate { - SourceTemplate::new(self.to_string()) - } + fn source(&self) -> SourceTemplate { + SourceTemplate::new(self.to_string()) + } - fn id(&self) -> String { - let mut s = DefaultHasher::new(); - self.hash(&mut s); + fn id(&self) -> String { + let mut s = DefaultHasher::new(); + self.hash(&mut s); - s.finish().to_string() - } + s.finish().to_string() + } } impl Display for ComputeShader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Self::format_bindings(f, "input", &self.inputs, 0)?; - Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?; - - for (i, (name, binding)) in self.named.iter().enumerate() { - Self::format_binding( - f, - name.as_str(), - binding, - self.inputs.len() + self.outputs.len() + i, - )?; - } - - f.write_fmt(format_args!( - "const WORKGROUP_SIZE_X = {}u; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Self::format_bindings(f, "input", &self.inputs, 0)?; + Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?; + + for (i, (name, binding)) in self.named.iter().enumerate() { + Self::format_binding( + f, + name.as_str(), + binding, + self.inputs.len() + self.outputs.len() + i, + )?; + } + + f.write_fmt(format_args!( + "const WORKGROUP_SIZE_X = {}u; const WORKGROUP_SIZE_Y = {}u; const WORKGROUP_SIZE_Z = {}u;\n", - self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z - ))?; + self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z + ))?; - f.write_fmt(format_args!( - " + f.write_fmt(format_args!( + " @compute @workgroup_size({}, {}, {}) fn main( ", - self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z - ))?; + self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z + ))?; - if self.global_invocation_id { - f.write_str(" @builtin(global_invocation_id) global_id: vec3,\n")?; - } + if self.global_invocation_id { + f.write_str(" @builtin(global_invocation_id) global_id: vec3,\n")?; + } - if self.num_workgroups { - f.write_str(" @builtin(num_workgroups) num_workgroups: vec3,\n")?; - } + if self.num_workgroups { + f.write_str(" @builtin(num_workgroups) num_workgroups: vec3,\n")?; + } - f.write_fmt(format_args!( - ") {{ + f.write_fmt(format_args!( + ") {{ {} }}", - self.body - ))?; - - for function in self.functions.iter() { - f.write_fmt(format_args!("{function}\n\n"))?; - } + self.body + ))?; - Ok(()) + for function in self.functions.iter() { + f.write_fmt(format_args!("{function}\n\n"))?; } + + Ok(()) + } } impl ComputeShader { - fn format_bindings( - f: &mut core::fmt::Formatter<'_>, - prefix: &str, - bindings: &[Binding], - num_entry: usize, - ) -> core::fmt::Result { - for (i, binding) in bindings.iter().enumerate() { - Self::format_binding( - f, - format!("{prefix}_{i}_global").as_str(), - binding, - num_entry + i, - )?; - } - - Ok(()) + fn format_bindings( + f: &mut core::fmt::Formatter<'_>, + prefix: &str, + bindings: &[Binding], + num_entry: usize, + ) -> core::fmt::Result { + for (i, binding) in bindings.iter().enumerate() { + Self::format_binding( + f, + format!("{prefix}_{i}_global").as_str(), + binding, + num_entry + i, + )?; } - fn format_binding( - f: &mut core::fmt::Formatter<'_>, - name: &str, - binding: &Binding, - num_entry: usize, - ) -> core::fmt::Result { - let ty = match binding.size { - Some(size) => format!("array<{}, {}>", binding.elem, size), - None => format!("array<{}>", binding.elem), - }; - - f.write_fmt(format_args!( - "@group(0) + Ok(()) + } + + fn format_binding( + f: &mut core::fmt::Formatter<'_>, + name: &str, + binding: &Binding, + num_entry: usize, + ) -> core::fmt::Result { + let ty = match binding.size { + Some(size) => format!("array<{}, {}>", binding.elem, size), + None => format!("array<{}>", binding.elem), + }; + + f.write_fmt(format_args!( + "@group(0) @binding({}) var<{}, {}> {}: {}; \n", - num_entry, binding.location, binding.visibility, name, ty - ))?; + num_entry, binding.location, binding.visibility, name, ty + ))?; - Ok(()) - } + Ok(()) + } } impl Display for Location { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Location::Storage => f.write_str("storage"), - Location::Workgroup => f.write_str("workgroup"), - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Location::Storage => f.write_str("storage"), + Location::Workgroup => f.write_str("workgroup"), } + } } impl Display for Elem { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Elem::F32 => f.write_str("f32"), - Elem::I32 => f.write_str("i32"), - Elem::U32 => f.write_str("u32"), - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Elem::F32 => f.write_str("f32"), + Elem::I32 => f.write_str("i32"), + Elem::U32 => f.write_str("u32"), } + } } impl Display for Visibility { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Visibility::Read => f.write_str("read"), - Visibility::ReadWrite => f.write_str("read_write"), - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Visibility::Read => f.write_str("read"), + Visibility::ReadWrite => f.write_str("read_write"), } + } } diff --git a/burn-wgpu/src/fusion/codegen/variable.rs b/burn-wgpu/src/fusion/codegen/variable.rs index b74c4dbb80..f827bcef6d 100644 --- a/burn-wgpu/src/fusion/codegen/variable.rs +++ b/burn-wgpu/src/fusion/codegen/variable.rs @@ -3,19 +3,19 @@ use std::fmt::Display; #[derive(Debug, Hash, Clone)] pub enum Variable { - Input(u16), - Scalar(u16, Elem), - Local(u16), - Output(u16), + Input(u16), + Scalar(u16, Elem), + Local(u16), + Output(u16), } impl Display for Variable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Variable::Input(number) => f.write_fmt(format_args!("input_{number}")), - Variable::Local(number) => f.write_fmt(format_args!("local_{number}")), - Variable::Output(number) => f.write_fmt(format_args!("output_{number}")), - Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")), - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Variable::Input(number) => f.write_fmt(format_args!("input_{number}")), + Variable::Local(number) => f.write_fmt(format_args!("local_{number}")), + Variable::Output(number) => f.write_fmt(format_args!("output_{number}")), + Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")), } + } } diff --git a/burn-wgpu/src/fusion/elemwise/ops.rs b/burn-wgpu/src/fusion/elemwise/ops.rs index 78dec96fd6..87db9e7b99 100644 --- a/burn-wgpu/src/fusion/elemwise/ops.rs +++ b/burn-wgpu/src/fusion/elemwise/ops.rs @@ -1,15 +1,15 @@ use crate::{ - fusion::codegen::{Elem, Operator, Variable}, - fusion::kernel::FusionKernel, - FloatElement, GraphicsApi, IntElement, Wgpu, + fusion::codegen::{Elem, Operator, Variable}, + fusion::kernel::FusionKernel, + FloatElement, GraphicsApi, IntElement, Wgpu, }; use burn_fusion::{ - graph::{ - BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription, - TensorOpsDescription, UnaryOpsDescription, - }, - FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, - TensorId, + graph::{ + BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription, + TensorOpsDescription, UnaryOpsDescription, + }, + FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, + TensorId, }; use burn_tensor::{Device, Element}; use hashbrown::HashMap; @@ -18,453 +18,454 @@ use std::sync::Arc; /// Fused element wise operations that are normally memory bound. pub struct FloatElementWiseFusionOps where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - pub(crate) inputs: Vec, - pub(crate) locals: HashMap, - pub(crate) tensors: HashMap, - pub(crate) scalars_f32: Vec, - pub(crate) operators: Vec, - pub(crate) properties: FusionProperties, - pub(crate) current_output_shape: Vec, - device: Device>, + pub(crate) inputs: Vec, + pub(crate) locals: HashMap, + pub(crate) tensors: HashMap, + pub(crate) scalars_f32: Vec, + pub(crate) operators: Vec, + pub(crate) properties: FusionProperties, + pub(crate) current_output_shape: Vec, + device: Device>, } impl FusionOps> - for FloatElementWiseFusionOps + for FloatElementWiseFusionOps { - fn register(&mut self, ops: Arc>>) -> FusionStatus { - match ops.as_ref() { - TensorOpsDescription::FloatOps(ops) => { - if !self.register_float(ops) { - return FusionStatus::Closed(self.properties); - } - } - TensorOpsDescription::NumericOpsFloat(ops) => { - if !self.register_numeric(ops) { - return FusionStatus::Closed(self.properties); - } - } - _ => { - return FusionStatus::Closed(self.properties); - } - }; - - self.properties.score += 1; - self.properties.ready = self.operators.len() > 1; - - FusionStatus::Open(self.properties) - } - - fn execute(&mut self, handles: &mut HandleContainer>) { - let inputs = self.input_descriptions(); - let outputs = self.output_descriptions(); - let locals = outputs - .iter() - .map(|out| *self.locals.get(&out.id).unwrap()) - .collect::>(); - - FusionKernel::new(&self.device) - .inputs(&inputs, &self.scalars_f32) - .body(&self.operators) - .outputs(&outputs, &locals) - .execute(handles); - } - - fn reset(&mut self) { - self.inputs.clear(); - self.locals.drain(); - self.tensors.clear(); - self.scalars_f32.clear(); - self.operators.clear(); - self.properties = FusionProperties::default(); - self.current_output_shape.clear(); - } - - fn len(&self) -> usize { - self.operators.len() - } + fn register(&mut self, ops: Arc>>) -> FusionStatus { + match ops.as_ref() { + TensorOpsDescription::FloatOps(ops) => { + if !self.register_float(ops) { + return FusionStatus::Closed(self.properties); + } + } + TensorOpsDescription::NumericOpsFloat(ops) => { + if !self.register_numeric(ops) { + return FusionStatus::Closed(self.properties); + } + } + _ => { + return FusionStatus::Closed(self.properties); + } + }; + + self.properties.score += 1; + self.properties.ready = self.operators.len() > 1; + + FusionStatus::Open(self.properties) + } + + fn execute(&mut self, handles: &mut HandleContainer>) { + let inputs = self.input_descriptions(); + let outputs = self.output_descriptions(); + let locals = outputs + .iter() + .map(|out| *self.locals.get(&out.id).unwrap()) + .collect::>(); + + FusionKernel::new(&self.device) + .inputs(&inputs, &self.scalars_f32) + .body(&self.operators) + .outputs(&outputs, &locals) + .execute(handles); + } + + fn reset(&mut self) { + self.inputs.clear(); + self.locals.drain(); + self.tensors.clear(); + self.scalars_f32.clear(); + self.operators.clear(); + self.properties = FusionProperties::default(); + self.current_output_shape.clear(); + } + + fn len(&self) -> usize { + self.operators.len() + } } impl FloatElementWiseFusionOps where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - pub fn new(device: Device>) -> Self { - Self { - inputs: Vec::new(), - locals: HashMap::new(), - tensors: HashMap::new(), - scalars_f32: Vec::new(), - operators: Vec::new(), - current_output_shape: Vec::new(), - properties: FusionProperties::default(), - device, - } + pub fn new(device: Device>) -> Self { + Self { + inputs: Vec::new(), + locals: HashMap::new(), + tensors: HashMap::new(), + scalars_f32: Vec::new(), + operators: Vec::new(), + current_output_shape: Vec::new(), + properties: FusionProperties::default(), + device, } - - fn input_descriptions(&self) -> Vec<&TensorDescription> { - self.inputs - .iter() - .map(|input| { - let updated_tensor = self.tensors.get(&input.id).unwrap(); - updated_tensor - }) - .collect::>() + } + + fn input_descriptions(&self) -> Vec<&TensorDescription> { + self + .inputs + .iter() + .map(|input| { + let updated_tensor = self.tensors.get(&input.id).unwrap(); + updated_tensor + }) + .collect::>() + } + + fn output_descriptions(&self) -> Vec<&TensorDescription> { + let mut outputs = Vec::new(); + let mut local_tensor_ids_input = Vec::new(); + let mut local_tensor_ids_output = Vec::new(); + + // Mark a variable to the provided list of tensor ids using the variable list. + // + // Only local variables can become outputs. + let mark = |var: &Variable, list: &mut Vec| { + if let Variable::Local(index) = var { + if let Some((id, _)) = self + .locals + .iter() + .find(|(_id, position)| *position == index) + { + if !list.contains(id) { + list.push(id.clone()); + } + } + } + }; + + // For all operators, mark their local tensor id in the proper set. + for ops in self.operators.iter() { + match ops { + Operator::AssignGlobal { input: _, out: _ } => { + // Nothing to do here. + } + Operator::ReadGlobal { + variable: _, + position: _, + position_out: _, + } => { + // Nothing to do here. + } + Operator::Add { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Sub { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Mul { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Div { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Exp { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Abs { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Erf { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Log { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Log1p { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Cos { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Sin { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Tanh { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Powf { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Recip { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + } } - fn output_descriptions(&self) -> Vec<&TensorDescription> { - let mut outputs = Vec::new(); - let mut local_tensor_ids_input = Vec::new(); - let mut local_tensor_ids_output = Vec::new(); - - // Mark a variable to the provided list of tensor ids using the variable list. - // - // Only local variables can become outputs. - let mark = |var: &Variable, list: &mut Vec| { - if let Variable::Local(index) = var { - if let Some((id, _)) = self - .locals - .iter() - .find(|(_id, position)| *position == index) - { - if !list.contains(id) { - list.push(id.clone()); - } - } - } - }; - - // For all operators, mark their local tensor id in the proper set. - for ops in self.operators.iter() { - match ops { - Operator::AssignGlobal { input: _, out: _ } => { - // Nothing to do here. - } - Operator::ReadGlobal { - variable: _, - position: _, - position_out: _, - } => { - // Nothing to do here. - } - Operator::Add { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Sub { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Mul { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Div { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Exp { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Abs { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Erf { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Log { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Log1p { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Cos { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Sin { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Tanh { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Powf { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Recip { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - } - } + // All output tensors that are never read by a following operation should be written to + // since they are essentially the "logical" output of the shader. + for out in local_tensor_ids_output { + let is_read = local_tensor_ids_input.contains(&out); - // All output tensors that are never read by a following operation should be written to - // since they are essentially the "logical" output of the shader. - for out in local_tensor_ids_output { - let is_read = local_tensor_ids_input.contains(&out); + if !is_read { + outputs.push(self.tensors.get(&out).unwrap()); + } + } - if !is_read { - outputs.push(self.tensors.get(&out).unwrap()); - } + // All tensors where their latest description is read only should be written to since they + // are going to be used after the fused kernel by other operations. + for tensor in self.tensors.values() { + if let burn_fusion::TensorStatus::ReadOnly = tensor.status { + if self.locals.contains_key(&tensor.id) { + outputs.push(tensor); } + } + } - // All tensors where their latest description is read only should be written to since they - // are going to be used after the fused kernel by other operations. - for tensor in self.tensors.values() { - if let burn_fusion::TensorStatus::ReadOnly = tensor.status { - if self.locals.contains_key(&tensor.id) { - outputs.push(tensor); - } - } + outputs + } + + fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable { + let already_exists = self.tensors.contains_key(&tensor.id); + + let variable = match already_exists { + false => { + // New input + let var = Variable::Input(self.inputs.len() as u16); + self.inputs.push(tensor.clone()); + var + } + true => match self.locals.get(&tensor.id) { + // Is a local variable. + Some(local_index) => Variable::Local(*local_index), + // Isn't a local variable, so must be an existing input. + None => { + let input = self + .inputs + .iter() + .enumerate() + .find(|(_, input)| input.id == tensor.id) + .unwrap(); + let input_index = input.0; + Variable::Input(input_index as u16) } + }, + }; - outputs - } - - fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable { - let already_exists = self.tensors.contains_key(&tensor.id); - - let variable = match already_exists { - false => { - // New input - let var = Variable::Input(self.inputs.len() as u16); - self.inputs.push(tensor.clone()); - var - } - true => match self.locals.get(&tensor.id) { - // Is a local variable. - Some(local_index) => Variable::Local(*local_index), - // Isn't a local variable, so must be an existing input. - None => { - let input = self - .inputs - .iter() - .enumerate() - .find(|(_, input)| input.id == tensor.id) - .unwrap(); - let input_index = input.0; - Variable::Input(input_index as u16) - } - }, - }; - - // Update the tensor description with the new version. - self.tensors.insert(tensor.id.clone(), tensor.clone()); - - variable - } + // Update the tensor description with the new version. + self.tensors.insert(tensor.id.clone(), tensor.clone()); - fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable { - // Update the tensor description to the new version. - self.tensors.insert(tensor.id.clone(), tensor.clone()); + variable + } - // Output already registered as a local variable. - if let Some(index) = self.locals.get(&tensor.id) { - return Variable::Local(*index); - } + fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable { + // Update the tensor description to the new version. + self.tensors.insert(tensor.id.clone(), tensor.clone()); - // New local variable. - let local_index = self.locals.len() as u16; - self.locals.insert(tensor.id.clone(), local_index); - Variable::Local(local_index) + // Output already registered as a local variable. + if let Some(index) = self.locals.get(&tensor.id) { + return Variable::Local(*index); } - fn register_float(&mut self, ops: &FloatOpsDescription) -> bool { - match ops { - FloatOpsDescription::Exp(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Exp { input, out }) - } - FloatOpsDescription::Log(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Log { input, out }) - } - FloatOpsDescription::Log1p(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out }) - } - FloatOpsDescription::Cos(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Cos { input, out }) - } - FloatOpsDescription::Sin(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Sin { input, out }) - } - FloatOpsDescription::Powf(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out }) - } - FloatOpsDescription::Tanh(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out }) - } - FloatOpsDescription::Erf(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Erf { input, out }) - } - FloatOpsDescription::Recip(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Recip { input, out }) - } - _ => false, - } + // New local variable. + let local_index = self.locals.len() as u16; + self.locals.insert(tensor.id.clone(), local_index); + Variable::Local(local_index) + } + + fn register_float(&mut self, ops: &FloatOpsDescription) -> bool { + match ops { + FloatOpsDescription::Exp(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Exp { input, out }) + } + FloatOpsDescription::Log(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Log { input, out }) + } + FloatOpsDescription::Log1p(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out }) + } + FloatOpsDescription::Cos(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Cos { input, out }) + } + FloatOpsDescription::Sin(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Sin { input, out }) + } + FloatOpsDescription::Powf(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out }) + } + FloatOpsDescription::Tanh(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out }) + } + FloatOpsDescription::Erf(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Erf { input, out }) + } + FloatOpsDescription::Recip(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Recip { input, out }) + } + _ => false, } - - fn register_numeric( - &mut self, - ops: &NumericOpsDescription, - ) -> bool { - match ops { - NumericOpsDescription::Add(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) - } - NumericOpsDescription::AddScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) - } - NumericOpsDescription::Sub(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) - } - NumericOpsDescription::SubScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) - } - NumericOpsDescription::Mul(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) - } - NumericOpsDescription::MulScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) - } - NumericOpsDescription::Div(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) - } - NumericOpsDescription::DivScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) - } - NumericOpsDescription::Abs(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Abs { input, out }) - } - _ => false, - } + } + + fn register_numeric( + &mut self, + ops: &NumericOpsDescription, + ) -> bool { + match ops { + NumericOpsDescription::Add(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) + } + NumericOpsDescription::AddScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) + } + NumericOpsDescription::Sub(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) + } + NumericOpsDescription::SubScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) + } + NumericOpsDescription::Mul(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) + } + NumericOpsDescription::MulScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) + } + NumericOpsDescription::Div(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) + } + NumericOpsDescription::DivScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) + } + NumericOpsDescription::Abs(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Abs { input, out }) + } + _ => false, + } + } + + fn register_binary_ops(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool + where + Func: Fn(Variable, Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; } - fn register_binary_ops(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool - where - Func: Fn(Variable, Variable, Variable) -> Operator, - { - if !self.output_is_compatible(&desc.out) { - return false; - } + let lhs = self.input_to_var(&desc.lhs); + let rhs = self.input_to_var(&desc.rhs); + let out = self.output_to_var(&desc.out); - let lhs = self.input_to_var(&desc.lhs); - let rhs = self.input_to_var(&desc.rhs); - let out = self.output_to_var(&desc.out); + self.operators.push(func(lhs, rhs, out)); - self.operators.push(func(lhs, rhs, out)); + true + } - true + fn register_unary_ops(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool + where + Func: Fn(Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; } - fn register_unary_ops(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool - where - Func: Fn(Variable, Variable) -> Operator, - { - if !self.output_is_compatible(&desc.out) { - return false; - } + let input = self.input_to_var(&desc.input); + let out = self.output_to_var(&desc.out); - let input = self.input_to_var(&desc.input); - let out = self.output_to_var(&desc.out); + self.operators.push(func(input, out)); - self.operators.push(func(input, out)); + true + } - true + fn register_scalar_ops( + &mut self, + desc: &ScalarOpsDescription, + func: Func, + ) -> bool + where + Func: Fn(Variable, Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; } - fn register_scalar_ops( - &mut self, - desc: &ScalarOpsDescription, - func: Func, - ) -> bool - where - Func: Fn(Variable, Variable, Variable) -> Operator, - { - if !self.output_is_compatible(&desc.out) { - return false; - } + let lhs = self.input_to_var(&desc.lhs); + let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32); + self.scalars_f32.push(desc.rhs.elem()); + let out = self.output_to_var(&desc.out); - let lhs = self.input_to_var(&desc.lhs); - let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32); - self.scalars_f32.push(desc.rhs.elem()); - let out = self.output_to_var(&desc.out); + self.operators.push(func(lhs, rhs, out)); - self.operators.push(func(lhs, rhs, out)); + true + } - true + fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { + if self.current_output_shape.is_empty() { + self.current_output_shape = out.shape.clone(); + } else if self.current_output_shape != out.shape { + return false; } - fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { - if self.current_output_shape.is_empty() { - self.current_output_shape = out.shape.clone(); - } else if self.current_output_shape != out.shape { - return false; - } - - true - } + true + } } #[cfg(test)] mod tests { - use super::*; - use burn_fusion::graph::{BinaryOpsDescription, Ops}; - use burn_fusion::Fusion; - use burn_tensor::Tensor; + use super::*; + use burn_fusion::graph::{BinaryOpsDescription, Ops}; + use burn_fusion::Fusion; + use burn_tensor::Tensor; - struct FakeAddOps; + struct FakeAddOps; - impl Ops for FakeAddOps { - type Args = BinaryOpsDescription; - - fn execute(&self, _: &Self::Args, _: &mut HandleContainer) { - todo!() - } - } + impl Ops for FakeAddOps { + type Args = BinaryOpsDescription; - #[test] - fn test_fusion_same_behavior() { - type Backend = Wgpu; - type FusedBackend = Fusion; - - let data_1 = - Tensor::::random([1, 32], burn_tensor::Distribution::Default).into_data(); - let data_2 = - Tensor::::random([32, 32], burn_tensor::Distribution::Default).into_data(); - - let tensor_1 = Tensor::::from_data(data_1.clone()); - let tensor_2 = Tensor::::from_data(data_2.clone()); - let tensor_3 = tensor_1.clone() + tensor_2; - let tensor_4 = tensor_3.clone() - tensor_1; - let tensor_5 = tensor_4 + 5.0; - let tensor_6 = tensor_5 + tensor_3; - let result_ref = tensor_6.recip().into_data(); - - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - let tensor_3 = tensor_1.clone() + tensor_2; - let tensor_4 = tensor_3.clone() - tensor_1; - let tensor_5 = tensor_4 + 5.0; - let tensor_6 = tensor_5 + tensor_3; - let result_fused = tensor_6.recip().into_data(); - - result_fused.assert_approx_eq(&result_ref, 3); + fn execute(&self, _: &Self::Args, _: &mut HandleContainer) { + todo!() } + } + + #[test] + fn test_fusion_same_behavior() { + type Backend = Wgpu; + type FusedBackend = Fusion; + + let data_1 = + Tensor::::random([1, 32], burn_tensor::Distribution::Default).into_data(); + let data_2 = + Tensor::::random([32, 32], burn_tensor::Distribution::Default).into_data(); + + let tensor_1 = Tensor::::from_data(data_1.clone()); + let tensor_2 = Tensor::::from_data(data_2.clone()); + let tensor_3 = tensor_1.clone() + tensor_2; + let tensor_4 = tensor_3.clone() - tensor_1; + let tensor_5 = tensor_4 + 5.0; + let tensor_6 = tensor_5 + tensor_3; + let result_ref = tensor_6.recip().into_data(); + + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + let tensor_3 = tensor_1.clone() + tensor_2; + let tensor_4 = tensor_3.clone() - tensor_1; + let tensor_5 = tensor_4 + 5.0; + let tensor_6 = tensor_5 + tensor_3; + let result_fused = tensor_6.recip().into_data(); + + result_fused.assert_approx_eq(&result_ref, 3); + } } diff --git a/burn-wgpu/src/fusion/kernel.rs b/burn-wgpu/src/fusion/kernel.rs index 31a0353cb7..d1abc04111 100644 --- a/burn-wgpu/src/fusion/kernel.rs +++ b/burn-wgpu/src/fusion/kernel.rs @@ -3,10 +3,10 @@ use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient}; use crate::fusion::codegen::Function; use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank}; use crate::fusion::{ - codegen::{ - Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize, - }, - WgpuFusionHandle, + codegen::{ + Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize, + }, + WgpuFusionHandle, }; use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT}; use crate::{FloatElement, GraphicsApi, IntElement, Wgpu}; @@ -42,279 +42,279 @@ pub struct ExecutionPhase; /// handles provided. pub struct FusionKernel where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - operations: Vec, - input_bindings: Vec<(Binding, TensorDescription)>, - output_bindings: Vec<(Binding, TensorDescription)>, - named_bindings: Vec<(String, Binding, DataBuffer)>, - functions: Vec, - num_elems_output: usize, - device: Device>, - client: WgpuComputeClient, - _phase: PhantomData, + operations: Vec, + input_bindings: Vec<(Binding, TensorDescription)>, + output_bindings: Vec<(Binding, TensorDescription)>, + named_bindings: Vec<(String, Binding, DataBuffer)>, + functions: Vec, + num_elems_output: usize, + device: Device>, + client: WgpuComputeClient, + _phase: PhantomData, } enum DataBuffer { - F32(Vec), - U32(Vec), + F32(Vec), + U32(Vec), } impl FusionKernel { - /// Create a new fusion kernel on the given device. - pub fn new(device: &Device>) -> Self { - let client = compute_client::(device); + /// Create a new fusion kernel on the given device. + pub fn new(device: &Device>) -> Self { + let client = compute_client::(device); - Self { - operations: Vec::new(), - input_bindings: Vec::new(), - output_bindings: Vec::new(), - named_bindings: Vec::new(), - functions: Vec::new(), - num_elems_output: 0, - device: device.clone(), - client, - _phase: PhantomData, - } + Self { + operations: Vec::new(), + input_bindings: Vec::new(), + output_bindings: Vec::new(), + named_bindings: Vec::new(), + functions: Vec::new(), + num_elems_output: 0, + device: device.clone(), + client, + _phase: PhantomData, } + } - /// Register the inputs used by the kernel. - pub fn inputs( - mut self, - inputs_tensor: &[&TensorDescription], - inputs_scalar_f32: &[f32], - ) -> FusionKernel { - for (i, input) in inputs_tensor.iter().enumerate() { - self.input_bindings.push(( - Binding { - elem: Elem::F32, - visibility: Visibility::Read, - location: Location::Storage, - size: None, - }, - (*input).clone(), - )); + /// Register the inputs used by the kernel. + pub fn inputs( + mut self, + inputs_tensor: &[&TensorDescription], + inputs_scalar_f32: &[f32], + ) -> FusionKernel { + for (i, input) in inputs_tensor.iter().enumerate() { + self.input_bindings.push(( + Binding { + elem: Elem::F32, + visibility: Visibility::Read, + location: Location::Storage, + size: None, + }, + (*input).clone(), + )); - self.operations.push(Operator::ReadGlobal { - variable: Variable::Input(i as u16), - position: i, - position_out: inputs_tensor.len(), // First output - }); - } + self.operations.push(Operator::ReadGlobal { + variable: Variable::Input(i as u16), + position: i, + position_out: inputs_tensor.len(), // First output + }); + } - if !inputs_scalar_f32.is_empty() { - self.named_bindings.push(( - "scalars_f32".to_string(), - Binding { - elem: Elem::F32, - visibility: Visibility::Read, - location: Location::Storage, - size: Some(inputs_scalar_f32.len()), - }, - DataBuffer::F32(inputs_scalar_f32.to_vec()), - )); - } + if !inputs_scalar_f32.is_empty() { + self.named_bindings.push(( + "scalars_f32".to_string(), + Binding { + elem: Elem::F32, + visibility: Visibility::Read, + location: Location::Storage, + size: Some(inputs_scalar_f32.len()), + }, + DataBuffer::F32(inputs_scalar_f32.to_vec()), + )); + } - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, - } + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, } + } } impl FusionKernel { - /// Register the [operators](Operator) that the kernel must execute in the order provided. - pub fn body(mut self, operators: &[Operator]) -> FusionKernel { - let mut register_function = |function: Function| { - if !self.functions.contains(&function) { - self.functions.push(function); - } - }; + /// Register the [operators](Operator) that the kernel must execute in the order provided. + pub fn body(mut self, operators: &[Operator]) -> FusionKernel { + let mut register_function = |function: Function| { + if !self.functions.contains(&function) { + self.functions.push(function); + } + }; - // Since not all operators are native to WGSL, we need to add the custom ones. - for ops in operators.iter() { - match ops { - Operator::Powf { - lhs: _, - rhs: _, - out: _, - } => { - register_function(Function::Powf(Elem::F32)); - } - Operator::Erf { input: _, out: _ } => { - register_function(Function::Erf(Elem::F32)); - } - _ => {} - } - self.operations.push(ops.clone()); + // Since not all operators are native to WGSL, we need to add the custom ones. + for ops in operators.iter() { + match ops { + Operator::Powf { + lhs: _, + rhs: _, + out: _, + } => { + register_function(Function::Powf(Elem::F32)); } - - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, + Operator::Erf { input: _, out: _ } => { + register_function(Function::Erf(Elem::F32)); } + _ => {} + } + self.operations.push(ops.clone()); } + + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, + } + } } impl FusionKernel { - /// Register the outputs with their local variable index. - /// - /// Note that the index corresponds to the registered [operator](Operator) number at the - /// [body phase](BodyPhase). - /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). - pub fn outputs( - mut self, - outputs: &[&TensorDescription], - locals: &[u16], - ) -> FusionKernel { - let mut num_elems_launch_option = 0; + /// Register the outputs with their local variable index. + /// + /// Note that the index corresponds to the registered [operator](Operator) number at the + /// [body phase](BodyPhase). + /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). + pub fn outputs( + mut self, + outputs: &[&TensorDescription], + locals: &[u16], + ) -> FusionKernel { + let mut num_elems_launch_option = 0; - for (i, (output, local)) in outputs.iter().zip(locals).enumerate() { - let num_elems_output = calculate_num_elems_dyn_rank(&output.shape); - if num_elems_output > num_elems_launch_option { - num_elems_launch_option = num_elems_output; - } + for (i, (output, local)) in outputs.iter().zip(locals).enumerate() { + let num_elems_output = calculate_num_elems_dyn_rank(&output.shape); + if num_elems_output > num_elems_launch_option { + num_elems_launch_option = num_elems_output; + } - self.output_bindings.push(( - Binding { - elem: Elem::F32, - visibility: Visibility::ReadWrite, - location: Location::Storage, - size: None, - }, - (*output).clone(), - )); + self.output_bindings.push(( + Binding { + elem: Elem::F32, + visibility: Visibility::ReadWrite, + location: Location::Storage, + size: None, + }, + (*output).clone(), + )); - self.operations.push(Operator::AssignGlobal { - input: Variable::Local(*local), - out: Variable::Output(i as u16), - }); - } + self.operations.push(Operator::AssignGlobal { + input: Variable::Local(*local), + out: Variable::Output(i as u16), + }); + } - self.num_elems_output = num_elems_launch_option; + self.num_elems_output = num_elems_launch_option; - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, - } + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, } + } } impl FusionKernel { - /// Execute the kernel on the provided [handles](HandleContainer). - pub fn execute(mut self, handle_container: &mut HandleContainer>) { - let mut inputs = Vec::with_capacity(self.input_bindings.len()); - let mut outputs = Vec::with_capacity(self.output_bindings.len()); - let mut named = Vec::with_capacity(2); - let mut info = Vec::new(); - let mut handles = - Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity()); + /// Execute the kernel on the provided [handles](HandleContainer). + pub fn execute(mut self, handle_container: &mut HandleContainer>) { + let mut inputs = Vec::with_capacity(self.input_bindings.len()); + let mut outputs = Vec::with_capacity(self.output_bindings.len()); + let mut named = Vec::with_capacity(2); + let mut info = Vec::new(); + let mut handles = Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity()); - // Inner function to fill the info buffer. - let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| { - if info.is_empty() { - info.push(handle.strides.len() as u32); - } + // Inner function to fill the info buffer. + let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| { + if info.is_empty() { + info.push(handle.strides.len() as u32); + } - for s in handle.strides.iter() { - info.push(*s as u32); - } - for s in tensor.shape.iter() { - info.push(*s as u32); - } - }; + for s in handle.strides.iter() { + info.push(*s as u32); + } + for s in tensor.shape.iter() { + info.push(*s as u32); + } + }; - // We start by registering the inputs. - for (binding, tensor) in self.input_bindings.into_iter() { - let handle = handle_container.get_handle(&tensor); - register_info_tensor(&tensor, &handle); + // We start by registering the inputs. + for (binding, tensor) in self.input_bindings.into_iter() { + let handle = handle_container.get_handle(&tensor); + register_info_tensor(&tensor, &handle); - inputs.push(binding); - handles.push(handle.handle); - } + inputs.push(binding); + handles.push(handle.handle); + } - // Then we follow with the outputs. - for (binding, tensor) in self.output_bindings { - let num_elems = calculate_num_elems_dyn_rank(&tensor.shape); - let handle_fusion = WgpuFusionHandle { - client: self.client.clone(), - device: self.device.clone(), - strides: strides_dyn_rank(&tensor.shape), - handle: self.client.empty(core::mem::size_of::() * num_elems), - }; - register_info_tensor(&tensor, &handle_fusion); + // Then we follow with the outputs. + for (binding, tensor) in self.output_bindings { + let num_elems = calculate_num_elems_dyn_rank(&tensor.shape); + let handle_fusion = WgpuFusionHandle { + client: self.client.clone(), + device: self.device.clone(), + strides: strides_dyn_rank(&tensor.shape), + handle: self.client.empty(core::mem::size_of::() * num_elems), + }; + register_info_tensor(&tensor, &handle_fusion); - handles.push(handle_fusion.handle.clone()); - handle_container.register_handle(tensor.id, handle_fusion); - outputs.push(binding); - } + handles.push(handle_fusion.handle.clone()); + handle_container.register_handle(tensor.id, handle_fusion); + outputs.push(binding); + } - // Now we can create the info handle. - Self::build_info_handle(&mut self.named_bindings, info); + // Now we can create the info handle. + Self::build_info_handle(&mut self.named_bindings, info); - // Finally we finish with the named bindings. - for (name, binding, data) in self.named_bindings { - let handle = self.client.create(match &data { - DataBuffer::F32(values) => bytemuck::cast_slice(values), - DataBuffer::U32(values) => bytemuck::cast_slice(values), - }); - named.push((name, binding)); - handles.push(handle); - } + // Finally we finish with the named bindings. + for (name, binding, data) in self.named_bindings { + let handle = self.client.create(match &data { + DataBuffer::F32(values) => bytemuck::cast_slice(values), + DataBuffer::U32(values) => bytemuck::cast_slice(values), + }); + named.push((name, binding)); + handles.push(handle); + } - // We create the shader codegen type and launch the kernel. - let kernel = ComputeShader { - inputs, - outputs, - named, - workgroup_size: WorkgroupSize::default(), - body: Body::new(self.operations), - num_workgroups: true, - global_invocation_id: true, - functions: self.functions, - }; + // We create the shader codegen type and launch the kernel. + let kernel = ComputeShader { + inputs, + outputs, + named, + workgroup_size: WorkgroupSize::default(), + body: Body::new(self.operations), + num_workgroups: true, + global_invocation_id: true, + functions: self.functions, + }; - let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT); - let kernel = Box::new(DynamicKernel::new(kernel, workgroup)); + let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT); + let kernel = Box::new(DynamicKernel::new(kernel, workgroup)); - self.client - .execute(kernel, &handles.iter().collect::>()); - } + self + .client + .execute(kernel, &handles.iter().collect::>()); + } - fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec) { - named_bindings.push(( - "info".to_string(), - Binding { - elem: Elem::U32, - visibility: Visibility::Read, - location: Location::Storage, - size: None, // We avoid putting the length here since it will force a new kernel - // for each tensor rank. - }, - DataBuffer::U32(info), - )); - } + fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec) { + named_bindings.push(( + "info".to_string(), + Binding { + elem: Elem::U32, + visibility: Visibility::Read, + location: Location::Storage, + size: None, // We avoid putting the length here since it will force a new kernel + // for each tensor rank. + }, + DataBuffer::U32(info), + )); + } } diff --git a/burn-wgpu/src/graphics.rs b/burn-wgpu/src/graphics.rs index c3b3c01f19..8d709c6e4b 100644 --- a/burn-wgpu/src/graphics.rs +++ b/burn-wgpu/src/graphics.rs @@ -8,8 +8,8 @@ /// - [DirectX 12](Dx12) /// - [WebGpu](WebGpu) pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static { - /// The wgpu backend. - fn backend() -> wgpu::Backend; + /// The wgpu backend. + fn backend() -> wgpu::Backend; } /// Vulkan graphics API. @@ -41,46 +41,46 @@ pub struct WebGpu; pub struct AutoGraphicsApi; impl GraphicsApi for Vulkan { - fn backend() -> wgpu::Backend { - wgpu::Backend::Vulkan - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Vulkan + } } impl GraphicsApi for Metal { - fn backend() -> wgpu::Backend { - wgpu::Backend::Metal - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Metal + } } impl GraphicsApi for OpenGl { - fn backend() -> wgpu::Backend { - wgpu::Backend::Gl - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Gl + } } impl GraphicsApi for Dx11 { - fn backend() -> wgpu::Backend { - wgpu::Backend::Dx11 - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Dx11 + } } impl GraphicsApi for Dx12 { - fn backend() -> wgpu::Backend { - wgpu::Backend::Dx12 - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Dx12 + } } impl GraphicsApi for WebGpu { - fn backend() -> wgpu::Backend { - wgpu::Backend::BrowserWebGpu - } + fn backend() -> wgpu::Backend { + wgpu::Backend::BrowserWebGpu + } } impl GraphicsApi for AutoGraphicsApi { - fn backend() -> wgpu::Backend { - #[cfg(target_os = "macos")] - return wgpu::Backend::Metal; - #[cfg(not(target_os = "macos"))] - wgpu::Backend::Vulkan - } + fn backend() -> wgpu::Backend { + #[cfg(target_os = "macos")] + return wgpu::Backend::Metal; + #[cfg(not(target_os = "macos"))] + wgpu::Backend::Vulkan + } } diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs index 133015f49d..aca2333233 100644 --- a/burn-wgpu/src/kernel/base.rs +++ b/burn-wgpu/src/kernel/base.rs @@ -1,9 +1,9 @@ use super::SourceTemplate; use crate::{ - compute::{StaticKernel, WgpuComputeClient, WgpuHandle, WorkGroup}, - element::WgpuElement, - kernel, - tensor::WgpuTensor, + compute::{StaticKernel, WgpuComputeClient, WgpuHandle, WorkGroup}, + element::WgpuElement, + kernel, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -14,169 +14,169 @@ pub(crate) const WORKGROUP_DEFAULT: usize = 32; /// Static wgpu kernel to create a [source template](SourceTemplate). pub trait StaticKernelSource: Send + 'static + Sync { - /// Source template for the kernel. - fn source() -> SourceTemplate; + /// Source template for the kernel. + fn source() -> SourceTemplate; } /// Dynamic wgpu kernel to create a [source template](SourceTemplate). pub trait DynamicKernelSource: Send + Sync { - /// Source template for the kernel. - fn source(&self) -> SourceTemplate; - /// Identifier for the kernel, used for caching kernel compilation. - fn id(&self) -> String; + /// Source template for the kernel. + fn source(&self) -> SourceTemplate; + /// Identifier for the kernel, used for caching kernel compilation. + fn id(&self) -> String; } /// Generates kernel source code by replacing some information using templating. #[macro_export] macro_rules! kernel_wgsl { - ( + ( $struct:ident, $file:expr ) => { - /// Generated kernel from wgsl file. - #[derive(new)] - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::SourceTemplate::new(include_str!($file)) - } - } - }; + /// Generated kernel from wgsl file. + #[derive(new)] + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::SourceTemplate::new(include_str!($file)) + } + } + }; } kernel_wgsl!(ContiguousRaw, "../template/contiguous.wgsl"); /// Make a wgpu tensor contiguous. pub fn into_contiguous( - tensor: WgpuTensor, + tensor: WgpuTensor, ) -> WgpuTensor { - if tensor.is_contiguous() { - return tensor; - } - - let num_elems = tensor.shape.num_elements(); - let handle = tensor.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - handle, - ); - let info = build_info(&[&tensor, &output]); - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); - - tensor - .client - .execute(kernel, &[&tensor.handle, &output.handle, &info_handle]); - - output + if tensor.is_contiguous() { + return tensor; + } + + let num_elems = tensor.shape.num_elements(); + let handle = tensor.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + handle, + ); + let info = build_info(&[&tensor, &output]); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = Box::new(StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); + + tensor + .client + .execute(kernel, &[&tensor.handle, &output.handle, &info_handle]); + + output } /// Similar to [into contiguous](into_contiguous) but with dynamic rank. pub fn into_contiguous_dyn( - client: WgpuComputeClient, - input: WgpuHandle, - input_shape: &[usize], - input_strides: &[usize], - output_shape: &[usize], - output_strides: &[usize], - num_elems: usize, + client: WgpuComputeClient, + input: WgpuHandle, + input_shape: &[usize], + input_strides: &[usize], + output_shape: &[usize], + output_strides: &[usize], + num_elems: usize, ) -> WgpuHandle { - let handle = client.empty(num_elems * core::mem::size_of::()); - let info = kernel::build_info_dyn::( - &[input_shape, output_shape], - &[input_strides, output_strides], - ); + let handle = client.empty(num_elems * core::mem::size_of::()); + let info = kernel::build_info_dyn::( + &[input_shape, output_shape], + &[input_strides, output_strides], + ); - let info_handle = client.create(bytemuck::cast_slice(&info)); + let info_handle = client.create(bytemuck::cast_slice(&info)); - let kernel = Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); + let kernel = Box::new(StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); - client.execute(kernel, &[&input, &handle, &info_handle]); + client.execute(kernel, &[&input, &handle, &info_handle]); - handle + handle } /// Generates kernel source code by replacing some information using templating. pub struct KernelSettings< + K: StaticKernelSource, + E: WgpuElement, + I: WgpuElement, + const WORKGROUP_X_SIZE: usize, + const WORKGROUP_Y_SIZE: usize, + const WORKGROUP_Z_SIZE: usize, +> { + _k: PhantomData, + _e: PhantomData, + _i: PhantomData, +} + +impl< K: StaticKernelSource, E: WgpuElement, I: WgpuElement, const WORKGROUP_X_SIZE: usize, const WORKGROUP_Y_SIZE: usize, const WORKGROUP_Z_SIZE: usize, -> { - _k: PhantomData, - _e: PhantomData, - _i: PhantomData, -} - -impl< - K: StaticKernelSource, - E: WgpuElement, - I: WgpuElement, - const WORKGROUP_X_SIZE: usize, - const WORKGROUP_Y_SIZE: usize, - const WORKGROUP_Z_SIZE: usize, - > StaticKernelSource - for KernelSettings + > StaticKernelSource + for KernelSettings { - fn source() -> SourceTemplate { - K::source() - .register("workgroup_size_x", WORKGROUP_X_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string()) - .register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string()) - .register( - "workgroup_size", - (WORKGROUP_X_SIZE * WORKGROUP_Y_SIZE * WORKGROUP_Z_SIZE).to_string(), - ) - .register("elem", E::type_name()) - .register("int", I::type_name()) - } + fn source() -> SourceTemplate { + K::source() + .register("workgroup_size_x", WORKGROUP_X_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string()) + .register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string()) + .register( + "workgroup_size", + (WORKGROUP_X_SIZE * WORKGROUP_Y_SIZE * WORKGROUP_Z_SIZE).to_string(), + ) + .register("elem", E::type_name()) + .register("int", I::type_name()) + } } /// Generate kernel source code by replacing some information using templating. #[derive(new)] pub struct DynamicKernelSettings { - workgroup_x_size: usize, - workgroup_y_size: usize, - workgroup_z_size: usize, - _k: PhantomData, - _e: PhantomData, - _i: PhantomData, + workgroup_x_size: usize, + workgroup_y_size: usize, + workgroup_z_size: usize, + _k: PhantomData, + _e: PhantomData, + _i: PhantomData, } impl DynamicKernelSource - for DynamicKernelSettings + for DynamicKernelSettings { - fn source(&self) -> SourceTemplate { - K::source() - .register("workgroup_size_x", self.workgroup_x_size.to_string()) - .register("workgroup_size_y", self.workgroup_y_size.to_string()) - .register("workgroup_size_z", self.workgroup_z_size.to_string()) - .register( - "workgroup_size", - (self.workgroup_x_size * self.workgroup_y_size * self.workgroup_z_size).to_string(), - ) - .register("elem", E::type_name()) - .register("int", I::type_name()) - } - - fn id(&self) -> String { - let id = core::any::TypeId::of::(); - - format!( - "{:?}-dyn-settings{}-{}-{}", - id, self.workgroup_x_size, self.workgroup_y_size, self.workgroup_z_size - ) - } + fn source(&self) -> SourceTemplate { + K::source() + .register("workgroup_size_x", self.workgroup_x_size.to_string()) + .register("workgroup_size_y", self.workgroup_y_size.to_string()) + .register("workgroup_size_z", self.workgroup_z_size.to_string()) + .register( + "workgroup_size", + (self.workgroup_x_size * self.workgroup_y_size * self.workgroup_z_size).to_string(), + ) + .register("elem", E::type_name()) + .register("int", I::type_name()) + } + + fn id(&self) -> String { + let id = core::any::TypeId::of::(); + + format!( + "{:?}-dyn-settings{}-{}-{}", + id, self.workgroup_x_size, self.workgroup_y_size, self.workgroup_z_size + ) + } } /// Create a vector containing the dimension, strides and shape of tensors. @@ -193,84 +193,84 @@ impl DynamicKernelSource /// | (2 * D + 1)..(3 * D + 1) | lhs shape | /// | (3 * D + 1)..(4 * D + 1) | rhs shape | pub fn build_info(tensors: &[&WgpuTensor]) -> Vec { - let mut info: Vec = vec![0; tensors.len() * 2 * D + 1]; - info[0] = D as u32; - - let mut current = 1; - for tensor in tensors.iter() { - for d in 0..D { - info[current] = tensor.strides[d] as u32; - current += 1; - } + let mut info: Vec = vec![0; tensors.len() * 2 * D + 1]; + info[0] = D as u32; + + let mut current = 1; + for tensor in tensors.iter() { + for d in 0..D { + info[current] = tensor.strides[d] as u32; + current += 1; } - for tensor in tensors.iter() { - for d in 0..D { - info[current] = tensor.shape.dims[d] as u32; - current += 1; - } + } + for tensor in tensors.iter() { + for d in 0..D { + info[current] = tensor.shape.dims[d] as u32; + current += 1; } - info + } + info } /// Similar to [build info](build_info) but with dynamic rank. pub fn build_info_dyn(shapes: &[&[usize]], strides: &[&[usize]]) -> Vec { - let rank = shapes.get(0).unwrap().len(); - let mut info: Vec = vec![0; shapes.len() * 2 * rank + 1]; - info[0] = rank as u32; - - let mut current = 1; - for stride in strides.iter() { - for d in 0..rank { - info[current] = stride[d] as u32; - current += 1; - } + let rank = shapes.get(0).unwrap().len(); + let mut info: Vec = vec![0; shapes.len() * 2 * rank + 1]; + info[0] = rank as u32; + + let mut current = 1; + for stride in strides.iter() { + for d in 0..rank { + info[current] = stride[d] as u32; + current += 1; } - for shape in shapes.iter() { - for d in 0..rank { - info[current] = shape[d] as u32; - current += 1; - } + } + for shape in shapes.iter() { + for d in 0..rank { + info[current] = shape[d] as u32; + current += 1; } - info + } + info } pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> WorkGroup { - let num_elem_per_invocation = workgroup_size * workgroup_size; - let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32); - let workgroup_x = f32::ceil(f32::sqrt(workgroups)); - let workgroup_y = f32::ceil(num_elems as f32 / (workgroup_x * num_elem_per_invocation as f32)); + let num_elem_per_invocation = workgroup_size * workgroup_size; + let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32); + let workgroup_x = f32::ceil(f32::sqrt(workgroups)); + let workgroup_y = f32::ceil(num_elems as f32 / (workgroup_x * num_elem_per_invocation as f32)); - WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) + WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) } pub(crate) fn prng_workgroup( - num_elems: usize, - workgroup_size: usize, - n_values_per_thread: usize, + num_elems: usize, + workgroup_size: usize, + n_values_per_thread: usize, ) -> WorkGroup { - let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32); - let num_elem_per_invocation = workgroup_size * workgroup_size; - let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32); - let workgroup_x = f32::ceil(f32::sqrt(num_invocations)); - let workgroup_y = f32::ceil(num_invocations / workgroup_x); + let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32); + let num_elem_per_invocation = workgroup_size * workgroup_size; + let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32); + let workgroup_x = f32::ceil(f32::sqrt(num_invocations)); + let workgroup_y = f32::ceil(num_invocations / workgroup_x); - WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) + WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) } #[cfg(test)] mod tests { - use super::*; - use core::any::TypeId; + use super::*; + use core::any::TypeId; - #[test] - fn test_kernel_type_id() { - kernel_wgsl!(Add, "../template/binary_elemwise.wgsl"); + #[test] + fn test_kernel_type_id() { + kernel_wgsl!(Add, "../template/binary_elemwise.wgsl"); - let type_id_1 = TypeId::of::>(); - let type_id_2 = TypeId::of::>(); - let type_id_3 = TypeId::of::>(); + let type_id_1 = TypeId::of::>(); + let type_id_2 = TypeId::of::>(); + let type_id_3 = TypeId::of::>(); - assert_ne!(type_id_1, type_id_2); - assert_eq!(type_id_1, type_id_3); - } + assert_ne!(type_id_1, type_id_2); + assert_eq!(type_id_1, type_id_3); + } } diff --git a/burn-wgpu/src/kernel/binary_elemwise.rs b/burn-wgpu/src/kernel/binary_elemwise.rs index 0fc4a0e93d..05b2324dea 100644 --- a/burn-wgpu/src/kernel/binary_elemwise.rs +++ b/burn-wgpu/src/kernel/binary_elemwise.rs @@ -1,5 +1,5 @@ use super::{ - build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, + build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, }; use crate::compute::StaticKernel; use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; @@ -7,175 +7,177 @@ use burn_tensor::Shape; kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl"); kernel_wgsl!( - BinaryElemwiseInplaceRaw, - "../template/binary_elemwise_inplace.wgsl" + BinaryElemwiseInplaceRaw, + "../template/binary_elemwise_inplace.wgsl" ); /// Creates a binary elementwise kernel. #[macro_export] macro_rules! binary_elemwise { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::BinaryElemwiseRaw::source().register( - "body", - format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops), - ) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::BinaryElemwiseRaw::source().register( + "body", + format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops), + ) + } + } + }; } /// Creates a binary elementwise inplace kernel. #[macro_export] macro_rules! binary_elemwise_inplace { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::BinaryElemwiseInplaceRaw::source().register( - "body", - format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops), - ) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::BinaryElemwiseInplaceRaw::source().register( + "body", + format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops), + ) + } + } + }; } /// Execute a binary kernel using the default settings. pub fn binary_elemwise_default( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise::(lhs, rhs) + binary_elemwise::(lhs, rhs) } /// Execute a binary kernel using the provided WORKGROUP. pub fn binary_elemwise< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); - - let mut shape_out = [0; D]; - lhs.shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - - let shape_out = Shape::new(shape_out); - let num_elems = shape_out.num_elements(); - - let handle = lhs.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle); - - let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], - ); - - output + lhs.assert_is_on_same_device(&rhs); + + let mut shape_out = [0; D]; + lhs + .shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::new(shape_out); + let num_elems = shape_out.num_elements(); + + let handle = lhs.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle); + + let info = build_info(&[&lhs, &rhs, &output]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + ); + + output } /// Execute a binary inplace kernel using the default settings. pub fn binary_elemwise_inplace_default( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise_inplace::(lhs, rhs) + binary_elemwise_inplace::(lhs, rhs) } /// Execute a binary inplace kernel using the provided WORKGROUP. pub fn binary_elemwise_inplace< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&rhs); - let info = build_info(&[&lhs, &rhs]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::>::new( - elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP), - ); + let info = build_info(&[&lhs, &rhs]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::>::new( + elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP), + ); - lhs.client - .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); + lhs + .client + .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); - lhs + lhs } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; - - binary_elemwise!(TestKernel, "*"); - binary_elemwise_inplace!(TestKernelInplace, "*"); - - #[test] - fn binary_should_work_with_multiple_invocations() { - let lhs = Tensor::::random([6, 256], Distribution::Default); - let rhs = Tensor::::random([6, 256], Distribution::Default); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - let rhs_ref = Tensor::::from_data(rhs.to_data()); - - let actual = - binary_elemwise::(lhs.into_primitive(), rhs.into_primitive()); - let expected = lhs_ref * rhs_ref; - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; + + binary_elemwise!(TestKernel, "*"); + binary_elemwise_inplace!(TestKernelInplace, "*"); + + #[test] + fn binary_should_work_with_multiple_invocations() { + let lhs = Tensor::::random([6, 256], Distribution::Default); + let rhs = Tensor::::random([6, 256], Distribution::Default); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + let rhs_ref = Tensor::::from_data(rhs.to_data()); + + let actual = + binary_elemwise::(lhs.into_primitive(), rhs.into_primitive()); + let expected = lhs_ref * rhs_ref; + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn binary_inplace_should_work_with_multiple_invocations() { + let lhs = Tensor::::random([6, 256], Distribution::Default); + let rhs = Tensor::::random([6, 256], Distribution::Default); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + let rhs_ref = Tensor::::from_data(rhs.to_data()); + + let actual = binary_elemwise_inplace::( + lhs.into_primitive(), + rhs.into_primitive(), + ); + let expected = lhs_ref * rhs_ref; - #[test] - fn binary_inplace_should_work_with_multiple_invocations() { - let lhs = Tensor::::random([6, 256], Distribution::Default); - let rhs = Tensor::::random([6, 256], Distribution::Default); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - let rhs_ref = Tensor::::from_data(rhs.to_data()); - - let actual = binary_elemwise_inplace::( - lhs.into_primitive(), - rhs.into_primitive(), - ); - let expected = lhs_ref * rhs_ref; - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/kernel/cast.rs b/burn-wgpu/src/kernel/cast.rs index d840c5268e..8daf33b83b 100644 --- a/burn-wgpu/src/kernel/cast.rs +++ b/burn-wgpu/src/kernel/cast.rs @@ -1,84 +1,77 @@ use super::{KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}; use crate::{ - compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, + tensor::WgpuTensor, }; use std::{any::TypeId, marker::PhantomData}; kernel_wgsl!(CastRaw, "../template/cast.wgsl"); struct Cast { - _i: PhantomData, - _o: PhantomData, + _i: PhantomData, + _o: PhantomData, } impl StaticKernelSource - for Cast + for Cast { - fn source() -> SourceTemplate { - CastRaw::source() - .register("input_elem", InputElem::type_name()) - .register("output_elem", OutputElem::type_name()) - } + fn source() -> SourceTemplate { + CastRaw::source() + .register("input_elem", InputElem::type_name()) + .register("output_elem", OutputElem::type_name()) + } } /// Cast a tensor to the given element type. pub fn cast( - tensor: WgpuTensor, + tensor: WgpuTensor, ) -> WgpuTensor { - if TypeId::of::() == TypeId::of::() { - return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); - } + if TypeId::of::() == TypeId::of::() { + return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); + } - let num_elems = tensor.shape.num_elements(); - let kernel = StaticKernel::< - KernelSettings< - Cast, - f32, - i32, - WORKGROUP_DEFAULT, - WORKGROUP_DEFAULT, - 1, - >, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let num_elems = tensor.shape.num_elements(); + let kernel = StaticKernel::< + KernelSettings, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let handle = tensor - .client - .empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - tensor.client.clone(), - tensor.device, - tensor.shape.clone(), - handle, - ); + let handle = tensor + .client + .empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + tensor.client.clone(), + tensor.device, + tensor.shape.clone(), + handle, + ); - tensor - .client - .execute(Box::new(kernel), &[&tensor.handle, &output.handle]); + tensor + .client + .execute(Box::new(kernel), &[&tensor.handle, &output.handle]); - output + output } #[cfg(test)] mod tests { - use super::*; - use crate::tests::TestBackend; - use burn_tensor::{Int, Tensor}; + use super::*; + use crate::tests::TestBackend; + use burn_tensor::{Int, Tensor}; - #[test] - fn should_cast_int_to_float() { - const START: usize = 0; - const END: usize = 100; + #[test] + fn should_cast_int_to_float() { + const START: usize = 0; + const END: usize = 100; - let tensor = Tensor::::arange(START..END); - let tensor_float = cast::(tensor.clone().into_primitive()); + let tensor = Tensor::::arange(START..END); + let tensor_float = cast::(tensor.clone().into_primitive()); - let data_int = tensor.into_data(); - let data_float = Tensor::::from_primitive(tensor_float).into_data(); + let data_int = tensor.into_data(); + let data_float = Tensor::::from_primitive(tensor_float).into_data(); - for i in START..END { - assert_eq!(data_int.value[i], i as i32); - assert_eq!(data_float.value[i], i as f32); - } + for i in START..END { + assert_eq!(data_int.value[i], i as i32); + assert_eq!(data_float.value[i], i as f32); } + } } diff --git a/burn-wgpu/src/kernel/cat.rs b/burn-wgpu/src/kernel/cat.rs index 4271541153..e4c4b44cac 100644 --- a/burn-wgpu/src/kernel/cat.rs +++ b/burn-wgpu/src/kernel/cat.rs @@ -1,9 +1,9 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings}, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings}, + kernel_wgsl, + tensor::WgpuTensor, }; use super::WORKGROUP_DEFAULT; @@ -11,84 +11,82 @@ use super::WORKGROUP_DEFAULT; kernel_wgsl!(Cat, "../template/cat.wgsl"); pub fn cat( - inputs: Vec>, - dim: usize, + inputs: Vec>, + dim: usize, ) -> WgpuTensor { - let first_input = inputs.get(0).unwrap(); - let client = &first_input.client; - let mut shape_output = first_input.shape.clone(); - shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum(); - - let buffer = first_input - .client - .empty(shape_output.num_elements() * std::mem::size_of::()); - - let output = WgpuTensor::new( - client.clone(), - first_input.device.clone(), - shape_output, - buffer, + let first_input = inputs.get(0).unwrap(); + let client = &first_input.client; + let mut shape_output = first_input.shape.clone(); + shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum(); + + let buffer = first_input + .client + .empty(shape_output.num_elements() * std::mem::size_of::()); + + let output = WgpuTensor::new( + client.clone(), + first_input.device.clone(), + shape_output, + buffer, + ); + + let mut dim_cat_index = 0; + + for input in inputs.iter() { + let mut info = build_info(&[input, &output]); + info.push(dim as u32); + info.push(dim_cat_index as u32); + dim_cat_index += input.shape.dims[dim]; + let info_buffer = client.create(bytemuck::cast_slice(&info)); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT), + ); + + client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_buffer], ); + } - let mut dim_cat_index = 0; - - for input in inputs.iter() { - let mut info = build_info(&[input, &output]); - info.push(dim as u32); - info.push(dim_cat_index as u32); - dim_cat_index += input.shape.dims[dim]; - let info_buffer = client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - input.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_buffer], - ); - } - - output + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Tensor}; - - #[test] - fn cat_should_support_multiple_invocations_dim0() { - test_same_as_reference([6, 256], 2, 0); - } - - #[test] - fn cat_should_support_multiple_invocations_dim1() { - test_same_as_reference([6, 256], 2, 1); - } - - #[test] - fn cat_should_support_uneven_launch() { - test_same_as_reference([1, 137], 2, 0); - } - - fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) { - TestBackend::seed(0); - let tensors = (0..num_tensors) - .map(|_| Tensor::::random(shape, Distribution::Default)) - .collect::>(); - let tensors_ref = tensors - .iter() - .map(|tensor| Tensor::::from_data(tensor.to_data())) - .collect::>(); - - let tensor = Tensor::::cat(tensors, dim); - let tensor_ref = Tensor::::cat(tensors_ref, dim); - - tensor - .into_data() - .assert_approx_eq(&tensor_ref.into_data(), 3); - } + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Tensor}; + + #[test] + fn cat_should_support_multiple_invocations_dim0() { + test_same_as_reference([6, 256], 2, 0); + } + + #[test] + fn cat_should_support_multiple_invocations_dim1() { + test_same_as_reference([6, 256], 2, 1); + } + + #[test] + fn cat_should_support_uneven_launch() { + test_same_as_reference([1, 137], 2, 0); + } + + fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) { + TestBackend::seed(0); + let tensors = (0..num_tensors) + .map(|_| Tensor::::random(shape, Distribution::Default)) + .collect::>(); + let tensors_ref = tensors + .iter() + .map(|tensor| Tensor::::from_data(tensor.to_data())) + .collect::>(); + + let tensor = Tensor::::cat(tensors, dim); + let tensor_ref = Tensor::::cat(tensors_ref, dim); + + tensor + .into_data() + .assert_approx_eq(&tensor_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/clamp.rs b/burn-wgpu/src/kernel/clamp.rs index dcd774d8e0..b3195a7de7 100644 --- a/burn-wgpu/src/kernel/clamp.rs +++ b/burn-wgpu/src/kernel/clamp.rs @@ -1,11 +1,11 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, - unary_scalar, unary_scalar_inplace, + compute::StaticKernel, + element::WgpuElement, + kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, + unary_scalar, unary_scalar_inplace, }; use super::{elemwise_workgroup, KernelSettings}; @@ -14,105 +14,106 @@ kernel_wgsl!(Clamp, "../template/clamp/clamp.wgsl"); kernel_wgsl!(ClampInplace, "../template/clamp/clamp_inplace.wgsl"); pub(crate) fn clamp_min( - input: WgpuTensor, - min_value: E, + input: WgpuTensor, + min_value: E, ) -> WgpuTensor { - unary_scalar!(ClampMin, func "max"); - unary_scalar_inplace!(ClampMinInplace, func "max"); + unary_scalar!(ClampMin, func "max"); + unary_scalar_inplace!(ClampMinInplace, func "max"); - if input.can_mut() { - return unary_scalar_inplace_default::(input, min_value); - } + if input.can_mut() { + return unary_scalar_inplace_default::(input, min_value); + } - unary_scalar::(input, min_value) + unary_scalar::(input, min_value) } pub(crate) fn clamp_max( - input: WgpuTensor, - max_value: E, + input: WgpuTensor, + max_value: E, ) -> WgpuTensor { - unary_scalar!(ClampMax, func "min"); - unary_scalar_inplace!(ClampMaxInPlace, func "min"); + unary_scalar!(ClampMax, func "min"); + unary_scalar_inplace!(ClampMaxInPlace, func "min"); - if input.can_mut() { - return unary_scalar_inplace_default::(input, max_value); - } + if input.can_mut() { + return unary_scalar_inplace_default::(input, max_value); + } - unary_scalar::(input, max_value) + unary_scalar::(input, max_value) } pub(crate) fn clamp( - input: WgpuTensor, - min_value: E, - max_value: E, + input: WgpuTensor, + min_value: E, + max_value: E, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let min_handle = input.client.create(E::as_bytes(&[min_value])); - let max_handle = input.client.create(E::as_bytes(&[max_value])); + let num_elems = input.shape.num_elements(); + let min_handle = input.client.create(E::as_bytes(&[min_value])); + let max_handle = input.client.create(E::as_bytes(&[max_value])); - if input.can_mut() { - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - input - .client - .execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]); - - return input; - } - - let output = empty_device(input.client.clone(), input.device.clone(), input.shape); + if input.can_mut() { let kernel = StaticKernel::< - KernelSettings, + KernelSettings, >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &min_handle, &max_handle], + input + .client + .execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]); + + return input; + } + + let output = empty_device(input.client.clone(), input.device.clone(), input.shape); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), ); - output + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &min_handle, &max_handle], + ); + + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; - #[test] - fn clamp_min_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); + #[test] + fn clamp_min_should_match_reference() { + let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); - let output = input.clamp_min(0.5); + let output = input.clamp_min(0.5); - output - .into_data() - .assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3); + } - #[test] - fn clamp_max_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); + #[test] + fn clamp_max_should_match_reference() { + let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); - let output = input.clamp_max(0.5); + let output = input.clamp_max(0.5); - output - .into_data() - .assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3); + } - #[test] - fn clamp_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); + #[test] + fn clamp_should_match_reference() { + let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); - let output = input.clamp(0.3, 0.7); + let output = input.clamp(0.3, 0.7); - output - .into_data() - .assert_approx_eq(&input_ref.clamp(0.3, 0.7).into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&input_ref.clamp(0.3, 0.7).into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/comparison/base.rs b/burn-wgpu/src/kernel/comparison/base.rs index 570b5deedd..8db310f587 100644 --- a/burn-wgpu/src/kernel/comparison/base.rs +++ b/burn-wgpu/src/kernel/comparison/base.rs @@ -1,8 +1,8 @@ use crate::{ - comparison, comparison_elem, comparison_elem_inplace, comparison_inplace, - element::WgpuElement, - kernel::{comparison, comparison_elem, comparison_elem_inplace, comparison_inplace}, - tensor::WgpuTensor, + comparison, comparison_elem, comparison_elem_inplace, comparison_inplace, + element::WgpuElement, + kernel::{comparison, comparison_elem, comparison_elem_inplace, comparison_inplace}, + tensor::WgpuTensor, }; use std::mem; @@ -31,136 +31,136 @@ comparison_elem_inplace!(LowerElemInplace, "<"); comparison_elem_inplace!(LowerEqualElemInplace, "<="); pub fn equal( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn greater( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn greater_equal( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn lower( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn lower_equal( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn equal_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn greater_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn lower_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn greater_equal_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn lower_equal_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } diff --git a/burn-wgpu/src/kernel/comparison/binary.rs b/burn-wgpu/src/kernel/comparison/binary.rs index 257585dd2a..9c4b8c7178 100644 --- a/burn-wgpu/src/kernel/comparison/binary.rs +++ b/burn-wgpu/src/kernel/comparison/binary.rs @@ -1,177 +1,172 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{ - build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, - }, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::Shape; kernel_wgsl!(ComparisonRaw, "../../template/comparison/binary.wgsl"); kernel_wgsl!( - ComparisonInplaceRaw, - "../../template/comparison/binary_inplace.wgsl" + ComparisonInplaceRaw, + "../../template/comparison/binary_inplace.wgsl" ); /// Creates a comparison kernel. #[macro_export] macro_rules! comparison { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonRaw::source().register( - "body", - format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops), - ) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonRaw::source().register( + "body", + format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops), + ) + } + } + }; } /// Creates a comparison inplace kernel. #[macro_export] macro_rules! comparison_inplace { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonInplaceRaw::source() - .register( - "body", - "lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);", - ) - .add_template(format!( - "{}return {{{{ elem }}}}(lhs {} rhs);{}", - "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", - $ops, - "\n}\n" - )) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonInplaceRaw::source() + .register( + "body", + "lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);", + ) + .add_template(format!( + "{}return {{{{ elem }}}}(lhs {} rhs);{}", + "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", $ops, "\n}\n" + )) + } + } + }; } pub fn comparison( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); - let mut shape_out = [0; D]; - lhs.shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - - let shape_out = Shape::new(shape_out); - let num_elems = shape_out.num_elements(); - - let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); - - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), - ); - let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + lhs.assert_is_on_same_device(&rhs); + let mut shape_out = [0; D]; + lhs + .shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::new(shape_out); + let num_elems = shape_out.num_elements(); + + let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); + + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), ); + let info = build_info(&[&lhs, &rhs, &output]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - WgpuTensor::new(output.client, output.device, output.shape, output.handle) + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + ); + + WgpuTensor::new(output.client, output.device, output.shape, output.handle) } pub fn comparison_inplace( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&rhs); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), - ); - let info = build_info(&[&lhs, &rhs]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), + ); + let info = build_info(&[&lhs, &rhs]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - lhs.client - .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); + lhs + .client + .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); - WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; - - comparison!(LowerEqual, "<="); - comparison_inplace!(LowerEqualInplace, "<="); - - #[test] - fn comparison_should_work_with_multiple_invocations() { - let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); - - let value = Tensor::::from_primitive( - comparison::(lhs.into_primitive(), rhs.into_primitive()), - ); - - let value_ref = lhs_ref.lower_equal(rhs_ref); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[test] - fn comparison_inplace_should_work_with_multiple_invocations() { - let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); - - let value = Tensor::::from_primitive(comparison_inplace::< - LowerEqualInplace, - f32, - 3, - >( - lhs.into_primitive(), - rhs.into_primitive(), - )); - - let value_ref = lhs_ref.lower_equal(rhs_ref); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs() -> ( - Tensor, - Tensor, - Tensor, - Tensor, - ) { - TestBackend::seed(0); - let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); - let rhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - let rhs_ref = Tensor::::from_data(rhs.to_data()); - - (lhs, rhs, lhs_ref, rhs_ref) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; + + comparison!(LowerEqual, "<="); + comparison_inplace!(LowerEqualInplace, "<="); + + #[test] + fn comparison_should_work_with_multiple_invocations() { + let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); + + let value = Tensor::::from_primitive(comparison::( + lhs.into_primitive(), + rhs.into_primitive(), + )); + + let value_ref = lhs_ref.lower_equal(rhs_ref); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[test] + fn comparison_inplace_should_work_with_multiple_invocations() { + let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); + + let value = + Tensor::::from_primitive( + comparison_inplace::(lhs.into_primitive(), rhs.into_primitive()), + ); + + let value_ref = lhs_ref.lower_equal(rhs_ref); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs() -> ( + Tensor, + Tensor, + Tensor, + Tensor, + ) { + TestBackend::seed(0); + let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); + let rhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + let rhs_ref = Tensor::::from_data(rhs.to_data()); + + (lhs, rhs, lhs_ref, rhs_ref) + } } diff --git a/burn-wgpu/src/kernel/comparison/elem.rs b/burn-wgpu/src/kernel/comparison/elem.rs index b358307222..56d8b2b6a0 100644 --- a/burn-wgpu/src/kernel/comparison/elem.rs +++ b/burn-wgpu/src/kernel/comparison/elem.rs @@ -1,141 +1,138 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!(ComparisonElemRaw, "../../template/comparison/elem.wgsl"); kernel_wgsl!( - ComparisonElemInplaceRaw, - "../../template/comparison/elem_inplace.wgsl" + ComparisonElemInplaceRaw, + "../../template/comparison/elem_inplace.wgsl" ); /// Creates a comparison elementwise kernel. #[macro_export] macro_rules! comparison_elem { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonElemRaw::source() - .register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonElemRaw::source() + .register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops)) + } + } + }; } /// Creates a comparison elementwise inplace kernel. #[macro_export] macro_rules! comparison_elem_inplace { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonElemInplaceRaw::source() - .register("body", "lhs[id] = compare(lhs[id], rhs);") - .add_template(format!( - "{}return {{{{ elem }}}}(lhs {} rhs);{}", - "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", - $ops, - "\n}\n" - )) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonElemInplaceRaw::source() + .register("body", "lhs[id] = compare(lhs[id], rhs);") + .add_template(format!( + "{}return {{{{ elem }}}}(lhs {} rhs);{}", + "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", $ops, "\n}\n" + )) + } + } + }; } pub fn comparison_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); + let num_elems = lhs.shape.num_elements(); - let handle = lhs.client.empty(num_elems * core::mem::size_of::()); - let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), - ); + let handle = lhs.client.empty(num_elems * core::mem::size_of::()); + let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), + ); - lhs.client - .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]); + lhs + .client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]); - WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle) + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle) } pub fn comparison_elem_inplace( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - let kernel = - StaticKernel::>::new( - elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); - lhs.client - .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); - - WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) + let kernel = + StaticKernel::>::new( + elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), + ); + let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); + lhs + .client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); + + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; - - comparison_elem!(LowerEqual, "<="); - comparison_elem_inplace!(LowerEqualInplace, "<="); - - #[test] - fn comparison_elem_should_work_with_multiple_invocations() { - let (lhs, lhs_ref, rhs) = inputs(); - - let value = - Tensor::::from_primitive(comparison_elem::( - lhs.into_primitive(), - rhs, - )); - - let value_ref = lhs_ref.lower_equal_elem(rhs); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[test] - fn comparison_elem_inplace_should_work_with_multiple_invocations() { - let (lhs, lhs_ref, rhs) = inputs(); - - let value = - Tensor::::from_primitive(comparison_elem_inplace::< - LowerEqualInplace, - f32, - 3, - >(lhs.into_primitive(), rhs)); - - let value_ref = lhs_ref.lower_equal_elem(rhs); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs() -> (Tensor, Tensor, f32) { - TestBackend::seed(0); - let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - - (lhs, lhs_ref, 5.0) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; + + comparison_elem!(LowerEqual, "<="); + comparison_elem_inplace!(LowerEqualInplace, "<="); + + #[test] + fn comparison_elem_should_work_with_multiple_invocations() { + let (lhs, lhs_ref, rhs) = inputs(); + + let value = Tensor::::from_primitive( + comparison_elem::(lhs.into_primitive(), rhs), + ); + + let value_ref = lhs_ref.lower_equal_elem(rhs); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[test] + fn comparison_elem_inplace_should_work_with_multiple_invocations() { + let (lhs, lhs_ref, rhs) = inputs(); + + let value = Tensor::::from_primitive(comparison_elem_inplace::< + LowerEqualInplace, + f32, + 3, + >(lhs.into_primitive(), rhs)); + + let value_ref = lhs_ref.lower_equal_elem(rhs); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs() -> (Tensor, Tensor, f32) { + TestBackend::seed(0); + let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + + (lhs, lhs_ref, 5.0) + } } diff --git a/burn-wgpu/src/kernel/conv/conv2d.rs b/burn-wgpu/src/kernel/conv/conv2d.rs index 39b0ecf45d..86c03f303d 100644 --- a/burn-wgpu/src/kernel/conv/conv2d.rs +++ b/burn-wgpu/src/kernel/conv/conv2d.rs @@ -1,108 +1,106 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions}, - Element, ElementConversion, Shape, + ops::{conv::calculate_conv_output_size, ConvOptions}, + Element, ElementConversion, Shape, }; kernel_wgsl!(Conv2d, "../../template/conv/conv2d.wgsl"); pub(crate) fn conv2d( - input: WgpuTensor, - weight: WgpuTensor, - bias: Option>, - options: ConvOptions<2>, + input: WgpuTensor, + weight: WgpuTensor, + bias: Option>, + options: ConvOptions<2>, ) -> WgpuTensor { - let input = kernel::into_contiguous(input); - let weight = kernel::into_contiguous(weight); - let [batch_size, _, in_height, in_width] = input.shape.dims; - let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims; + let input = kernel::into_contiguous(input); + let weight = kernel::into_contiguous(weight); + let [batch_size, _, in_height, in_width] = input.shape.dims; + let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims; - let out_0 = calculate_conv_output_size( - kernel_0, - options.stride[0], - options.padding[0], - options.dilation[0], - in_height, - ); - let out_1 = calculate_conv_output_size( - kernel_1, - options.stride[1], - options.padding[1], - options.dilation[1], - in_width, - ); + let out_0 = calculate_conv_output_size( + kernel_0, + options.stride[0], + options.padding[0], + options.dilation[0], + in_height, + ); + let out_1 = calculate_conv_output_size( + kernel_1, + options.stride[1], + options.padding[1], + options.dilation[1], + in_width, + ); - let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]); - - let output = empty_device( - input.client.clone(), - input.device.clone(), - shape_out.clone(), - ); + let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]); - let mut info = build_info(&[&input, &output, &weight]); - info.push(options.stride[0] as u32); - info.push(options.stride[1] as u32); - info.push(options.padding[0] as u32); - info.push(options.padding[1] as u32); - info.push(options.dilation[0] as u32); - info.push(options.dilation[1] as u32); - info.push(options.groups as u32); + let output = empty_device( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + ); - let bias_handle = bias - .map(|bias| bias.handle) - .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); + let mut info = build_info(&[&input, &output, &weight]); + info.push(options.stride[0] as u32); + info.push(options.stride[1] as u32); + info.push(options.padding[0] as u32); + info.push(options.padding[1] as u32); + info.push(options.dilation[0] as u32); + info.push(options.dilation[1] as u32); + info.push(options.groups as u32); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); + let bias_handle = bias + .map(|bias| bias.handle) + .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &weight.handle, - &bias_handle, - &output.handle, - &info_handle, - ], + let kernel = + StaticKernel::>::new( + elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), ); - output + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &weight.handle, + &bias_handle, + &output.handle, + &info_handle, + ], + ); + + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{module, Distribution, Tensor}; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{module, Distribution, Tensor}; - #[test] - fn conv2d_should_work_with_multiple_invocations() { - let input = Tensor::::random([6, 16, 32, 32], Distribution::Default); - let weight = Tensor::::random([12, 8, 3, 3], Distribution::Default); - let bias = Tensor::::random([12], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); - let weight_ref = Tensor::::from_data(weight.to_data()); - let bias_ref = Tensor::::from_data(bias.to_data()); - let options = burn_tensor::ops::ConvOptions::new([2, 3], [2, 3], [2, 3], 2); + #[test] + fn conv2d_should_work_with_multiple_invocations() { + let input = Tensor::::random([6, 16, 32, 32], Distribution::Default); + let weight = Tensor::::random([12, 8, 3, 3], Distribution::Default); + let bias = Tensor::::random([12], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); + let weight_ref = Tensor::::from_data(weight.to_data()); + let bias_ref = Tensor::::from_data(bias.to_data()); + let options = burn_tensor::ops::ConvOptions::new([2, 3], [2, 3], [2, 3], 2); - let output = module::conv2d(input, weight, Some(bias), options.clone()); - let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options); + let output = module::conv2d(input, weight, Some(bias), options.clone()); + let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options); - output - .into_data() - .assert_approx_eq(&output_ref.into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&output_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/conv/conv_transpose2d.rs b/burn-wgpu/src/kernel/conv/conv_transpose2d.rs index d59a2cb6f8..b9eae1c861 100644 --- a/burn-wgpu/src/kernel/conv/conv_transpose2d.rs +++ b/burn-wgpu/src/kernel/conv/conv_transpose2d.rs @@ -1,120 +1,119 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape}; kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl"); pub(crate) fn conv_transpose2d( - input: WgpuTensor, - weight: WgpuTensor, - bias: Option>, - options: ConvTransposeOptions<2>, + input: WgpuTensor, + weight: WgpuTensor, + bias: Option>, + options: ConvTransposeOptions<2>, ) -> WgpuTensor { - let input = kernel::into_contiguous(input); - let weight = kernel::into_contiguous(weight); - let [batch_size, _, in_height, in_width] = input.shape.dims; - let [_, out_channels, kernel_0, kernel_1] = weight.shape.dims; - - let out_0 = (in_height - 1) * options.stride[0] - + options.dilation[0] * (kernel_0 - 1) - + options.padding_out[0] - - 2 * options.padding[0] - + 1; - let out_1 = (in_width - 1) * options.stride[1] - + options.dilation[1] * (kernel_1 - 1) - + options.padding_out[1] - - 2 * options.padding[1] - + 1; - - let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); - let num_elems = shape_out.num_elements(); - - let output = empty_device( - input.client.clone(), - input.device.clone(), - shape_out.clone(), - ); - let mut info = build_info(&[&input, &output, &weight]); - - info.push(options.stride[0] as u32); - info.push(options.stride[1] as u32); - info.push(options.padding[0] as u32); - info.push(options.padding[1] as u32); - info.push(options.dilation[0] as u32); - info.push(options.dilation[1] as u32); - info.push(options.groups as u32); - - let bias_handle = bias - .map(|bias| bias.handle) - .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); - - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &weight.handle, - &bias_handle, - &output.handle, - &info_handle, - ], - ); - - output + let input = kernel::into_contiguous(input); + let weight = kernel::into_contiguous(weight); + let [batch_size, _, in_height, in_width] = input.shape.dims; + let [_, out_channels, kernel_0, kernel_1] = weight.shape.dims; + + let out_0 = (in_height - 1) * options.stride[0] + + options.dilation[0] * (kernel_0 - 1) + + options.padding_out[0] + - 2 * options.padding[0] + + 1; + let out_1 = (in_width - 1) * options.stride[1] + + options.dilation[1] * (kernel_1 - 1) + + options.padding_out[1] + - 2 * options.padding[1] + + 1; + + let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); + let num_elems = shape_out.num_elements(); + + let output = empty_device( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + ); + let mut info = build_info(&[&input, &output, &weight]); + + info.push(options.stride[0] as u32); + info.push(options.stride[1] as u32); + info.push(options.padding[0] as u32); + info.push(options.padding[1] as u32); + info.push(options.dilation[0] as u32); + info.push(options.dilation[1] as u32); + info.push(options.groups as u32); + + let bias_handle = bias + .map(|bias| bias.handle) + .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); + + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &weight.handle, + &bias_handle, + &output.handle, + &info_handle, + ], + ); + + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, module, Distribution, Tensor}; - - #[test] - fn conv_transpose2d_should_work_with_multiple_invocations() { - TestBackend::seed(0); - - let height = 8; - let width = 8; - let in_channels = 8; - let out_channels = 8; - let batch_size = 32; - let kernel_size_0 = 3; - let kernel_size_1 = 3; - let options = - burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1); - - let input = Tensor::::random( - [batch_size, in_channels, height, width], - Distribution::Default, - ); - let weight = Tensor::::random( - [ - in_channels, - out_channels / options.groups, - kernel_size_0, - kernel_size_1, - ], - Distribution::Default, - ); - let bias = Tensor::::random([out_channels], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); - let weight_ref = Tensor::::from_data(weight.to_data()); - let bias_ref = Tensor::::from_data(bias.to_data()); - - let output = module::conv_transpose2d(input, weight, Some(bias), options.clone()); - let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options); - - output - .into_data() - .assert_approx_eq(&output_ref.into_data(), 3); - } + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, module, Distribution, Tensor}; + + #[test] + fn conv_transpose2d_should_work_with_multiple_invocations() { + TestBackend::seed(0); + + let height = 8; + let width = 8; + let in_channels = 8; + let out_channels = 8; + let batch_size = 32; + let kernel_size_0 = 3; + let kernel_size_1 = 3; + let options = burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1); + + let input = Tensor::::random( + [batch_size, in_channels, height, width], + Distribution::Default, + ); + let weight = Tensor::::random( + [ + in_channels, + out_channels / options.groups, + kernel_size_0, + kernel_size_1, + ], + Distribution::Default, + ); + let bias = Tensor::::random([out_channels], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); + let weight_ref = Tensor::::from_data(weight.to_data()); + let bias_ref = Tensor::::from_data(bias.to_data()); + + let output = module::conv_transpose2d(input, weight, Some(bias), options.clone()); + let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options); + + output + .into_data() + .assert_approx_eq(&output_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/index/gather.rs b/burn-wgpu/src/kernel/index/gather.rs index 48c7c5baab..a1da2533c0 100644 --- a/burn-wgpu/src/kernel/index/gather.rs +++ b/burn-wgpu/src/kernel/index/gather.rs @@ -1,88 +1,87 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(Gather, "../../template/index/gather.wgsl"); pub(crate) fn gather( - dim: usize, - tensor: WgpuTensor, - indices: WgpuTensor, + dim: usize, + tensor: WgpuTensor, + indices: WgpuTensor, ) -> WgpuTensor { - let shape_output = indices.shape.clone(); - let num_elems = shape_output.num_elements(); - let indices = kernel::into_contiguous(indices); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + let shape_output = indices.shape.clone(); + let num_elems = shape_output.num_elements(); + let indices = kernel::into_contiguous(indices); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - let mut info = build_info(&[&tensor, &output]); - info.push(dim as u32); - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + let mut info = build_info(&[&tensor, &output]); + info.push(dim as u32); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - tensor.client.execute( - Box::new(kernel), - &[ - &tensor.handle, - &indices.handle, - &output.handle, - &info_handle, - ], - ); + tensor.client.execute( + Box::new(kernel), + &[ + &tensor.handle, + &indices.handle, + &output.handle, + &info_handle, + ], + ); - output + output } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Int, Shape, Tensor}; + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Int, Shape, Tensor}; - #[test] - fn gather_should_work_with_multiple_workgroups_dim0() { - test_same_as_ref([6, 256], 0); - } + #[test] + fn gather_should_work_with_multiple_workgroups_dim0() { + test_same_as_ref([6, 256], 0); + } - #[test] - fn gather_should_work_with_multiple_workgroups_dim1() { - test_same_as_ref([6, 256], 1); - } + #[test] + fn gather_should_work_with_multiple_workgroups_dim1() { + test_same_as_ref([6, 256], 1); + } - fn test_same_as_ref(shape: [usize; D], dim: usize) { - TestBackend::seed(0); - let max = shape[dim]; - let shape = Shape::new(shape); - let tensor = Tensor::::random(shape.clone(), Distribution::Default); - let indices = Tensor::::from_data( - Tensor::::random( - [shape.num_elements()], - Distribution::Uniform(0., max as f32), - ) - .into_data() - .convert(), - ) - .reshape(shape); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let indices_ref = - Tensor::::from_data(indices.to_data().convert()); + fn test_same_as_ref(shape: [usize; D], dim: usize) { + TestBackend::seed(0); + let max = shape[dim]; + let shape = Shape::new(shape); + let tensor = Tensor::::random(shape.clone(), Distribution::Default); + let indices = Tensor::::from_data( + Tensor::::random( + [shape.num_elements()], + Distribution::Uniform(0., max as f32), + ) + .into_data() + .convert(), + ) + .reshape(shape); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let indices_ref = Tensor::::from_data(indices.to_data().convert()); - let actual = Tensor::::from_primitive(gather( - dim, - tensor.into_primitive(), - indices.into_primitive(), - )); - let expected = tensor_ref.gather(dim, indices_ref); + let actual = Tensor::::from_primitive(gather( + dim, + tensor.into_primitive(), + indices.into_primitive(), + )); + let expected = tensor_ref.gather(dim, indices_ref); - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/index/scatter.rs b/burn-wgpu/src/kernel/index/scatter.rs index 2c09a4f1fc..5b8489f782 100644 --- a/burn-wgpu/src/kernel/index/scatter.rs +++ b/burn-wgpu/src/kernel/index/scatter.rs @@ -1,141 +1,138 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!(Scatter, "../../template/index/scatter.wgsl"); pub(crate) fn scatter( - dim: usize, - tensor: WgpuTensor, - indices: WgpuTensor, - value: WgpuTensor, + dim: usize, + tensor: WgpuTensor, + indices: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - let indices = kernel::into_contiguous(indices); - let tensor = kernel::into_contiguous(tensor); - let value = kernel::into_contiguous(value); - - let tensor = match tensor.can_mut() { - true => tensor, - false => tensor.copy(), - }; - - let mut info = build_info(&[&tensor, &value]); - let mut strides = [0; D]; - let mut current = 1; - let mut num_elems_per_workgroup = 1; - - tensor - .shape - .dims - .iter() - .enumerate() - .rev() - .filter(|(index, _val)| *index != dim) - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - num_elems_per_workgroup *= tensor.shape.dims[index]; - }); - - strides - .into_iter() - .for_each(|stride| info.push(stride as u32)); - - info.push(dim as u32); - - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - num_elems_per_workgroup, - WORKGROUP_DEFAULT, - )); - - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &indices.handle, &value.handle, &info_handle], + let indices = kernel::into_contiguous(indices); + let tensor = kernel::into_contiguous(tensor); + let value = kernel::into_contiguous(value); + + let tensor = match tensor.can_mut() { + true => tensor, + false => tensor.copy(), + }; + + let mut info = build_info(&[&tensor, &value]); + let mut strides = [0; D]; + let mut current = 1; + let mut num_elems_per_workgroup = 1; + + tensor + .shape + .dims + .iter() + .enumerate() + .rev() + .filter(|(index, _val)| *index != dim) + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + num_elems_per_workgroup *= tensor.shape.dims[index]; + }); + + strides + .into_iter() + .for_each(|stride| info.push(stride as u32)); + + info.push(dim as u32); + + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems_per_workgroup, WORKGROUP_DEFAULT), ); - tensor + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &indices.handle, &value.handle, &info_handle], + ); + + tensor } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; - - #[test] - fn scatter_should_work_with_multiple_workgroups_2d_dim0() { - same_as_reference_same_shape(0, [256, 32]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_2d_dim1() { - same_as_reference_same_shape(1, [32, 256]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_3d_dim0() { - same_as_reference_same_shape(0, [256, 6, 6]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_3d_dim1() { - same_as_reference_same_shape(1, [6, 256, 6]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_3d_dim2() { - same_as_reference_same_shape(2, [6, 6, 256]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_diff_shapes() { - same_as_reference_diff_shape(1, [32, 128], [32, 1]); - } - - fn same_as_reference_diff_shape( - dim: usize, - shape1: [usize; D], - shape2: [usize; D], - ) { - TestBackend::seed(0); - let tensor = Tensor::::random(shape1, Distribution::Default); - let value = Tensor::::random(shape2, Distribution::Default); - let indices = Tensor::::from_data( - Tensor::::random( - [shape2.iter().product()], - Distribution::Uniform(0., shape2[dim] as f32), - ) - .into_data() - .convert(), - ) - .reshape(shape2); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - let indices_ref = - Tensor::::from_data(indices.to_data().convert()); - - let actual = Tensor::::from_primitive(scatter( - dim, - tensor.into_primitive(), - indices.into_primitive(), - value.into_primitive(), - )); - let expected = tensor_ref.scatter(dim, indices_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - fn same_as_reference_same_shape(dim: usize, shape: [usize; D]) { - same_as_reference_diff_shape(dim, shape, shape); - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; + + #[test] + fn scatter_should_work_with_multiple_workgroups_2d_dim0() { + same_as_reference_same_shape(0, [256, 32]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_2d_dim1() { + same_as_reference_same_shape(1, [32, 256]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_3d_dim0() { + same_as_reference_same_shape(0, [256, 6, 6]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_3d_dim1() { + same_as_reference_same_shape(1, [6, 256, 6]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_3d_dim2() { + same_as_reference_same_shape(2, [6, 6, 256]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_diff_shapes() { + same_as_reference_diff_shape(1, [32, 128], [32, 1]); + } + + fn same_as_reference_diff_shape( + dim: usize, + shape1: [usize; D], + shape2: [usize; D], + ) { + TestBackend::seed(0); + let tensor = Tensor::::random(shape1, Distribution::Default); + let value = Tensor::::random(shape2, Distribution::Default); + let indices = Tensor::::from_data( + Tensor::::random( + [shape2.iter().product()], + Distribution::Uniform(0., shape2[dim] as f32), + ) + .into_data() + .convert(), + ) + .reshape(shape2); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + let indices_ref = Tensor::::from_data(indices.to_data().convert()); + + let actual = Tensor::::from_primitive(scatter( + dim, + tensor.into_primitive(), + indices.into_primitive(), + value.into_primitive(), + )); + let expected = tensor_ref.scatter(dim, indices_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + fn same_as_reference_same_shape(dim: usize, shape: [usize; D]) { + same_as_reference_diff_shape(dim, shape, shape); + } } diff --git a/burn-wgpu/src/kernel/index/select.rs b/burn-wgpu/src/kernel/index/select.rs index 5000b90608..228be8bd05 100644 --- a/burn-wgpu/src/kernel/index/select.rs +++ b/burn-wgpu/src/kernel/index/select.rs @@ -1,177 +1,172 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(IndexSelect, "../../template/index/select.wgsl"); kernel_wgsl!( - SelectAssignInplace, - "../../template/index/select_assign_inplace.wgsl" + SelectAssignInplace, + "../../template/index/select_assign_inplace.wgsl" ); pub(crate) fn select( - tensor: WgpuTensor, - dim: usize, - indices: WgpuTensor, + tensor: WgpuTensor, + dim: usize, + indices: WgpuTensor, ) -> WgpuTensor { - let mut output_shape = tensor.shape.clone(); - output_shape.dims[dim] = indices.shape.dims[0]; - - let num_elems = output_shape.num_elements(); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), output_shape); - - let mut info = build_info(&[&tensor, &output]); - info.push(dim as u32); - - let info_handle = output.client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - tensor.client.execute( - Box::new(kernel), - &[ - &tensor.handle, - &indices.handle, - &output.handle, - &info_handle, - ], - ); - - output + let mut output_shape = tensor.shape.clone(); + output_shape.dims[dim] = indices.shape.dims[0]; + + let num_elems = output_shape.num_elements(); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), output_shape); + + let mut info = build_info(&[&tensor, &output]); + info.push(dim as u32); + + let info_handle = output.client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + + tensor.client.execute( + Box::new(kernel), + &[ + &tensor.handle, + &indices.handle, + &output.handle, + &info_handle, + ], + ); + + output } pub(crate) fn select_assign( - tensor: WgpuTensor, - dim: usize, - indices: WgpuTensor, - value: WgpuTensor, + tensor: WgpuTensor, + dim: usize, + indices: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - let tensor = match tensor.can_mut() { - true => tensor, - false => tensor.copy(), - }; - - let mut info = build_info(&[&tensor, &value]); - let mut strides = [0; D]; - let mut current = 1; - let mut num_elems_per_workgroup = 1; - - tensor - .shape - .dims - .iter() - .enumerate() - .rev() - .filter(|(index, _val)| *index != dim) - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - num_elems_per_workgroup *= tensor.shape.dims[index]; - }); - - strides - .into_iter() - .for_each(|stride| info.push(stride as u32)); - - info.push(dim as u32); - - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - num_elems_per_workgroup, - WORKGROUP_DEFAULT, - )); - - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &indices.handle, &value.handle, &info_handle], - ); - - tensor + let tensor = match tensor.can_mut() { + true => tensor, + false => tensor.copy(), + }; + + let mut info = build_info(&[&tensor, &value]); + let mut strides = [0; D]; + let mut current = 1; + let mut num_elems_per_workgroup = 1; + + tensor + .shape + .dims + .iter() + .enumerate() + .rev() + .filter(|(index, _val)| *index != dim) + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + num_elems_per_workgroup *= tensor.shape.dims[index]; + }); + + strides + .into_iter() + .for_each(|stride| info.push(stride as u32)); + + info.push(dim as u32); + + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + num_elems_per_workgroup, + WORKGROUP_DEFAULT, + )); + + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &indices.handle, &value.handle, &info_handle], + ); + + tensor } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; - - #[test] - fn select_should_work_with_multiple_workgroups() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let indices = Tensor::::arange(0..100); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let indices_ref = - Tensor::::from_data(indices.to_data().convert()); - - let actual = select(tensor.into_primitive(), 1, indices.into_primitive()); - let expected = tensor_ref.select(1, indices_ref); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_2d_dim0() { - select_assign_same_as_ref(0, [256, 6]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_2d_dim1() { - select_assign_same_as_ref(1, [6, 256]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_3d_dim0() { - select_assign_same_as_ref(0, [256, 6, 6]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_3d_dim1() { - select_assign_same_as_ref(1, [6, 256, 6]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_3d_dim2() { - select_assign_same_as_ref(2, [6, 6, 256]); - } - - fn select_assign_same_as_ref(dim: usize, shape: [usize; D]) { - TestBackend::seed(0); - let tensor = Tensor::::random(shape, Distribution::Default); - let value = Tensor::::random(shape, Distribution::Default); - let indices = Tensor::::from_data( - Tensor::::random( - [shape[dim]], - Distribution::Uniform(0., shape[dim] as f32), - ) - .into_data() - .convert(), - ); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - let indices_ref = - Tensor::::from_data(indices.to_data().convert()); - - let actual = Tensor::::from_primitive(select_assign( - tensor.into_primitive(), - dim, - indices.into_primitive(), - value.into_primitive(), - )); - let expected = tensor_ref.select_assign(dim, indices_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; + + #[test] + fn select_should_work_with_multiple_workgroups() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let indices = Tensor::::arange(0..100); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let indices_ref = Tensor::::from_data(indices.to_data().convert()); + + let actual = select(tensor.into_primitive(), 1, indices.into_primitive()); + let expected = tensor_ref.select(1, indices_ref); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_2d_dim0() { + select_assign_same_as_ref(0, [256, 6]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_2d_dim1() { + select_assign_same_as_ref(1, [6, 256]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_3d_dim0() { + select_assign_same_as_ref(0, [256, 6, 6]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_3d_dim1() { + select_assign_same_as_ref(1, [6, 256, 6]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_3d_dim2() { + select_assign_same_as_ref(2, [6, 6, 256]); + } + + fn select_assign_same_as_ref(dim: usize, shape: [usize; D]) { + TestBackend::seed(0); + let tensor = Tensor::::random(shape, Distribution::Default); + let value = Tensor::::random(shape, Distribution::Default); + let indices = Tensor::::from_data( + Tensor::::random([shape[dim]], Distribution::Uniform(0., shape[dim] as f32)) + .into_data() + .convert(), + ); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + let indices_ref = Tensor::::from_data(indices.to_data().convert()); + + let actual = Tensor::::from_primitive(select_assign( + tensor.into_primitive(), + dim, + indices.into_primitive(), + value.into_primitive(), + )); + let expected = tensor_ref.select_assign(dim, indices_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/index/slice.rs b/burn-wgpu/src/kernel/index/slice.rs index e431168b68..62d59d06b6 100644 --- a/burn-wgpu/src/kernel/index/slice.rs +++ b/burn-wgpu/src/kernel/index/slice.rs @@ -1,132 +1,130 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::Shape; use std::ops::Range; kernel_wgsl!(IndexRaw, "../../template/index/slice.wgsl"); kernel_wgsl!( - IndexAssignInplaceRaw, - "../../template/index/slice_assign_inplace.wgsl" + IndexAssignInplaceRaw, + "../../template/index/slice_assign_inplace.wgsl" ); pub(crate) fn slice( - tensor: WgpuTensor, - indices: [Range; D2], + tensor: WgpuTensor, + indices: [Range; D2], ) -> WgpuTensor { - let mut dims = tensor.shape.dims; - for i in 0..D2 { - dims[i] = indices[i].end - indices[i].start; - } - let shape_output = Shape::new(dims); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - slice_on_output(tensor, output, indices) + let mut dims = tensor.shape.dims; + for i in 0..D2 { + dims[i] = indices[i].end - indices[i].start; + } + let shape_output = Shape::new(dims); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + slice_on_output(tensor, output, indices) } pub(crate) fn slice_on_output( - tensor: WgpuTensor, - output: WgpuTensor, - indices: [Range; D2], + tensor: WgpuTensor, + output: WgpuTensor, + indices: [Range; D2], ) -> WgpuTensor { - let mut info = build_info(&[&tensor, &output]); + let mut info = build_info(&[&tensor, &output]); - for i in 0..D1 { - let start = indices.get(i).map(|index| index.start).unwrap_or(0); - info.push(start as u32); - } + for i in 0..D1 { + let start = indices.get(i).map(|index| index.start).unwrap_or(0); + info.push(start as u32); + } - let info_handle = output.client.create(bytemuck::cast_slice(&info)); + let info_handle = output.client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &output.handle, &info_handle], + let kernel = + StaticKernel::>::new( + elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), ); - output + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &output.handle, &info_handle], + ); + + output } pub(crate) fn slice_assign( - tensor: WgpuTensor, - indices: [Range; D2], - value: WgpuTensor, + tensor: WgpuTensor, + indices: [Range; D2], + value: WgpuTensor, ) -> WgpuTensor { - let tensor = match tensor.can_mut() { - true => tensor, - false => tensor.copy(), - }; - let num_elems = tensor.shape.num_elements(); - let mut info = build_info(&[&tensor, &value]); - - for i in 0..D1 { - let start = indices.get(i).map(|index| index.start).unwrap_or(0); - info.push(start as u32); - } - - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &value.handle, &info_handle], - ); - - tensor + let tensor = match tensor.can_mut() { + true => tensor, + false => tensor.copy(), + }; + let num_elems = tensor.shape.num_elements(); + let mut info = build_info(&[&tensor, &value]); + + for i in 0..D1 { + let start = indices.get(i).map(|index| index.start).unwrap_or(0); + info.push(start as u32); + } + + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &value.handle, &info_handle], + ); + + tensor } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; - - #[test] - fn slice_should_work_with_multiple_workgroups() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let indices = [3..5, 45..256]; - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let actual = slice(tensor.into_primitive(), indices.clone()); - let expected = tensor_ref.slice(indices); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } - - #[test] - fn slice_assign_should_work_with_multiple_workgroups() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let value = Tensor::::random([2, 211], Distribution::Default); - let indices = [3..5, 45..256]; - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - - let actual = slice_assign( - tensor.into_primitive(), - indices.clone(), - value.into_primitive(), - ); - let expected = tensor_ref.slice_assign(indices, value_ref); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; + + #[test] + fn slice_should_work_with_multiple_workgroups() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let indices = [3..5, 45..256]; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let actual = slice(tensor.into_primitive(), indices.clone()); + let expected = tensor_ref.slice(indices); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn slice_assign_should_work_with_multiple_workgroups() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let value = Tensor::::random([2, 211], Distribution::Default); + let indices = [3..5, 45..256]; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + + let actual = slice_assign( + tensor.into_primitive(), + indices.clone(), + value.into_primitive(), + ); + let expected = tensor_ref.slice_assign(indices, value_ref); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/kernel/mask/base.rs b/burn-wgpu/src/kernel/mask/base.rs index cfd2fe765f..44b63f73ba 100644 --- a/burn-wgpu/src/kernel/mask/base.rs +++ b/burn-wgpu/src/kernel/mask/base.rs @@ -2,29 +2,29 @@ use crate::{element::WgpuElement, tensor::WgpuTensor}; /// Execute the mask fill kernel. pub fn mask_fill( - tensor: WgpuTensor, - mask: WgpuTensor, - value: E, + tensor: WgpuTensor, + mask: WgpuTensor, + value: E, ) -> WgpuTensor { - if tensor.can_mut() { - return super::mask_fill::mask_fill_inplace(tensor, mask, value); - } + if tensor.can_mut() { + return super::mask_fill::mask_fill_inplace(tensor, mask, value); + } - super::mask_fill::mask_fill(tensor, mask, value) + super::mask_fill::mask_fill(tensor, mask, value) } /// Execute the mask where kernel. pub fn mask_where( - tensor: WgpuTensor, - mask: WgpuTensor, - value: WgpuTensor, + tensor: WgpuTensor, + mask: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - if tensor.can_mut_broadcast(&value) { - return super::mask_where::mask_where_inplace(tensor, mask, value, false); - } - if value.can_mut_broadcast(&tensor) { - return super::mask_where::mask_where_inplace(value, mask, tensor, true); - } + if tensor.can_mut_broadcast(&value) { + return super::mask_where::mask_where_inplace(tensor, mask, value, false); + } + if value.can_mut_broadcast(&tensor) { + return super::mask_where::mask_where_inplace(value, mask, tensor, true); + } - super::mask_where::mask_where(tensor, mask, value) + super::mask_where::mask_where(tensor, mask, value) } diff --git a/burn-wgpu/src/kernel/mask/mask_fill.rs b/burn-wgpu/src/kernel/mask/mask_fill.rs index aed9c33593..22679853bd 100644 --- a/burn-wgpu/src/kernel/mask/mask_fill.rs +++ b/burn-wgpu/src/kernel/mask/mask_fill.rs @@ -1,122 +1,122 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(MaskFill, "../../template/mask/fill.wgsl"); kernel_wgsl!(MaskFillInplace, "../../template/mask/fill_inplace.wgsl"); pub fn mask_fill( - input: WgpuTensor, - mask: WgpuTensor, - value: E, + input: WgpuTensor, + mask: WgpuTensor, + value: E, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let output = empty_device( - input.client.clone(), - input.device.clone(), - input.shape.clone(), - ); - - let value_handle = output.client.create(E::as_bytes(&[value])); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &mask, &output]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &value_handle, - &mask.handle, - &output.handle, - &info_handle, - ], - ); - - output + let num_elems = input.shape.num_elements(); + let output = empty_device( + input.client.clone(), + input.device.clone(), + input.shape.clone(), + ); + + let value_handle = output.client.create(E::as_bytes(&[value])); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &mask, &output]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &value_handle, + &mask.handle, + &output.handle, + &info_handle, + ], + ); + + output } pub fn mask_fill_inplace( - input: WgpuTensor, - mask: WgpuTensor, - value: E, + input: WgpuTensor, + mask: WgpuTensor, + value: E, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let value_handle = input.client.create(E::as_bytes(&[value])); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &mask]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &value_handle, &mask.handle, &info_handle], - ); - - input + let num_elems = input.shape.num_elements(); + let value_handle = input.client.create(E::as_bytes(&[value])); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &mask]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &value_handle, &mask.handle, &info_handle], + ); + + input } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Bool, Distribution, Tensor}; - - #[test] - fn mask_fill_should_work_with_multiple_invocations() { - let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); - - let actual = Tensor::::from_primitive(mask_fill::( - tensor.into_primitive(), - mask.into_primitive(), - 4.0, - )); - let expected = tensor_ref.mask_fill(mask_ref, 4.0); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[test] - fn mask_fill_inplace_should_work_with_multiple_invocations() { - let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); - - let actual = Tensor::::from_primitive(mask_fill_inplace::( - tensor.into_primitive(), - mask.into_primitive(), - 4.0, - )); - let expected = tensor_ref.mask_fill(mask_ref, 4.0); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs_mask_fill() -> ( - Tensor, - Tensor, - Tensor, - Tensor, - ) { - let tensor = Tensor::::random([2, 6, 256], Distribution::Default); - let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) - .lower_equal_elem(0.5); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let mask_ref = Tensor::::from_data(mask.to_data()); - - (tensor, mask, tensor_ref, mask_ref) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Bool, Distribution, Tensor}; + + #[test] + fn mask_fill_should_work_with_multiple_invocations() { + let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); + + let actual = Tensor::::from_primitive(mask_fill::( + tensor.into_primitive(), + mask.into_primitive(), + 4.0, + )); + let expected = tensor_ref.mask_fill(mask_ref, 4.0); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[test] + fn mask_fill_inplace_should_work_with_multiple_invocations() { + let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); + + let actual = Tensor::::from_primitive(mask_fill_inplace::( + tensor.into_primitive(), + mask.into_primitive(), + 4.0, + )); + let expected = tensor_ref.mask_fill(mask_ref, 4.0); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs_mask_fill() -> ( + Tensor, + Tensor, + Tensor, + Tensor, + ) { + let tensor = Tensor::::random([2, 6, 256], Distribution::Default); + let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) + .lower_equal_elem(0.5); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let mask_ref = Tensor::::from_data(mask.to_data()); + + (tensor, mask, tensor_ref, mask_ref) + } } diff --git a/burn-wgpu/src/kernel/mask/mask_where.rs b/burn-wgpu/src/kernel/mask/mask_where.rs index 9972554ab8..9775ed242d 100644 --- a/burn-wgpu/src/kernel/mask/mask_where.rs +++ b/burn-wgpu/src/kernel/mask/mask_where.rs @@ -1,150 +1,150 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(MaskWhere, "../../template/mask/where.wgsl"); kernel_wgsl!(MaskWhereInplace, "../../template/mask/where_inplace.wgsl"); pub fn mask_where( - input: WgpuTensor, - mask: WgpuTensor, - value: WgpuTensor, + input: WgpuTensor, + mask: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let output = empty_device( - input.client.clone(), - input.device.clone(), - input.shape.clone(), - ); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &value, &mask, &output]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &value.handle, - &mask.handle, - &output.handle, - &info_handle, - ], - ); - - output + let num_elems = input.shape.num_elements(); + let output = empty_device( + input.client.clone(), + input.device.clone(), + input.shape.clone(), + ); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &value, &mask, &output]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &value.handle, + &mask.handle, + &output.handle, + &info_handle, + ], + ); + + output } pub fn mask_where_inplace( - input: WgpuTensor, - mask: WgpuTensor, - value: WgpuTensor, - reverse: bool, + input: WgpuTensor, + mask: WgpuTensor, + value: WgpuTensor, + reverse: bool, ) -> WgpuTensor { - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - input.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let mut info = build_info(&[&input, &value, &mask]); - info.push(match reverse { - true => 1, - false => 0, - }); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &value.handle, &mask.handle, &info_handle], - ); - - input + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + input.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let mut info = build_info(&[&input, &value, &mask]); + info.push(match reverse { + true => 1, + false => 0, + }); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &value.handle, &mask.handle, &info_handle], + ); + + input } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; - - #[test] - fn mask_where_should_work_with_multiple_invocations() { - let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - - let actual = Tensor::::from_primitive(mask_where::( - tensor.into_primitive(), - mask.into_primitive(), - value.into_primitive(), - )); - let expected = tensor_ref.mask_where(mask_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - #[test] - fn mask_where_inplace_direction_1_should_work_with_multiple_invocations() { - let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - - let actual = Tensor::::from_primitive(mask_where_inplace::( - tensor.into_primitive(), - mask.into_primitive(), - value.into_primitive(), - false, - )); - let expected = tensor_ref.mask_where(mask_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[test] - fn mask_where_inplace_direction_0_should_work_with_multiple_invocation() { - let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - - let actual = Tensor::::from_primitive(mask_where_inplace::( - value.into_primitive(), - mask.into_primitive(), - tensor.into_primitive(), - true, - )); - let expected = tensor_ref.mask_where(mask_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs_mask_where() -> ( - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - ) { - TestBackend::seed(0); - let tensor = Tensor::::random([2, 6, 256], Distribution::Default); - let value = Tensor::::random([2, 6, 256], Distribution::Default); - let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) - .lower_equal_elem(0.5); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - let mask_ref = Tensor::::from_data(mask.to_data()); - assert_eq!(mask.to_data(), mask_ref.to_data()); - - (tensor, value, mask, tensor_ref, value_ref, mask_ref) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; + + #[test] + fn mask_where_should_work_with_multiple_invocations() { + let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); + + let actual = Tensor::::from_primitive(mask_where::( + tensor.into_primitive(), + mask.into_primitive(), + value.into_primitive(), + )); + let expected = tensor_ref.mask_where(mask_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + #[test] + fn mask_where_inplace_direction_1_should_work_with_multiple_invocations() { + let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); + + let actual = Tensor::::from_primitive(mask_where_inplace::( + tensor.into_primitive(), + mask.into_primitive(), + value.into_primitive(), + false, + )); + let expected = tensor_ref.mask_where(mask_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[test] + fn mask_where_inplace_direction_0_should_work_with_multiple_invocation() { + let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); + + let actual = Tensor::::from_primitive(mask_where_inplace::( + value.into_primitive(), + mask.into_primitive(), + tensor.into_primitive(), + true, + )); + let expected = tensor_ref.mask_where(mask_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs_mask_where() -> ( + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + ) { + TestBackend::seed(0); + let tensor = Tensor::::random([2, 6, 256], Distribution::Default); + let value = Tensor::::random([2, 6, 256], Distribution::Default); + let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) + .lower_equal_elem(0.5); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + let mask_ref = Tensor::::from_data(mask.to_data()); + assert_eq!(mask.to_data(), mask_ref.to_data()); + + (tensor, value, mask, tensor_ref, value_ref, mask_ref) + } } diff --git a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs index c57270a12b..4d2e0f71e8 100644 --- a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs +++ b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs @@ -2,201 +2,201 @@ use burn_tensor::Shape; use std::marker::PhantomData; use crate::{ - compute::{DynamicKernel, Kernel, WorkGroup}, - element::WgpuElement, - kernel::{ - build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, - WORKGROUP_DEFAULT, - }, - kernel_wgsl, - tensor::WgpuTensor, + compute::{DynamicKernel, Kernel, WorkGroup}, + element::WgpuElement, + kernel::{ + build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, + WORKGROUP_DEFAULT, + }, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!( - MatmulMemCoalescingRaw, - "../../template/matmul/mem_coalescing.wgsl" + MatmulMemCoalescingRaw, + "../../template/matmul/mem_coalescing.wgsl" ); #[derive(new, Debug)] struct MatmulMemCoalescing { - workgroup_size_x: usize, - workgroup_size_y: usize, - _elem: PhantomData, + workgroup_size_x: usize, + workgroup_size_y: usize, + _elem: PhantomData, } impl DynamicKernelSource for MatmulMemCoalescing { - fn source(&self) -> SourceTemplate { - MatmulMemCoalescingRaw::source() - .register("workgroup_size_x", self.workgroup_size_x.to_string()) - .register("workgroup_size_y", self.workgroup_size_y.to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulMemCoalescingRaw::source() + .register("workgroup_size_x", self.workgroup_size_x.to_string()) + .register("workgroup_size_y", self.workgroup_size_y.to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using memory coalescing algorithm with workgroups of size 16 pub fn matmul_mem_coalescing_default( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) + matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) } /// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes pub fn matmul_mem_coalescing( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, - workgroup_size_x: usize, - workgroup_size_y: usize, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, + workgroup_size_x: usize, + workgroup_size_y: usize, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&rhs); - let lhs = into_contiguous(lhs); - let rhs = into_contiguous(rhs); + let lhs = into_contiguous(lhs); + let rhs = into_contiguous(rhs); - let info = build_info(&[&lhs, &rhs, &output]); + let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - let kernel = matmul_mem_coalescing_kernel::( - &lhs.shape, - &rhs.shape, - &output.shape, - workgroup_size_x, - workgroup_size_y, - ); + let kernel = matmul_mem_coalescing_kernel::( + &lhs.shape, + &rhs.shape, + &output.shape, + workgroup_size_x, + workgroup_size_y, + ); - lhs.client.execute( - kernel, - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], - ); + lhs.client.execute( + kernel, + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + ); - output + output } fn matmul_mem_coalescing_kernel( - lhs_shape: &Shape, - rhs_shape: &Shape, - output_shape: &Shape, - workgroup_size_x: usize, - workgroup_size_y: usize, + lhs_shape: &Shape, + rhs_shape: &Shape, + output_shape: &Shape, + workgroup_size_x: usize, + workgroup_size_y: usize, ) -> Box { - let num_rows = lhs_shape.dims[D - 2]; - let num_cols = rhs_shape.dims[D - 1]; - - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output_shape.dims[i]; - } - - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); - - Box::new(DynamicKernel::new( - MatmulMemCoalescing::::new(workgroup_size_x, workgroup_size_y), - workgroup, - )) + let num_rows = lhs_shape.dims[D - 2]; + let num_cols = rhs_shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output_shape.dims[i]; + } + + let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); + + Box::new(DynamicKernel::new( + MatmulMemCoalescing::::new(workgroup_size_x, workgroup_size_y), + workgroup, + )) } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_mem_coalescing_straightforward() { - test_with_params::<2, 2>(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_shapes_smaller_than_blocks() { - test_with_params::<16, 16>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_n_smaller_than_m() { - test_with_params::<2, 2>(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_m_smaller_than_n() { - test_with_params::<2, 2>(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_k_smaller_than_m_n() { - test_with_params::<2, 2>(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_k_larger_than_m_n() { - test_with_params::<2, 2>(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_multibatch_1_dim() { - test_with_params::<2, 2>(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_multibatch_2_dims() { - test_with_params::<2, 2>(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_mem_coalescing_blocks_divide_shapes_unevenly() { - test_with_params::<3, 3>(7, 7, 7, 1, 1); - } - - fn test_with_params( - m: usize, - k: usize, - n: usize, - batch_1: usize, - batch_2: usize, - ) { - let func = |lhs, rhs, out| { - matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y) - }; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_batches_no_padding() { - let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_col_no_padding() { - let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_with_batch_no_padding() { - let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_mem_coalescing_straightforward() { + test_with_params::<2, 2>(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_shapes_smaller_than_blocks() { + test_with_params::<16, 16>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_n_smaller_than_m() { + test_with_params::<2, 2>(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_m_smaller_than_n() { + test_with_params::<2, 2>(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_k_smaller_than_m_n() { + test_with_params::<2, 2>(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_k_larger_than_m_n() { + test_with_params::<2, 2>(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_multibatch_1_dim() { + test_with_params::<2, 2>(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_multibatch_2_dims() { + test_with_params::<2, 2>(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_mem_coalescing_blocks_divide_shapes_unevenly() { + test_with_params::<3, 3>(7, 7, 7, 1, 1); + } + + fn test_with_params( + m: usize, + k: usize, + n: usize, + batch_1: usize, + batch_2: usize, + ) { + let func = |lhs, rhs, out| { + matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y) + }; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_batches_no_padding() { + let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_col_no_padding() { + let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_with_batch_no_padding() { + let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/naive.rs b/burn-wgpu/src/kernel/matmul/naive.rs index 8d7d95b51a..a0f2055809 100644 --- a/burn-wgpu/src/kernel/matmul/naive.rs +++ b/burn-wgpu/src/kernel/matmul/naive.rs @@ -1,9 +1,9 @@ use crate::{ - compute::{StaticKernel, WorkGroup}, - element::WgpuElement, - kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource}, - kernel_wgsl, - tensor::WgpuTensor, + compute::{StaticKernel, WorkGroup}, + element::WgpuElement, + kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl"); @@ -11,164 +11,164 @@ kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl"); struct MatmulNaive; impl StaticKernelSource - for MatmulNaive + for MatmulNaive { - fn source() -> SourceTemplate { - MatmulNaiveRaw::source() - .register("block_size_m", WORKGROUP_SIZE_X.to_string()) - .register("block_size_n", WORKGROUP_SIZE_Y.to_string()) - } + fn source() -> SourceTemplate { + MatmulNaiveRaw::source() + .register("block_size_m", WORKGROUP_SIZE_X.to_string()) + .register("block_size_n", WORKGROUP_SIZE_Y.to_string()) + } } /// Matrix multiplication using naive algorithm with workgroups of size 16 pub fn matmul_naive_default( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, ) -> WgpuTensor { - matmul_naive::(lhs, rhs, output) + matmul_naive::(lhs, rhs, output) } /// Matrix multiplication using naive algorithm with custom workgroup sizes pub fn matmul_naive< - E: WgpuElement, - const D: usize, - const WORKGROUP_SIZE_X: usize, - const WORKGROUP_SIZE_Y: usize, + E: WgpuElement, + const D: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, >( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); - - let lhs = into_contiguous(lhs); - let rhs = into_contiguous(rhs); - - let num_rows = lhs.shape.dims[D - 2]; - let num_cols = rhs.shape.dims[D - 1]; - - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / WORKGROUP_SIZE_X as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / WORKGROUP_SIZE_Y as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output.shape.dims[i]; - } - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); - - let kernel = StaticKernel::< - KernelSettings< - MatmulNaive, - E, - i32, - WORKGROUP_SIZE_X, - WORKGROUP_SIZE_Y, - 1, - >, - >::new(workgroup); - - let info = build_info(&[&lhs, &rhs, &output]); - - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], - ); - - output + lhs.assert_is_on_same_device(&rhs); + + let lhs = into_contiguous(lhs); + let rhs = into_contiguous(rhs); + + let num_rows = lhs.shape.dims[D - 2]; + let num_cols = rhs.shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / WORKGROUP_SIZE_X as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / WORKGROUP_SIZE_Y as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output.shape.dims[i]; + } + let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); + + let kernel = StaticKernel::< + KernelSettings< + MatmulNaive, + E, + i32, + WORKGROUP_SIZE_X, + WORKGROUP_SIZE_Y, + 1, + >, + >::new(workgroup); + + let info = build_info(&[&lhs, &rhs, &output]); + + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + ); + + output } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_naive_straightforward() { - test_with_params::<2, 2>(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_naive_shapes_smaller_than_blocks() { - test_with_params::<16, 16>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_n_smaller_than_m() { - test_with_params::<2, 2>(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_naive_m_smaller_than_n() { - test_with_params::<2, 2>(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_k_smaller_than_m_n() { - test_with_params::<2, 2>(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_k_larger_than_m_n() { - test_with_params::<2, 2>(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_multibatch_1_dim() { - test_with_params::<2, 2>(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_naive_multibatch_2_dims() { - test_with_params::<2, 2>(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_naive_blocks_divide_shapes_unevenly() { - test_with_params::<3, 3>(7, 7, 7, 1, 1); - } - - fn test_with_params( - m: usize, - k: usize, - n: usize, - batch_1: usize, - batch_2: usize, - ) { - let func = matmul_naive::; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_batches_no_padding() { - let matmul_func = matmul_naive::; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_col_no_padding() { - let matmul_func = matmul_naive::; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_naive::; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_naive_straightforward() { + test_with_params::<2, 2>(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_naive_shapes_smaller_than_blocks() { + test_with_params::<16, 16>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_n_smaller_than_m() { + test_with_params::<2, 2>(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_naive_m_smaller_than_n() { + test_with_params::<2, 2>(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_k_smaller_than_m_n() { + test_with_params::<2, 2>(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_k_larger_than_m_n() { + test_with_params::<2, 2>(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_multibatch_1_dim() { + test_with_params::<2, 2>(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_naive_multibatch_2_dims() { + test_with_params::<2, 2>(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_naive_blocks_divide_shapes_unevenly() { + test_with_params::<3, 3>(7, 7, 7, 1, 1); + } + + fn test_with_params( + m: usize, + k: usize, + n: usize, + batch_1: usize, + batch_2: usize, + ) { + let func = matmul_naive::; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_batches_no_padding() { + let matmul_func = matmul_naive::; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_col_no_padding() { + let matmul_func = matmul_naive::; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_naive::; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs index 0ce7808d5c..8700bc8336 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs @@ -1,10 +1,10 @@ use super::padding::{crop, pad_round, PaddingOutput}; use crate::{ - compute::{DynamicKernel, WgpuHandle, WorkGroup}, - element::WgpuElement, - kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource}, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::{DynamicKernel, WgpuHandle, WorkGroup}, + element::WgpuElement, + kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource}, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::{Element, Shape}; @@ -14,75 +14,75 @@ pub(crate) const B_K: usize = 32; pub(crate) const WORKGROUP_SIZE: usize = 16; pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { - let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; - let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; - let mut num_blocks_z = 1; - for i in 0..D - 2 { - num_blocks_z *= output_shape.dims[i]; - } + let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; + let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; + let mut num_blocks_z = 1; + for i in 0..D - 2 { + num_blocks_z *= output_shape.dims[i]; + } - WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) + WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) } pub(super) fn make_info_handle( - lhs: &WgpuTensor, - rhs: &WgpuTensor, - output: &WgpuTensor, + lhs: &WgpuTensor, + rhs: &WgpuTensor, + output: &WgpuTensor, ) -> WgpuHandle { - let info = build_info(&[lhs, rhs, output]); - rhs.client.create(bytemuck::cast_slice(&info)) + let info = build_info(&[lhs, rhs, output]); + rhs.client.create(bytemuck::cast_slice(&info)) } #[allow(clippy::too_many_arguments)] pub(super) fn matmul_tiling_2d_launch< - E: WgpuElement + Element, - const D: usize, - K: DynamicKernelSource + 'static, + E: WgpuElement + Element, + const D: usize, + K: DynamicKernelSource + 'static, >( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, - kernel: K, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, + kernel: K, ) -> WgpuTensor { - // A tensor may need to be padded, in which case it will implicitly become contiguous - // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. - // If batches were swapped among themselves, or if the last two dims are transposed, the underlying - // kernel handles it without needing to turn it into contiguous. - let round_lhs = pad_round(lhs, B_M, B_K); - let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_lhs.into_tensor(), - }; - let round_rhs = pad_round(rhs, B_K, B_N); - let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_rhs.into_tensor(), - }; + // A tensor may need to be padded, in which case it will implicitly become contiguous + // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. + // If batches were swapped among themselves, or if the last two dims are transposed, the underlying + // kernel handles it without needing to turn it into contiguous. + let round_lhs = pad_round(lhs, B_M, B_K); + let lhs = match round_lhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_lhs.into_tensor(), + }; + let round_rhs = pad_round(rhs, B_K, B_N); + let rhs = match round_rhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_rhs.into_tensor(), + }; - let rounded_output_shape = shape_out(&lhs, &rhs); + let rounded_output_shape = shape_out(&lhs, &rhs); - let rounded_output = empty_device( - rhs.client.clone(), - rhs.device.clone(), - rounded_output_shape.clone(), - ); + let rounded_output = empty_device( + rhs.client.clone(), + rhs.device.clone(), + rounded_output_shape.clone(), + ); - let workgroup = make_workgroup(&rounded_output_shape); - let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); + let workgroup = make_workgroup(&rounded_output_shape); + let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); - lhs.client.execute( - Box::new(DynamicKernel::new(kernel, workgroup)), - &[ - &lhs.handle, - &rhs.handle, - &rounded_output.handle, - &info_handle, - ], - ); + lhs.client.execute( + Box::new(DynamicKernel::new(kernel, workgroup)), + &[ + &lhs.handle, + &rhs.handle, + &rounded_output.handle, + &info_handle, + ], + ); - crop(rounded_output, output) + crop(rounded_output, output) } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs index 290d340783..b159ea1f59 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs @@ -3,25 +3,25 @@ use std::ops::Range; use burn_tensor::{Element, Shape}; use crate::{ - element::WgpuElement, - kernel::{slice_assign, slice_on_output}, - ops::numeric::zeros_device, - tensor::WgpuTensor, + element::WgpuElement, + kernel::{slice_assign, slice_on_output}, + ops::numeric::zeros_device, + tensor::WgpuTensor, }; // Output of the pad_round function. Allows to know explicitly if early return occurred pub(super) enum PaddingOutput { - Padded(WgpuTensor), - Unchanged(WgpuTensor), + Padded(WgpuTensor), + Unchanged(WgpuTensor), } impl PaddingOutput { - pub fn into_tensor(self) -> WgpuTensor { - match self { - PaddingOutput::Padded(tensor) => tensor, - PaddingOutput::Unchanged(tensor) => tensor, - } + pub fn into_tensor(self) -> WgpuTensor { + match self { + PaddingOutput::Padded(tensor) => tensor, + PaddingOutput::Unchanged(tensor) => tensor, } + } } /// Pads tensor with zeros to make tensor number of rows and columns @@ -29,280 +29,279 @@ impl PaddingOutput { /// For instance tensor of shape [1000, 1000] with divisors 64 and 64 /// will be padded to [1024, 1024] with the last 24 elements being zeros pub(super) fn pad_round( - tensor: WgpuTensor, - row_divisor: usize, - col_divisor: usize, + tensor: WgpuTensor, + row_divisor: usize, + col_divisor: usize, ) -> PaddingOutput { - let previous_row_dim = tensor.shape.dims[D - 2]; - let previous_col_dim = tensor.shape.dims[D - 1]; - let row_modulo = previous_row_dim % row_divisor; - let col_modulo = previous_col_dim % col_divisor; - - let new_row_dim = match row_modulo { - 0 => previous_row_dim, - _ => previous_row_dim + row_divisor - row_modulo, - }; - let new_col_dim = match col_modulo { - 0 => previous_col_dim, - _ => previous_col_dim + col_divisor - col_modulo, - }; - if previous_row_dim == new_row_dim && previous_col_dim == new_col_dim { - return PaddingOutput::Unchanged(tensor); - } - - let mut padded_shape = Vec::with_capacity(D); - for i in 0..D - 2 { - padded_shape.push(tensor.shape.dims[i]); - } - padded_shape.push(new_row_dim); - padded_shape.push(new_col_dim); - - PaddingOutput::Padded(padding::(tensor, padded_shape.into())) + let previous_row_dim = tensor.shape.dims[D - 2]; + let previous_col_dim = tensor.shape.dims[D - 1]; + let row_modulo = previous_row_dim % row_divisor; + let col_modulo = previous_col_dim % col_divisor; + + let new_row_dim = match row_modulo { + 0 => previous_row_dim, + _ => previous_row_dim + row_divisor - row_modulo, + }; + let new_col_dim = match col_modulo { + 0 => previous_col_dim, + _ => previous_col_dim + col_divisor - col_modulo, + }; + if previous_row_dim == new_row_dim && previous_col_dim == new_col_dim { + return PaddingOutput::Unchanged(tensor); + } + + let mut padded_shape = Vec::with_capacity(D); + for i in 0..D - 2 { + padded_shape.push(tensor.shape.dims[i]); + } + padded_shape.push(new_row_dim); + padded_shape.push(new_col_dim); + + PaddingOutput::Padded(padding::(tensor, padded_shape.into())) } /// Pads tensor by adding zeros when padded dim is larger than tensor dim fn padding( - tensor: WgpuTensor, - padded_shape: Shape, + tensor: WgpuTensor, + padded_shape: Shape, ) -> WgpuTensor { - let ranges = padded_shape - .dims - .iter() - .map(|dim| 0..*dim) - .collect::>>() - .try_into() - .unwrap(); - - slice_assign::( - zeros_device(tensor.client.clone(), tensor.device.clone(), padded_shape), - ranges, - tensor, - ) + let ranges = padded_shape + .dims + .iter() + .map(|dim| 0..*dim) + .collect::>>() + .try_into() + .unwrap(); + + slice_assign::( + zeros_device(tensor.client.clone(), tensor.device.clone(), padded_shape), + ranges, + tensor, + ) } /// Crops tensor by deleting values when cropped dim is smaller than tensor dim pub(super) fn crop( - tensor: WgpuTensor, - output: WgpuTensor, + tensor: WgpuTensor, + output: WgpuTensor, ) -> WgpuTensor { - let ranges = output - .shape - .dims - .iter() - .map(|dim| 0..*dim) - .collect::>>() - .try_into() - .unwrap(); - slice_on_output::(tensor, output, ranges) + let ranges = output + .shape + .dims + .iter() + .map(|dim| 0..*dim) + .collect::>>() + .try_into() + .unwrap(); + slice_on_output::(tensor, output, ranges) } #[cfg(test)] mod tests { - use super::*; - use crate::tests::TestTensor; - - #[test] - fn padding_already_round_should_have_same_shape() { - let row = 10; - let row_divisor = 5; - let col = 12; - let col_divisor = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [row, col].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); + use super::*; + use crate::tests::TestTensor; + + #[test] + fn padding_already_round_should_have_same_shape() { + let row = 10; + let row_divisor = 5; + let col = 12; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row, col].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_already_round_should_have_same_values() { + let row = 10; + let row_divisor = 5; + let col = 12; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor); + + let padded = TestTensor::from_primitive(padded.into_tensor()); + padded.into_data().assert_approx_eq(&tensor.into_data(), 3); + } + + #[test] + fn padding_not_round_should_have_rounded_shape() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [12, 15].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_not_round_should_have_same_values() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor).into_tensor(); + + let padded = TestTensor::from_primitive(padded).to_data(); + let tensor = tensor.into_data(); + for i in 0..row { + for j in 0..col { + assert!(padded.value[i * 15 + j] == tensor.value[i * col + j]); + } } - - #[test] - fn padding_already_round_should_have_same_values() { - let row = 10; - let row_divisor = 5; - let col = 12; - let col_divisor = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor); - - let padded = TestTensor::from_primitive(padded.into_tensor()); - padded.into_data().assert_approx_eq(&tensor.into_data(), 3); + } + + #[test] + fn padding_not_round_should_have_zero_padding() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + let padded = TestTensor::from_primitive(padded).to_data(); + + // check right of matrix + for i in 0..row { + for j in col..15 { + assert!(padded.value[i * 15 + j] == 0.0); + } } - - #[test] - fn padding_not_round_should_have_rounded_shape() { - let row = 10; - let row_divisor = 6; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [12, 15].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); + // check below matrix, including bottom right + for i in row..12 { + for j in 0..15 { + assert!(padded.value[i * 15 + j] == 0.0); + } } - - #[test] - fn padding_not_round_should_have_same_values() { - let row = 10; - let row_divisor = 6; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let padded = - pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor).into_tensor(); - - let padded = TestTensor::from_primitive(padded).to_data(); - let tensor = tensor.into_data(); - for i in 0..row { - for j in 0..col { - assert!(padded.value[i * 15 + j] == tensor.value[i * col + j]); - } - } + } + + #[test] + fn padding_works_with_batch() { + let row = 10; + let row_divisor = 4; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([2, 3, row, col], burn_tensor::Distribution::Default); + let expected_shape = [2, 3, 12, 15].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_with_row_divisor_larger_than_row() { + let row = 10; + let row_divisor = 32; + let col = 4; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row_divisor, 2 * col_divisor].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_with_row_divisor_equal_to_row_but_col_must_be_padded() { + let row = 32; + let row_divisor = 32; + let col = 4; + let col_divisor = 64; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [32, 64].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn crop_same_shape_should_be_unchanged_shape() { + let row = 10; + let col = 12; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row, col].into(); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([row, col]).into_primitive(), + ); + + assert!(unpadded.shape == expected_shape); + } + + #[test] + fn crop_same_shape_should_have_unchanged_values() { + let row = 10; + let col = 12; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([row, col]).into_primitive(), + ); + + let unpadded = TestTensor::from_primitive(unpadded).to_data(); + let tensor = tensor.into_data(); + for i in 0..row { + for j in 0..col { + assert!(unpadded.value[i * col + j] == tensor.value[i * col + j]); + } } - - #[test] - fn padding_not_round_should_have_zero_padding() { - let row = 10; - let row_divisor = 6; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - let padded = TestTensor::from_primitive(padded).to_data(); - - // check right of matrix - for i in 0..row { - for j in col..15 { - assert!(padded.value[i * 15 + j] == 0.0); - } - } - // check below matrix, including bottom right - for i in row..12 { - for j in 0..15 { - assert!(padded.value[i * 15 + j] == 0.0); - } - } - } - - #[test] - fn padding_works_with_batch() { - let row = 10; - let row_divisor = 4; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([2, 3, row, col], burn_tensor::Distribution::Default); - let expected_shape = [2, 3, 12, 15].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn padding_with_row_divisor_larger_than_row() { - let row = 10; - let row_divisor = 32; - let col = 4; - let col_divisor = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [row_divisor, 2 * col_divisor].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn padding_with_row_divisor_equal_to_row_but_col_must_be_padded() { - let row = 32; - let row_divisor = 32; - let col = 4; - let col_divisor = 64; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [32, 64].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn crop_same_shape_should_be_unchanged_shape() { - let row = 10; - let col = 12; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [row, col].into(); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([row, col]).into_primitive(), - ); - - assert!(unpadded.shape == expected_shape); - } - - #[test] - fn crop_same_shape_should_have_unchanged_values() { - let row = 10; - let col = 12; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([row, col]).into_primitive(), - ); - - let unpadded = TestTensor::from_primitive(unpadded).to_data(); - let tensor = tensor.into_data(); - for i in 0..row { - for j in 0..col { - assert!(unpadded.value[i * col + j] == tensor.value[i * col + j]); - } - } - } - - #[test] - fn crop_should_decrease_shape() { - let row = 10; - let keep_rows = 8; - let col = 12; - let keep_cols = 10; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [keep_rows, keep_cols].into(); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([keep_rows, keep_cols]).into_primitive(), - ); - - assert!(unpadded.shape == expected_shape); - } - - #[test] - fn crop_should_keep_same_values() { - let row = 4; - let keep_rows = 3; - let col = 4; - let keep_cols = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([keep_rows, keep_cols]).into_primitive(), - ); - - let unpadded = TestTensor::from_primitive(unpadded).to_data(); - let tensor = tensor.into_data(); - - for i in 0..keep_rows { - for j in 0..keep_cols { - assert!(unpadded.value[i * keep_cols + j] == tensor.value[i * col + j]); - } - } + } + + #[test] + fn crop_should_decrease_shape() { + let row = 10; + let keep_rows = 8; + let col = 12; + let keep_cols = 10; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [keep_rows, keep_cols].into(); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([keep_rows, keep_cols]).into_primitive(), + ); + + assert!(unpadded.shape == expected_shape); + } + + #[test] + fn crop_should_keep_same_values() { + let row = 4; + let keep_rows = 3; + let col = 4; + let keep_cols = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([keep_rows, keep_cols]).into_primitive(), + ); + + let unpadded = TestTensor::from_primitive(unpadded).to_data(); + let tensor = tensor.into_data(); + + for i in 0..keep_rows { + for j in 0..keep_cols { + assert!(unpadded.value[i * keep_cols + j] == tensor.value[i * col + j]); + } } + } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs b/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs index b444bd20bc..ed8e68ec52 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs @@ -1,10 +1,10 @@ use burn_tensor::Element; use crate::{ - compute::DynamicKernel, - element::WgpuElement, - kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::WgpuTensor, + compute::DynamicKernel, + element::WgpuElement, + kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -13,183 +13,183 @@ use crate::kernel_wgsl; use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE}; kernel_wgsl!( - MatmulTiling2DUnpaddedRaw, - "../../../template/matmul/blocktiling_2d/unpadded.wgsl" + MatmulTiling2DUnpaddedRaw, + "../../../template/matmul/blocktiling_2d/unpadded.wgsl" ); #[derive(new, Debug)] struct MatmulTiling2DUnpadded { - _elem: PhantomData, + _elem: PhantomData, } impl DynamicKernelSource for MatmulTiling2DUnpadded { - fn source(&self) -> SourceTemplate { - MatmulTiling2DUnpaddedRaw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulTiling2DUnpaddedRaw::source() + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) + .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) + .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_z", "1".to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using tiling 2d algorithm with /// vec4 primitive on both lhs and rhs, with no padding needed pub fn matmul_tiling_2d_unpadded( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - let lhs = match lhs.batch_swapped_with_row_col() { - true => into_contiguous(lhs), - false => lhs, - }; - let rhs = match rhs.batch_swapped_with_row_col() { - true => into_contiguous(rhs), - false => rhs, - }; - - let workgroup = make_workgroup(&out.shape); - let info_handle = make_info_handle(&lhs, &rhs, &out); - - lhs.client.execute( - Box::new(DynamicKernel::new( - MatmulTiling2DUnpadded::::new(), - workgroup, - )), - &[&lhs.handle, &rhs.handle, &out.handle, &info_handle], - ); - - out + let lhs = match lhs.batch_swapped_with_row_col() { + true => into_contiguous(lhs), + false => lhs, + }; + let rhs = match rhs.batch_swapped_with_row_col() { + true => into_contiguous(rhs), + false => rhs, + }; + + let workgroup = make_workgroup(&out.shape); + let info_handle = make_info_handle(&lhs, &rhs, &out); + + lhs.client.execute( + Box::new(DynamicKernel::new( + MatmulTiling2DUnpadded::::new(), + workgroup, + )), + &[&lhs.handle, &rhs.handle, &out.handle, &info_handle], + ); + + out } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_unpadded_straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_shapes_equal_blocks() { - test_with_params(64, 32, 64, 2, 2); - } - - #[test] - pub fn test_matmul_unpadded_m_exceeds_block() { - test_with_params(75, 32, 64, 2, 2); - } - - #[test] - pub fn test_matmul_unpadded_k_exceeds_block() { - test_with_params(64, 33, 32, 1, 1); - } - - #[test] - pub fn test_matmul_irregular_shape() { - test_with_params(123, 255, 72, 3, 5); - } - - #[test] - pub fn test64_matmul_unpadded_n_exceeds_block() { - test_with_params(64, 32, 75, 2, 2); - } - - #[test] - pub fn test_matmul_unpadded_n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_unpadded_multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_unpadded_blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_large() { - test_with_params(134, 242, 250, 1, 1); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let func = matmul_tiling_2d_unpadded; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_primitive_swapped_batches_no_padding() { - let matmul_func = matmul_tiling_2d_unpadded; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_primitive_swapped_row_col_no_padding() { - let matmul_func = matmul_tiling_2d_unpadded; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_primitive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_tiling_2d_unpadded; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_unpadded_straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_shapes_equal_blocks() { + test_with_params(64, 32, 64, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_m_exceeds_block() { + test_with_params(75, 32, 64, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_k_exceeds_block() { + test_with_params(64, 33, 32, 1, 1); + } + + #[test] + pub fn test_matmul_irregular_shape() { + test_with_params(123, 255, 72, 3, 5); + } + + #[test] + pub fn test64_matmul_unpadded_n_exceeds_block() { + test_with_params(64, 32, 75, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_unpadded_multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_unpadded_blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_large() { + test_with_params(134, 242, 250, 1, 1); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let func = matmul_tiling_2d_unpadded; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_batches_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_row_col_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs b/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs index 1130ccd742..587a15bb2a 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs @@ -1,9 +1,9 @@ use burn_tensor::Element; use crate::{ - element::WgpuElement, - kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::WgpuTensor, + element::WgpuElement, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -12,139 +12,139 @@ use crate::kernel_wgsl; use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE}; kernel_wgsl!( - MatmulTiling2Dvec4Raw, - "../../../template/matmul/blocktiling_2d/vec4.wgsl" + MatmulTiling2Dvec4Raw, + "../../../template/matmul/blocktiling_2d/vec4.wgsl" ); #[derive(new, Debug)] struct MatmulTiling2Dvec4 { - _elem: PhantomData, + _elem: PhantomData, } impl DynamicKernelSource for MatmulTiling2Dvec4 { - fn source(&self) -> SourceTemplate { - MatmulTiling2Dvec4Raw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulTiling2Dvec4Raw::source() + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) + .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) + .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_z", "1".to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using tiling 2d algorithm with /// vec4 primitive on both lhs and rhs pub fn matmul_tiling_2d_vec4( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - let kernel = MatmulTiling2Dvec4::::new(); - matmul_tiling_2d_launch(lhs, rhs, out, kernel) + let kernel = MatmulTiling2Dvec4::::new(); + matmul_tiling_2d_launch(lhs, rhs, out, kernel) } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_vec4_primitive_straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_large() { - test_with_params(134, 242, 250, 1, 1); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let func = matmul_tiling_2d_vec4; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { - let matmul_func = matmul_tiling_2d_vec4; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { - let matmul_func = matmul_tiling_2d_vec4; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_tiling_2d_vec4; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_vec4_primitive_straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_large() { + test_with_params(134, 242, 250, 1, 1); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let func = matmul_tiling_2d_vec4; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { + let matmul_func = matmul_tiling_2d_vec4; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { + let matmul_func = matmul_tiling_2d_vec4; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_tiling_2d_vec4; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs b/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs index e00db5ec89..3dd1a77861 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs @@ -1,9 +1,9 @@ use burn_tensor::Element; use crate::{ - element::WgpuElement, - kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::WgpuTensor, + element::WgpuElement, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -12,140 +12,140 @@ use crate::kernel_wgsl; use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE}; kernel_wgsl!( - MatmulTiling2DVec4LhsRaw, - "../../../template/matmul/blocktiling_2d/vec4_lhs.wgsl" + MatmulTiling2DVec4LhsRaw, + "../../../template/matmul/blocktiling_2d/vec4_lhs.wgsl" ); #[derive(new, Debug)] struct MatmulTiling2DVec4Lhs { - _elem: PhantomData, + _elem: PhantomData, } impl DynamicKernelSource for MatmulTiling2DVec4Lhs { - fn source(&self) -> SourceTemplate { - MatmulTiling2DVec4LhsRaw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn", (B_K * B_N).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulTiling2DVec4LhsRaw::source() + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) + .register("bk_x_bn", (B_K * B_N).to_string()) + .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_z", "1".to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using tiling 2d algorithm with /// vec4 primitive on lhs only pub fn matmul_tiling_2d_vec4_lhs( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - let kernel = MatmulTiling2DVec4Lhs::::new(); - matmul_tiling_2d_launch(lhs, rhs, out, kernel) + let kernel = MatmulTiling2DVec4Lhs::::new(); + matmul_tiling_2d_launch(lhs, rhs, out, kernel) } #[cfg(test)] mod tests { - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - use super::matmul_tiling_2d_vec4_lhs; - - #[test] - pub fn test_matmul_vec4_primitive_straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_large() { - test_with_params(134, 242, 250, 1, 1); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let func = matmul_tiling_2d_vec4_lhs; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { - let matmul_func = matmul_tiling_2d_vec4_lhs; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { - let matmul_func = matmul_tiling_2d_vec4_lhs; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_tiling_2d_vec4_lhs; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + use super::matmul_tiling_2d_vec4_lhs; + + #[test] + pub fn test_matmul_vec4_primitive_straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_large() { + test_with_params(134, 242, 250, 1, 1); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let func = matmul_tiling_2d_vec4_lhs; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { + let matmul_func = matmul_tiling_2d_vec4_lhs; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { + let matmul_func = matmul_tiling_2d_vec4_lhs; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_tiling_2d_vec4_lhs; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tune/base.rs b/burn-wgpu/src/kernel/matmul/tune/base.rs index 2699e233c4..5ae2c261ec 100644 --- a/burn-wgpu/src/kernel/matmul/tune/base.rs +++ b/burn-wgpu/src/kernel/matmul/tune/base.rs @@ -2,11 +2,11 @@ use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; use crate::{ - compute::WgpuAutotuneKey, - element::WgpuElement, - kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use super::key::MatmulAutotuneKey; @@ -14,162 +14,162 @@ use super::key::MatmulAutotuneKey; /// Set of matmul implementations available for autotune /// Autotune key is given by concatenating the closest upper power of 2 of m, k and n pub struct MatmulAutotuneOperationSet { - key: WgpuAutotuneKey, - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + key: WgpuAutotuneKey, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, } impl MatmulAutotuneOperationSet { - fn new(lhs: WgpuTensor, rhs: WgpuTensor, out: WgpuTensor) -> Self { - Self { - key: WgpuAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)), - lhs, - rhs, - out, - } + fn new(lhs: WgpuTensor, rhs: WgpuTensor, out: WgpuTensor) -> Self { + Self { + key: WgpuAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)), + lhs, + rhs, + out, } + } } impl AutotuneOperationSet - for MatmulAutotuneOperationSet + for MatmulAutotuneOperationSet { - fn key(&self) -> WgpuAutotuneKey { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); - let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1); - let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1); - - let out = empty_device( - self.out.client.clone(), - self.out.device.clone(), - self.out.shape.clone(), - ); - - vec![ - Box::new(MemoryCoalescingMatmulDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(MemoryCoalescingMatmulW16x16::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(Vec4TilingMatmulDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(Vec4TilingMatmulUnpaddedDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - match fastest_index { - 0 => Box::new(MemoryCoalescingMatmulDefault::::new( - self.lhs, self.rhs, self.out, - )), - 1 => Box::new(MemoryCoalescingMatmulW16x16::::new( - self.lhs, self.rhs, self.out, - )), - 2 => Box::new(Vec4TilingMatmulDefault::::new( - self.lhs, self.rhs, self.out, - )), - 3 => Box::new(Vec4TilingMatmulUnpaddedDefault::::new( - self.lhs, self.rhs, self.out, - )), - 4 => Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( - self.lhs, self.rhs, self.out, - )), - _ => panic!("Fastest index is out of bound"), - } + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1); + let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1); + + let out = empty_device( + self.out.client.clone(), + self.out.device.clone(), + self.out.shape.clone(), + ); + + vec![ + Box::new(MemoryCoalescingMatmulDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(MemoryCoalescingMatmulW16x16::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(Vec4TilingMatmulDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(Vec4TilingMatmulUnpaddedDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + match fastest_index { + 0 => Box::new(MemoryCoalescingMatmulDefault::::new( + self.lhs, self.rhs, self.out, + )), + 1 => Box::new(MemoryCoalescingMatmulW16x16::::new( + self.lhs, self.rhs, self.out, + )), + 2 => Box::new(Vec4TilingMatmulDefault::::new( + self.lhs, self.rhs, self.out, + )), + 3 => Box::new(Vec4TilingMatmulUnpaddedDefault::::new( + self.lhs, self.rhs, self.out, + )), + 4 => Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( + self.lhs, self.rhs, self.out, + )), + _ => panic!("Fastest index is out of bound"), } + } } /// Executes autotune on matmul operations pub fn matmul_autotune( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let client = lhs.client.clone(); + let client = lhs.client.clone(); - let output = init_matmul_output(&lhs, &rhs); + let output = init_matmul_output(&lhs, &rhs); - let operation_set = Box::new(MatmulAutotuneOperationSet::::new( - lhs, - rhs, - output.clone(), - )); + let operation_set = Box::new(MatmulAutotuneOperationSet::::new( + lhs, + rhs, + output.clone(), + )); - client.execute_autotune(operation_set); + client.execute_autotune(operation_set); - output + output } macro_rules! matmul_tune_ops { - ($name:ident, $func:expr) => { - #[derive(new)] - pub(crate) struct $name { - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, - } - - impl AutotuneOperation for $name { - fn execute(self: Box) { - #[allow(clippy::redundant_closure_call)] - $func(self.lhs, self.rhs, self.out); - } - - fn clone(&self) -> Box { - Box::new(Self { - lhs: self.lhs.clone(), - rhs: self.rhs.clone(), - out: self.out.clone(), - }) - } - } - }; + ($name:ident, $func:expr) => { + #[derive(new)] + pub(crate) struct $name { + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, + } + + impl AutotuneOperation for $name { + fn execute(self: Box) { + #[allow(clippy::redundant_closure_call)] + $func(self.lhs, self.rhs, self.out); + } + + fn clone(&self) -> Box { + Box::new(Self { + lhs: self.lhs.clone(), + rhs: self.rhs.clone(), + out: self.out.clone(), + }) + } + } + }; } // Potentially better for small matrices. matmul_tune_ops!( - MemoryCoalescingMatmulDefault, - crate::kernel::matmul::matmul_mem_coalescing_default + MemoryCoalescingMatmulDefault, + crate::kernel::matmul::matmul_mem_coalescing_default ); // Potentially better for small matrices. matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| { - crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16) + crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16) }); // Maybe the fastest on MacOS. matmul_tune_ops!( - Vec4LhsOnlyTilingMatmulDefault, - crate::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs + Vec4LhsOnlyTilingMatmulDefault, + crate::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs ); // Probably the fastest when fixed sizes. matmul_tune_ops!( - Vec4TilingMatmulDefault, - crate::kernel::matmul::vec4::matmul_tiling_2d_vec4 + Vec4TilingMatmulDefault, + crate::kernel::matmul::vec4::matmul_tiling_2d_vec4 ); // Probably the fastest otherwise. matmul_tune_ops!( - Vec4TilingMatmulUnpaddedDefault, - crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded + Vec4TilingMatmulUnpaddedDefault, + crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded ); diff --git a/burn-wgpu/src/kernel/matmul/tune/key.rs b/burn-wgpu/src/kernel/matmul/tune/key.rs index 37f619dde1..48d7655f50 100644 --- a/burn-wgpu/src/kernel/matmul/tune/key.rs +++ b/burn-wgpu/src/kernel/matmul/tune/key.rs @@ -1,119 +1,119 @@ use burn_tensor::Shape; use core::fmt::Debug; use std::{ - cmp::{max, min}, - fmt::Display, - hash::Hash, + cmp::{max, min}, + fmt::Display, + hash::Hash, }; #[derive(Hash, Eq, PartialEq, Debug, Clone)] /// Autotune key representative of matmul versions pub struct MatmulAutotuneKey { - round: bool, // True when all matmul dims are multiples of 64 - broadcast: bool, // True when there are differences in batch size - anchored_m: usize, - anchored_k: usize, - anchored_n: usize, - anchored_batch: usize, + round: bool, // True when all matmul dims are multiples of 64 + broadcast: bool, // True when there are differences in batch size + anchored_m: usize, + anchored_k: usize, + anchored_n: usize, + anchored_batch: usize, } impl Display for MatmulAutotuneKey { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - format!( - "Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}", - self.round, - self.broadcast, - self.anchored_m, - self.anchored_k, - self.anchored_n, - self.anchored_batch - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + format!( + "Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}", + self.round, + self.broadcast, + self.anchored_m, + self.anchored_k, + self.anchored_n, + self.anchored_batch + ) + .as_str(), + ) + } } impl MatmulAutotuneKey { - /// Create a matmul autotune key from the input shapes - pub fn new(lhs_shape: &Shape, rhs_shape: &Shape) -> Self { - let m = lhs_shape.dims[D - 2]; - let k = lhs_shape.dims[D - 1]; - let n = rhs_shape.dims[D - 1]; + /// Create a matmul autotune key from the input shapes + pub fn new(lhs_shape: &Shape, rhs_shape: &Shape) -> Self { + let m = lhs_shape.dims[D - 2]; + let k = lhs_shape.dims[D - 1]; + let n = rhs_shape.dims[D - 1]; - let mut broadcast = false; - let mut batch_product_lhs = 1; - let mut batch_product_rhs = 1; + let mut broadcast = false; + let mut batch_product_lhs = 1; + let mut batch_product_rhs = 1; - for b in 0..D - 2 { - batch_product_lhs *= lhs_shape.dims[b]; - batch_product_rhs *= rhs_shape.dims[b]; - if lhs_shape.dims[b] != rhs_shape.dims[b] { - broadcast = true; - } - } - let batch_product = max(batch_product_lhs, batch_product_rhs); + for b in 0..D - 2 { + batch_product_lhs *= lhs_shape.dims[b]; + batch_product_rhs *= rhs_shape.dims[b]; + if lhs_shape.dims[b] != rhs_shape.dims[b] { + broadcast = true; + } + } + let batch_product = max(batch_product_lhs, batch_product_rhs); - let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0; + let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0; - Self { - round, - broadcast, - anchored_m: anchor(m, None), - anchored_k: anchor(k, None), - anchored_n: anchor(n, None), - anchored_batch: anchor(batch_product, Some(256)), - } + Self { + round, + broadcast, + anchored_m: anchor(m, None), + anchored_k: anchor(k, None), + anchored_n: anchor(n, None), + anchored_batch: anchor(batch_product, Some(256)), } + } } fn anchor(x: usize, max: Option) -> usize { - let exp = f32::ceil(f32::log2(x as f32)) as u32; - let power_of_2 = 2_u32.pow(exp) as usize; - if let Some(max) = max { - min(power_of_2, max) - } else { - power_of_2 - } + let exp = f32::ceil(f32::log2(x as f32)) as u32; + let power_of_2 = 2_u32.pow(exp) as usize; + if let Some(max) = max { + min(power_of_2, max) + } else { + power_of_2 + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn matmul_autotune_key_all_same_and_round() { - let lhs_shape: Shape<3> = [4, 512, 512].into(); - let rhs_shape: Shape<3> = [4, 512, 512].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + #[test] + fn matmul_autotune_key_all_same_and_round() { + let lhs_shape: Shape<3> = [4, 512, 512].into(); + let rhs_shape: Shape<3> = [4, 512, 512].into(); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); - assert!(key.round); - assert!(!key.broadcast); - assert!(key.anchored_m == 512); - assert!(key.anchored_k == 512); - assert!(key.anchored_n == 512); - } + assert!(key.round); + assert!(!key.broadcast); + assert!(key.anchored_m == 512); + assert!(key.anchored_k == 512); + assert!(key.anchored_n == 512); + } - #[test] - fn matmul_autotune_key_all_different() { - let lhs_shape: Shape<4> = [2, 3, 511, 512].into(); - let rhs_shape: Shape<4> = [3, 2, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + #[test] + fn matmul_autotune_key_all_different() { + let lhs_shape: Shape<4> = [2, 3, 511, 512].into(); + let rhs_shape: Shape<4> = [3, 2, 512, 513].into(); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); - assert!(!key.round); - assert!(key.broadcast); - assert!(key.anchored_m == 512); - assert!(key.anchored_k == 512); - assert!(key.anchored_n == 1024); - assert!(key.anchored_batch == 8); - } + assert!(!key.round); + assert!(key.broadcast); + assert!(key.anchored_m == 512); + assert!(key.anchored_k == 512); + assert!(key.anchored_n == 1024); + assert!(key.anchored_batch == 8); + } - #[test] - fn matmul_autotune_key_large_batch() { - let lhs_shape: Shape<4> = [128, 512, 511, 512].into(); - let rhs_shape: Shape<4> = [200, 400, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + #[test] + fn matmul_autotune_key_large_batch() { + let lhs_shape: Shape<4> = [128, 512, 511, 512].into(); + let rhs_shape: Shape<4> = [200, 400, 512, 513].into(); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); - assert!(key.anchored_batch == 256); - } + assert!(key.anchored_batch == 256); + } } diff --git a/burn-wgpu/src/kernel/matmul/utils.rs b/burn-wgpu/src/kernel/matmul/utils.rs index 8a791ba359..d9b4fee4da 100644 --- a/burn-wgpu/src/kernel/matmul/utils.rs +++ b/burn-wgpu/src/kernel/matmul/utils.rs @@ -3,85 +3,86 @@ use burn_tensor::Shape; /// Creates an empty output tensor with matmul output shape pub fn init_matmul_output( - lhs: &WgpuTensor, - rhs: &WgpuTensor, + lhs: &WgpuTensor, + rhs: &WgpuTensor, ) -> WgpuTensor { - empty_device(lhs.client.clone(), lhs.device.clone(), shape_out(lhs, rhs)) + empty_device(lhs.client.clone(), lhs.device.clone(), shape_out(lhs, rhs)) } pub(crate) fn shape_out( - lhs: &WgpuTensor, - rhs: &WgpuTensor, + lhs: &WgpuTensor, + rhs: &WgpuTensor, ) -> Shape { - let mut shape_out = [0; D]; - lhs.shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - shape_out[D - 2] = lhs.shape.dims[D - 2]; - shape_out[D - 1] = rhs.shape.dims[D - 1]; - Shape::new(shape_out) + let mut shape_out = [0; D]; + lhs + .shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + shape_out[D - 2] = lhs.shape.dims[D - 2]; + shape_out[D - 1] = rhs.shape.dims[D - 1]; + Shape::new(shape_out) } #[cfg(test)] pub(crate) mod tests { - use crate::tensor::WgpuTensor; - use crate::tests::{ReferenceTensor, TestTensor}; - use burn_tensor::Shape; + use crate::tensor::WgpuTensor; + use crate::tests::{ReferenceTensor, TestTensor}; + use burn_tensor::Shape; - use super::init_matmul_output; + use super::init_matmul_output; - pub(crate) fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) - where - F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, - S: Into>, - { - let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + pub(crate) fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) + where + F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, + S: Into>, + { + let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let x_wgpu = TestTensor::from_data(x.to_data()).into_primitive(); - let y_wgpu = TestTensor::from_data(y.to_data()).into_primitive(); + let x_wgpu = TestTensor::from_data(x.to_data()).into_primitive(); + let y_wgpu = TestTensor::from_data(y.to_data()).into_primitive(); - let z_reference = x.matmul(y); + let z_reference = x.matmul(y); - let out = init_matmul_output(&x_wgpu, &y_wgpu); - let z = func(x_wgpu, y_wgpu, out); - let z = TestTensor::from_primitive(z); + let out = init_matmul_output(&x_wgpu, &y_wgpu); + let z = func(x_wgpu, y_wgpu, out); + let z = TestTensor::from_primitive(z); - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } - pub(crate) fn same_as_reference_swapped_dims( - func: F, - swap_lhs: [usize; 2], - swap_rhs: [usize; 2], - shape_lhs: S, - shape_rhs: S, - ) where - F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, - S: Into>, - { - let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + pub(crate) fn same_as_reference_swapped_dims( + func: F, + swap_lhs: [usize; 2], + swap_rhs: [usize; 2], + shape_lhs: S, + shape_rhs: S, + ) where + F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, + S: Into>, + { + let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let x_wgpu = TestTensor::from_data(x.to_data()).swap_dims(swap_lhs[0], swap_lhs[1]); - let y_wgpu = TestTensor::from_data(y.to_data()).swap_dims(swap_rhs[0], swap_rhs[1]); + let x_wgpu = TestTensor::from_data(x.to_data()).swap_dims(swap_lhs[0], swap_lhs[1]); + let y_wgpu = TestTensor::from_data(y.to_data()).swap_dims(swap_rhs[0], swap_rhs[1]); - let z_reference = x - .swap_dims(swap_lhs[0], swap_lhs[1]) - .matmul(y.swap_dims(swap_rhs[0], swap_rhs[1])); + let z_reference = x + .swap_dims(swap_lhs[0], swap_lhs[1]) + .matmul(y.swap_dims(swap_rhs[0], swap_rhs[1])); - let out = init_matmul_output( - &x_wgpu.clone().into_primitive(), - &y_wgpu.clone().into_primitive(), - ); - let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive(), out); - let z = TestTensor::from_primitive(z); + let out = init_matmul_output( + &x_wgpu.clone().into_primitive(), + &y_wgpu.clone().into_primitive(), + ); + let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive(), out); + let z = TestTensor::from_primitive(z); - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs b/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs index 3da4c74776..20162ffab8 100644 --- a/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs @@ -1,95 +1,95 @@ use crate::{ - compute::{StaticKernel, WgpuHandle}, - element::WgpuElement, - kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::{StaticKernel, WgpuHandle}, + element::WgpuElement, + kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::Shape; kernel_wgsl!( - AdaptiveAvgPool2d, - "../../template/pool/adaptive_avg_pool2d.wgsl" + AdaptiveAvgPool2d, + "../../template/pool/adaptive_avg_pool2d.wgsl" ); kernel_wgsl!( - AdaptiveAvgPool2dBackward, - "../../template/pool/adaptive_avg_pool2d_backward.wgsl" + AdaptiveAvgPool2dBackward, + "../../template/pool/adaptive_avg_pool2d_backward.wgsl" ); pub(crate) fn adaptive_avg_pool2d( - x: WgpuTensor, - output_size: [usize; 2], + x: WgpuTensor, + output_size: [usize; 2], ) -> WgpuTensor { - let [batch_size, channels, _, _] = x.shape.dims; + let [batch_size, channels, _, _] = x.shape.dims; - let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); - let output = empty_device(x.client.clone(), x.device.clone(), output_shape); + let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); + let output = empty_device(x.client.clone(), x.device.clone(), output_shape); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); - let info_handle = build_info(&x, &output); - x.client - .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); + let info_handle = build_info(&x, &output); + x.client + .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); - output + output } pub(crate) fn adaptive_avg_pool2d_backward( - x: WgpuTensor, - out_grad: WgpuTensor, + x: WgpuTensor, + out_grad: WgpuTensor, ) -> WgpuTensor { - let output_shape = x.shape.clone(); - let num_elems = output_shape.num_elements(); - let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - x.client.clone(), - x.device.clone(), - output_shape, - output_buffer, - ); + let output_shape = x.shape.clone(); + let num_elems = output_shape.num_elements(); + let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + x.client.clone(), + x.device.clone(), + output_shape, + output_buffer, + ); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); - let info_handle = build_info(&x, &out_grad); + let info_handle = build_info(&x, &out_grad); - x.client.execute( - Box::new(kernel), - &[&out_grad.handle, &output.handle, &info_handle], - ); + x.client.execute( + Box::new(kernel), + &[&out_grad.handle, &output.handle, &info_handle], + ); - output + output } fn build_info(x: &WgpuTensor, output: &WgpuTensor) -> WgpuHandle { - let mut info: [u32; 16] = [0; 16]; - info[0] = x.strides[0] as u32; - info[1] = x.strides[1] as u32; - info[2] = x.strides[2] as u32; - info[3] = x.strides[3] as u32; - info[4] = x.shape.dims[0] as u32; - info[5] = x.shape.dims[1] as u32; - info[6] = x.shape.dims[2] as u32; - info[7] = x.shape.dims[3] as u32; + let mut info: [u32; 16] = [0; 16]; + info[0] = x.strides[0] as u32; + info[1] = x.strides[1] as u32; + info[2] = x.strides[2] as u32; + info[3] = x.strides[3] as u32; + info[4] = x.shape.dims[0] as u32; + info[5] = x.shape.dims[1] as u32; + info[6] = x.shape.dims[2] as u32; + info[7] = x.shape.dims[3] as u32; - info[8] = output.strides[0] as u32; - info[9] = output.strides[1] as u32; - info[10] = output.strides[2] as u32; - info[11] = output.strides[3] as u32; - info[12] = output.shape.dims[0] as u32; - info[13] = output.shape.dims[1] as u32; - info[14] = output.shape.dims[2] as u32; - info[15] = output.shape.dims[3] as u32; + info[8] = output.strides[0] as u32; + info[9] = output.strides[1] as u32; + info[10] = output.strides[2] as u32; + info[11] = output.strides[3] as u32; + info[12] = output.shape.dims[0] as u32; + info[13] = output.shape.dims[1] as u32; + info[14] = output.shape.dims[2] as u32; + info[15] = output.shape.dims[3] as u32; - output.client.create(bytemuck::cast_slice(&info)) + output.client.create(bytemuck::cast_slice(&info)) } diff --git a/burn-wgpu/src/kernel/pool/avg_pool2d.rs b/burn-wgpu/src/kernel/pool/avg_pool2d.rs index 05a5d840d3..85f2e44f38 100644 --- a/burn-wgpu/src/kernel/pool/avg_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/avg_pool2d.rs @@ -1,169 +1,154 @@ use crate::{ - compute::{Kernel, StaticKernel}, - element::WgpuElement, - kernel::{ - self, elemwise_workgroup, - pool::{build_output_and_info_pool2d, build_pool2d_info}, - KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, - }, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::{Kernel, StaticKernel}, + element::WgpuElement, + kernel::{ + self, elemwise_workgroup, + pool::{build_output_and_info_pool2d, build_pool2d_info}, + KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, + }, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(AvgPool2dRaw, "../../template/pool/avg_pool2d.wgsl"); kernel_wgsl!( - AvgPool2dBackwardRaw, - "../../template/pool/avg_pool2d_backward.wgsl" + AvgPool2dBackwardRaw, + "../../template/pool/avg_pool2d_backward.wgsl" ); struct AvgPool2dBackward; struct AvgPool2d; impl StaticKernelSource for AvgPool2dBackward { - fn source() -> kernel::SourceTemplate { - AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) - } + fn source() -> kernel::SourceTemplate { + AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) + } } impl StaticKernelSource for AvgPool2d { - fn source() -> kernel::SourceTemplate { - AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) - } + fn source() -> kernel::SourceTemplate { + AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) + } } pub(crate) fn avg_pool2d( - x: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> WgpuTensor { - let (info_handle, output) = - build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]); - - let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); - let kernel: Box = match count_include_pad { - true => Box::new(StaticKernel::< - KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(workgroup)), - false => Box::new(StaticKernel::< - KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(workgroup)), - }; - - x.client - .execute(kernel, &[&x.handle, &output.handle, &info_handle]); - - output + let (info_handle, output) = + build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]); + + let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); + let kernel: Box = match count_include_pad { + true => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, + >::new(workgroup)), + false => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, + >::new(workgroup)), + }; + + x.client + .execute(kernel, &[&x.handle, &output.handle, &info_handle]); + + output } pub(crate) fn avg_pool2d_backward( - x: WgpuTensor, - grad: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: WgpuTensor, + grad: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> WgpuTensor { - let grad = kernel::into_contiguous(grad); - let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone()); - let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]); - let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); - - let kernel: Box = match count_include_pad { - true => Box::new(StaticKernel::< - KernelSettings< - AvgPool2dBackward, - E, - i32, - WORKGROUP_DEFAULT, - WORKGROUP_DEFAULT, - 1, - >, - >::new(workgroup)), - false => Box::new(StaticKernel::< - KernelSettings< - AvgPool2dBackward, - E, - i32, - WORKGROUP_DEFAULT, - WORKGROUP_DEFAULT, - 1, - >, - >::new(workgroup)), - }; - - x.client - .execute(kernel, &[&grad.handle, &output.handle, &info_handle]); - - output + let grad = kernel::into_contiguous(grad); + let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone()); + let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]); + let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); + + let kernel: Box = match count_include_pad { + true => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, + >::new(workgroup)), + false => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, + >::new(workgroup)), + }; + + x.client + .execute(kernel, &[&grad.handle, &output.handle, &info_handle]); + + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, module, ops::ModuleOps, Distribution, Tensor}; - - #[test] - fn avg_pool2d_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 4]; - let stride = [1, 2]; - let padding = [1, 2]; - let count_include_pad = true; - - let pooled = module::avg_pool2d(tensor, kernel_size, stride, padding, count_include_pad); - let pooled_ref = - module::avg_pool2d(tensor_ref, kernel_size, stride, padding, count_include_pad); - - pooled - .into_data() - .assert_approx_eq(&pooled_ref.into_data(), 3); - } - - #[test] - fn avg_pool2d_backward_should_work_with_multiple_invocations() { - TestBackend::seed(0); - ReferenceBackend::seed(0); - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 3]; - let stride = [1, 1]; - let padding = [1, 1]; - let count_include_pad = true; - - let shape_out = module::avg_pool2d( - tensor.clone(), - kernel_size, - stride, - padding, - count_include_pad, - ) - .shape(); - let grad_output = Tensor::::random(shape_out, Distribution::Default); - let grad_output_ref = Tensor::::from_data(grad_output.to_data()); - - let grad: Tensor = - Tensor::from_primitive(TestBackend::avg_pool2d_backward( - tensor.into_primitive(), - grad_output.into_primitive(), - kernel_size, - stride, - padding, - count_include_pad, - )); - let grad_ref: Tensor = - Tensor::from_primitive(ReferenceBackend::avg_pool2d_backward( - tensor_ref.into_primitive(), - grad_output_ref.into_primitive(), - kernel_size, - stride, - padding, - count_include_pad, - )); - - grad.into_data().assert_approx_eq(&grad_ref.into_data(), 3); - } + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, module, ops::ModuleOps, Distribution, Tensor}; + + #[test] + fn avg_pool2d_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 4]; + let stride = [1, 2]; + let padding = [1, 2]; + let count_include_pad = true; + + let pooled = module::avg_pool2d(tensor, kernel_size, stride, padding, count_include_pad); + let pooled_ref = + module::avg_pool2d(tensor_ref, kernel_size, stride, padding, count_include_pad); + + pooled + .into_data() + .assert_approx_eq(&pooled_ref.into_data(), 3); + } + + #[test] + fn avg_pool2d_backward_should_work_with_multiple_invocations() { + TestBackend::seed(0); + ReferenceBackend::seed(0); + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 3]; + let stride = [1, 1]; + let padding = [1, 1]; + let count_include_pad = true; + + let shape_out = module::avg_pool2d( + tensor.clone(), + kernel_size, + stride, + padding, + count_include_pad, + ) + .shape(); + let grad_output = Tensor::::random(shape_out, Distribution::Default); + let grad_output_ref = Tensor::::from_data(grad_output.to_data()); + + let grad: Tensor = Tensor::from_primitive(TestBackend::avg_pool2d_backward( + tensor.into_primitive(), + grad_output.into_primitive(), + kernel_size, + stride, + padding, + count_include_pad, + )); + let grad_ref: Tensor = + Tensor::from_primitive(ReferenceBackend::avg_pool2d_backward( + tensor_ref.into_primitive(), + grad_output_ref.into_primitive(), + kernel_size, + stride, + padding, + count_include_pad, + )); + + grad.into_data().assert_approx_eq(&grad_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/pool/base.rs b/burn-wgpu/src/kernel/pool/base.rs index 13a16d6ec6..9e48b21326 100644 --- a/burn-wgpu/src/kernel/pool/base.rs +++ b/burn-wgpu/src/kernel/pool/base.rs @@ -1,73 +1,72 @@ use crate::{ - compute::WgpuHandle, element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor, + compute::WgpuHandle, element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor, }; use burn_tensor::Shape; /// Build basic info to launch pool 2d kernels. pub fn build_output_and_info_pool2d( - x: &WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: &WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (WgpuHandle, WgpuTensor) { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape.dims; + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape.dims; - let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) - / stride_height) - + 1; - let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) - / stride_width) - + 1; - let shape_out = Shape::new([batch_size, channels, out_height, out_width]); - let output = empty_device(x.client.clone(), x.device.clone(), shape_out); + let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) + / stride_height) + + 1; + let out_width = + ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; + let shape_out = Shape::new([batch_size, channels, out_height, out_width]); + let output = empty_device(x.client.clone(), x.device.clone(), shape_out); - let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation); + let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation); - (info_buffer, output) + (info_buffer, output) } pub fn build_pool2d_info( - input: &WgpuTensor, - output: &WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + input: &WgpuTensor, + output: &WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> WgpuHandle { - let mut info: [u32; 24] = [0; 24]; - info[0] = input.strides[0] as u32; - info[1] = input.strides[1] as u32; - info[2] = input.strides[2] as u32; - info[3] = input.strides[3] as u32; - info[4] = input.shape.dims[0] as u32; - info[5] = input.shape.dims[1] as u32; - info[6] = input.shape.dims[2] as u32; - info[7] = input.shape.dims[3] as u32; + let mut info: [u32; 24] = [0; 24]; + info[0] = input.strides[0] as u32; + info[1] = input.strides[1] as u32; + info[2] = input.strides[2] as u32; + info[3] = input.strides[3] as u32; + info[4] = input.shape.dims[0] as u32; + info[5] = input.shape.dims[1] as u32; + info[6] = input.shape.dims[2] as u32; + info[7] = input.shape.dims[3] as u32; - info[8] = output.strides[0] as u32; - info[9] = output.strides[1] as u32; - info[10] = output.strides[2] as u32; - info[11] = output.strides[3] as u32; - info[12] = output.shape.dims[0] as u32; - info[13] = output.shape.dims[1] as u32; - info[14] = output.shape.dims[2] as u32; - info[15] = output.shape.dims[3] as u32; + info[8] = output.strides[0] as u32; + info[9] = output.strides[1] as u32; + info[10] = output.strides[2] as u32; + info[11] = output.strides[3] as u32; + info[12] = output.shape.dims[0] as u32; + info[13] = output.shape.dims[1] as u32; + info[14] = output.shape.dims[2] as u32; + info[15] = output.shape.dims[3] as u32; - info[16] = kernel_size[0] as u32; - info[17] = kernel_size[1] as u32; - info[18] = stride[0] as u32; - info[19] = stride[1] as u32; - info[20] = padding[0] as u32; - info[21] = padding[1] as u32; - info[22] = dilation[0] as u32; - info[23] = dilation[1] as u32; + info[16] = kernel_size[0] as u32; + info[17] = kernel_size[1] as u32; + info[18] = stride[0] as u32; + info[19] = stride[1] as u32; + info[20] = padding[0] as u32; + info[21] = padding[1] as u32; + info[22] = dilation[0] as u32; + info[23] = dilation[1] as u32; - let info_buffer = input.client.create(bytemuck::cast_slice(&info)); + let info_buffer = input.client.create(bytemuck::cast_slice(&info)); - info_buffer + info_buffer } diff --git a/burn-wgpu/src/kernel/pool/max_pool2d.rs b/burn-wgpu/src/kernel/pool/max_pool2d.rs index 77ce5d998b..e06588755b 100644 --- a/burn-wgpu/src/kernel/pool/max_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/max_pool2d.rs @@ -1,194 +1,189 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{ - self, elemwise_workgroup, - pool::{build_output_and_info_pool2d, build_pool2d_info}, - KernelSettings, WORKGROUP_DEFAULT, - }, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{ + self, elemwise_workgroup, + pool::{build_output_and_info_pool2d, build_pool2d_info}, + KernelSettings, WORKGROUP_DEFAULT, + }, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(MaxPool2d, "../../template/pool/max_pool2d.wgsl"); kernel_wgsl!( - MaxPool2dWithIndicesBackward, - "../../template/pool/max_pool2d_with_indices_backward.wgsl" + MaxPool2dWithIndicesBackward, + "../../template/pool/max_pool2d_with_indices_backward.wgsl" ); kernel_wgsl!( - MaxPool2dWithIndices, - "../../template/pool/max_pool2d_with_indices.wgsl" + MaxPool2dWithIndices, + "../../template/pool/max_pool2d_with_indices.wgsl" ); pub(crate) fn max_pool2d( - x: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> WgpuTensor { - let (info_handle, output) = - build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - x.client - .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); - - output + let (info_handle, output) = + build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + + x.client + .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); + + output } pub(crate) fn max_pool2d_with_indices( - x: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (WgpuTensor, WgpuTensor) { - let (info_handle, output) = - build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); - let indices = empty_device(x.client.clone(), x.device, output.shape.clone()); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - x.client.execute( - Box::new(kernel), - &[&x.handle, &output.handle, &indices.handle, &info_handle], - ); - - (output, indices) + let (info_handle, output) = + build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); + let indices = empty_device(x.client.clone(), x.device, output.shape.clone()); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + + x.client.execute( + Box::new(kernel), + &[&x.handle, &output.handle, &indices.handle, &info_handle], + ); + + (output, indices) } pub(crate) fn max_pool2d_with_indices_backward( - x: WgpuTensor, - grad: WgpuTensor, - indices: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: WgpuTensor, + grad: WgpuTensor, + indices: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> WgpuTensor { - let grad = kernel::into_contiguous(grad); - let indices = kernel::into_contiguous(indices); - - let num_elems = x.shape.num_elements(); - let buffer = x.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer); - - let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - x.client.execute( - Box::new(kernel), - &[&indices.handle, &grad.handle, &output.handle, &info_handle], - ); - output + let grad = kernel::into_contiguous(grad); + let indices = kernel::into_contiguous(indices); + + let num_elems = x.shape.num_elements(); + let buffer = x.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer); + + let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + + x.client.execute( + Box::new(kernel), + &[&indices.handle, &grad.handle, &output.handle, &info_handle], + ); + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor}; - - #[test] - pub fn max_pool2d_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 3]; - let stride = [2, 2]; - let padding = [1, 1]; - let dilation = [1, 1]; - - let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation); - let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation); - - pooled - .into_data() - .assert_approx_eq(&pooled_ref.into_data(), 3); - } - - #[test] - pub fn max_pool2d_with_indices_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 3]; - let stride = [2, 2]; - let padding = [1, 1]; - let dilation = [1, 1]; - - let (pooled, indices) = - module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation); - let (pooled_ref, indices_ref) = - module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation); - - pooled - .into_data() - .assert_approx_eq(&pooled_ref.into_data(), 3); - assert_eq!(indices.into_data(), indices_ref.into_data().convert()); - } - - #[test] - pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let grad_output = Tensor::::random([32, 32, 16, 16], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let grad_output_ref = Tensor::::from_data(grad_output.to_data()); - let kernel_size = [3, 3]; - let stride = [2, 2]; - let padding = [1, 1]; - let dilation = [1, 1]; - - let (_, indices) = - module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation); - let (_, indices_ref) = module::max_pool2d_with_indices( - tensor_ref.clone(), - kernel_size, - stride, - padding, - dilation, - ); - let grad = TestBackend::max_pool2d_with_indices_backward( - tensor.into_primitive(), - kernel_size, - stride, - padding, - dilation, - grad_output.into_primitive(), - indices.into_primitive(), - ) - .x_grad; - let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward( - tensor_ref.into_primitive(), - kernel_size, - stride, - padding, - dilation, - grad_output_ref.into_primitive(), - indices_ref.into_primitive(), - ) - .x_grad; - - Tensor::::from_primitive(grad) - .into_data() - .assert_approx_eq( - &Tensor::::from_primitive(grad_ref).into_data(), - 3, - ); - } + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor}; + + #[test] + pub fn max_pool2d_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 3]; + let stride = [2, 2]; + let padding = [1, 1]; + let dilation = [1, 1]; + + let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation); + let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation); + + pooled + .into_data() + .assert_approx_eq(&pooled_ref.into_data(), 3); + } + + #[test] + pub fn max_pool2d_with_indices_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 3]; + let stride = [2, 2]; + let padding = [1, 1]; + let dilation = [1, 1]; + + let (pooled, indices) = + module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation); + let (pooled_ref, indices_ref) = + module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation); + + pooled + .into_data() + .assert_approx_eq(&pooled_ref.into_data(), 3); + assert_eq!(indices.into_data(), indices_ref.into_data().convert()); + } + + #[test] + pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let grad_output = Tensor::::random([32, 32, 16, 16], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let grad_output_ref = Tensor::::from_data(grad_output.to_data()); + let kernel_size = [3, 3]; + let stride = [2, 2]; + let padding = [1, 1]; + let dilation = [1, 1]; + + let (_, indices) = + module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation); + let (_, indices_ref) = + module::max_pool2d_with_indices(tensor_ref.clone(), kernel_size, stride, padding, dilation); + let grad = TestBackend::max_pool2d_with_indices_backward( + tensor.into_primitive(), + kernel_size, + stride, + padding, + dilation, + grad_output.into_primitive(), + indices.into_primitive(), + ) + .x_grad; + let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward( + tensor_ref.into_primitive(), + kernel_size, + stride, + padding, + dilation, + grad_output_ref.into_primitive(), + indices_ref.into_primitive(), + ) + .x_grad; + + Tensor::::from_primitive(grad) + .into_data() + .assert_approx_eq( + &Tensor::::from_primitive(grad_ref).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/kernel/prng/base.rs b/burn-wgpu/src/kernel/prng/base.rs index c0478db0e5..fd5338bdc4 100644 --- a/burn-wgpu/src/kernel/prng/base.rs +++ b/burn-wgpu/src/kernel/prng/base.rs @@ -1,7 +1,7 @@ use crate::{ - compute::{WgpuComputeClient, WgpuHandle}, - element::WgpuElement, - kernel_wgsl, SEED, + compute::{WgpuComputeClient, WgpuHandle}, + element::WgpuElement, + kernel_wgsl, SEED, }; use burn_common::rand::get_seeded_rng; use rand::Rng; @@ -9,86 +9,86 @@ use rand::Rng; kernel_wgsl!(Prng, "../../template/prng/prng.wgsl"); pub(crate) fn get_seeds() -> Vec { - let mut seed = SEED.lock().unwrap(); - let mut rng = match seed.as_ref() { - Some(rng_seeded) => rng_seeded.clone(), - None => get_seeded_rng(), - }; - let mut seeds: Vec = Vec::with_capacity(4); - for _ in 0..4 { - seeds.push(rng.gen()); - } - *seed = Some(rng); - seeds + let mut seed = SEED.lock().unwrap(); + let mut rng = match seed.as_ref() { + Some(rng_seeded) => rng_seeded.clone(), + None => get_seeded_rng(), + }; + let mut seeds: Vec = Vec::with_capacity(4); + for _ in 0..4 { + seeds.push(rng.gen()); + } + *seed = Some(rng); + seeds } pub(crate) fn make_info_buffer( - client: WgpuComputeClient, - n_values_per_thread: usize, + client: WgpuComputeClient, + n_values_per_thread: usize, ) -> WgpuHandle { - let mut info = get_seeds(); - info.insert(0, n_values_per_thread as u32); - client.create(bytemuck::cast_slice(&info)) + let mut info = get_seeds(); + info.insert(0, n_values_per_thread as u32); + client.create(bytemuck::cast_slice(&info)) } pub(crate) fn make_args_buffer( - client: WgpuComputeClient, - args: &[E], + client: WgpuComputeClient, + args: &[E], ) -> WgpuHandle { - client.create(E::as_bytes(args)) + client.create(E::as_bytes(args)) } #[cfg(test)] pub mod tests { - use burn_tensor::Element; + use burn_tensor::Element; - #[derive(Default, Copy, Clone)] - pub struct BinStats { - pub count: usize, - pub n_runs: usize, // Number of sequences of same bin - } + #[derive(Default, Copy, Clone)] + pub struct BinStats { + pub count: usize, + pub n_runs: usize, // Number of sequences of same bin + } - pub fn calculate_bin_stats( - numbers: Vec, - number_of_bins: usize, - low: f32, - high: f32, - ) -> Vec { - let range = (high - low) / number_of_bins as f32; - let mut output: Vec = (0..number_of_bins).map(|_| Default::default()).collect(); - let mut initialized = false; - let mut current_runs = number_of_bins; // impossible value for starting point - for number in numbers { - let num = number.elem::(); - if num < low || num > high { - continue; - } - let index = f32::floor((num - low) / range) as usize; - output[index].count += 1; - if initialized && index != current_runs { - output[current_runs].n_runs += 1; - } - initialized = true; - current_runs = index; - } + pub fn calculate_bin_stats( + numbers: Vec, + number_of_bins: usize, + low: f32, + high: f32, + ) -> Vec { + let range = (high - low) / number_of_bins as f32; + let mut output: Vec = (0..number_of_bins).map(|_| Default::default()).collect(); + let mut initialized = false; + let mut current_runs = number_of_bins; // impossible value for starting point + for number in numbers { + let num = number.elem::(); + if num < low || num > high { + continue; + } + let index = f32::floor((num - low) / range) as usize; + output[index].count += 1; + if initialized && index != current_runs { output[current_runs].n_runs += 1; - output + } + initialized = true; + current_runs = index; } + output[current_runs].n_runs += 1; + output + } - #[test] - fn test_count_bins() { - let numbers = vec![0., 1., 1.5, 2., 2.5, 3., 2.5, 1.5, 3.5]; - let number_of_bins = 4; - let low = 0.; - let high = 4.; - let stats = calculate_bin_stats(numbers, number_of_bins, low, high); - assert_eq!(stats[0].count, 1); - assert_eq!(stats[0].n_runs, 1); - assert_eq!(stats[1].count, 3); - assert_eq!(stats[1].n_runs, 2); - assert_eq!(stats[2].count, 3); - assert_eq!(stats[2].n_runs, 2); - assert_eq!(stats[3].count, 2); - assert_eq!(stats[3].n_runs, 2); - } + #[test] + fn test_count_bins() { + let numbers = vec![0., 1., 1.5, 2., 2.5, 3., 2.5, 1.5, 3.5]; + let number_of_bins = 4; + let low = 0.; + let high = 4.; + let stats = calculate_bin_stats(numbers, number_of_bins, low, high); + assert_eq!(stats[0].count, 1); + assert_eq!(stats[0].n_runs, 1); + assert_eq!(stats[1].count, 3); + assert_eq!(stats[1].n_runs, 2); + assert_eq!(stats[2].count, 3); + assert_eq!(stats[2].n_runs, 2); + assert_eq!(stats[3].count, 2); + assert_eq!(stats[3].n_runs, 2); + } } diff --git a/burn-wgpu/src/kernel/prng/bernoulli.rs b/burn-wgpu/src/kernel/prng/bernoulli.rs index 69c4cbf21e..1a2324420c 100644 --- a/burn-wgpu/src/kernel/prng/bernoulli.rs +++ b/burn-wgpu/src/kernel/prng/bernoulli.rs @@ -1,13 +1,13 @@ use crate::{ - compute::{compute_client, StaticKernel}, - element::WgpuElement, - kernel::{ - prng::base::{make_args_buffer, make_info_buffer}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, - }, - ops::numeric::empty_device, - tensor::WgpuTensor, - GraphicsApi, WgpuDevice, + compute::{compute_client, StaticKernel}, + element::WgpuElement, + kernel::{ + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, + }, + ops::numeric::empty_device, + tensor::WgpuTensor, + GraphicsApi, WgpuDevice, }; use burn_tensor::Shape; @@ -16,122 +16,115 @@ use super::base::Prng; struct BernoulliPrng; impl StaticKernelSource for BernoulliPrng { - fn source() -> SourceTemplate { - Prng::source() - .register("num_args", "1") - .register( - "prng_loop", - include_str!("../../template/prng/bernoulli_inner_loop.wgsl"), - ) - .add_template("fn cast_elem(e: bool) -> {{ elem }} {return {{elem}}(e);}") - } + fn source() -> SourceTemplate { + Prng::source() + .register("num_args", "1") + .register( + "prng_loop", + include_str!("../../template/prng/bernoulli_inner_loop.wgsl"), + ) + .add_template("fn cast_elem(e: bool) -> {{ elem }} {return {{elem}}(e);}") + } } /// Pseudo-random generator for bernoulli pub fn random_bernoulli( - shape: Shape, - device: &WgpuDevice, - prob: E, + shape: Shape, + device: &WgpuDevice, + prob: E, ) -> WgpuTensor { - const N_VALUES_PER_THREAD: usize = 128; - - let client = compute_client::(device); - let output = empty_device(client.clone(), device.clone(), shape.clone()); - let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); - let args_handle = make_args_buffer(client.clone(), &[prob]); - let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); - - client.execute( - Box::new(kernel), - &[&output.handle, &info_handle, &args_handle], - ); - - output + const N_VALUES_PER_THREAD: usize = 128; + + let client = compute_client::(device); + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[prob]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); + + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], + ); + + output } #[cfg(test)] mod tests { - use core::f32; - - use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; - use serial_test::serial; - - use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; - - #[test] - #[serial] - fn subsequent_calls_give_different_tensors() { - TestBackend::seed(0); - let shape: Shape<2> = [40, 40].into(); - let device = WgpuDevice::default(); - - let tensor_1 = Tensor::::random_device( - shape.clone(), - Distribution::Bernoulli(0.5), - &device, - ); - let tensor_2 = Tensor::::random_device( - shape.clone(), - Distribution::Bernoulli(0.5), - &device, - ); - let mut diff_exists = false; - for i in 0..shape.num_elements() { - if tensor_1.to_data().value[i] != tensor_2.to_data().value[i] { - diff_exists = true; - break; - } - } - assert!(diff_exists); - } - - #[test] - #[serial] - fn number_of_1_proportional_to_prob() { - TestBackend::seed(0); - let shape: Shape<2> = [40, 40].into(); - let device = WgpuDevice::default(); - let prob = 0.7; - - let tensor_1 = Tensor::::random_device( - shape.clone(), - Distribution::Bernoulli(prob), - &device, - ); - - // High bound slightly over 1 so 1.0 is included in second bin - let bin_stats = calculate_bin_stats(tensor_1.into_data().value, 2, 0., 1.1); - assert!( - f32::abs((bin_stats[1].count as f32 / shape.num_elements() as f32) - prob as f32) - < 0.05 - ); + use core::f32; + + use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; + use serial_test::serial; + + use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; + + #[test] + #[serial] + fn subsequent_calls_give_different_tensors() { + TestBackend::seed(0); + let shape: Shape<2> = [40, 40].into(); + let device = WgpuDevice::default(); + + let tensor_1 = + Tensor::::random_device(shape.clone(), Distribution::Bernoulli(0.5), &device); + let tensor_2 = + Tensor::::random_device(shape.clone(), Distribution::Bernoulli(0.5), &device); + let mut diff_exists = false; + for i in 0..shape.num_elements() { + if tensor_1.to_data().value[i] != tensor_2.to_data().value[i] { + diff_exists = true; + break; + } } + assert!(diff_exists); + } + + #[test] + #[serial] + fn number_of_1_proportional_to_prob() { + TestBackend::seed(0); + let shape: Shape<2> = [40, 40].into(); + let device = WgpuDevice::default(); + let prob = 0.7; + + let tensor_1 = Tensor::::random_device( + shape.clone(), + Distribution::Bernoulli(prob), + &device, + ); - #[test] - #[serial] - fn runs_test() { - TestBackend::seed(0); - let shape = Shape::new([512, 512]); - let device = WgpuDevice::default(); - let tensor = - Tensor::::random_device(shape, Distribution::Bernoulli(0.5), &device); - - let numbers = tensor.into_data().value; - let stats = calculate_bin_stats(numbers, 2, 0., 1.1); - let n_0 = stats[0].count as f32; - let n_1 = stats[1].count as f32; - let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; - - let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; - let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) - / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); - let z = (n_runs - expectation) / variance.sqrt(); - - // below 2 means we can have good confidence in the randomness - // we put 2.5 to make sure it passes even when very unlucky - assert!(z.abs() < 2.5); - } + // High bound slightly over 1 so 1.0 is included in second bin + let bin_stats = calculate_bin_stats(tensor_1.into_data().value, 2, 0., 1.1); + assert!( + f32::abs((bin_stats[1].count as f32 / shape.num_elements() as f32) - prob as f32) < 0.05 + ); + } + + #[test] + #[serial] + fn runs_test() { + TestBackend::seed(0); + let shape = Shape::new([512, 512]); + let device = WgpuDevice::default(); + let tensor = + Tensor::::random_device(shape, Distribution::Bernoulli(0.5), &device); + + let numbers = tensor.into_data().value; + let stats = calculate_bin_stats(numbers, 2, 0., 1.1); + let n_0 = stats[0].count as f32; + let n_1 = stats[1].count as f32; + let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; + + let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; + let variance = + ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); + let z = (n_runs - expectation) / variance.sqrt(); + + // below 2 means we can have good confidence in the randomness + // we put 2.5 to make sure it passes even when very unlucky + assert!(z.abs() < 2.5); + } } diff --git a/burn-wgpu/src/kernel/prng/normal.rs b/burn-wgpu/src/kernel/prng/normal.rs index fc80f3f1e7..dd88e5f5fe 100644 --- a/burn-wgpu/src/kernel/prng/normal.rs +++ b/burn-wgpu/src/kernel/prng/normal.rs @@ -1,15 +1,15 @@ use burn_tensor::Shape; use crate::{ - compute::{compute_client, StaticKernel}, - element::WgpuElement, - kernel::{ - prng::base::{make_args_buffer, make_info_buffer}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, - }, - ops::numeric::empty_device, - tensor::WgpuTensor, - GraphicsApi, WgpuDevice, + compute::{compute_client, StaticKernel}, + element::WgpuElement, + kernel::{ + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, + }, + ops::numeric::empty_device, + tensor::WgpuTensor, + GraphicsApi, WgpuDevice, }; use super::base::Prng; @@ -17,110 +17,107 @@ use super::base::Prng; struct NormalPrng; impl StaticKernelSource for NormalPrng { - fn source() -> SourceTemplate { - Prng::source() - .register("num_args", "2") - .register( - "prng_loop", - include_str!("../../template/prng/normal_inner_loop.wgsl"), - ) - .add_template(include_str!( - "../../template/prng/box_muller_transform.wgsl" - )) - } + fn source() -> SourceTemplate { + Prng::source() + .register("num_args", "2") + .register( + "prng_loop", + include_str!("../../template/prng/normal_inner_loop.wgsl"), + ) + .add_template(include_str!( + "../../template/prng/box_muller_transform.wgsl" + )) + } } /// Pseudo-random generator for normal distribution pub fn random_normal( - shape: Shape, - device: &WgpuDevice, - mean: E, - std: E, + shape: Shape, + device: &WgpuDevice, + mean: E, + std: E, ) -> WgpuTensor { - const N_VALUES_PER_THREAD: usize = 128; // must be even + const N_VALUES_PER_THREAD: usize = 128; // must be even - let client = compute_client::(device); - let output = empty_device(client.clone(), device.clone(), shape.clone()); - let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); - let args_handle = make_args_buffer(client.clone(), &[mean, std]); - let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); + let client = compute_client::(device); + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[mean, std]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); - client.execute( - Box::new(kernel), - &[&output.handle, &info_handle, &args_handle], - ); + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], + ); - output + output } #[cfg(test)] mod tests { - use burn_tensor::{backend::Backend, Data, Distribution, Shape, Tensor}; - use serial_test::serial; + use burn_tensor::{backend::Backend, Data, Distribution, Shape, Tensor}; + use serial_test::serial; - use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; + use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; - #[test] - #[serial] - fn subsequent_calls_give_different_tensors() { - TestBackend::seed(0); - let shape = [4, 5]; - let device = WgpuDevice::default(); + #[test] + #[serial] + fn subsequent_calls_give_different_tensors() { + TestBackend::seed(0); + let shape = [4, 5]; + let device = WgpuDevice::default(); - let tensor_1 = - Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); - let tensor_2 = - Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); - for i in 0..20 { - assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); - } + let tensor_1 = + Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); + let tensor_2 = + Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); + for i in 0..20 { + assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); } + } - #[test] - #[serial] - fn empirical_mean_close_to_expectation() { - TestBackend::seed(0); - let shape = [128, 128]; - let device = WgpuDevice::default(); - let mean = 10.; - let tensor = - Tensor::::random_device(shape, Distribution::Normal(mean, 2.), &device); - let empirical_mean = tensor.mean().into_data(); - empirical_mean.assert_approx_eq(&Data::from([mean as f32]), 1); - } + #[test] + #[serial] + fn empirical_mean_close_to_expectation() { + TestBackend::seed(0); + let shape = [128, 128]; + let device = WgpuDevice::default(); + let mean = 10.; + let tensor = + Tensor::::random_device(shape, Distribution::Normal(mean, 2.), &device); + let empirical_mean = tensor.mean().into_data(); + empirical_mean.assert_approx_eq(&Data::from([mean as f32]), 1); + } - #[test] - #[serial] - fn normal_respects_68_95_99_rule() { - // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule - let shape: Shape<2> = [1000, 1000].into(); - let device = WgpuDevice::default(); - let mu = 0.; - let s = 1.; - let tensor = Tensor::::random_device( - shape.clone(), - Distribution::Normal(mu, s), - &device, - ); - let stats = calculate_bin_stats( - tensor.into_data().value, - 6, - (mu - 3. * s) as f32, - (mu + 3. * s) as f32, - ); - let assert_approx_eq = |count, percent| { - let expected = percent * shape.num_elements() as f32 / 100.; - assert!(f32::abs(count as f32 - expected) < 2000.); - }; - assert_approx_eq(stats[0].count, 2.1); - assert_approx_eq(stats[1].count, 13.6); - assert_approx_eq(stats[2].count, 34.1); - assert_approx_eq(stats[3].count, 34.1); - assert_approx_eq(stats[4].count, 13.6); - assert_approx_eq(stats[5].count, 2.1); - } + #[test] + #[serial] + fn normal_respects_68_95_99_rule() { + // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule + let shape: Shape<2> = [1000, 1000].into(); + let device = WgpuDevice::default(); + let mu = 0.; + let s = 1.; + let tensor = + Tensor::::random_device(shape.clone(), Distribution::Normal(mu, s), &device); + let stats = calculate_bin_stats( + tensor.into_data().value, + 6, + (mu - 3. * s) as f32, + (mu + 3. * s) as f32, + ); + let assert_approx_eq = |count, percent| { + let expected = percent * shape.num_elements() as f32 / 100.; + assert!(f32::abs(count as f32 - expected) < 2000.); + }; + assert_approx_eq(stats[0].count, 2.1); + assert_approx_eq(stats[1].count, 13.6); + assert_approx_eq(stats[2].count, 34.1); + assert_approx_eq(stats[3].count, 34.1); + assert_approx_eq(stats[4].count, 13.6); + assert_approx_eq(stats[5].count, 2.1); + } } diff --git a/burn-wgpu/src/kernel/prng/uniform.rs b/burn-wgpu/src/kernel/prng/uniform.rs index ec9f8e00a7..bf9880ba35 100644 --- a/burn-wgpu/src/kernel/prng/uniform.rs +++ b/burn-wgpu/src/kernel/prng/uniform.rs @@ -1,15 +1,15 @@ use burn_tensor::Shape; use crate::{ - compute::{compute_client, StaticKernel, WgpuComputeClient}, - element::WgpuElement, - kernel::{ - prng::base::{make_args_buffer, make_info_buffer}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, - }, - ops::numeric::empty_device, - tensor::WgpuTensor, - GraphicsApi, WgpuDevice, + compute::{compute_client, StaticKernel, WgpuComputeClient}, + element::WgpuElement, + kernel::{ + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, + }, + ops::numeric::empty_device, + tensor::WgpuTensor, + GraphicsApi, WgpuDevice, }; use super::base::Prng; @@ -17,154 +17,149 @@ use super::base::Prng; struct UniformPrng; impl StaticKernelSource for UniformPrng { - fn source() -> SourceTemplate { - Prng::source().register("num_args", "2").register( - "prng_loop", - include_str!("../../template/prng/uniform_inner_loop.wgsl"), - ) - } + fn source() -> SourceTemplate { + Prng::source().register("num_args", "2").register( + "prng_loop", + include_str!("../../template/prng/uniform_inner_loop.wgsl"), + ) + } } /// Pseudo-random generator for uniform distribution pub fn random_uniform( - shape: Shape, - device: &WgpuDevice, - low: E, - high: E, + shape: Shape, + device: &WgpuDevice, + low: E, + high: E, ) -> WgpuTensor { - let client = compute_client::(device); - uniform_kernel(client, device, &shape, low, high) + let client = compute_client::(device); + uniform_kernel(client, device, &shape, low, high) } /// Pseudo-random generator for uniform distribution, based on /// another tensor's client, device and shape pub fn random_like_uniform( - tensor: &WgpuTensor, - low: E, - high: E, + tensor: &WgpuTensor, + low: E, + high: E, ) -> WgpuTensor { - uniform_kernel( - tensor.client.clone(), - &tensor.device, - &tensor.shape, - low, - high, - ) + uniform_kernel( + tensor.client.clone(), + &tensor.device, + &tensor.shape, + low, + high, + ) } fn uniform_kernel( - client: WgpuComputeClient, - device: &WgpuDevice, - shape: &Shape, - low: E, - high: E, + client: WgpuComputeClient, + device: &WgpuDevice, + shape: &Shape, + low: E, + high: E, ) -> WgpuTensor { - const N_VALUES_PER_THREAD: usize = 128; - - let output = empty_device(client.clone(), device.clone(), shape.clone()); - let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); - let args_handle = make_args_buffer(client.clone(), &[low, high]); - let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); - - client.execute( - Box::new(kernel), - &[&output.handle, &info_handle, &args_handle], - ); - - output + const N_VALUES_PER_THREAD: usize = 128; + + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[low, high]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); + + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], + ); + + output } #[cfg(test)] mod tests { - use core::f32; - - use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; - use serial_test::serial; - - use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; - - #[test] - #[serial] - fn subsequent_calls_give_different_tensors() { - TestBackend::seed(0); - let shape = [4, 5]; - let device = WgpuDevice::default(); - - let tensor_1 = - Tensor::::random_device(shape, Distribution::Default, &device); - let tensor_2 = - Tensor::::random_device(shape, Distribution::Default, &device); - for i in 0..20 { - assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); - } - } + use core::f32; - #[test] - #[serial] - fn values_all_within_interval_default() { - TestBackend::seed(0); - let shape = [24, 24]; - let device = WgpuDevice::default(); + use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; + use serial_test::serial; - let tensor = Tensor::::random_device(shape, Distribution::Default, &device); - tensor.to_data().assert_within_range(0..1); - } + use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; - #[test] - #[serial] - fn values_all_within_interval_uniform() { - TestBackend::seed(0); - let shape = [24, 24]; - let device = WgpuDevice::default(); - - let tensor = - Tensor::::random_device(shape, Distribution::Uniform(5., 17.), &device); - tensor.to_data().assert_within_range(5..17); - } - - #[test] - #[serial] - fn at_least_one_value_per_bin_uniform() { - TestBackend::seed(0); - let shape = [64, 64]; - let device = WgpuDevice::default(); - - let tensor = Tensor::::random_device( - shape, - Distribution::Uniform(-5., 10.), - &device, - ); - let numbers = tensor.into_data().value; - let stats = calculate_bin_stats(numbers, 3, -5., 10.); - assert!(stats[0].count >= 1); - assert!(stats[1].count >= 1); - assert!(stats[2].count >= 1); - } + #[test] + #[serial] + fn subsequent_calls_give_different_tensors() { + TestBackend::seed(0); + let shape = [4, 5]; + let device = WgpuDevice::default(); - #[test] - #[serial] - fn runs_test() { - TestBackend::seed(0); - let shape = Shape::new([512, 512]); - let device = WgpuDevice::default(); - let tensor = Tensor::::random_device(shape, Distribution::Default, &device); - - let numbers = tensor.into_data().value; - let stats = calculate_bin_stats(numbers, 2, 0., 1.); - let n_0 = stats[0].count as f32; - let n_1 = stats[1].count as f32; - let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; - - let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; - let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) - / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); - let z = (n_runs - expectation) / variance.sqrt(); - - // below 2 means we can have good confidence in the randomness - // we put 2.5 to make sure it passes even when very unlucky - assert!(z.abs() < 2.5); + let tensor_1 = Tensor::::random_device(shape, Distribution::Default, &device); + let tensor_2 = Tensor::::random_device(shape, Distribution::Default, &device); + for i in 0..20 { + assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); } + } + + #[test] + #[serial] + fn values_all_within_interval_default() { + TestBackend::seed(0); + let shape = [24, 24]; + let device = WgpuDevice::default(); + + let tensor = Tensor::::random_device(shape, Distribution::Default, &device); + tensor.to_data().assert_within_range(0..1); + } + + #[test] + #[serial] + fn values_all_within_interval_uniform() { + TestBackend::seed(0); + let shape = [24, 24]; + let device = WgpuDevice::default(); + + let tensor = + Tensor::::random_device(shape, Distribution::Uniform(5., 17.), &device); + tensor.to_data().assert_within_range(5..17); + } + + #[test] + #[serial] + fn at_least_one_value_per_bin_uniform() { + TestBackend::seed(0); + let shape = [64, 64]; + let device = WgpuDevice::default(); + + let tensor = + Tensor::::random_device(shape, Distribution::Uniform(-5., 10.), &device); + let numbers = tensor.into_data().value; + let stats = calculate_bin_stats(numbers, 3, -5., 10.); + assert!(stats[0].count >= 1); + assert!(stats[1].count >= 1); + assert!(stats[2].count >= 1); + } + + #[test] + #[serial] + fn runs_test() { + TestBackend::seed(0); + let shape = Shape::new([512, 512]); + let device = WgpuDevice::default(); + let tensor = Tensor::::random_device(shape, Distribution::Default, &device); + + let numbers = tensor.into_data().value; + let stats = calculate_bin_stats(numbers, 2, 0., 1.); + let n_0 = stats[0].count as f32; + let n_1 = stats[1].count as f32; + let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; + + let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; + let variance = + ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); + let z = (n_runs - expectation) / variance.sqrt(); + + // below 2 means we can have good confidence in the randomness + // we put 2.5 to make sure it passes even when very unlucky + assert!(z.abs() < 2.5); + } } diff --git a/burn-wgpu/src/kernel/reduce/base.rs b/burn-wgpu/src/kernel/reduce/base.rs index 0f58369607..bf50288116 100644 --- a/burn-wgpu/src/kernel/reduce/base.rs +++ b/burn-wgpu/src/kernel/reduce/base.rs @@ -2,21 +2,21 @@ use crate::{element::WgpuElement, tensor::WgpuTensor}; /// Creates an empty output tensor with reduce output shape pub fn init_reduce_output( - input: &WgpuTensor, - reduce_dim: usize, + input: &WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[reduce_dim] = 1; + let mut shape_out = input.shape.clone(); + shape_out.dims[reduce_dim] = 1; - // Create output handle - let num_elems_output = shape_out.num_elements(); - let handle = input - .client - .empty(num_elems_output * core::mem::size_of::()); - WgpuTensor::new( - input.client.clone(), - input.device.clone(), - shape_out.clone(), - handle, - ) + // Create output handle + let num_elems_output = shape_out.num_elements(); + let handle = input + .client + .empty(num_elems_output * core::mem::size_of::()); + WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + handle, + ) } diff --git a/burn-wgpu/src/kernel/reduce/reduction.rs b/burn-wgpu/src/kernel/reduce/reduction.rs index 432f678827..aa8bb3f9e4 100644 --- a/burn-wgpu/src/kernel/reduce/reduction.rs +++ b/burn-wgpu/src/kernel/reduce/reduction.rs @@ -1,18 +1,18 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{ - build_info, elemwise_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, - WORKGROUP_DEFAULT, - }, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{ + build_info, elemwise_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, + WORKGROUP_DEFAULT, + }, + kernel_wgsl, + tensor::WgpuTensor, }; use burn_tensor::Shape; kernel_wgsl!( - RecursiveSumRaw, - "../../template/reduction/recursive_sum.wgsl" + RecursiveSumRaw, + "../../template/reduction/recursive_sum.wgsl" ); kernel_wgsl!(ReductionDimRaw, "../../template/reduction/reduce_dim.wgsl"); kernel_wgsl!(ReductionArgsRaw, "../../template/reduction/args.wgsl"); @@ -23,199 +23,199 @@ pub(crate) struct SumDim; pub(crate) struct MeanDim; impl StaticKernelSource for SumDim { - fn source() -> SourceTemplate { - ReductionDimRaw::source().register("assign", "output[id] = sum;") - } + fn source() -> SourceTemplate { + ReductionDimRaw::source().register("assign", "output[id] = sum;") + } } impl StaticKernelSource for MeanDim { - fn source() -> SourceTemplate { - ReductionDimRaw::source() - .add_template( - "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { + fn source() -> SourceTemplate { + ReductionDimRaw::source() + .add_template( + "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { return sum / {{ elem }}(dim); }", - ) - .register("assign", "output[id] = mean_dim(sum, shape_dim);") - } + ) + .register("assign", "output[id] = mean_dim(sum, shape_dim);") + } } impl StaticKernelSource for ArgsMax { - fn source() -> SourceTemplate { - ReductionArgsRaw::source() - .register("cmp", ">") - .register("initial", (-32767).to_string()) - } + fn source() -> SourceTemplate { + ReductionArgsRaw::source() + .register("cmp", ">") + .register("initial", (-32767).to_string()) + } } impl StaticKernelSource for ArgsMin { - fn source() -> SourceTemplate { - ReductionArgsRaw::source() - .register("cmp", "<") - .register("initial", 32767.to_string()) - } + fn source() -> SourceTemplate { + ReductionArgsRaw::source() + .register("cmp", "<") + .register("initial", 32767.to_string()) + } } /// Sum all elements in the input buffer. pub fn sum(input: WgpuTensor) -> WgpuTensor { - let mut input_handle = input.handle; - let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT); + let mut input_handle = input.handle; + let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT); - loop { - let num_invocations = workgroup.num_invocations(); - let handle = input - .client - .empty(core::mem::size_of::() * num_invocations); + loop { + let num_invocations = workgroup.num_invocations(); + let handle = input + .client + .empty(core::mem::size_of::() * num_invocations); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); - input - .client - .execute(Box::new(kernel), &[&input_handle, &handle]); + input + .client + .execute(Box::new(kernel), &[&input_handle, &handle]); - if num_invocations <= 1 { - return WgpuTensor::new(input.client, input.device, Shape::new([1]), handle); - } - - input_handle = handle; - workgroup = elemwise_workgroup(num_invocations, WORKGROUP_DEFAULT); + if num_invocations <= 1 { + return WgpuTensor::new(input.client, input.device, Shape::new([1]), handle); } + + input_handle = handle; + workgroup = elemwise_workgroup(num_invocations, WORKGROUP_DEFAULT); + } } /// Execute the sum dim kernel. pub fn sum_dim( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim::(input, output, dim) + reduction_dim::(input, output, dim) } /// Execute the mean dim kernel. pub fn mean_dim( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim::(input, output, dim) + reduction_dim::(input, output, dim) } fn reduction_dim( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - let kernel = - StaticKernel::>::new( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), - ); - - let mut info = build_info(&[&input, &output]); - info.push(dim as u32); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_handle], + let kernel = + StaticKernel::>::new( + elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), ); - output + let mut info = build_info(&[&input, &output]); + info.push(dim as u32); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); + + output } /// Execute the argmax kernel. pub fn argmax( - input: WgpuTensor, - dim: usize, + input: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_args_dim::(input, dim) + reduction_args_dim::(input, dim) } /// Execute the argmin kernel. pub fn argmin( - input: WgpuTensor, - dim: usize, + input: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_args_dim::(input, dim) + reduction_args_dim::(input, dim) } fn reduction_args_dim( - input: WgpuTensor, - dim: usize, + input: WgpuTensor, + dim: usize, ) -> WgpuTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[dim] = 1; - let num_elems = shape_out.num_elements(); - let buffer = input.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - input.client.clone(), - input.device.clone(), - shape_out, - buffer, + let mut shape_out = input.shape.clone(); + shape_out.dims[dim] = 1; + let num_elems = shape_out.num_elements(); + let buffer = input.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape_out, + buffer, + ); + + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), ); + let mut info = build_info(&[&input, &output]); + info.push(dim as u32); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), - ); - let mut info = build_info(&[&input, &output]); - info.push(dim as u32); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_handle], - ); + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); - WgpuTensor::new(output.client, output.device, output.shape, output.handle) + WgpuTensor::new(output.client, output.device, output.shape, output.handle) } #[cfg(test)] mod tests { - use super::*; - use crate::{ - kernel::reduce::init_reduce_output, - tests::{ReferenceBackend, TestBackend}, - }; - use burn_tensor::{Distribution, Int, Tensor}; - - #[test] - fn reduction_sum_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let val = Tensor::::from_primitive(sum(tensor.into_primitive())); - let val_ref = tensor_ref.sum(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_sum_dim_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 1024], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 1; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(reduction_dim::( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_args_dim_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 1024], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let val = Tensor::::from_primitive(argmax(tensor.into_primitive(), 1)); - let val_ref = tensor_ref.argmax(1); - - assert_eq!(val_ref.into_data().convert(), val.into_data()); - } + use super::*; + use crate::{ + kernel::reduce::init_reduce_output, + tests::{ReferenceBackend, TestBackend}, + }; + use burn_tensor::{Distribution, Int, Tensor}; + + #[test] + fn reduction_sum_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let val = Tensor::::from_primitive(sum(tensor.into_primitive())); + let val_ref = tensor_ref.sum(); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 1; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(reduction_dim::( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(1); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_args_dim_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let val = Tensor::::from_primitive(argmax(tensor.into_primitive(), 1)); + let val_ref = tensor_ref.argmax(1); + + assert_eq!(val_ref.into_data().convert(), val.into_data()); + } } diff --git a/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs b/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs index 4d4fb43e3a..cd6d7dfa91 100644 --- a/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs +++ b/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs @@ -1,170 +1,168 @@ use crate::{ - compute::{StaticKernel, WorkGroup}, - element::WgpuElement, - kernel::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}, - kernel_wgsl, - tensor::WgpuTensor, + compute::{StaticKernel, WorkGroup}, + element::WgpuElement, + kernel::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!( - ReductionDimSharedMemoryRaw, - "../../template/reduction/reduce_dim_shared_memory.wgsl" + ReductionDimSharedMemoryRaw, + "../../template/reduction/reduce_dim_shared_memory.wgsl" ); pub(crate) struct SumDimSharedMemory; pub(crate) struct MeanDimSharedMemory; impl StaticKernelSource for SumDimSharedMemory { - fn source() -> SourceTemplate { - ReductionDimSharedMemoryRaw::source() - .register( - "shared_size", - (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), - ) - .register("initial", 0.0.to_string()) - .register("update", "shared_memory[local_id] += value; ") - .register("assign", "output[output_position] = final_value; ") - } + fn source() -> SourceTemplate { + ReductionDimSharedMemoryRaw::source() + .register( + "shared_size", + (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), + ) + .register("initial", 0.0.to_string()) + .register("update", "shared_memory[local_id] += value; ") + .register("assign", "output[output_position] = final_value; ") + } } impl StaticKernelSource for MeanDimSharedMemory { - fn source() -> SourceTemplate { - ReductionDimSharedMemoryRaw::source() - .register( - "shared_size", - (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), - ) - .register("initial", 0.0.to_string()) - .register("update", "shared_memory[local_id] += value; ") - .add_template( - "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { + fn source() -> SourceTemplate { + ReductionDimSharedMemoryRaw::source() + .register( + "shared_size", + (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), + ) + .register("initial", 0.0.to_string()) + .register("update", "shared_memory[local_id] += value; ") + .add_template( + "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { return sum / {{ elem }}(dim); }", - ) - .register( - "assign", - "output[output_position] = mean_dim(final_value, shape_input_dim_reduce);", - ) - } + ) + .register( + "assign", + "output[output_position] = mean_dim(final_value, shape_input_dim_reduce);", + ) + } } /// Execute the sum dim kernel leveraging shared memory /// Probably more efficient on tensors where the dimension to reduced /// is much larger than the others pub fn sum_dim_shared_memory( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim_shared_memory::(input, output, dim) + reduction_dim_shared_memory::(input, output, dim) } /// Execute the mean dim kernel leveraging shared memory /// Probably more efficient on tensors where the dimension to reduced /// is much larger than the others pub fn mean_dim_shared_memory( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim_shared_memory::(input, output, dim) + reduction_dim_shared_memory::(input, output, dim) } fn reduction_dim_shared_memory( - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let num_elems_output = output.shape.num_elements(); - let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x); - let grid = WorkGroup::new(n_workgroups_x as u32, n_workgroups_y as u32, 1); + let num_elems_output = output.shape.num_elements(); + let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32)); + let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x); + let grid = WorkGroup::new(n_workgroups_x as u32, n_workgroups_y as u32, 1); - let kernel = - StaticKernel::>::new( - grid, - ); + let kernel = + StaticKernel::>::new(grid); - // Build info - let mut info = build_info(&[&input, &output]); + // Build info + let mut info = build_info(&[&input, &output]); - // Reduce groups are elements that are aligned along the reduce dim - let reduce_group_size = input.shape.dims[reduce_dim]; - let n_invocation_per_workgroup = WORKGROUP_DEFAULT * WORKGROUP_DEFAULT; - let n_reduce_elements_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32; + // Reduce groups are elements that are aligned along the reduce dim + let reduce_group_size = input.shape.dims[reduce_dim]; + let n_invocation_per_workgroup = WORKGROUP_DEFAULT * WORKGROUP_DEFAULT; + let n_reduce_elements_per_thread = + f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32; - // Add dimension of reduction and how many reduce elements are treated per thread - info.push(reduce_dim as u32); - info.push(n_reduce_elements_per_thread); + // Add dimension of reduction and how many reduce elements are treated per thread + info.push(reduce_dim as u32); + info.push(n_reduce_elements_per_thread); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_handle], - ); + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); - output + output } #[cfg(test)] mod tests { - use super::*; - use crate::{ - kernel::reduce::init_reduce_output, - tests::{ReferenceBackend, TestBackend}, - }; - use burn_tensor::{Distribution, Tensor}; - - #[test] - fn reduction_sum_dim_shared_memory_small() { - let tensor = Tensor::::random([700], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 0; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(sum_dim_shared_memory( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_sum_dim_shared_memory_medium() { - let tensor = Tensor::::random([6, 1024], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 1; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(sum_dim_shared_memory( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_sum_dim_shared_memory_large() { - let tensor = Tensor::::random([4, 1024, 50], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 2; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(sum_dim_shared_memory( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } + use super::*; + use crate::{ + kernel::reduce::init_reduce_output, + tests::{ReferenceBackend, TestBackend}, + }; + use burn_tensor::{Distribution, Tensor}; + + #[test] + fn reduction_sum_dim_shared_memory_small() { + let tensor = Tensor::::random([700], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 0; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_shared_memory_medium() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 1; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_shared_memory_large() { + let tensor = Tensor::::random([4, 1024, 50], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 2; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/reduce/tune/base.rs b/burn-wgpu/src/kernel/reduce/tune/base.rs index d52bf37dcb..02d1f1b4de 100644 --- a/burn-wgpu/src/kernel/reduce/tune/base.rs +++ b/burn-wgpu/src/kernel/reduce/tune/base.rs @@ -1,27 +1,27 @@ #[macro_export] /// Generate an autotune operation for a reduce kernel macro_rules! reduce_tune_ops { - ($name:ident, $func:expr) => { - #[derive(new)] - pub(crate) struct $name { - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, - } + ($name:ident, $func:expr) => { + #[derive(new)] + pub(crate) struct $name { + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, + } - impl AutotuneOperation for $name { - fn execute(self: Box) { - #[allow(clippy::redundant_closure_call)] - $func(self.input, self.output, self.reduce_dim); - } + impl AutotuneOperation for $name { + fn execute(self: Box) { + #[allow(clippy::redundant_closure_call)] + $func(self.input, self.output, self.reduce_dim); + } - fn clone(&self) -> Box { - Box::new(Self { - input: self.input.clone(), - output: self.output.clone(), - reduce_dim: self.reduce_dim.clone(), - }) - } - } - }; + fn clone(&self) -> Box { + Box::new(Self { + input: self.input.clone(), + output: self.output.clone(), + reduce_dim: self.reduce_dim.clone(), + }) + } + } + }; } diff --git a/burn-wgpu/src/kernel/reduce/tune/key.rs b/burn-wgpu/src/kernel/reduce/tune/key.rs index db5e4b21bf..44f61647b5 100644 --- a/burn-wgpu/src/kernel/reduce/tune/key.rs +++ b/burn-wgpu/src/kernel/reduce/tune/key.rs @@ -5,45 +5,45 @@ use burn_tensor::Shape; #[derive(Hash, Eq, PartialEq, Debug, Clone)] /// Autotune key representative of reduce versions pub struct ReduceAutotuneKey { - reduce_dim_length: usize, - others_product: usize, + reduce_dim_length: usize, + others_product: usize, } impl Display for ReduceAutotuneKey { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - format!( - "Reduce - reduce_dim_length: {:?} others_product: {:?}", - self.reduce_dim_length, self.others_product - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + format!( + "Reduce - reduce_dim_length: {:?} others_product: {:?}", + self.reduce_dim_length, self.others_product + ) + .as_str(), + ) + } } impl ReduceAutotuneKey { - /// Create a reduce autotune key from the input shape and reduce dim - pub fn new(shape: &Shape, reduce_dim: usize) -> Self { - let reduce_dim_length = shape.dims[reduce_dim]; - let mut others_product = 1; - for d in 0..D { - if d != reduce_dim { - others_product *= shape.dims[d] - } - } - Self { - reduce_dim_length: anchor(reduce_dim_length, None), - others_product: anchor(others_product, None), - } + /// Create a reduce autotune key from the input shape and reduce dim + pub fn new(shape: &Shape, reduce_dim: usize) -> Self { + let reduce_dim_length = shape.dims[reduce_dim]; + let mut others_product = 1; + for d in 0..D { + if d != reduce_dim { + others_product *= shape.dims[d] + } + } + Self { + reduce_dim_length: anchor(reduce_dim_length, None), + others_product: anchor(others_product, None), } + } } fn anchor(x: usize, max: Option) -> usize { - let exp = f32::ceil(f32::log2(x as f32)) as u32; - let power_of_2 = 2_u32.pow(exp) as usize; - if let Some(max) = max { - min(power_of_2, max) - } else { - power_of_2 - } + let exp = f32::ceil(f32::log2(x as f32)) as u32; + let power_of_2 = 2_u32.pow(exp) as usize; + if let Some(max) = max { + min(power_of_2, max) + } else { + power_of_2 + } } diff --git a/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs index a19fc7cf34..8d272529b8 100644 --- a/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs +++ b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs @@ -2,15 +2,15 @@ use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; use crate::{ - compute::WgpuAutotuneKey, - element::WgpuElement, - kernel::{ - prng::random_like_uniform, - reduce::{init_reduce_output, mean_dim, mean_dim_shared_memory}, - }, - ops::numeric::empty_device, - reduce_tune_ops, - tensor::WgpuTensor, + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{ + prng::random_like_uniform, + reduce::{init_reduce_output, mean_dim, mean_dim_shared_memory}, + }, + ops::numeric::empty_device, + reduce_tune_ops, + tensor::WgpuTensor, }; use super::ReduceAutotuneKey; @@ -19,90 +19,90 @@ use super::ReduceAutotuneKey; /// Autotune key is given by concatenating the closest upper power of 2 of /// dim to reduce, and product of others pub struct MeanDimAutotuneOperationSet { - key: WgpuAutotuneKey, - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, + key: WgpuAutotuneKey, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, } impl MeanDimAutotuneOperationSet { - fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { - Self { - key: WgpuAutotuneKey::MeanDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), - input, - output, - reduce_dim, - } + fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { + Self { + key: WgpuAutotuneKey::MeanDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), + input, + output, + reduce_dim, } + } } impl AutotuneOperationSet - for MeanDimAutotuneOperationSet + for MeanDimAutotuneOperationSet { - fn key(&self) -> WgpuAutotuneKey { - self.key.clone() - } + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } - fn autotunables(&self) -> Vec> { - let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); - let output = empty_device( - self.output.client.clone(), - self.output.device.clone(), - self.output.shape.clone(), - ); + let output = empty_device( + self.output.client.clone(), + self.output.device.clone(), + self.output.shape.clone(), + ); - vec![ - Box::new(MeanDimAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - Box::new(MeanDimSharedMemoryAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - ] - } + vec![ + Box::new(MeanDimAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + Box::new(MeanDimSharedMemoryAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + ] + } - fn fastest(self: Box, fastest_index: usize) -> Box { - // Warning: since AutotuneOperationSet shares his key with SumDimAutotuneOperationSet - // we must make sure the order here is correlated with SumDim - match fastest_index { - 0 => Box::new(MeanDimAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - 1 => Box::new(MeanDimSharedMemoryAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - _ => panic!("Fastest index is out of bound"), - } + fn fastest(self: Box, fastest_index: usize) -> Box { + // Warning: since AutotuneOperationSet shares his key with SumDimAutotuneOperationSet + // we must make sure the order here is correlated with SumDim + match fastest_index { + 0 => Box::new(MeanDimAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + 1 => Box::new(MeanDimSharedMemoryAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + _ => panic!("Fastest index is out of bound"), } + } } /// Executes autotune on mean_dim operation pub fn mean_dim_autotune( - input: WgpuTensor, - reduce_dim: usize, + input: WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let client = input.client.clone(); + let client = input.client.clone(); - let output = init_reduce_output(&input, reduce_dim); + let output = init_reduce_output(&input, reduce_dim); - let operation_set = Box::new(MeanDimAutotuneOperationSet::::new( - input, - output.clone(), - reduce_dim, - )); + let operation_set = Box::new(MeanDimAutotuneOperationSet::::new( + input, + output.clone(), + reduce_dim, + )); - client.execute_autotune(operation_set); + client.execute_autotune(operation_set); - output + output } // Probably better on balanced tensor shapes diff --git a/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs index a5831d7016..541a73d6db 100644 --- a/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs +++ b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs @@ -2,15 +2,15 @@ use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; use crate::{ - compute::WgpuAutotuneKey, - element::WgpuElement, - kernel::{ - prng::random_like_uniform, - reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory}, - }, - ops::numeric::empty_device, - reduce_tune_ops, - tensor::WgpuTensor, + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{ + prng::random_like_uniform, + reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory}, + }, + ops::numeric::empty_device, + reduce_tune_ops, + tensor::WgpuTensor, }; use super::ReduceAutotuneKey; @@ -19,90 +19,90 @@ use super::ReduceAutotuneKey; /// Autotune key is given by concatenating the closest upper power of 2 of /// dim to reduce, and product of others pub struct SumDimAutotuneOperationSet { - key: WgpuAutotuneKey, - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, + key: WgpuAutotuneKey, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, } impl SumDimAutotuneOperationSet { - fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { - Self { - key: WgpuAutotuneKey::SumDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), - input, - output, - reduce_dim, - } + fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { + Self { + key: WgpuAutotuneKey::SumDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), + input, + output, + reduce_dim, } + } } impl AutotuneOperationSet - for SumDimAutotuneOperationSet + for SumDimAutotuneOperationSet { - fn key(&self) -> WgpuAutotuneKey { - self.key.clone() - } + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } - fn autotunables(&self) -> Vec> { - let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); - let output = empty_device( - self.output.client.clone(), - self.output.device.clone(), - self.output.shape.clone(), - ); + let output = empty_device( + self.output.client.clone(), + self.output.device.clone(), + self.output.shape.clone(), + ); - vec![ - Box::new(SumDimAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - Box::new(SumDimSharedMemoryAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - ] - } + vec![ + Box::new(SumDimAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + Box::new(SumDimSharedMemoryAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + ] + } - fn fastest(self: Box, fastest_index: usize) -> Box { - // Warning: since AutotuneOperationSet shares his key with MeanDimAutotuneOperationSet - // we must make sure the order here is correlated with MeanDim - match fastest_index { - 0 => Box::new(SumDimAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - 1 => Box::new(SumDimSharedMemoryAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - _ => panic!("Fastest index is out of bound"), - } + fn fastest(self: Box, fastest_index: usize) -> Box { + // Warning: since AutotuneOperationSet shares his key with MeanDimAutotuneOperationSet + // we must make sure the order here is correlated with MeanDim + match fastest_index { + 0 => Box::new(SumDimAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + 1 => Box::new(SumDimSharedMemoryAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + _ => panic!("Fastest index is out of bound"), } + } } /// Executes autotune on sum_dim operation pub fn sum_dim_autotune( - input: WgpuTensor, - reduce_dim: usize, + input: WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let client = input.client.clone(); + let client = input.client.clone(); - let output = init_reduce_output(&input, reduce_dim); + let output = init_reduce_output(&input, reduce_dim); - let operation_set = Box::new(SumDimAutotuneOperationSet::::new( - input, - output.clone(), - reduce_dim, - )); + let operation_set = Box::new(SumDimAutotuneOperationSet::::new( + input, + output.clone(), + reduce_dim, + )); - client.execute_autotune(operation_set); + client.execute_autotune(operation_set); - output + output } // Probably better on balanced tensor shapes diff --git a/burn-wgpu/src/kernel/source.rs b/burn-wgpu/src/kernel/source.rs index b13c2f6a50..6bbc0b9752 100644 --- a/burn-wgpu/src/kernel/source.rs +++ b/burn-wgpu/src/kernel/source.rs @@ -6,64 +6,64 @@ use std::collections::HashMap; /// They will be updated with their proper value when `generate` is called. #[derive(Debug)] pub struct SourceTemplate { - items: HashMap, - templates: Vec, + items: HashMap, + templates: Vec, } impl SourceTemplate { - /// Create a new source template. - pub fn new(template: S) -> Self - where - S: Into, - { - Self { - items: HashMap::new(), - templates: vec![template.into()], - } + /// Create a new source template. + pub fn new(template: S) -> Self + where + S: Into, + { + Self { + items: HashMap::new(), + templates: vec![template.into()], } + } - /// Register the value for a placeholder item. - /// - /// # Notes - /// - /// The value can't have placeholders, since it would require recursive templating with - /// possibly circular dependencies. If you want to add a value that has some - /// placeholders, consider adding a new template to the source using - /// [add_template](SourceTemplate::add_template). The added template can be a function, and you can - /// register the function call instead. - pub fn register(mut self, name: Name, value: Value) -> Self - where - Name: Into, - Value: Into, - { - self.items.insert(name.into(), value.into()); - self - } - - /// Add a new template. - pub fn add_template(mut self, template: S) -> Self - where - S: Into, - { - self.templates.push(template.into()); - self - } + /// Register the value for a placeholder item. + /// + /// # Notes + /// + /// The value can't have placeholders, since it would require recursive templating with + /// possibly circular dependencies. If you want to add a value that has some + /// placeholders, consider adding a new template to the source using + /// [add_template](SourceTemplate::add_template). The added template can be a function, and you can + /// register the function call instead. + pub fn register(mut self, name: Name, value: Value) -> Self + where + Name: Into, + Value: Into, + { + self.items.insert(name.into(), value.into()); + self + } - /// Complete the template and returns the source code. - pub fn complete(mut self) -> String { - let mut source = self.templates.remove(0); + /// Add a new template. + pub fn add_template(mut self, template: S) -> Self + where + S: Into, + { + self.templates.push(template.into()); + self + } - for s in self.templates.into_iter() { - source.push_str(&s); - } + /// Complete the template and returns the source code. + pub fn complete(mut self) -> String { + let mut source = self.templates.remove(0); - let template = text_placeholder::Template::new(&source); - let mut context = HashMap::new(); + for s in self.templates.into_iter() { + source.push_str(&s); + } - for (key, value) in self.items.iter() { - context.insert(key.as_str(), value.as_str()); - } + let template = text_placeholder::Template::new(&source); + let mut context = HashMap::new(); - template.fill_with_hashmap(&context) + for (key, value) in self.items.iter() { + context.insert(key.as_str(), value.as_str()); } + + template.fill_with_hashmap(&context) + } } diff --git a/burn-wgpu/src/kernel/unary.rs b/burn-wgpu/src/kernel/unary.rs index e1f28dac12..7584be7b1f 100644 --- a/burn-wgpu/src/kernel/unary.rs +++ b/burn-wgpu/src/kernel/unary.rs @@ -7,202 +7,202 @@ kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl"); /// Creates a unary kernel. #[macro_export] macro_rules! unary { - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - let source = $crate::kernel::UnaryRaw::source(); - source.register("body", format!("output[id] = {}(input[id]);", $func)) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + let source = $crate::kernel::UnaryRaw::source(); + source.register("body", format!("output[id] = {}(input[id]);", $func)) + } + } + }; + ( $struct:ident, body $body:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source().register("body", $body) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryRaw::source().register("body", $body) + } + } + }; + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source() - .register("body", format!("output[id] = {}(input[id]);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryRaw::source() + .register("body", format!("output[id] = {}(input[id]);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Creates a unary inplace kernel. #[macro_export] macro_rules! unary_inplace { - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source() - .register("body", format!("input[id] = {}(input[id]);", $func)) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source() + .register("body", format!("input[id] = {}(input[id]);", $func)) + } + } + }; + ( $struct:ident, body $body:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source().register("body", $body) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source().register("body", $body) + } + } + }; + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source() - .register("body", format!("input[id] = {}(input[id]);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source() + .register("body", format!("input[id] = {}(input[id]);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Execute a unary kernel using the default settings. pub fn unary_default( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - unary::(input) + unary::(input) } /// Execute a unary inplace kernel using the default settings. pub fn unary_inplace_default( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - unary_inplace::(input) + unary_inplace::(input) } /// Execute a unary inplace kernel using the provided WORKGROUP. pub fn unary_inplace< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); + let num_elems = input.shape.num_elements(); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); - input.client.execute(Box::new(kernel), &[&input.handle]); + input.client.execute(Box::new(kernel), &[&input.handle]); - input + input } /// Execute a unary kernel using the provided WORKGROUP. pub fn unary( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let buffer = input.client.empty(num_elems * core::mem::size_of::()); - let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer); - // Since we don't handle the stride inside the kernel, the output tensor have the same strides - // as the input tensor. It might not be in the default format. - output.strides = input.strides; - - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - input - .client - .execute(Box::new(kernel), &[&input.handle, &output.handle]); - - output + let num_elems = input.shape.num_elements(); + let buffer = input.client.empty(num_elems * core::mem::size_of::()); + let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer); + // Since we don't handle the stride inside the kernel, the output tensor have the same strides + // as the input tensor. It might not be in the default format. + output.strides = input.strides; + + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + input + .client + .execute(Box::new(kernel), &[&input.handle, &output.handle]); + + output } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; - unary!(TestKernel, func "log"); - unary_inplace!(TestKernelInplace, func "log"); + unary!(TestKernel, func "log"); + unary_inplace!(TestKernelInplace, func "log"); - #[test] - fn unary_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary::(tensor.into_primitive()); - let expected = tensor_ref.log(); + let actual = unary::(tensor.into_primitive()); + let expected = tensor_ref.log(); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } - #[test] - fn unary_inplace_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_inplace_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary_inplace::(tensor.into_primitive()); - let expected = tensor_ref.log(); + let actual = unary_inplace::(tensor.into_primitive()); + let expected = tensor_ref.log(); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn tanh_should_not_have_numerical_bugs_on_macos() { + fn tanh_one_value(input: f32) -> f32 { + let tensor = Tensor::::ones([1]) * input; + let output = tensor.tanh().into_primitive(); + Tensor::::from_primitive(output) + .into_data() + .value[0] } - #[test] - fn tanh_should_not_have_numerical_bugs_on_macos() { - fn tanh_one_value(input: f32) -> f32 { - let tensor = Tensor::::ones([1]) * input; - let output = tensor.tanh().into_primitive(); - Tensor::::from_primitive(output) - .into_data() - .value[0] - } - - let ok = tanh_one_value(43.0); // metal tanh gives 1.0 which is the right answer - let zero = tanh_one_value(44.0); // metal tanh gives zero when within 43.67..44.36 - let nan = tanh_one_value(45.0); // metal tanh gives nan when over 44.36 - let neg = tanh_one_value(-45.0); // metal works correctly here - - assert!(!ok.is_nan() && ok == 1.0); - assert!(!zero.is_nan() && zero == 1.0); - assert!(!nan.is_nan() && nan == 1.0); - assert!(!neg.is_nan() && neg == -1.0); - } + let ok = tanh_one_value(43.0); // metal tanh gives 1.0 which is the right answer + let zero = tanh_one_value(44.0); // metal tanh gives zero when within 43.67..44.36 + let nan = tanh_one_value(45.0); // metal tanh gives nan when over 44.36 + let neg = tanh_one_value(-45.0); // metal works correctly here + + assert!(!ok.is_nan() && ok == 1.0); + assert!(!zero.is_nan() && zero == 1.0); + assert!(!nan.is_nan() && nan == 1.0); + assert!(!neg.is_nan() && neg == -1.0); + } } diff --git a/burn-wgpu/src/kernel/unary_scalar.rs b/burn-wgpu/src/kernel/unary_scalar.rs index dc68443df5..fabaa468a9 100644 --- a/burn-wgpu/src/kernel/unary_scalar.rs +++ b/burn-wgpu/src/kernel/unary_scalar.rs @@ -3,218 +3,218 @@ use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::Wg kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl"); kernel_wgsl!( - UnaryScalarInplaceRaw, - "../template/unary_scalar_inplace.wgsl" + UnaryScalarInplaceRaw, + "../template/unary_scalar_inplace.wgsl" ); /// Creates a unary scalar kernel. #[macro_export] macro_rules! unary_scalar { - ( + ( $struct:ident, ops $ops:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = lhs[id] {} rhs;", $ops)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() + .register("body", format!("output[id] = lhs[id] {} rhs;", $ops)) + } + } + }; - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() + .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) + } + } + }; - ( + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() + .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Creates a unary scalar inplace kernel. #[macro_export] macro_rules! unary_scalar_inplace { - ( + ( $struct:ident, ops $ops:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() + .register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops)) + } + } + }; - ( + ( $struct:ident, body $body:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body) + } + } + }; - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() + .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) + } + } + }; - ( + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() + .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Execute a unary scalar kernel using the default settings. pub fn unary_scalar_default( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - unary_scalar::(lhs, scalar) + unary_scalar::(lhs, scalar) } /// Execute a unary scalar kernel using the provided WORKGROUP. pub fn unary_scalar< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); - let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs_handle, &output.handle], - ); - - output + let num_elems = lhs.shape.num_elements(); + let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs_handle, &output.handle], + ); + + output } /// Execute a unary scalar inplace kernel using the default settings. pub fn unary_scalar_inplace_default( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - unary_scalar_inplace::(lhs, scalar) + unary_scalar_inplace::(lhs, scalar) } /// Execute a unary scalar inplace kernel using the provided WORKGROUP. pub fn unary_scalar_inplace< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); + let num_elems = lhs.shape.num_elements(); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); - lhs.client - .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); + lhs + .client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); - lhs + lhs } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; - unary_scalar!(TestKernel, ops "*"); - unary_scalar_inplace!(TestKernelInplace, ops "*"); + unary_scalar!(TestKernel, ops "*"); + unary_scalar_inplace!(TestKernelInplace, ops "*"); - #[test] - fn unary_scalar_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_scalar_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary_scalar::(tensor.into_primitive(), 5.0); - let expected = tensor_ref.mul_scalar(5.0); + let actual = unary_scalar::(tensor.into_primitive(), 5.0); + let expected = tensor_ref.mul_scalar(5.0); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } - #[test] - fn unary_scalar_inplace_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_scalar_inplace_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = - unary_scalar_inplace::(tensor.into_primitive(), 5.0); - let expected = tensor_ref.mul_scalar(5.0); + let actual = unary_scalar_inplace::(tensor.into_primitive(), 5.0); + let expected = tensor_ref.mul_scalar(5.0); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index d04b282eda..52ab6703b7 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -32,15 +32,15 @@ mod fusion; #[cfg(test)] mod tests { - use super::*; + use super::*; - pub type TestBackend = Wgpu; - pub type ReferenceBackend = burn_ndarray::NdArray; + pub type TestBackend = Wgpu; + pub type ReferenceBackend = burn_ndarray::NdArray; - pub type TestTensor = burn_tensor::Tensor; - pub type ReferenceTensor = burn_tensor::Tensor; - pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensor = burn_tensor::Tensor; + pub type ReferenceTensor = burn_tensor::Tensor; + pub type TestTensorInt = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!(); + burn_autodiff::testgen_all!(); } diff --git a/burn-wgpu/src/ops/activation_ops.rs b/burn-wgpu/src/ops/activation_ops.rs index 2256628602..ff6dea19a0 100644 --- a/burn-wgpu/src/ops/activation_ops.rs +++ b/burn-wgpu/src/ops/activation_ops.rs @@ -1,25 +1,25 @@ use burn_tensor::ops::{ActivationOps, FloatTensor}; use crate::{ - element::{FloatElement, IntElement}, - kernel::{unary_default, unary_inplace_default}, - unary, unary_inplace, GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel::{unary_default, unary_inplace_default}, + unary, unary_inplace, GraphicsApi, Wgpu, }; impl ActivationOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn relu(tensor: FloatTensor) -> FloatTensor { - unary!(Relu, body "output[id] = max(input[id], 0.0);"); - unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);"); + fn relu(tensor: FloatTensor) -> FloatTensor { + unary!(Relu, body "output[id] = max(input[id], 0.0);"); + unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);"); - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } + + unary_default::(tensor) + } } diff --git a/burn-wgpu/src/ops/base.rs b/burn-wgpu/src/ops/base.rs index 93ea576a73..bf06dcb824 100644 --- a/burn-wgpu/src/ops/base.rs +++ b/burn-wgpu/src/ops/base.rs @@ -1,78 +1,78 @@ use crate::{ - compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi, - WgpuDevice, + compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi, + WgpuDevice, }; use burn_tensor::{Data, Reader, Shape}; pub fn from_data( - data: Data, - device: &WgpuDevice, + data: Data, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); - let buffer = client.create(E::as_bytes(&data.value)); + let client = compute_client::(device); + let buffer = client.create(E::as_bytes(&data.value)); - WgpuTensor::new(client, device.clone(), data.shape, buffer) + WgpuTensor::new(client, device.clone(), data.shape, buffer) } pub fn into_data(tensor: WgpuTensor) -> Reader> { - let tensor = kernel::into_contiguous(tensor); + let tensor = kernel::into_contiguous(tensor); - tensor - .client - .read(&tensor.handle) - .map(|bytes| Data::new(E::from_bytes(&bytes).to_vec(), tensor.shape)) + tensor + .client + .read(&tensor.handle) + .map(|bytes| Data::new(E::from_bytes(&bytes).to_vec(), tensor.shape)) } pub fn bool_into_data(tensor: WgpuTensor) -> Reader> { - let tensor = kernel::into_contiguous(tensor); + let tensor = kernel::into_contiguous(tensor); - tensor.client.read(&tensor.handle).map(|bytes| { - Data::new( - u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), - tensor.shape, - ) - }) + tensor.client.read(&tensor.handle).map(|bytes| { + Data::new( + u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), + tensor.shape, + ) + }) } pub fn to_device( - tensor: WgpuTensor, - device: &WgpuDevice, + tensor: WgpuTensor, + device: &WgpuDevice, ) -> WgpuTensor { - if &tensor.device == device { - return tensor; - } + if &tensor.device == device { + return tensor; + } - let client = compute_client::(device); - tensor.to_client(client, device.clone()) + let client = compute_client::(device); + tensor.to_client(client, device.clone()) } pub fn empty( - shape: Shape, - device: &WgpuDevice, + shape: Shape, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); - let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); + let client = compute_client::(device); + let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - WgpuTensor::new(client, device.clone(), shape, buffer) + WgpuTensor::new(client, device.clone(), shape, buffer) } pub fn swap_dims( - mut tensor: WgpuTensor, - dim1: usize, - dim2: usize, + mut tensor: WgpuTensor, + dim1: usize, + dim2: usize, ) -> WgpuTensor { - tensor.strides.swap(dim1, dim2); - tensor.shape.dims.swap(dim1, dim2); + tensor.strides.swap(dim1, dim2); + tensor.shape.dims.swap(dim1, dim2); - tensor + tensor } pub fn reshape( - tensor: WgpuTensor, - shape: Shape, + tensor: WgpuTensor, + shape: Shape, ) -> WgpuTensor { - // TODO: Not force standard layout all the time (improve performance). - let tensor = kernel::into_contiguous(tensor); + // TODO: Not force standard layout all the time (improve performance). + let tensor = kernel::into_contiguous(tensor); - WgpuTensor::new(tensor.client, tensor.device, shape, tensor.handle) + WgpuTensor::new(tensor.client, tensor.device, shape, tensor.handle) } diff --git a/burn-wgpu/src/ops/bool_ops.rs b/burn-wgpu/src/ops/bool_ops.rs index 82470694f2..938054b116 100644 --- a/burn-wgpu/src/ops/bool_ops.rs +++ b/burn-wgpu/src/ops/bool_ops.rs @@ -1,8 +1,8 @@ use crate::{ - element::{FloatElement, IntElement}, - kernel, - tensor::WgpuTensor, - GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel, + tensor::WgpuTensor, + GraphicsApi, Wgpu, }; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; use burn_tensor::{ops::BoolTensorOps, Data, Shape}; @@ -11,116 +11,117 @@ use std::ops::Range; impl BoolTensorOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - super::empty::(shape, device) + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + super::empty::(shape, device) + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + tensor.shape.clone() + } + + fn bool_into_data(tensor: BoolTensor) -> Reader> { + super::bool_into_data(tensor) + } + + fn bool_from_data( + data: Data, + device: &Device, + ) -> BoolTensor { + let data: Data = Data::new( + data + .value + .into_iter() + .map(|c| match c { + true => 1, + false => 0, + }) + .collect(), + data.shape, + ); + super::from_data::(data, device) + } + + fn bool_into_int(tensor: BoolTensor) -> IntTensor { + if std::mem::size_of::() == std::mem::size_of::() { + return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); } - fn bool_shape(tensor: &BoolTensor) -> Shape { - tensor.shape.clone() - } - - fn bool_into_data(tensor: BoolTensor) -> Reader> { - super::bool_into_data(tensor) - } - - fn bool_from_data( - data: Data, - device: &Device, - ) -> BoolTensor { - let data: Data = Data::new( - data.value - .into_iter() - .map(|c| match c { - true => 1, - false => 0, - }) - .collect(), - data.shape, - ); - super::from_data::(data, device) - } - - fn bool_into_int(tensor: BoolTensor) -> IntTensor { - if std::mem::size_of::() == std::mem::size_of::() { - return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); - } - - let device = Self::bool_device(&tensor); - let data = Self::bool_into_data(tensor) - .read_sync() - .expect("Can't convert bool to int with a different type size async") - .convert::(); - - Self::int_from_data(data, &device) - } - - fn bool_device(tensor: &BoolTensor) -> Device { - tensor.device.clone() - } - - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - super::to_device::(tensor, device) - } - - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - super::reshape(tensor, shape) - } - - fn bool_slice( - tensor: BoolTensor, - ranges: [Range; D2], - ) -> BoolTensor { - kernel::slice(tensor, ranges) - } - - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [Range; D2], - value: BoolTensor, - ) -> BoolTensor { - kernel::slice_assign(tensor, ranges, value) - } - - fn bool_cat( - tensors: Vec>, - dim: usize, - ) -> BoolTensor { - kernel::cat(tensors, dim) - } - - fn bool_equal( - lhs: BoolTensor, - rhs: BoolTensor, - ) -> BoolTensor { - kernel::equal(lhs, rhs) - } - - fn bool_not(tensor: BoolTensor) -> BoolTensor { - kernel::equal_elem(tensor, 0) - } - - fn bool_into_float(tensor: BoolTensor) -> FloatTensor { - kernel::cast(tensor) - } - - fn bool_swap_dims( - mut tensor: BoolTensor, - dim1: usize, - dim2: usize, - ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { - tensor.strides.swap(dim1, dim2); - tensor.shape.dims.swap(dim1, dim2); - - tensor - } + let device = Self::bool_device(&tensor); + let data = Self::bool_into_data(tensor) + .read_sync() + .expect("Can't convert bool to int with a different type size async") + .convert::(); + + Self::int_from_data(data, &device) + } + + fn bool_device(tensor: &BoolTensor) -> Device { + tensor.device.clone() + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + super::to_device::(tensor, device) + } + + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + super::reshape(tensor, shape) + } + + fn bool_slice( + tensor: BoolTensor, + ranges: [Range; D2], + ) -> BoolTensor { + kernel::slice(tensor, ranges) + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [Range; D2], + value: BoolTensor, + ) -> BoolTensor { + kernel::slice_assign(tensor, ranges, value) + } + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> BoolTensor { + kernel::cat(tensors, dim) + } + + fn bool_equal( + lhs: BoolTensor, + rhs: BoolTensor, + ) -> BoolTensor { + kernel::equal(lhs, rhs) + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + kernel::equal_elem(tensor, 0) + } + + fn bool_into_float(tensor: BoolTensor) -> FloatTensor { + kernel::cast(tensor) + } + + fn bool_swap_dims( + mut tensor: BoolTensor, + dim1: usize, + dim2: usize, + ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { + tensor.strides.swap(dim1, dim2); + tensor.shape.dims.swap(dim1, dim2); + + tensor + } } diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index 35dcbe12f3..900bddbef6 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -9,13 +9,13 @@ use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; #[cfg(not(feature = "autotune"))] use crate::kernel::reduce::init_reduce_output; use crate::kernel::{ - self, reduce, unary_default, unary_inplace_default, unary_scalar_default, - unary_scalar_inplace_default, + self, reduce, unary_default, unary_inplace_default, unary_scalar_default, + unary_scalar_inplace_default, }; use crate::{unary, unary_inplace, unary_scalar, FloatElement, GraphicsApi, IntElement, Wgpu}; use crate::{unary_scalar_inplace, WgpuDevice}; use burn_tensor::ops::{ - BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, + BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, }; use burn_tensor::{ops::TensorOps, Data, Distribution, Shape}; use burn_tensor::{ElementConversion, Reader}; @@ -24,503 +24,499 @@ use std::ops::Range; impl TensorOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor { - super::from_data::(data, device) - } - - fn random( - shape: Shape, - distribution: Distribution>, - device: &Device, - ) -> FloatTensor { - match distribution { - Distribution::Default => random_uniform::(shape, device, 0.elem(), 1.elem()), - Distribution::Uniform(low, high) => random_uniform::(shape, device, low, high), - Distribution::Bernoulli(prob) => { - random_bernoulli::(shape, device, prob.elem()) - } - Distribution::Normal(mean, std) => { - random_normal::(shape, device, mean.elem(), std.elem()) - } - } - } - - fn shape(tensor: &FloatTensor) -> Shape { - tensor.shape.clone() - } - - fn into_data(tensor: FloatTensor) -> Reader, D>> { - super::into_data(tensor) - } - - fn device(tensor: &FloatTensor) -> Device { - tensor.device.clone() - } - - fn to_device( - tensor: FloatTensor, - device: &Device, - ) -> FloatTensor { - super::to_device::(tensor, device) - } - - fn empty(shape: Shape, device: &Device) -> FloatTensor { - super::empty::(shape, device) - } - - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::add(lhs, rhs) - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::add_scalar(lhs, rhs) - } - - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - numeric::zeros::(shape, device) - } - - fn full( - shape: Shape, - fill_value: FloatElem, - device: &WgpuDevice, - ) -> FloatTensor { - numeric::full::(shape, device, fill_value) - } - - fn ones(shape: Shape, device: &Device) -> FloatTensor { - numeric::ones::(shape, device) - } - - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::sub(lhs, rhs) - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::sub_scalar(lhs, rhs) - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::mul(lhs, rhs) - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::mul_scalar(lhs, rhs) - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::div(lhs, rhs) - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::div_scalar(lhs, rhs) - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[cfg(feature = "autotune")] - { - matmul_autotune(lhs, rhs) - } - - #[cfg(not(feature = "autotune"))] - { - let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_vec4(lhs, rhs, out) - } - } - - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - super::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - super::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - kernel::gather(dim, tensor, indices) - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - kernel::scatter(dim, tensor, indices, value) - } - - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - kernel::select(tensor, dim, indices) - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - kernel::select_assign(tensor, dim, indices, value) - } - - fn slice( - tensor: FloatTensor, - ranges: [Range; D2], - ) -> FloatTensor { - kernel::slice(tensor, ranges) - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [Range; D2], - value: FloatTensor, - ) -> FloatTensor { - kernel::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor { - kernel::mask_where(tensor, mask, value) - } - - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - kernel::mask_fill(tensor, mask, value) - } - - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::equal(lhs, rhs) - } - - fn equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::equal_elem(lhs, rhs) - } - - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::greater(lhs, rhs) - } - - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::greater_elem(lhs, rhs) - } - - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::greater_equal(lhs, rhs) - } + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor { + super::from_data::(data, device) + } + + fn random( + shape: Shape, + distribution: Distribution>, + device: &Device, + ) -> FloatTensor { + match distribution { + Distribution::Default => random_uniform::(shape, device, 0.elem(), 1.elem()), + Distribution::Uniform(low, high) => random_uniform::(shape, device, low, high), + Distribution::Bernoulli(prob) => random_bernoulli::(shape, device, prob.elem()), + Distribution::Normal(mean, std) => { + random_normal::(shape, device, mean.elem(), std.elem()) + } + } + } + + fn shape(tensor: &FloatTensor) -> Shape { + tensor.shape.clone() + } + + fn into_data(tensor: FloatTensor) -> Reader, D>> { + super::into_data(tensor) + } + + fn device(tensor: &FloatTensor) -> Device { + tensor.device.clone() + } + + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor { + super::to_device::(tensor, device) + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + super::empty::(shape, device) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::add(lhs, rhs) + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::add_scalar(lhs, rhs) + } + + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + numeric::zeros::(shape, device) + } + + fn full( + shape: Shape, + fill_value: FloatElem, + device: &WgpuDevice, + ) -> FloatTensor { + numeric::full::(shape, device, fill_value) + } + + fn ones(shape: Shape, device: &Device) -> FloatTensor { + numeric::ones::(shape, device) + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::sub(lhs, rhs) + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::sub_scalar(lhs, rhs) + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::mul(lhs, rhs) + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::mul_scalar(lhs, rhs) + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::div(lhs, rhs) + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::div_scalar(lhs, rhs) + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[cfg(feature = "autotune")] + { + matmul_autotune(lhs, rhs) + } + + #[cfg(not(feature = "autotune"))] + { + let out = init_matmul_output(&lhs, &rhs); + matmul_tiling_2d_vec4(lhs, rhs, out) + } + } + + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + super::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + super::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + kernel::gather(dim, tensor, indices) + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + kernel::scatter(dim, tensor, indices, value) + } + + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + kernel::select(tensor, dim, indices) + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + kernel::select_assign(tensor, dim, indices, value) + } + + fn slice( + tensor: FloatTensor, + ranges: [Range; D2], + ) -> FloatTensor { + kernel::slice(tensor, ranges) + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [Range; D2], + value: FloatTensor, + ) -> FloatTensor { + kernel::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor { + kernel::mask_where(tensor, mask, value) + } + + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + kernel::mask_fill(tensor, mask, value) + } + + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::equal(lhs, rhs) + } + + fn equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::equal_elem(lhs, rhs) + } + + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::greater(lhs, rhs) + } + + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::greater_elem(lhs, rhs) + } + + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::greater_equal(lhs, rhs) + } + + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::greater_equal_elem(lhs, rhs) + } + + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::lower(lhs, rhs) + } + + fn lower_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::lower_elem(lhs, rhs) + } + + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::lower_equal(lhs, rhs) + } + + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::lower_equal_elem(lhs, rhs) + } + + fn sum(tensor: FloatTensor) -> FloatTensor { + reduce::sum(tensor) + } + + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[cfg(feature = "autotune")] + { + reduce::sum_dim_autotune(tensor, dim) + } + + #[cfg(not(feature = "autotune"))] + { + let output = init_reduce_output(&tensor, dim); + reduce::sum_dim(tensor, output, dim) + } + } - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::greater_equal_elem(lhs, rhs) + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[cfg(feature = "autotune")] + { + reduce::mean_dim_autotune(tensor, dim) } - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::lower(lhs, rhs) + #[cfg(not(feature = "autotune"))] + { + let output = init_reduce_output(&tensor, dim); + reduce::mean_dim(tensor, output, dim) } + } - fn lower_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::lower_elem(lhs, rhs) - } + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + kernel::cast(tensor.clone()) + } - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::lower_equal(lhs, rhs) - } + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + kernel::cast(tensor) + } - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::lower_equal_elem(lhs, rhs) - } + fn exp(lhs: FloatTensor) -> FloatTensor { + unary!(Exp, func "exp"); + unary_inplace!(ExpInplace, func "exp"); - fn sum(tensor: FloatTensor) -> FloatTensor { - reduce::sum(tensor) + if lhs.can_mut() { + return unary_inplace_default::(lhs); } - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[cfg(feature = "autotune")] - { - reduce::sum_dim_autotune(tensor, dim) - } - - #[cfg(not(feature = "autotune"))] - { - let output = init_reduce_output(&tensor, dim); - reduce::sum_dim(tensor, output, dim) - } - } + unary_default::(lhs) + } - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[cfg(feature = "autotune")] - { - reduce::mean_dim_autotune(tensor, dim) - } - - #[cfg(not(feature = "autotune"))] - { - let output = init_reduce_output(&tensor, dim); - reduce::mean_dim(tensor, output, dim) - } - } + fn log(tensor: FloatTensor) -> FloatTensor { + unary!(Log, func "log"); + unary_inplace!(LogInplace, func "log"); - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - kernel::cast(tensor.clone()) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - kernel::cast(tensor) - } - - fn exp(lhs: FloatTensor) -> FloatTensor { - unary!(Exp, func "exp"); - unary_inplace!(ExpInplace, func "exp"); + unary_default::(tensor) + } - if lhs.can_mut() { - return unary_inplace_default::(lhs); - } + fn log1p(tensor: FloatTensor) -> FloatTensor { + unary!(Log1p, body "output[id] = log(1.0 + input[id]);"); + unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);"); - unary_default::(lhs) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn log(tensor: FloatTensor) -> FloatTensor { - unary!(Log, func "log"); - unary_inplace!(LogInplace, func "log"); + unary_default::(tensor) + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } + fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { + unary_scalar!(Powf, func "powf", include "../template/powf.wgsl"); + unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl"); - unary_default::(tensor) + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs.elem()); } - fn log1p(tensor: FloatTensor) -> FloatTensor { - unary!(Log1p, body "output[id] = log(1.0 + input[id]);"); - unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);"); + unary_scalar_default::(lhs, rhs.elem()) + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } + fn sqrt(tensor: FloatTensor) -> FloatTensor { + unary!(Sqrt, func "sqrt"); + unary_inplace!(SqrtInplace, func "sqrt"); - unary_default::(tensor) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { - unary_scalar!(Powf, func "powf", include "../template/powf.wgsl"); - unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl"); + unary_default::(tensor) + } - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs.elem()); - } + fn abs(tensor: FloatTensor) -> FloatTensor { + unary!(Abs, func "abs"); + unary_inplace!(AbsInplace, func "abs"); - unary_scalar_default::(lhs, rhs.elem()) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn sqrt(tensor: FloatTensor) -> FloatTensor { - unary!(Sqrt, func "sqrt"); - unary_inplace!(SqrtInplace, func "sqrt"); + unary_default::(tensor) + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } + fn cos(tensor: FloatTensor) -> FloatTensor { + unary!(Cos, func "cos"); + unary_inplace!(CosInplace, func "cos"); - unary_default::(tensor) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn abs(tensor: FloatTensor) -> FloatTensor { - unary!(Abs, func "abs"); - unary_inplace!(AbsInplace, func "abs"); + unary_default::(tensor) + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } + fn sin(tensor: FloatTensor) -> FloatTensor { + unary!(Sin, func "sin"); + unary_inplace!(SinInplace, func "sin"); - unary_default::(tensor) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn cos(tensor: FloatTensor) -> FloatTensor { - unary!(Cos, func "cos"); - unary_inplace!(CosInplace, func "cos"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) - } + unary_default::(tensor) + } - fn sin(tensor: FloatTensor) -> FloatTensor { - unary!(Sin, func "sin"); - unary_inplace!(SinInplace, func "sin"); + fn tanh(tensor: FloatTensor) -> FloatTensor { + // Metal has a weird numerical behaviour with tanh which require a new function + #[cfg(target_os = "macos")] + unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl"); + #[cfg(target_os = "macos")] + unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl"); - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } + #[cfg(not(target_os = "macos"))] + unary!(Tanh, func "tanh"); + #[cfg(not(target_os = "macos"))] + unary_inplace!(TanhInplace, func "tanh"); - unary_default::(tensor) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn tanh(tensor: FloatTensor) -> FloatTensor { - // Metal has a weird numerical behaviour with tanh which require a new function - #[cfg(target_os = "macos")] - unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl"); - #[cfg(target_os = "macos")] - unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl"); + unary_default::(tensor) + } - #[cfg(not(target_os = "macos"))] - unary!(Tanh, func "tanh"); - #[cfg(not(target_os = "macos"))] - unary_inplace!(TanhInplace, func "tanh"); + fn erf(tensor: FloatTensor) -> FloatTensor { + unary!(Erf, func "erf", include "../template/erf.wgsl"); + unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl"); - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn erf(tensor: FloatTensor) -> FloatTensor { - unary!(Erf, func "erf", include "../template/erf.wgsl"); - unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl"); + unary_default::(tensor) + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + kernel::cat(tensors, dim) + } - unary_default::(tensor) - } + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + reduce::argmax(tensor, dim) + } - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - kernel::cat(tensors, dim) - } + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + reduce::argmin(tensor, dim) + } - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - reduce::argmax(tensor, dim) - } + fn into_int(tensor: FloatTensor) -> IntTensor { + kernel::cast(tensor) + } - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - reduce::argmin(tensor, dim) - } + fn clamp_min( + tensor: FloatTensor, + min: FloatElem, + ) -> FloatTensor { + kernel::clamp_min(tensor, min) + } - fn into_int(tensor: FloatTensor) -> IntTensor { - kernel::cast(tensor) - } + fn clamp_max( + tensor: FloatTensor, + max: FloatElem, + ) -> FloatTensor { + kernel::clamp_max(tensor, max) + } - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - kernel::clamp_min(tensor, min) - } + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + kernel::clamp(tensor, min, max) + } - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - kernel::clamp_max(tensor, max) - } + fn recip(tensor: FloatTensor, D>) -> FloatTensor, D> { + unary!(Recip, func "1.0 /"); + unary_inplace!(RecipInplace, func "1.0 /"); - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - kernel::clamp(tensor, min, max) + if tensor.can_mut() { + return unary_inplace_default::(tensor); } - fn recip( - tensor: FloatTensor, D>, - ) -> FloatTensor, D> { - unary!(Recip, func "1.0 /"); - unary_inplace!(RecipInplace, func "1.0 /"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) - } + unary_default::(tensor) + } } diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs index bbef2dd6a6..c79c601fcd 100644 --- a/burn-wgpu/src/ops/int_ops.rs +++ b/burn-wgpu/src/ops/int_ops.rs @@ -3,8 +3,8 @@ use super::numeric; use crate::kernel::reduce::{self, init_reduce_output}; use crate::kernel::{unary_default, unary_inplace_default}; use crate::{ - element::{FloatElement, IntElement}, - kernel, unary, unary_inplace, GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel, unary, unary_inplace, GraphicsApi, Wgpu, }; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; @@ -14,317 +14,314 @@ use std::ops::Range; impl IntTensorOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn int_empty(shape: Shape, device: &Device) -> IntTensor { - super::empty::(shape, device) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - tensor.shape.clone() - } - - fn int_into_data(tensor: IntTensor) -> Reader> { - super::into_data(tensor) - } - - fn int_from_data( - data: Data, - device: &Device, - ) -> IntTensor { - super::from_data::(data, device) - } - - fn int_device(tensor: &IntTensor) -> Device { - tensor.device.clone() - } - - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - super::to_device::(tensor, device) - } - - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - super::reshape(tensor, shape) - } - - fn int_slice( - tensor: IntTensor, - ranges: [Range; D2], - ) -> IntTensor { - kernel::slice(tensor, ranges) - } - - fn int_slice_assign( - tensor: IntTensor, - ranges: [Range; D2], - value: IntTensor, - ) -> IntTensor { - kernel::slice_assign(tensor, ranges, value) - } - - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - value: IntTensor, - ) -> IntTensor { - kernel::mask_where(tensor, mask, value) - } - - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor { - kernel::mask_fill(tensor, mask, value) - } - - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - kernel::gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - kernel::scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - kernel::select(tensor, dim, indices) - } - - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - kernel::select_assign(tensor, dim, indices, value) - } - - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - kernel::cat(tensors, dim) - } - - fn int_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::equal::(lhs, rhs) - } - - fn int_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::equal_elem::(lhs, rhs) - } - - fn int_greater( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::greater::(lhs, rhs) - } - - fn int_greater_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::greater_elem::(lhs, rhs) - } - - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::greater_equal::(lhs, rhs) - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::greater_equal_elem::(lhs, rhs) - } - - fn int_lower( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::lower::(lhs, rhs) - } - - fn int_lower_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::lower_elem::(lhs, rhs) - } - - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::lower_equal::(lhs, rhs) - } - - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::lower_equal_elem::(lhs, rhs) - } - - fn int_add( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::add::(lhs, rhs) - } - - fn int_add_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::add_scalar(lhs, rhs) - } - - fn int_sub( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::sub(lhs, rhs) - } - - fn int_sub_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::sub_scalar(lhs, rhs) - } - - fn int_mul( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::mul(lhs, rhs) - } - - fn int_mul_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::mul_scalar(lhs, rhs) - } - - fn int_div( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::div(lhs, rhs) - } - - fn int_div_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::div_scalar(lhs, rhs) - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - numeric::zeros::(shape, device) - } - - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - numeric::ones::(shape, device) - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - kernel::reduce::sum(tensor) - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let output = init_reduce_output(&tensor, dim); - reduce::sum_dim(tensor, output, dim) - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let output = init_reduce_output(&tensor, dim); - reduce::mean_dim(tensor, output, dim) - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax(tensor, dim) - } - - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin(tensor, dim) - } - - fn int_clamp_min( - tensor: IntTensor, - min: IntElem, - ) -> IntTensor { - kernel::clamp_min(tensor, min) - } - - fn int_clamp_max( - tensor: IntTensor, - max: IntElem, - ) -> IntTensor { - kernel::clamp_max(tensor, max) - } - - fn int_clamp( - tensor: IntTensor, - min: IntElem, - max: IntElem, - ) -> IntTensor { - kernel::clamp(tensor, min, max) - } - - fn int_abs(tensor: IntTensor) -> IntTensor { - unary!(IntAbs, func "abs"); - unary_inplace!(IntAbsInplace, func "abs"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) - } - - fn int_into_float(tensor: IntTensor) -> FloatTensor { - kernel::cast(tensor) - } - - fn int_swap_dims( - mut tensor: IntTensor, - dim1: usize, - dim2: usize, - ) -> IntTensor { - tensor.strides.swap(dim1, dim2); - tensor.shape.dims.swap(dim1, dim2); - - tensor - } + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + super::empty::(shape, device) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + tensor.shape.clone() + } + + fn int_into_data(tensor: IntTensor) -> Reader> { + super::into_data(tensor) + } + + fn int_from_data(data: Data, device: &Device) -> IntTensor { + super::from_data::(data, device) + } + + fn int_device(tensor: &IntTensor) -> Device { + tensor.device.clone() + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + super::to_device::(tensor, device) + } + + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + super::reshape(tensor, shape) + } + + fn int_slice( + tensor: IntTensor, + ranges: [Range; D2], + ) -> IntTensor { + kernel::slice(tensor, ranges) + } + + fn int_slice_assign( + tensor: IntTensor, + ranges: [Range; D2], + value: IntTensor, + ) -> IntTensor { + kernel::slice_assign(tensor, ranges, value) + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + value: IntTensor, + ) -> IntTensor { + kernel::mask_where(tensor, mask, value) + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor { + kernel::mask_fill(tensor, mask, value) + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + kernel::gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + kernel::scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + kernel::select(tensor, dim, indices) + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + kernel::select_assign(tensor, dim, indices, value) + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + kernel::cat(tensors, dim) + } + + fn int_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::equal::(lhs, rhs) + } + + fn int_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::equal_elem::(lhs, rhs) + } + + fn int_greater( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::greater::(lhs, rhs) + } + + fn int_greater_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::greater_elem::(lhs, rhs) + } + + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::greater_equal::(lhs, rhs) + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::greater_equal_elem::(lhs, rhs) + } + + fn int_lower( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::lower::(lhs, rhs) + } + + fn int_lower_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::lower_elem::(lhs, rhs) + } + + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::lower_equal::(lhs, rhs) + } + + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::lower_equal_elem::(lhs, rhs) + } + + fn int_add( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::add::(lhs, rhs) + } + + fn int_add_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::add_scalar(lhs, rhs) + } + + fn int_sub( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::sub(lhs, rhs) + } + + fn int_sub_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::sub_scalar(lhs, rhs) + } + + fn int_mul( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::mul(lhs, rhs) + } + + fn int_mul_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::mul_scalar(lhs, rhs) + } + + fn int_div( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::div(lhs, rhs) + } + + fn int_div_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::div_scalar(lhs, rhs) + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + numeric::zeros::(shape, device) + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + numeric::ones::(shape, device) + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + kernel::reduce::sum(tensor) + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let output = init_reduce_output(&tensor, dim); + reduce::sum_dim(tensor, output, dim) + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let output = init_reduce_output(&tensor, dim); + reduce::mean_dim(tensor, output, dim) + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + kernel::reduce::argmax(tensor, dim) + } + + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + kernel::reduce::argmin(tensor, dim) + } + + fn int_clamp_min( + tensor: IntTensor, + min: IntElem, + ) -> IntTensor { + kernel::clamp_min(tensor, min) + } + + fn int_clamp_max( + tensor: IntTensor, + max: IntElem, + ) -> IntTensor { + kernel::clamp_max(tensor, max) + } + + fn int_clamp( + tensor: IntTensor, + min: IntElem, + max: IntElem, + ) -> IntTensor { + kernel::clamp(tensor, min, max) + } + + fn int_abs(tensor: IntTensor) -> IntTensor { + unary!(IntAbs, func "abs"); + unary_inplace!(IntAbsInplace, func "abs"); + + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } + + unary_default::(tensor) + } + + fn int_into_float(tensor: IntTensor) -> FloatTensor { + kernel::cast(tensor) + } + + fn int_swap_dims( + mut tensor: IntTensor, + dim1: usize, + dim2: usize, + ) -> IntTensor { + tensor.strides.swap(dim1, dim2); + tensor.shape.dims.swap(dim1, dim2); + + tensor + } } diff --git a/burn-wgpu/src/ops/module_ops.rs b/burn-wgpu/src/ops/module_ops.rs index e523c581fd..0d6b171f31 100644 --- a/burn-wgpu/src/ops/module_ops.rs +++ b/burn-wgpu/src/ops/module_ops.rs @@ -1,113 +1,110 @@ use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; use crate::{ - element::{FloatElement, IntElement}, - kernel, GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel, GraphicsApi, Wgpu, }; use burn_tensor::ops::{FloatTensor, IntTensor}; impl ModuleOps for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor { - kernel::conv::conv2d(x, weight, bias, options) - } + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor { + kernel::conv::conv2d(x, weight, bias, options) + } - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor { - kernel::conv::conv_transpose2d(x, weight, bias, options) - } + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + kernel::conv::conv_transpose2d(x, weight, bias, options) + } - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - kernel::pool::avg_pool2d(x, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + kernel::pool::avg_pool2d(x, kernel_size, stride, padding, count_include_pad) + } - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - kernel::pool::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + kernel::pool::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) + } - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor { - kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation) - } + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor { + kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation) + } - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - let (output, indices) = - kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + let (output, indices) = + kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); - MaxPool2dWithIndices::new(output, indices) - } + MaxPool2dWithIndices::new(output, indices) + } - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward> { - MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward( - x, - output_grad, - indices, - kernel_size, - stride, - padding, - dilation, - )) - } + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward> { + MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward( + x, + output_grad, + indices, + kernel_size, + stride, + padding, + dilation, + )) + } - fn adaptive_avg_pool2d( - x: FloatTensor, - output_size: [usize; 2], - ) -> FloatTensor { - kernel::pool::adaptive_avg_pool2d(x, output_size) - } + fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { + kernel::pool::adaptive_avg_pool2d(x, output_size) + } - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - kernel::pool::adaptive_avg_pool2d_backward(x, grad) - } + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + kernel::pool::adaptive_avg_pool2d_backward(x, grad) + } } diff --git a/burn-wgpu/src/ops/numeric.rs b/burn-wgpu/src/ops/numeric.rs index f4b57e83d8..0ff3fb890a 100644 --- a/burn-wgpu/src/ops/numeric.rs +++ b/burn-wgpu/src/ops/numeric.rs @@ -1,197 +1,197 @@ use crate::compute::{compute_client, WgpuComputeClient}; use crate::kernel::{ - binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default, - unary_scalar_inplace_default, + binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default, + unary_scalar_inplace_default, }; use crate::{ - binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, - unary_scalar, unary_scalar_inplace, + binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary_scalar, + unary_scalar_inplace, }; use crate::{GraphicsApi, WgpuDevice}; use burn_tensor::{Element, ElementConversion, Shape}; pub fn full( - shape: Shape, - device: &WgpuDevice, - value: E, + shape: Shape, + device: &WgpuDevice, + value: E, ) -> WgpuTensor { - let client = compute_client::(device); + let client = compute_client::(device); - full_device(client, shape, device.clone(), value) + full_device(client, shape, device.clone(), value) } pub fn full_device( - client: WgpuComputeClient, - shape: Shape, - device: WgpuDevice, - value: E, + client: WgpuComputeClient, + shape: Shape, + device: WgpuDevice, + value: E, ) -> WgpuTensor { - let empty = empty_device(client, device, shape); + let empty = empty_device(client, device, shape); - unary_scalar_inplace!(Full, body "lhs[id] = rhs;"); - unary_scalar_inplace_default::(empty, value) + unary_scalar_inplace!(Full, body "lhs[id] = rhs;"); + unary_scalar_inplace_default::(empty, value) } pub fn zeros( - shape: Shape, - device: &WgpuDevice, + shape: Shape, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); + let client = compute_client::(device); - zeros_device(client, device.clone(), shape) + zeros_device(client, device.clone(), shape) } pub fn zeros_device( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, ) -> WgpuTensor { - full_device::(client, shape, device, 0.elem()) + full_device::(client, shape, device, 0.elem()) } pub fn ones( - shape: Shape, - device: &WgpuDevice, + shape: Shape, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); + let client = compute_client::(device); - ones_device(client, device.clone(), shape) + ones_device(client, device.clone(), shape) } pub fn ones_device( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, ) -> WgpuTensor { - full_device::(client, shape, device, 1.elem()) + full_device::(client, shape, device, 1.elem()) } pub fn empty_device( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, ) -> WgpuTensor { - let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); + let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - WgpuTensor::new(client, device, shape, buffer) + WgpuTensor::new(client, device, shape, buffer) } pub fn add( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Add, "+"); - binary_elemwise_inplace!(AddInplace, "+"); + binary_elemwise!(Add, "+"); + binary_elemwise_inplace!(AddInplace, "+"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - if rhs.can_mut_broadcast(&lhs) { - return binary_elemwise_inplace_default::(rhs, lhs); - } + if rhs.can_mut_broadcast(&lhs) { + return binary_elemwise_inplace_default::(rhs, lhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn add_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(AddScalar, ops "+"); - unary_scalar_inplace!(AddScalarInplace, ops "+"); + unary_scalar!(AddScalar, ops "+"); + unary_scalar_inplace!(AddScalarInplace, ops "+"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } pub fn sub( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Sub, "-"); - binary_elemwise_inplace!(SubInplace, "-"); + binary_elemwise!(Sub, "-"); + binary_elemwise_inplace!(SubInplace, "-"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn sub_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(SubScalar, ops "-"); - unary_scalar_inplace!(SubScalarInplace, ops "-"); + unary_scalar!(SubScalar, ops "-"); + unary_scalar_inplace!(SubScalarInplace, ops "-"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } pub fn mul( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Mul, "*"); - binary_elemwise_inplace!(MulInplace, "*"); + binary_elemwise!(Mul, "*"); + binary_elemwise_inplace!(MulInplace, "*"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - if rhs.can_mut_broadcast(&lhs) { - return binary_elemwise_inplace_default::(rhs, lhs); - } + if rhs.can_mut_broadcast(&lhs) { + return binary_elemwise_inplace_default::(rhs, lhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn mul_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(MulScalar, ops "*"); - unary_scalar_inplace!(MulScalarInplace, ops "*"); + unary_scalar!(MulScalar, ops "*"); + unary_scalar_inplace!(MulScalarInplace, ops "*"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } pub fn div( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Div, "/"); - binary_elemwise_inplace!(DivInplace, "/"); + binary_elemwise!(Div, "/"); + binary_elemwise_inplace!(DivInplace, "/"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn div_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(DivScalar, ops "/"); - unary_scalar_inplace!(DivScalarInplace, ops "/"); + unary_scalar!(DivScalar, ops "/"); + unary_scalar_inplace!(DivScalarInplace, ops "/"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } diff --git a/burn-wgpu/src/tensor/base.rs b/burn-wgpu/src/tensor/base.rs index f939a0fcdf..9388f2f3fd 100644 --- a/burn-wgpu/src/tensor/base.rs +++ b/burn-wgpu/src/tensor/base.rs @@ -1,6 +1,6 @@ use crate::{ - compute::{WgpuComputeClient, WgpuHandle}, - unary, WgpuDevice, + compute::{WgpuComputeClient, WgpuHandle}, + unary, WgpuDevice, }; use crate::{element::WgpuElement, kernel::unary_default}; use burn_tensor::Shape; @@ -9,135 +9,135 @@ use std::marker::PhantomData; /// The basic tensor primitive struct. #[derive(Debug, Clone)] pub struct WgpuTensor { - /// Compute client for wgpu. - pub client: WgpuComputeClient, - /// The buffer where the data are stored. - pub handle: WgpuHandle, - /// The shape of the current tensor. - pub shape: Shape, - /// The device of the current tensor. - pub device: WgpuDevice, - /// The strides of the current tensor. - pub strides: [usize; D], - pub(crate) elem: PhantomData, + /// Compute client for wgpu. + pub client: WgpuComputeClient, + /// The buffer where the data are stored. + pub handle: WgpuHandle, + /// The shape of the current tensor. + pub shape: Shape, + /// The device of the current tensor. + pub device: WgpuDevice, + /// The strides of the current tensor. + pub strides: [usize; D], + pub(crate) elem: PhantomData, } impl WgpuTensor { - /// Create a new tensor. - pub fn new( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, - handle: WgpuHandle, - ) -> Self { - let mut strides = [0; D]; - - let mut current = 1; - shape - .dims - .iter() - .enumerate() - .rev() - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - }); - - Self { - client, - handle, - shape, - strides, - device, - elem: PhantomData, - } + /// Create a new tensor. + pub fn new( + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, + handle: WgpuHandle, + ) -> Self { + let mut strides = [0; D]; + + let mut current = 1; + shape + .dims + .iter() + .enumerate() + .rev() + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); + + Self { + client, + handle, + shape, + strides, + device, + elem: PhantomData, } - - /// Change the context of the current tensor and return the newly transferred tensor. - pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self { - let bytes = self - .client - .read(&self.handle) - .read_sync() - .expect("Can only change client synchronously"); - let handle = client.create(&bytes); - - Self { - client, - handle, - shape: self.shape.clone(), - strides: self.strides, - device, - elem: PhantomData, - } + } + + /// Change the context of the current tensor and return the newly transferred tensor. + pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self { + let bytes = self + .client + .read(&self.handle) + .read_sync() + .expect("Can only change client synchronously"); + let handle = client.create(&bytes); + + Self { + client, + handle, + shape: self.shape.clone(), + strides: self.strides, + device, + elem: PhantomData, } + } - pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor) -> bool { - if !self.handle.can_mut() { - return false; - } - - for i in 0..D { - // Output tensor will be different from the mutable tensor. - if self.shape.dims[i] < tensor_other.shape.dims[i] { - return false; - } - } - - true + pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor) -> bool { + if !self.handle.can_mut() { + return false; } - /// Copy the current tensor. - pub fn copy(&self) -> Self { - // Seems like using the copy buffer from the `wgpu` API leads to race condition when they - // are used inplace afterward. - // - // To avoid them we need to execute the whole pipeline, which leads to significant - // slowdowns. - // - // The solution is just to use a simple unary compute shader. - unary!(CopyBuffer, body "output[id] = input[id];"); - unary_default::(self.clone()) + for i in 0..D { + // Output tensor will be different from the mutable tensor. + if self.shape.dims[i] < tensor_other.shape.dims[i] { + return false; + } } - /// Check if the tensor is safe to mutate. - pub fn can_mut(&self) -> bool { - self.handle.can_mut() + true + } + + /// Copy the current tensor. + pub fn copy(&self) -> Self { + // Seems like using the copy buffer from the `wgpu` API leads to race condition when they + // are used inplace afterward. + // + // To avoid them we need to execute the whole pipeline, which leads to significant + // slowdowns. + // + // The solution is just to use a simple unary compute shader. + unary!(CopyBuffer, body "output[id] = input[id];"); + unary_default::(self.clone()) + } + + /// Check if the tensor is safe to mutate. + pub fn can_mut(&self) -> bool { + self.handle.can_mut() + } + + /// Assert that both tensors are on the same device. + pub fn assert_is_on_same_device(&self, other: &Self) { + if self.device != other.device { + panic!( + "Both tensors should be on the same device {:?} != {:?}", + self.device, other.device + ); } + } - /// Assert that both tensors are on the same device. - pub fn assert_is_on_same_device(&self, other: &Self) { - if self.device != other.device { - panic!( - "Both tensors should be on the same device {:?} != {:?}", - self.device, other.device - ); - } - } + /// Check if the current tensor is contiguous. + pub fn is_contiguous(&self) -> bool { + let mut current_stride = 0; + for d in 0..D { + let stride = self.strides[D - 1 - d]; - /// Check if the current tensor is contiguous. - pub fn is_contiguous(&self) -> bool { - let mut current_stride = 0; - for d in 0..D { - let stride = self.strides[D - 1 - d]; + if stride < current_stride { + return false; + } - if stride < current_stride { - return false; - } - - current_stride = stride; - } - - true + current_stride = stride; } - pub(crate) fn batch_swapped_with_row_col(&self) -> bool { - for d in 0..D - 2 { - let stride = self.strides[d]; - if stride < self.strides[D - 2] || stride < self.strides[D - 1] { - return true; - } - } - false + true + } + + pub(crate) fn batch_swapped_with_row_col(&self) -> bool { + for d in 0..D - 2 { + let stride = self.strides[d]; + if stride < self.strides[D - 2] || stride < self.strides[D - 1] { + return true; + } } + false + } } diff --git a/burn/src/lib.rs b/burn/src/lib.rs index 10028b6674..616d7755d8 100644 --- a/burn/src/lib.rs +++ b/burn/src/lib.rs @@ -12,5 +12,5 @@ pub use burn_core::*; /// Train module #[cfg(any(feature = "train", feature = "train-minimal"))] pub mod train { - pub use burn_train::*; + pub use burn_train::*; } diff --git a/examples/custom-renderer/examples/custom-renderer.rs b/examples/custom-renderer/examples/custom-renderer.rs index 94ce8e3e6f..a2de145f19 100644 --- a/examples/custom-renderer/examples/custom-renderer.rs +++ b/examples/custom-renderer/examples/custom-renderer.rs @@ -2,5 +2,5 @@ use burn::backend::wgpu::WgpuDevice; use burn::backend::{Autodiff, Wgpu}; fn main() { - custom_renderer::run::>(WgpuDevice::default()); + custom_renderer::run::>(WgpuDevice::default()); } diff --git a/examples/custom-renderer/src/lib.rs b/examples/custom-renderer/src/lib.rs index f1e066aff4..c08d5088e3 100644 --- a/examples/custom-renderer/src/lib.rs +++ b/examples/custom-renderer/src/lib.rs @@ -2,82 +2,82 @@ use burn::data::dataset::source::huggingface::MNISTDataset; use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress}; use burn::train::LearnerBuilder; use burn::{ - config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig, - tensor::backend::AutodiffBackend, + config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig, + tensor::backend::AutodiffBackend, }; use guide::{data::MNISTBatcher, model::ModelConfig}; #[derive(Config)] pub struct MnistTrainingConfig { - #[config(default = 10)] - pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, - #[config(default = 1e-4)] - pub lr: f64, - pub model: ModelConfig, - pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1e-4)] + pub lr: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, } struct CustomRenderer {} impl MetricsRenderer for CustomRenderer { - fn update_train(&mut self, _state: MetricState) {} + fn update_train(&mut self, _state: MetricState) {} - fn update_valid(&mut self, _state: MetricState) {} + fn update_valid(&mut self, _state: MetricState) {} - fn render_train(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_train(&mut self, item: TrainingProgress) { + dbg!(item); + } - fn render_valid(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_valid(&mut self, item: TrainingProgress) { + dbg!(item); + } } pub fn run(device: B::Device) { - // Create the configuration. - let config_model = ModelConfig::new(10, 1024); - let config_optimizer = AdamConfig::new(); - let config = MnistTrainingConfig::new(config_model, config_optimizer); + // Create the configuration. + let config_model = ModelConfig::new(10, 1024); + let config_optimizer = AdamConfig::new(); + let config = MnistTrainingConfig::new(config_model, config_optimizer); - B::seed(config.seed); + B::seed(config.seed); - // Create the model and optimizer. - let model = config.model.init(); - let optim = config.optimizer.init(); + // Create the model and optimizer. + let model = config.model.init(); + let optim = config.optimizer.init(); - // Create the batcher. - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); + // Create the batcher. + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); - // Create the dataloaders. - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); + // Create the dataloaders. + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); - // artifact dir does not need to be provided when log_to_file is false - let builder = LearnerBuilder::new("") - .devices(vec![device]) - .num_epochs(config.num_epochs) - .renderer(CustomRenderer {}) - .log_to_file(false); - // can be used to interrupt training - let _interrupter = builder.interrupter(); + // artifact dir does not need to be provided when log_to_file is false + let builder = LearnerBuilder::new("") + .devices(vec![device]) + .num_epochs(config.num_epochs) + .renderer(CustomRenderer {}) + .log_to_file(false); + // can be used to interrupt training + let _interrupter = builder.interrupter(); - let learner = builder.build(model, optim, config.lr); + let learner = builder.build(model, optim, config.lr); - let _model_trained = learner.fit(dataloader_train, dataloader_test); + let _model_trained = learner.fit(dataloader_train, dataloader_test); } diff --git a/examples/custom-training-loop/examples/custom-training-loop.rs b/examples/custom-training-loop/examples/custom-training-loop.rs index 1b264c527c..9e0e5ade36 100644 --- a/examples/custom-training-loop/examples/custom-training-loop.rs +++ b/examples/custom-training-loop/examples/custom-training-loop.rs @@ -2,5 +2,5 @@ use burn::backend::wgpu::WgpuDevice; use burn::backend::{Autodiff, Wgpu}; fn main() { - custom_training_loop::run::>(WgpuDevice::default()); + custom_training_loop::run::>(WgpuDevice::default()); } diff --git a/examples/custom-training-loop/src/lib.rs b/examples/custom-training-loop/src/lib.rs index 9eb8b154cb..80a72a262a 100644 --- a/examples/custom-training-loop/src/lib.rs +++ b/examples/custom-training-loop/src/lib.rs @@ -2,171 +2,171 @@ use std::marker::PhantomData; use burn::data::dataset::source::huggingface::MNISTDataset; use burn::{ - config::Config, - data::dataloader::DataLoaderBuilder, - module::AutodiffModule, - nn::loss::CrossEntropyLoss, - optim::{AdamConfig, GradientsParams, Optimizer}, - tensor::{ - backend::{AutodiffBackend, Backend}, - ElementConversion, Int, Tensor, - }, + config::Config, + data::dataloader::DataLoaderBuilder, + module::AutodiffModule, + nn::loss::CrossEntropyLoss, + optim::{AdamConfig, GradientsParams, Optimizer}, + tensor::{ + backend::{AutodiffBackend, Backend}, + ElementConversion, Int, Tensor, + }, }; use guide::{ - data::{MNISTBatch, MNISTBatcher}, - model::{Model, ModelConfig}, + data::{MNISTBatch, MNISTBatcher}, + model::{Model, ModelConfig}, }; #[derive(Config)] pub struct MnistTrainingConfig { - #[config(default = 10)] - pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, - #[config(default = 1e-4)] - pub lr: f64, - pub model: ModelConfig, - pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1e-4)] + pub lr: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, } pub fn run(device: B::Device) { - // Create the configuration. - let config_model = ModelConfig::new(10, 1024); - let config_optimizer = AdamConfig::new(); - let config = MnistTrainingConfig::new(config_model, config_optimizer); - - B::seed(config.seed); - - // Create the model and optimizer. - let mut model = config.model.init(); - let mut optim = config.optimizer.init(); - - // Create the batcher. - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); - - // Create the dataloaders. - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); - - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); - - // Iterate over our training and validation loop for X epochs. - for epoch in 1..config.num_epochs + 1 { - // Implement our training loop. - for (iteration, batch) in dataloader_train.iter().enumerate() { - let output = model.forward(batch.images); - let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); - let accuracy = accuracy(output, batch.targets); - - println!( - "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", - epoch, - iteration, - loss.clone().into_scalar(), - accuracy, - ); - - // Gradients for the current backward pass - let grads = loss.backward(); - // Gradients linked to each parameter of the model. - let grads = GradientsParams::from_grads(grads, &model); - // Update the model using the optimizer. - model = optim.step(config.lr, model, grads); - } - - // Get the model without autodiff. - let model_valid = model.valid(); - - // Implement our validation loop. - for (iteration, batch) in dataloader_test.iter().enumerate() { - let output = model_valid.forward(batch.images); - let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); - let accuracy = accuracy(output, batch.targets); - - println!( - "[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}", - iteration, - epoch, - loss.clone().into_scalar(), - accuracy, - ); - } + // Create the configuration. + let config_model = ModelConfig::new(10, 1024); + let config_optimizer = AdamConfig::new(); + let config = MnistTrainingConfig::new(config_model, config_optimizer); + + B::seed(config.seed); + + // Create the model and optimizer. + let mut model = config.model.init(); + let mut optim = config.optimizer.init(); + + // Create the batcher. + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); + + // Create the dataloaders. + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); + + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); + + // Iterate over our training and validation loop for X epochs. + for epoch in 1..config.num_epochs + 1 { + // Implement our training loop. + for (iteration, batch) in dataloader_train.iter().enumerate() { + let output = model.forward(batch.images); + let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); + let accuracy = accuracy(output, batch.targets); + + println!( + "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", + epoch, + iteration, + loss.clone().into_scalar(), + accuracy, + ); + + // Gradients for the current backward pass + let grads = loss.backward(); + // Gradients linked to each parameter of the model. + let grads = GradientsParams::from_grads(grads, &model); + // Update the model using the optimizer. + model = optim.step(config.lr, model, grads); } + + // Get the model without autodiff. + let model_valid = model.valid(); + + // Implement our validation loop. + for (iteration, batch) in dataloader_test.iter().enumerate() { + let output = model_valid.forward(batch.images); + let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); + let accuracy = accuracy(output, batch.targets); + + println!( + "[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}", + iteration, + epoch, + loss.clone().into_scalar(), + accuracy, + ); + } + } } /// Create out own accuracy metric calculation. fn accuracy(output: Tensor, targets: Tensor) -> f32 { - let predictions = output.argmax(1).squeeze(1); - let num_predictions: usize = targets.dims().iter().product(); - let num_corrects = predictions.equal(targets).int().sum().into_scalar(); + let predictions = output.argmax(1).squeeze(1); + let num_predictions: usize = targets.dims().iter().product(); + let num_corrects = predictions.equal(targets).int().sum().into_scalar(); - num_corrects.elem::() / num_predictions as f32 * 100.0 + num_corrects.elem::() / num_predictions as f32 * 100.0 } #[allow(dead_code)] struct Learner1 where - B: AutodiffBackend, + B: AutodiffBackend, { - model: Model, - optim: O, + model: Model, + optim: O, } #[allow(dead_code)] struct Learner2 { - model: M, - optim: O, + model: M, + optim: O, } #[allow(dead_code)] struct Learner3 { - model: M, - optim: O, - _b: PhantomData, + model: M, + optim: O, + _b: PhantomData, } #[allow(dead_code)] impl Learner1 where - B: AutodiffBackend, - O: Optimizer, B>, + B: AutodiffBackend, + O: Optimizer, B>, { - pub fn step1(&mut self, _batch: MNISTBatch) { - // - } + pub fn step1(&mut self, _batch: MNISTBatch) { + // + } } #[allow(dead_code)] impl Learner2, O> where - B: AutodiffBackend, - O: Optimizer, B>, + B: AutodiffBackend, + O: Optimizer, B>, { - pub fn step2(&mut self, _batch: MNISTBatch) { - // - } + pub fn step2(&mut self, _batch: MNISTBatch) { + // + } } #[allow(dead_code)] impl Learner2 { - pub fn step3(&mut self, _batch: MNISTBatch) - where - B: AutodiffBackend, - M: AutodiffModule, - O: Optimizer, - { - // - } + pub fn step3(&mut self, _batch: MNISTBatch) + where + B: AutodiffBackend, + M: AutodiffModule, + O: Optimizer, + { + // + } } diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index 9a2e9d96fc..67912aa1a6 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -1,76 +1,76 @@ use burn::tensor::{Distribution, Tensor}; use custom_wgpu_kernel::{ - matmul_add_relu_custom, matmul_add_relu_reference, AutodiffBackend, Backend, + matmul_add_relu_custom, matmul_add_relu_reference, AutodiffBackend, Backend, }; fn inference() { - let lhs = Tensor::::random([1, 32, 32], Distribution::Default); - let rhs = Tensor::random([32, 32, 32], Distribution::Default); - let bias = Tensor::random([32, 32, 32], Distribution::Default); + let lhs = Tensor::::random([1, 32, 32], Distribution::Default); + let rhs = Tensor::random([32, 32, 32], Distribution::Default); + let bias = Tensor::random([32, 32, 32], Distribution::Default); - let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()) - .into_data() - .convert::(); - let custom = matmul_add_relu_custom(lhs, rhs, bias) - .into_data() - .convert::(); + let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()) + .into_data() + .convert::(); + let custom = matmul_add_relu_custom(lhs, rhs, bias) + .into_data() + .convert::(); - reference.assert_approx_eq(&custom, 3); + reference.assert_approx_eq(&custom, 3); - println!("Both reference and the custom fused kernel have the same output"); + println!("Both reference and the custom fused kernel have the same output"); } fn autodiff() { - let lhs = Tensor::::random([1, 32, 32], Distribution::Default).require_grad(); - let rhs = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); - let bias = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); + let lhs = Tensor::::random([1, 32, 32], Distribution::Default).require_grad(); + let rhs = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); + let bias = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); - let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()); + let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()); - let mut gradients = reference.backward(); + let mut gradients = reference.backward(); - let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap(); - let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap(); - let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap(); + let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap(); + let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap(); + let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap(); - let lhs = lhs.detach(); - let rhs = rhs.detach(); - let bias = bias.detach(); + let lhs = lhs.detach(); + let rhs = rhs.detach(); + let bias = bias.detach(); - let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone()); + let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone()); - let mut gradients = custom.backward(); + let mut gradients = custom.backward(); - let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap(); - let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap(); - let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap(); + let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap(); + let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap(); + let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap(); - lhs_grad_ref - .into_data() - .convert::() - .assert_approx_eq(&lhs_grad_custom.into_data().convert(), 3); + lhs_grad_ref + .into_data() + .convert::() + .assert_approx_eq(&lhs_grad_custom.into_data().convert(), 3); - println!("Both reference and the custom fused kernel have the same lhs gradient"); + println!("Both reference and the custom fused kernel have the same lhs gradient"); - rhs_grad_ref - .into_data() - .convert::() - .assert_approx_eq(&rhs_grad_custom.into_data().convert(), 3); + rhs_grad_ref + .into_data() + .convert::() + .assert_approx_eq(&rhs_grad_custom.into_data().convert(), 3); - println!("Both reference and the custom fused kernel have the same rhs gradient"); + println!("Both reference and the custom fused kernel have the same rhs gradient"); - bias_grad_ref - .into_data() - .convert::() - .assert_approx_eq(&bias_grad_custom.into_data().convert(), 3); + bias_grad_ref + .into_data() + .convert::() + .assert_approx_eq(&bias_grad_custom.into_data().convert(), 3); - println!("Both reference and the custom fused kernel have the same bias gradient"); + println!("Both reference and the custom fused kernel have the same bias gradient"); } fn main() { - type MyBackend = burn::backend::Wgpu; - type MyAutodiffBackend = burn::backend::Autodiff; + type MyBackend = burn::backend::Wgpu; + type MyAutodiffBackend = burn::backend::Autodiff; - inference::(); - autodiff::(); + inference::(); + autodiff::(); } diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index fd50df6079..6cc755f55d 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -2,9 +2,9 @@ use crate::FloatTensor; use super::{AutodiffBackend, Backend}; use burn::backend::autodiff::{ - grads::Gradients, - ops::{broadcast_shape, Backward, Ops, OpsKind}, - Autodiff, + grads::Gradients, + ops::{broadcast_shape, Backward, Ops, OpsKind}, + Autodiff, }; use burn::backend::wgpu::{FloatElement, GraphicsApi, IntElement, Wgpu}; use burn::tensor::Shape; @@ -17,106 +17,103 @@ impl AutodiffBackend for Autodif // also implements our own API. This would allow us to call any function only implemented for Wgpu // and potentially call a custom kernel crafted only for this task. impl Backend for Autodiff { - fn fused_matmul_add_relu( - lhs: FloatTensor, - rhs: FloatTensor, - bias: FloatTensor, - ) -> FloatTensor { - // Create our zero-sized type that will implement the Backward trait. - #[derive(Debug)] - struct FusedMatmulAddReluBackward; + fn fused_matmul_add_relu( + lhs: FloatTensor, + rhs: FloatTensor, + bias: FloatTensor, + ) -> FloatTensor { + // Create our zero-sized type that will implement the Backward trait. + #[derive(Debug)] + struct FusedMatmulAddReluBackward; - // Implement the backward trait for the given backend B, the node gradient being of rank D - // with three other gradients to calculate (lhs, rhs, and bias). - impl Backward for FusedMatmulAddReluBackward { - // Our state that we must build during the forward pass to compute the backward pass. - // - // Note that we could improve the performance further by only keeping the state of - // tensors that are tracked, improving memory management, but for simplicity, we avoid - // that part. - type State = ( - FloatTensor, - FloatTensor, - FloatTensor, - Shape, - ); + // Implement the backward trait for the given backend B, the node gradient being of rank D + // with three other gradients to calculate (lhs, rhs, and bias). + impl Backward for FusedMatmulAddReluBackward { + // Our state that we must build during the forward pass to compute the backward pass. + // + // Note that we could improve the performance further by only keeping the state of + // tensors that are tracked, improving memory management, but for simplicity, we avoid + // that part. + type State = ( + FloatTensor, + FloatTensor, + FloatTensor, + Shape, + ); - fn backward(self, ops: Ops, grads: &mut Gradients) { - // Get the nodes of each variable. - let [node_lhs, node_rhs, node_bias] = ops.parents; - // Fetch the gradient for the current node. - let grad = grads.consume::(&ops.node); + fn backward(self, ops: Ops, grads: &mut Gradients) { + // Get the nodes of each variable. + let [node_lhs, node_rhs, node_bias] = ops.parents; + // Fetch the gradient for the current node. + let grad = grads.consume::(&ops.node); - // Set our state. - let (lhs, rhs, output, shape_bias) = ops.state; + // Set our state. + let (lhs, rhs, output, shape_bias) = ops.state; - // Fetch shapes of our tensor to support broadcasting. - let shape_lhs = B::shape(&lhs); - let shape_rhs = B::shape(&rhs); + // Fetch shapes of our tensor to support broadcasting. + let shape_lhs = B::shape(&lhs); + let shape_rhs = B::shape(&rhs); - // Compute the gradient of the output using the already existing `relu_backward` - // function in the basic Burn backend trait. - let grad_output = B::relu_backward(output, grad); + // Compute the gradient of the output using the already existing `relu_backward` + // function in the basic Burn backend trait. + let grad_output = B::relu_backward(output, grad); - // Compute the lhs gradient, which is the derivative of matmul with support for - // broadcasting. - let grad_lhs = broadcast_shape::( - B::matmul(grad_output.clone(), B::transpose(rhs)), - &shape_lhs, - ); - // Compute the rhs gradient, which is the derivative of matmul with support for - // broadcasting. - let grad_rhs = broadcast_shape::( - B::matmul(B::transpose(lhs), grad_output.clone()), - &shape_rhs, - ); - // The add derivative is only 1, so we just need to support broadcasting to - // compute the bias gradient. - let grad_bias = broadcast_shape::(grad_output, &shape_bias); + // Compute the lhs gradient, which is the derivative of matmul with support for + // broadcasting. + let grad_lhs = broadcast_shape::( + B::matmul(grad_output.clone(), B::transpose(rhs)), + &shape_lhs, + ); + // Compute the rhs gradient, which is the derivative of matmul with support for + // broadcasting. + let grad_rhs = broadcast_shape::( + B::matmul(B::transpose(lhs), grad_output.clone()), + &shape_rhs, + ); + // The add derivative is only 1, so we just need to support broadcasting to + // compute the bias gradient. + let grad_bias = broadcast_shape::(grad_output, &shape_bias); - // Register the gradient for each variable based on whether they are marked as - // `tracked`. - if let Some(node) = node_bias { - grads.register::(node, grad_bias); - } - if let Some(node) = node_lhs { - grads.register::(node, grad_lhs); - } - if let Some(node) = node_rhs { - grads.register::(node, grad_rhs); - } - } + // Register the gradient for each variable based on whether they are marked as + // `tracked`. + if let Some(node) = node_bias { + grads.register::(node, grad_bias); } + if let Some(node) = node_lhs { + grads.register::(node, grad_lhs); + } + if let Some(node) = node_rhs { + grads.register::(node, grad_rhs); + } + } + } - // Prepare a stateful operation with each variable node and corresponding graph. - // - // Each node can be fetched with `ops.parents` in the same order as defined here. - match FusedMatmulAddReluBackward - .prepare( - [lhs.node, rhs.node, bias.node], - [lhs.graph, rhs.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => { - // When at least one node is tracked, we should register our backward step. - // We compute the output and the state before finishing the preparation. - let bias_shape = B::shape(&bias.primitive); - let output = B::fused_matmul_add_relu( - lhs.primitive.clone(), - rhs.primitive.clone(), - bias.primitive, - ); + // Prepare a stateful operation with each variable node and corresponding graph. + // + // Each node can be fetched with `ops.parents` in the same order as defined here. + match FusedMatmulAddReluBackward + .prepare( + [lhs.node, rhs.node, bias.node], + [lhs.graph, rhs.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => { + // When at least one node is tracked, we should register our backward step. + // We compute the output and the state before finishing the preparation. + let bias_shape = B::shape(&bias.primitive); + let output = + B::fused_matmul_add_relu(lhs.primitive.clone(), rhs.primitive.clone(), bias.primitive); - let state = (lhs.primitive, rhs.primitive, output.clone(), bias_shape); - prep.finish(state, output) - } - OpsKind::UnTracked(prep) => { - // When no node is tracked, we can just compute the original operation without - // keeping any state. - let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive); - prep.finish(output) - } - } + let state = (lhs.primitive, rhs.primitive, output.clone(), bias_shape); + prep.finish(state, output) + } + OpsKind::UnTracked(prep) => { + // When no node is tracked, we can just compute the original operation without + // keeping any state. + let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive); + prep.finish(output) + } } + } } diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index 01e2e7db83..d92dede8f3 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -2,13 +2,11 @@ use crate::FloatTensor; use super::Backend; use burn::backend::wgpu::{ - compute::{DynamicKernel, WorkGroup}, - kernel::{ - build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, - }, - kernel_wgsl, - tensor::WgpuTensor, - FloatElement, GraphicsApi, IntElement, Wgpu, + compute::{DynamicKernel, WorkGroup}, + kernel::{build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, + kernel_wgsl, + tensor::WgpuTensor, + FloatElement, GraphicsApi, IntElement, Wgpu, }; use burn::tensor::Shape; use derive_new::new; @@ -20,95 +18,95 @@ kernel_wgsl!(FusedMatmulAddReluRaw, "./kernel.wgsl"); // Define our kernel type with workgroup information. #[derive(new, Debug)] struct FusedMatmulAddRelu { - workgroup_size_x: usize, - workgroup_size_y: usize, - _elem: PhantomData, + workgroup_size_x: usize, + workgroup_size_y: usize, + _elem: PhantomData, } // Implement the dynamic kernel trait for our kernel type. impl DynamicKernelSource for FusedMatmulAddRelu { - fn source(&self) -> SourceTemplate { - // Extend our raw kernel with workgroup size information using the - // `SourceTemplate` trait. - FusedMatmulAddReluRaw::source() - .register("workgroup_size_x", self.workgroup_size_x.to_string()) - .register("workgroup_size_y", self.workgroup_size_y.to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + // Extend our raw kernel with workgroup size information using the + // `SourceTemplate` trait. + FusedMatmulAddReluRaw::source() + .register("workgroup_size_x", self.workgroup_size_x.to_string()) + .register("workgroup_size_y", self.workgroup_size_y.to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + format!("{:?}", self) + } } /// Implement our custom backend trait for the existing backend `WgpuBackend`. impl Backend for Wgpu { - fn fused_matmul_add_relu( - lhs: FloatTensor, - rhs: FloatTensor, - bias: FloatTensor, - ) -> WgpuTensor { - // Define workgroup size, hardcoded for simplicity. - let workgroup_size_x = 16; - let workgroup_size_y = 16; - - lhs.assert_is_on_same_device(&rhs); - lhs.assert_is_on_same_device(&bias); - - // For simplicity, make sure each tensor is continuous. - let lhs = into_contiguous(lhs); - let rhs = into_contiguous(rhs); - let bias = into_contiguous(bias); - - // Get the matmul relevant shapes. - let num_rows = lhs.shape.dims[D - 2]; - let num_cols = rhs.shape.dims[D - 1]; - - // Compute shape of output, while tracking number of batches. - let mut num_batches = 1; - let mut shape_out = [0; D]; - for i in shape_out.into_iter().take(D - 2) { - shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]); - num_batches *= shape_out[i]; - } - shape_out[D - 2] = num_rows; - shape_out[D - 1] = num_cols; - let shape_out = Shape::new(shape_out); - - // Create a buffer for the output tensor. - let buffer = lhs - .client - .empty(shape_out.num_elements() * core::mem::size_of::()); - - // Create the output tensor primitive. - let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); - - // Create the kernel. - let kernel = FusedMatmulAddRelu::::new(workgroup_size_x, workgroup_size_y); - - // Build info buffer with tensor information needed by the kernel, such as shapes and strides. - let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - - // Declare the wgsl workgroup with the number of blocks in x, y and z. - let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_batches as u32); - - // Execute lazily the kernel with the launch information and the given buffers. - lhs.client.execute( - Box::new(DynamicKernel::new(kernel, workgroup)), - &[ - &lhs.handle, - &rhs.handle, - &bias.handle, - &output.handle, - &info_handle, - ], - ); - - // Return the output tensor. - output + fn fused_matmul_add_relu( + lhs: FloatTensor, + rhs: FloatTensor, + bias: FloatTensor, + ) -> WgpuTensor { + // Define workgroup size, hardcoded for simplicity. + let workgroup_size_x = 16; + let workgroup_size_y = 16; + + lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&bias); + + // For simplicity, make sure each tensor is continuous. + let lhs = into_contiguous(lhs); + let rhs = into_contiguous(rhs); + let bias = into_contiguous(bias); + + // Get the matmul relevant shapes. + let num_rows = lhs.shape.dims[D - 2]; + let num_cols = rhs.shape.dims[D - 1]; + + // Compute shape of output, while tracking number of batches. + let mut num_batches = 1; + let mut shape_out = [0; D]; + for i in shape_out.into_iter().take(D - 2) { + shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]); + num_batches *= shape_out[i]; } + shape_out[D - 2] = num_rows; + shape_out[D - 1] = num_cols; + let shape_out = Shape::new(shape_out); + + // Create a buffer for the output tensor. + let buffer = lhs + .client + .empty(shape_out.num_elements() * core::mem::size_of::()); + + // Create the output tensor primitive. + let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); + + // Create the kernel. + let kernel = FusedMatmulAddRelu::::new(workgroup_size_x, workgroup_size_y); + + // Build info buffer with tensor information needed by the kernel, such as shapes and strides. + let info = build_info(&[&lhs, &rhs, &output]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + // Declare the wgsl workgroup with the number of blocks in x, y and z. + let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; + let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_batches as u32); + + // Execute lazily the kernel with the launch information and the given buffers. + lhs.client.execute( + Box::new(DynamicKernel::new(kernel, workgroup)), + &[ + &lhs.handle, + &rhs.handle, + &bias.handle, + &output.handle, + &info_handle, + ], + ); + + // Return the output tensor. + output + } } diff --git a/examples/custom-wgpu-kernel/src/lib.rs b/examples/custom-wgpu-kernel/src/lib.rs index eb8cfd1570..ac1f6ecc58 100644 --- a/examples/custom-wgpu-kernel/src/lib.rs +++ b/examples/custom-wgpu-kernel/src/lib.rs @@ -8,11 +8,11 @@ pub type FloatTensor = : /// We create our own Backend trait that extends the Burn backend trait. pub trait Backend: burn::tensor::backend::Backend { - fn fused_matmul_add_relu( - lhs: FloatTensor, - rhs: FloatTensor, - bias: FloatTensor, - ) -> FloatTensor; + fn fused_matmul_add_relu( + lhs: FloatTensor, + rhs: FloatTensor, + bias: FloatTensor, + ) -> FloatTensor; } /// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait. @@ -20,26 +20,26 @@ pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {} /// We define our custom implementation using the added function on our custom backend. pub fn matmul_add_relu_custom( - lhs: Tensor, - rhs: Tensor, - bias: Tensor, + lhs: Tensor, + rhs: Tensor, + bias: Tensor, ) -> Tensor { - let output = B::fused_matmul_add_relu( - lhs.into_primitive(), - rhs.into_primitive(), - bias.into_primitive(), - ); + let output = B::fused_matmul_add_relu( + lhs.into_primitive(), + rhs.into_primitive(), + bias.into_primitive(), + ); - Tensor::from_primitive(output) + Tensor::from_primitive(output) } /// We define a reference implementation using basic tensor operations. pub fn matmul_add_relu_reference( - lhs: Tensor, - rhs: Tensor, - bias: Tensor, + lhs: Tensor, + rhs: Tensor, + bias: Tensor, ) -> Tensor { - let x = lhs.matmul(rhs) + bias; + let x = lhs.matmul(rhs) + bias; - activation::relu(x) + activation::relu(x) } diff --git a/examples/guide/examples/guide.rs b/examples/guide/examples/guide.rs index 274e511cfe..682502987b 100644 --- a/examples/guide/examples/guide.rs +++ b/examples/guide/examples/guide.rs @@ -5,21 +5,21 @@ use burn::optim::AdamConfig; use guide::{model::ModelConfig, training::TrainingConfig}; fn main() { - type MyBackend = Wgpu; - type MyAutodiffBackend = Autodiff; + type MyBackend = Wgpu; + type MyAutodiffBackend = Autodiff; - let device = burn::backend::wgpu::WgpuDevice::default(); - let artifact_dir = "/tmp/guide"; - guide::training::train::( - artifact_dir, - TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()), - device.clone(), - ); - guide::inference::infer::( - artifact_dir, - device, - burn::data::dataset::source::huggingface::MNISTDataset::test() - .get(42) - .unwrap(), - ); + let device = burn::backend::wgpu::WgpuDevice::default(); + let artifact_dir = "/tmp/guide"; + guide::training::train::( + artifact_dir, + TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()), + device.clone(), + ); + guide::inference::infer::( + artifact_dir, + device, + burn::data::dataset::source::huggingface::MNISTDataset::test() + .get(42) + .unwrap(), + ); } diff --git a/examples/guide/src/data.rs b/examples/guide/src/data.rs index fb43b4a08e..d089bb7226 100644 --- a/examples/guide/src/data.rs +++ b/examples/guide/src/data.rs @@ -1,45 +1,45 @@ use burn::{ - data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, - tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, + data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, + tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, }; pub struct MNISTBatcher { - device: B::Device, + device: B::Device, } impl MNISTBatcher { - pub fn new(device: B::Device) -> Self { - Self { device } - } + pub fn new(device: B::Device) -> Self { + Self { device } + } } #[derive(Clone, Debug)] pub struct MNISTBatch { - pub images: Tensor, - pub targets: Tensor, + pub images: Tensor, + pub targets: Tensor, } impl Batcher> for MNISTBatcher { - fn batch(&self, items: Vec) -> MNISTBatch { - let images = items - .iter() - .map(|item| Data::::from(item.image)) - .map(|data| Tensor::::from_data(data.convert())) - .map(|tensor| tensor.reshape([1, 28, 28])) - // normalize: make between [0,1] and make the mean = 0 and std = 1 - // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example - // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 - .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) - .collect(); + fn batch(&self, items: Vec) -> MNISTBatch { + let images = items + .iter() + .map(|item| Data::::from(item.image)) + .map(|data| Tensor::::from_data(data.convert())) + .map(|tensor| tensor.reshape([1, 28, 28])) + // normalize: make between [0,1] and make the mean = 0 and std = 1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) + .collect(); - let targets = items - .iter() - .map(|item| Tensor::::from_data([(item.label as i64).elem()])) - .collect(); + let targets = items + .iter() + .map(|item| Tensor::::from_data([(item.label as i64).elem()])) + .collect(); - let images = Tensor::cat(images, 0).to_device(&self.device); - let targets = Tensor::cat(targets, 0).to_device(&self.device); + let images = Tensor::cat(images, 0).to_device(&self.device); + let targets = Tensor::cat(targets, 0).to_device(&self.device); - MNISTBatch { images, targets } - } + MNISTBatch { images, targets } + } } diff --git a/examples/guide/src/inference.rs b/examples/guide/src/inference.rs index c2f0fbb494..9d665dcfd7 100644 --- a/examples/guide/src/inference.rs +++ b/examples/guide/src/inference.rs @@ -1,27 +1,27 @@ use crate::{data::MNISTBatcher, training::TrainingConfig}; use burn::data::dataset::source::huggingface::MNISTItem; use burn::{ - config::Config, - data::dataloader::batcher::Batcher, - module::Module, - record::{CompactRecorder, Recorder}, - tensor::backend::Backend, + config::Config, + data::dataloader::batcher::Batcher, + module::Module, + record::{CompactRecorder, Recorder}, + tensor::backend::Backend, }; pub fn infer(artifact_dir: &str, device: B::Device, item: MNISTItem) { - let config = - TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists"); - let record = CompactRecorder::new() - .load(format!("{artifact_dir}/model").into()) - .expect("Failed to load trained model"); + let config = + TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into()) + .expect("Failed to load trained model"); - let model = config.model.init_with::(record).to_device(&device); + let model = config.model.init_with::(record).to_device(&device); - let label = item.label; - let batcher = MNISTBatcher::new(device); - let batch = batcher.batch(vec![item]); - let output = model.forward(batch.images); - let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); + let label = item.label; + let batcher = MNISTBatcher::new(device); + let batch = batcher.batch(vec![item]); + let output = model.forward(batch.images); + let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); - println!("Predicted {} Expected {}", predicted, label); + println!("Predicted {} Expected {}", predicted, label); } diff --git a/examples/guide/src/model.rs b/examples/guide/src/model.rs index 665612b50b..8d842fbb65 100644 --- a/examples/guide/src/model.rs +++ b/examples/guide/src/model.rs @@ -1,83 +1,82 @@ use burn::{ - config::Config, - module::Module, - nn::{ - conv::{Conv2d, Conv2dConfig}, - pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, - Dropout, DropoutConfig, Linear, LinearConfig, ReLU, - }, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, + Dropout, DropoutConfig, Linear, LinearConfig, ReLU, + }, + tensor::{backend::Backend, Tensor}, }; #[derive(Module, Debug)] pub struct Model { - conv1: Conv2d, - conv2: Conv2d, - pool: AdaptiveAvgPool2d, - dropout: Dropout, - linear1: Linear, - linear2: Linear, - activation: ReLU, + conv1: Conv2d, + conv2: Conv2d, + pool: AdaptiveAvgPool2d, + dropout: Dropout, + linear1: Linear, + linear2: Linear, + activation: ReLU, } #[derive(Config, Debug)] pub struct ModelConfig { - num_classes: usize, - hidden_size: usize, - #[config(default = "0.5")] - dropout: f64, + num_classes: usize, + hidden_size: usize, + #[config(default = "0.5")] + dropout: f64, } impl ModelConfig { - /// Returns the initialized model. - pub fn init(&self) -> Model { - Model { - conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), - conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), - pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), - activation: ReLU::new(), - linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), - linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), - dropout: DropoutConfig::new(self.dropout).init(), - } + /// Returns the initialized model. + pub fn init(&self) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: ReLU::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), + dropout: DropoutConfig::new(self.dropout).init(), } - /// Returns the initialized model using the recorded weights. - pub fn init_with(&self, record: ModelRecord) -> Model { - Model { - conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), - conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), - pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), - activation: ReLU::new(), - linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), - linear2: LinearConfig::new(self.hidden_size, self.num_classes) - .init_with(record.linear2), - dropout: DropoutConfig::new(self.dropout).init(), - } + } + /// Returns the initialized model using the recorded weights. + pub fn init_with(&self, record: ModelRecord) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: ReLU::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init_with(record.linear2), + dropout: DropoutConfig::new(self.dropout).init(), } + } } impl Model { - /// # Shapes - /// - Images [batch_size, height, width] - /// - Output [batch_size, class_prob] - pub fn forward(&self, images: Tensor) -> Tensor { - let [batch_size, height, width] = images.dims(); + /// # Shapes + /// - Images [batch_size, height, width] + /// - Output [batch_size, class_prob] + pub fn forward(&self, images: Tensor) -> Tensor { + let [batch_size, height, width] = images.dims(); - // Create a channel. - let x = images.reshape([batch_size, 1, height, width]); + // Create a channel. + let x = images.reshape([batch_size, 1, height, width]); - let x = self.conv1.forward(x); // [batch_size, 8, _, _] - let x = self.dropout.forward(x); - let x = self.conv2.forward(x); // [batch_size, 16, _, _] - let x = self.dropout.forward(x); - let x = self.activation.forward(x); + let x = self.conv1.forward(x); // [batch_size, 8, _, _] + let x = self.dropout.forward(x); + let x = self.conv2.forward(x); // [batch_size, 16, _, _] + let x = self.dropout.forward(x); + let x = self.activation.forward(x); - let x = self.pool.forward(x); // [batch_size, 16, 8, 8] - let x = x.reshape([batch_size, 16 * 8 * 8]); - let x = self.linear1.forward(x); - let x = self.dropout.forward(x); - let x = self.activation.forward(x); + let x = self.pool.forward(x); // [batch_size, 16, 8, 8] + let x = x.reshape([batch_size, 16 * 8 * 8]); + let x = self.linear1.forward(x); + let x = self.dropout.forward(x); + let x = self.activation.forward(x); - self.linear2.forward(x) // [batch_size, num_classes] - } + self.linear2.forward(x) // [batch_size, num_classes] + } } diff --git a/examples/guide/src/training.rs b/examples/guide/src/training.rs index f04d132fd8..7582b459ea 100644 --- a/examples/guide/src/training.rs +++ b/examples/guide/src/training.rs @@ -1,109 +1,109 @@ use crate::{ - data::{MNISTBatch, MNISTBatcher}, - model::{Model, ModelConfig}, + data::{MNISTBatch, MNISTBatcher}, + model::{Model, ModelConfig}, }; use burn::data::dataset::source::huggingface::MNISTDataset; use burn::train::{ - metric::{AccuracyMetric, LossMetric}, - ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, + metric::{AccuracyMetric, LossMetric}, + ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, }; use burn::{ - self, - config::Config, - data::dataloader::DataLoaderBuilder, - module::Module, - nn::loss::CrossEntropyLoss, - optim::AdamConfig, - record::CompactRecorder, - tensor::{ - backend::{AutodiffBackend, Backend}, - Int, Tensor, - }, + self, + config::Config, + data::dataloader::DataLoaderBuilder, + module::Module, + nn::loss::CrossEntropyLoss, + optim::AdamConfig, + record::CompactRecorder, + tensor::{ + backend::{AutodiffBackend, Backend}, + Int, Tensor, + }, }; impl Model { - pub fn forward_classification( - &self, - images: Tensor, - targets: Tensor, - ) -> ClassificationOutput { - let output = self.forward(images); - let loss = CrossEntropyLoss::default().forward(output.clone(), targets.clone()); + pub fn forward_classification( + &self, + images: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + let output = self.forward(images); + let loss = CrossEntropyLoss::default().forward(output.clone(), targets.clone()); - ClassificationOutput::new(loss, output, targets) - } + ClassificationOutput::new(loss, output, targets) + } } impl TrainStep, ClassificationOutput> for Model { - fn step(&self, batch: MNISTBatch) -> TrainOutput> { - let item = self.forward_classification(batch.images, batch.targets); + fn step(&self, batch: MNISTBatch) -> TrainOutput> { + let item = self.forward_classification(batch.images, batch.targets); - TrainOutput::new(self, item.loss.backward(), item) - } + TrainOutput::new(self, item.loss.backward(), item) + } } impl ValidStep, ClassificationOutput> for Model { - fn step(&self, batch: MNISTBatch) -> ClassificationOutput { - self.forward_classification(batch.images, batch.targets) - } + fn step(&self, batch: MNISTBatch) -> ClassificationOutput { + self.forward_classification(batch.images, batch.targets) + } } #[derive(Config)] pub struct TrainingConfig { - pub model: ModelConfig, - pub optimizer: AdamConfig, - #[config(default = 10)] - pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, - #[config(default = 1.0e-4)] - pub learning_rate: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1.0e-4)] + pub learning_rate: f64, } pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { - std::fs::create_dir_all(artifact_dir).ok(); - config - .save(format!("{artifact_dir}/config.json")) - .expect("Save without error"); + std::fs::create_dir_all(artifact_dir).ok(); + config + .save(format!("{artifact_dir}/config.json")) + .expect("Save without error"); - B::seed(config.seed); + B::seed(config.seed); - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); - let learner = LearnerBuilder::new(artifact_dir) - .metric_train_numeric(AccuracyMetric::new()) - .metric_valid_numeric(AccuracyMetric::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device]) - .num_epochs(config.num_epochs) - .build( - config.model.init::(), - config.optimizer.init(), - config.learning_rate, - ); + let learner = LearnerBuilder::new(artifact_dir) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device]) + .num_epochs(config.num_epochs) + .build( + config.model.init::(), + config.optimizer.init(), + config.learning_rate, + ); - let model_trained = learner.fit(dataloader_train, dataloader_test); + let model_trained = learner.fit(dataloader_train, dataloader_test); - model_trained - .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) - .expect("Failed to save trained model"); + model_trained + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Failed to save trained model"); } diff --git a/examples/image-classification-web/build.rs b/examples/image-classification-web/build.rs index e9c4ad2fa1..627a3ea84c 100644 --- a/examples/image-classification-web/build.rs +++ b/examples/image-classification-web/build.rs @@ -13,45 +13,45 @@ const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx"; const OUT_DIR: &str = "model/"; fn main() { - // Re-run the build script if model files change. - println!("cargo:rerun-if-changed=src/model"); + // Re-run the build script if model files change. + println!("cargo:rerun-if-changed=src/model"); - // Check if half precision is enabled. - let half_precision = cfg!(feature = "half_precision"); + // Check if half precision is enabled. + let half_precision = cfg!(feature = "half_precision"); - // Generate the model code from the ONNX file. - ModelGen::new() - .input(INPUT_ONNX_FILE) - .out_dir(OUT_DIR) - .record_type(RecordType::Bincode) - .embed_states(true) - .half_precision(half_precision) - .run_from_script(); + // Generate the model code from the ONNX file. + ModelGen::new() + .input(INPUT_ONNX_FILE) + .out_dir(OUT_DIR) + .record_type(RecordType::Bincode) + .embed_states(true) + .half_precision(half_precision) + .run_from_script(); - // Generate the labels from the synset.txt file. - generate_labels_from_txt_file().unwrap(); + // Generate the labels from the synset.txt file. + generate_labels_from_txt_file().unwrap(); } /// Read labels from synset.txt and store them in a vector of strings in a Rust file. fn generate_labels_from_txt_file() -> std::io::Result<()> { - let out_dir = env::var("OUT_DIR").unwrap(); - let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE); - let mut f = File::create(dest_path)?; + let out_dir = env::var("OUT_DIR").unwrap(); + let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE); + let mut f = File::create(dest_path)?; - let file = File::open(LABEL_SOURCE_FILE)?; - let reader = BufReader::new(file); + let file = File::open(LABEL_SOURCE_FILE)?; + let reader = BufReader::new(file); - writeln!(f, "pub static LABELS: &[&str] = &[")?; - for line in reader.lines() { - writeln!( - f, - " \"{}\",", - extract_simple_label(line.unwrap()).unwrap() - )?; - } - writeln!(f, "];")?; + writeln!(f, "pub static LABELS: &[&str] = &[")?; + for line in reader.lines() { + writeln!( + f, + " \"{}\",", + extract_simple_label(line.unwrap()).unwrap() + )?; + } + writeln!(f, "];")?; - Ok(()) + Ok(()) } /// Extract the simple label from the full label. @@ -59,17 +59,17 @@ fn generate_labels_from_txt_file() -> std::io::Result<()> { /// The full label is of the form: "n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea" /// The simple label is of the form: "indigo bunting" fn extract_simple_label(input: String) -> Option { - // Split the string based on the space character. - let mut parts = input.split(' '); + // Split the string based on the space character. + let mut parts = input.split(' '); - // Skip the first part (the alphanumeric code). - parts.next()?; + // Skip the first part (the alphanumeric code). + parts.next()?; - // Get the remaining string. - let remaining = parts.collect::>().join(" "); + // Get the remaining string. + let remaining = parts.collect::>().join(" "); - // Find the first comma, if it exists, and take the substring before it. - let end_index = remaining.find(',').unwrap_or(remaining.len()); + // Find the first comma, if it exists, and take the substring before it. + let end_index = remaining.find(',').unwrap_or(remaining.len()); - Some(remaining[0..end_index].to_string()) + Some(remaining[0..end_index].to_string()) } diff --git a/examples/image-classification-web/src/model/normalizer.rs b/examples/image-classification-web/src/model/normalizer.rs index 7e1e6019ce..d55cbd100a 100644 --- a/examples/image-classification-web/src/model/normalizer.rs +++ b/examples/image-classification-web/src/model/normalizer.rs @@ -7,32 +7,32 @@ const STD: [f32; 3] = [0.229, 0.224, 0.225]; /// Normalizer for the imagenet dataset. pub struct Normalizer { - pub mean: Tensor, - pub std: Tensor, + pub mean: Tensor, + pub std: Tensor, } impl Normalizer { - /// Creates a new normalizer. - pub fn new() -> Self { - let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]); - let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]); - Self { mean, std } - } + /// Creates a new normalizer. + pub fn new() -> Self { + let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]); + let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]); + Self { mean, std } + } - /// Normalizes the input image according to the imagenet dataset. - /// - /// The input image should be in the range [0, 1]. - /// The output image will be in the range [-1, 1]. - /// - /// The normalization is done according to the following formula: - /// `input = (input - mean) / std` - pub fn normalize(&self, input: Tensor) -> Tensor { - (input - self.mean.clone()) / self.std.clone() - } + /// Normalizes the input image according to the imagenet dataset. + /// + /// The input image should be in the range [0, 1]. + /// The output image will be in the range [-1, 1]. + /// + /// The normalization is done according to the following formula: + /// `input = (input - mean) / std` + pub fn normalize(&self, input: Tensor) -> Tensor { + (input - self.mean.clone()) / self.std.clone() + } } impl Default for Normalizer { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } diff --git a/examples/image-classification-web/src/model/squeezenet.rs b/examples/image-classification-web/src/model/squeezenet.rs index d796ae629f..c729aa0177 100644 --- a/examples/image-classification-web/src/model/squeezenet.rs +++ b/examples/image-classification-web/src/model/squeezenet.rs @@ -1,6 +1,6 @@ // Generated model from squeezenet1.onnx mod internal_model { - include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs")); + include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs")); } pub use internal_model::*; diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 52e50c5653..df093a0ce4 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -1,19 +1,19 @@ #![allow(clippy::new_without_default)] use alloc::{ - string::{String, ToString}, - vec::Vec, + string::{String, ToString}, + vec::Vec, }; use core::convert::Into; use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel}; use burn::{ - backend::{ - wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice}, - NdArray, - }, - tensor::{activation::softmax, backend::Backend, Tensor}, + backend::{ + wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice}, + NdArray, + }, + tensor::{activation::softmax, backend::Backend, Tensor}, }; use burn_candle::Candle; @@ -24,14 +24,14 @@ use wasm_timer::Instant; #[allow(clippy::large_enum_variant)] /// The model is loaded to a specific backend pub enum ModelType { - /// The model is loaded to the Candle backend - WithCandleBackend(Model>), + /// The model is loaded to the Candle backend + WithCandleBackend(Model>), - /// The model is loaded to the NdArray backend - WithNdArrayBackend(Model>), + /// The model is loaded to the NdArray backend + WithNdArrayBackend(Model>), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(Model>), + /// The model is loaded to the Wgpu backend + WithWgpuBackend(Model>), } /// The image is 224x224 pixels with 3 channels (RGB) @@ -42,150 +42,150 @@ const CHANNELS: usize = 3; /// The image classifier #[wasm_bindgen] pub struct ImageClassifier { - model: ModelType, + model: ModelType, } #[wasm_bindgen] impl ImageClassifier { - /// Constructor called by JavaScripts with the new keyword. - #[wasm_bindgen(constructor)] - pub fn new() -> Self { - // Initialize the logger so that the logs are printed to the console - wasm_logger::init(wasm_logger::Config::default()); + /// Constructor called by JavaScripts with the new keyword. + #[wasm_bindgen(constructor)] + pub fn new() -> Self { + // Initialize the logger so that the logs are printed to the console + wasm_logger::init(wasm_logger::Config::default()); - log::info!("Initializing the image classifier"); + log::info!("Initializing the image classifier"); - Self { - model: ModelType::WithNdArrayBackend(Model::new()), - } - } - - /// Runs inference on the image - pub async fn inference(&self, input: &[f32]) -> Result { - log::info!("Running inference on the image"); - - let start = Instant::now(); - - let result = match self.model { - ModelType::WithCandleBackend(ref model) => model.forward(input).await, - ModelType::WithNdArrayBackend(ref model) => model.forward(input).await, - ModelType::WithWgpuBackend(ref model) => model.forward(input).await, - }; - - let duration = start.elapsed(); - - log::debug!("Inference is completed in {:?}", duration); - - top_5_classes(result) - } - - /// Sets the backend to Candle - pub async fn set_backend_candle(&mut self) -> Result<(), JsValue> { - log::info!("Loading the model to the Candle backend"); - let start = Instant::now(); - self.model = ModelType::WithCandleBackend(Model::new()); - let duration = start.elapsed(); - log::debug!("Model is loaded to the Candle backend in {:?}", duration); - Ok(()) - } - - /// Sets the backend to NdArray - pub async fn set_backend_ndarray(&mut self) -> Result<(), JsValue> { - log::info!("Loading the model to the NdArray backend"); - let start = Instant::now(); - self.model = ModelType::WithNdArrayBackend(Model::new()); - let duration = start.elapsed(); - log::debug!("Model is loaded to the NdArray backend in {:?}", duration); - Ok(()) - } - - /// Sets the backend to Wgpu - pub async fn set_backend_wgpu(&mut self) -> Result<(), JsValue> { - log::info!("Loading the model to the Wgpu backend"); - let start = Instant::now(); - init_async::(&WgpuDevice::default()).await; - self.model = ModelType::WithWgpuBackend(Model::new()); - let duration = start.elapsed(); - log::debug!("Model is loaded to the Wgpu backend in {:?}", duration); - - log::debug!("Warming up the model"); - let start = Instant::now(); - let _ = self.inference(&[0.0; HEIGHT * WIDTH * CHANNELS]).await; - let duration = start.elapsed(); - log::debug!("Warming up is completed in {:?}", duration); - Ok(()) + Self { + model: ModelType::WithNdArrayBackend(Model::new()), } + } + + /// Runs inference on the image + pub async fn inference(&self, input: &[f32]) -> Result { + log::info!("Running inference on the image"); + + let start = Instant::now(); + + let result = match self.model { + ModelType::WithCandleBackend(ref model) => model.forward(input).await, + ModelType::WithNdArrayBackend(ref model) => model.forward(input).await, + ModelType::WithWgpuBackend(ref model) => model.forward(input).await, + }; + + let duration = start.elapsed(); + + log::debug!("Inference is completed in {:?}", duration); + + top_5_classes(result) + } + + /// Sets the backend to Candle + pub async fn set_backend_candle(&mut self) -> Result<(), JsValue> { + log::info!("Loading the model to the Candle backend"); + let start = Instant::now(); + self.model = ModelType::WithCandleBackend(Model::new()); + let duration = start.elapsed(); + log::debug!("Model is loaded to the Candle backend in {:?}", duration); + Ok(()) + } + + /// Sets the backend to NdArray + pub async fn set_backend_ndarray(&mut self) -> Result<(), JsValue> { + log::info!("Loading the model to the NdArray backend"); + let start = Instant::now(); + self.model = ModelType::WithNdArrayBackend(Model::new()); + let duration = start.elapsed(); + log::debug!("Model is loaded to the NdArray backend in {:?}", duration); + Ok(()) + } + + /// Sets the backend to Wgpu + pub async fn set_backend_wgpu(&mut self) -> Result<(), JsValue> { + log::info!("Loading the model to the Wgpu backend"); + let start = Instant::now(); + init_async::(&WgpuDevice::default()).await; + self.model = ModelType::WithWgpuBackend(Model::new()); + let duration = start.elapsed(); + log::debug!("Model is loaded to the Wgpu backend in {:?}", duration); + + log::debug!("Warming up the model"); + let start = Instant::now(); + let _ = self.inference(&[0.0; HEIGHT * WIDTH * CHANNELS]).await; + let duration = start.elapsed(); + log::debug!("Warming up is completed in {:?}", duration); + Ok(()) + } } /// The image classifier model pub struct Model { - model: SqueezenetModel, - normalizer: Normalizer, + model: SqueezenetModel, + normalizer: Normalizer, } impl Model { - /// Constructor - pub fn new() -> Self { - Self { - model: SqueezenetModel::from_embedded(), - normalizer: Normalizer::new(), - } + /// Constructor + pub fn new() -> Self { + Self { + model: SqueezenetModel::from_embedded(), + normalizer: Normalizer::new(), } + } - /// Normalizes input and runs inference on the image - pub async fn forward(&self, input: &[f32]) -> Vec { - // Reshape from the 1D array to 3d tensor [ width, height, channels] - let input: Tensor = Tensor::from_floats(input).reshape([1, CHANNELS, HEIGHT, WIDTH]); + /// Normalizes input and runs inference on the image + pub async fn forward(&self, input: &[f32]) -> Vec { + // Reshape from the 1D array to 3d tensor [ width, height, channels] + let input: Tensor = Tensor::from_floats(input).reshape([1, CHANNELS, HEIGHT, WIDTH]); - // Normalize input: make between [-1,1] and make the mean=0 and std=1 - let input = self.normalizer.normalize(input); + // Normalize input: make between [-1,1] and make the mean=0 and std=1 + let input = self.normalizer.normalize(input); - // Run the tensor input through the model - let output = self.model.forward(input); + // Run the tensor input through the model + let output = self.model.forward(input); - // Convert the model output into probability distribution using softmax formula - let probabilies = softmax(output, 1); + // Convert the model output into probability distribution using softmax formula + let probabilies = softmax(output, 1); - #[cfg(not(target_family = "wasm"))] - let result = probabilies.into_data().convert::().value; + #[cfg(not(target_family = "wasm"))] + let result = probabilies.into_data().convert::().value; - // Forces the result to be computed - #[cfg(target_family = "wasm")] - let result = probabilies.into_data().await.convert::().value; + // Forces the result to be computed + #[cfg(target_family = "wasm")] + let result = probabilies.into_data().await.convert::().value; - result - } + result + } } #[wasm_bindgen] #[derive(Serialize)] pub struct InferenceResult { - index: usize, - probability: f32, - label: String, + index: usize, + probability: f32, + label: String, } /// Returns the top 5 classes and convert them into a JsValue fn top_5_classes(probabilies: Vec) -> Result { - // Convert the probabilities into a vector of (index, probability) - let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect(); - - // Sort the probabilities in descending order - probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); - - // Take the top 5 probabilities - probabilies.truncate(5); - - // Convert the probabilities into InferenceResult - let result: Vec = probabilies - .into_iter() - .map(|(index, probability)| InferenceResult { - index, - probability: *probability, - label: LABELS[index].to_string(), - }) - .collect(); - - // Convert the InferenceResult into a JsValue - Ok(serde_wasm_bindgen::to_value(&result)?) + // Convert the probabilities into a vector of (index, probability) + let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect(); + + // Sort the probabilities in descending order + probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + + // Take the top 5 probabilities + probabilies.truncate(5); + + // Convert the probabilities into InferenceResult + let result: Vec = probabilies + .into_iter() + .map(|(index, probability)| InferenceResult { + index, + probability: *probability, + label: LABELS[index].to_string(), + }) + .collect(); + + // Convert the InferenceResult into a JsValue + Ok(serde_wasm_bindgen::to_value(&result)?) } diff --git a/examples/mnist-inference-web/src/model.rs b/examples/mnist-inference-web/src/model.rs index 66054100a3..10cc05ae50 100644 --- a/examples/mnist-inference-web/src/model.rs +++ b/examples/mnist-inference-web/src/model.rs @@ -3,94 +3,94 @@ // Originally copied from the burn/examples/mnist package use burn::{ - module::Module, - nn::{self, BatchNorm, PaddingConfig2d}, - tensor::{backend::Backend, Tensor}, + module::Module, + nn::{self, BatchNorm, PaddingConfig2d}, + tensor::{backend::Backend, Tensor}, }; #[derive(Module, Debug)] pub struct Model { - conv1: ConvBlock, - conv2: ConvBlock, - conv3: ConvBlock, - dropout: nn::Dropout, - fc1: nn::Linear, - fc2: nn::Linear, - activation: nn::GELU, + conv1: ConvBlock, + conv2: ConvBlock, + conv3: ConvBlock, + dropout: nn::Dropout, + fc1: nn::Linear, + fc2: nn::Linear, + activation: nn::GELU, } const NUM_CLASSES: usize = 10; impl Model { - pub fn new() -> Self { - let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] - let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] - let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] - let hidden_size = 24 * 22 * 22; - let fc1 = nn::LinearConfig::new(hidden_size, 32) - .with_bias(false) - .init(); - let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) - .with_bias(false) - .init(); - - let dropout = nn::DropoutConfig::new(0.5).init(); - - Self { - conv1, - conv2, - conv3, - fc1, - fc2, - dropout, - activation: nn::GELU::new(), - } + pub fn new() -> Self { + let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] + let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] + let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] + let hidden_size = 24 * 22 * 22; + let fc1 = nn::LinearConfig::new(hidden_size, 32) + .with_bias(false) + .init(); + let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) + .with_bias(false) + .init(); + + let dropout = nn::DropoutConfig::new(0.5).init(); + + Self { + conv1, + conv2, + conv3, + fc1, + fc2, + dropout, + activation: nn::GELU::new(), } + } - pub fn forward(&self, input: Tensor) -> Tensor { - let [batch_size, height, width] = input.dims(); + pub fn forward(&self, input: Tensor) -> Tensor { + let [batch_size, height, width] = input.dims(); - let x = input.reshape([batch_size, 1, height, width]).detach(); - let x = self.conv1.forward(x); - let x = self.conv2.forward(x); - let x = self.conv3.forward(x); + let x = input.reshape([batch_size, 1, height, width]).detach(); + let x = self.conv1.forward(x); + let x = self.conv2.forward(x); + let x = self.conv3.forward(x); - let [batch_size, channels, height, width] = x.dims(); - let x = x.reshape([batch_size, channels * height * width]); + let [batch_size, channels, height, width] = x.dims(); + let x = x.reshape([batch_size, channels * height * width]); - let x = self.dropout.forward(x); - let x = self.fc1.forward(x); - let x = self.activation.forward(x); + let x = self.dropout.forward(x); + let x = self.fc1.forward(x); + let x = self.activation.forward(x); - self.fc2.forward(x) - } + self.fc2.forward(x) + } } #[derive(Module, Debug)] pub struct ConvBlock { - conv: nn::conv::Conv2d, - norm: BatchNorm, - activation: nn::GELU, + conv: nn::conv::Conv2d, + norm: BatchNorm, + activation: nn::GELU, } impl ConvBlock { - pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { - let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) - .with_padding(PaddingConfig2d::Valid) - .init(); - let norm = nn::BatchNormConfig::new(channels[1]).init(); - - Self { - conv, - norm, - activation: nn::GELU::new(), - } + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) + .with_padding(PaddingConfig2d::Valid) + .init(); + let norm = nn::BatchNormConfig::new(channels[1]).init(); + + Self { + conv, + norm, + activation: nn::GELU::new(), } + } - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.conv.forward(input); - let x = self.norm.forward(x); + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.conv.forward(input); + let x = self.norm.forward(x); - self.activation.forward(x) - } + self.activation.forward(x) + } } diff --git a/examples/mnist-inference-web/src/state.rs b/examples/mnist-inference-web/src/state.rs index 5fabc868e4..d1b1311cae 100644 --- a/examples/mnist-inference-web/src/state.rs +++ b/examples/mnist-inference-web/src/state.rs @@ -17,13 +17,13 @@ static STATE_ENCODED: &[u8] = include_bytes!("../model.bin"); /// Builds and loads trained parameters into the model. pub async fn build_and_load_model() -> Model { - #[cfg(feature = "wgpu")] - init_async::(&WgpuDevice::default()).await; + #[cfg(feature = "wgpu")] + init_async::(&WgpuDevice::default()).await; - let model: Model = Model::new(); - let record = BinBytesRecorder::::default() - .load(STATE_ENCODED.to_vec()) - .expect("Failed to decode state"); + let model: Model = Model::new(); + let record = BinBytesRecorder::::default() + .load(STATE_ENCODED.to_vec()) + .expect("Failed to decode state"); - model.load_record(record) + model.load_record(record) } diff --git a/examples/mnist-inference-web/src/web.rs b/examples/mnist-inference-web/src/web.rs index d15aa4a257..7c69308d3e 100644 --- a/examples/mnist-inference-web/src/web.rs +++ b/examples/mnist-inference-web/src/web.rs @@ -15,63 +15,63 @@ use burn::tensor::Tensor; /// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html) #[cfg_attr(target_family = "wasm", wasm_bindgen)] pub struct Mnist { - model: Option>, + model: Option>, } #[cfg_attr(target_family = "wasm", wasm_bindgen)] impl Mnist { - /// Constructor called by JavaScripts with the new keyword. - #[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))] - pub fn new() -> Self { - Self { model: None } + /// Constructor called by JavaScripts with the new keyword. + #[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))] + pub fn new() -> Self { + Self { model: None } + } + + /// Returns the inference results. + /// + /// This method is called from JavaScript via generated wrapper code by wasm-bindgen. + /// + /// # Arguments + /// + /// * `input` - A f32 slice of input 28x28 image + /// + /// See bindgen support types for passing and returning arrays: + /// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html) + /// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html) + /// + pub async fn inference(&mut self, input: &[f32]) -> Result { + if self.model.is_none() { + self.model = Some(build_and_load_model().await); } - /// Returns the inference results. - /// - /// This method is called from JavaScript via generated wrapper code by wasm-bindgen. - /// - /// # Arguments - /// - /// * `input` - A f32 slice of input 28x28 image - /// - /// See bindgen support types for passing and returning arrays: - /// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html) - /// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html) - /// - pub async fn inference(&mut self, input: &[f32]) -> Result { - if self.model.is_none() { - self.model = Some(build_and_load_model().await); - } - - let model = self.model.as_ref().unwrap(); - - // Reshape from the 1D array to 3d tensor [batch, height, width] - let input: Tensor = Tensor::from_floats(input).reshape([1, 28, 28]); - - // Normalize input: make between [0,1] and make the mean=0 and std=1 - // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example - // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 - - let input = ((input / 255) - 0.1307) / 0.3081; - - // Run the tensor input through the model - let output: Tensor = model.forward(input); - - // Convert the model output into probability distribution using softmax formula - let output = burn::tensor::activation::softmax(output, 1); - - // Flatten output tensor with [1, 10] shape into boxed slice of [f32] - #[cfg(not(target_family = "wasm"))] - let output = output.into_data().convert::().value; - - #[cfg(target_family = "wasm")] - let output = output.into_data().await.convert::().value; - - let array = Array::new(); - for value in output { - array.push(&value.into()); - } - - Ok(array) + let model = self.model.as_ref().unwrap(); + + // Reshape from the 1D array to 3d tensor [batch, height, width] + let input: Tensor = Tensor::from_floats(input).reshape([1, 28, 28]); + + // Normalize input: make between [0,1] and make the mean=0 and std=1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + + let input = ((input / 255) - 0.1307) / 0.3081; + + // Run the tensor input through the model + let output: Tensor = model.forward(input); + + // Convert the model output into probability distribution using softmax formula + let output = burn::tensor::activation::softmax(output, 1); + + // Flatten output tensor with [1, 10] shape into boxed slice of [f32] + #[cfg(not(target_family = "wasm"))] + let output = output.into_data().convert::().value; + + #[cfg(target_family = "wasm")] + let output = output.into_data().await.convert::().value; + + let array = Array::new(); + for value in output { + array.push(&value.into()); } + + Ok(array) + } } diff --git a/examples/mnist/examples/mnist.rs b/examples/mnist/examples/mnist.rs index e22a209a0e..3f0d16d63c 100644 --- a/examples/mnist/examples/mnist.rs +++ b/examples/mnist/examples/mnist.rs @@ -1,72 +1,72 @@ #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - let device = NdArrayDevice::Cpu; - training::run::>(device); - } + pub fn run() { + let device = NdArrayDevice::Cpu; + training::run::>(device); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - training::run::>(device); - } + training::run::>(device); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{Wgpu, WgpuDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::wgpu::{Wgpu, WgpuDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - let device = WgpuDevice::default(); - training::run::>(device); - } + pub fn run() { + let device = WgpuDevice::default(); + training::run::>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - let device = LibTorchDevice::Cpu; - training::run::>(device); - } + pub fn run() { + let device = LibTorchDevice::Cpu; + training::run::>(device); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/mnist/src/data.rs b/examples/mnist/src/data.rs index d424b9e038..0f253e948d 100644 --- a/examples/mnist/src/data.rs +++ b/examples/mnist/src/data.rs @@ -1,45 +1,45 @@ use burn::{ - data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, - tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, + data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, + tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, }; pub struct MNISTBatcher { - device: B::Device, + device: B::Device, } #[derive(Clone, Debug)] pub struct MNISTBatch { - pub images: Tensor, - pub targets: Tensor, + pub images: Tensor, + pub targets: Tensor, } impl MNISTBatcher { - pub fn new(device: B::Device) -> Self { - Self { device } - } + pub fn new(device: B::Device) -> Self { + Self { device } + } } impl Batcher> for MNISTBatcher { - fn batch(&self, items: Vec) -> MNISTBatch { - let images = items - .iter() - .map(|item| Data::::from(item.image)) - .map(|data| Tensor::::from_data(data.convert())) - .map(|tensor| tensor.reshape([1, 28, 28])) - // normalize: make between [0,1] and make the mean = 0 and std = 1 - // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example - // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 - .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) - .collect(); + fn batch(&self, items: Vec) -> MNISTBatch { + let images = items + .iter() + .map(|item| Data::::from(item.image)) + .map(|data| Tensor::::from_data(data.convert())) + .map(|tensor| tensor.reshape([1, 28, 28])) + // normalize: make between [0,1] and make the mean = 0 and std = 1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) + .collect(); - let targets = items - .iter() - .map(|item| Tensor::::from_data(Data::from([(item.label as i64).elem()]))) - .collect(); + let targets = items + .iter() + .map(|item| Tensor::::from_data(Data::from([(item.label as i64).elem()]))) + .collect(); - let images = Tensor::cat(images, 0).to_device(&self.device); - let targets = Tensor::cat(targets, 0).to_device(&self.device); + let images = Tensor::cat(images, 0).to_device(&self.device); + let targets = Tensor::cat(targets, 0).to_device(&self.device); - MNISTBatch { images, targets } - } + MNISTBatch { images, targets } + } } diff --git a/examples/mnist/src/model.rs b/examples/mnist/src/model.rs index eca5be9b14..02efcf2906 100644 --- a/examples/mnist/src/model.rs +++ b/examples/mnist/src/model.rs @@ -1,130 +1,130 @@ use crate::data::MNISTBatch; use burn::{ - module::Module, - nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d}, - tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, - }, - train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, + module::Module, + nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d}, + tensor::{ + backend::{AutodiffBackend, Backend}, + Tensor, + }, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, }; #[derive(Module, Debug)] pub struct Model { - conv1: ConvBlock, - conv2: ConvBlock, - conv3: ConvBlock, - dropout: nn::Dropout, - fc1: nn::Linear, - fc2: nn::Linear, - activation: nn::GELU, + conv1: ConvBlock, + conv2: ConvBlock, + conv3: ConvBlock, + dropout: nn::Dropout, + fc1: nn::Linear, + fc2: nn::Linear, + activation: nn::GELU, } impl Default for Model { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } const NUM_CLASSES: usize = 10; impl Model { - pub fn new() -> Self { - let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] - let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] - let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] - let hidden_size = 24 * 22 * 22; - let fc1 = nn::LinearConfig::new(hidden_size, 32) - .with_bias(false) - .init(); - let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) - .with_bias(false) - .init(); - - let dropout = nn::DropoutConfig::new(0.5).init(); - - Self { - conv1, - conv2, - conv3, - dropout, - fc1, - fc2, - activation: nn::GELU::new(), - } + pub fn new() -> Self { + let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] + let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] + let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] + let hidden_size = 24 * 22 * 22; + let fc1 = nn::LinearConfig::new(hidden_size, 32) + .with_bias(false) + .init(); + let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) + .with_bias(false) + .init(); + + let dropout = nn::DropoutConfig::new(0.5).init(); + + Self { + conv1, + conv2, + conv3, + dropout, + fc1, + fc2, + activation: nn::GELU::new(), } + } - pub fn forward(&self, input: Tensor) -> Tensor { - let [batch_size, height, width] = input.dims(); + pub fn forward(&self, input: Tensor) -> Tensor { + let [batch_size, height, width] = input.dims(); - let x = input.reshape([batch_size, 1, height, width]).detach(); - let x = self.conv1.forward(x); - let x = self.conv2.forward(x); - let x = self.conv3.forward(x); + let x = input.reshape([batch_size, 1, height, width]).detach(); + let x = self.conv1.forward(x); + let x = self.conv2.forward(x); + let x = self.conv3.forward(x); - let [batch_size, channels, height, width] = x.dims(); - let x = x.reshape([batch_size, channels * height * width]); + let [batch_size, channels, height, width] = x.dims(); + let x = x.reshape([batch_size, channels * height * width]); - let x = self.dropout.forward(x); - let x = self.fc1.forward(x); - let x = self.activation.forward(x); + let x = self.dropout.forward(x); + let x = self.fc1.forward(x); + let x = self.activation.forward(x); - self.fc2.forward(x) - } + self.fc2.forward(x) + } + + pub fn forward_classification(&self, item: MNISTBatch) -> ClassificationOutput { + let targets = item.targets; + let output = self.forward(item.images); + let loss = CrossEntropyLoss::default(); + let loss = loss.forward(output.clone(), targets.clone()); - pub fn forward_classification(&self, item: MNISTBatch) -> ClassificationOutput { - let targets = item.targets; - let output = self.forward(item.images); - let loss = CrossEntropyLoss::default(); - let loss = loss.forward(output.clone(), targets.clone()); - - ClassificationOutput { - loss, - output, - targets, - } + ClassificationOutput { + loss, + output, + targets, } + } } #[derive(Module, Debug)] pub struct ConvBlock { - conv: nn::conv::Conv2d, - norm: BatchNorm, - activation: nn::GELU, + conv: nn::conv::Conv2d, + norm: BatchNorm, + activation: nn::GELU, } impl ConvBlock { - pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { - let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) - .with_padding(PaddingConfig2d::Valid) - .init(); - let norm = nn::BatchNormConfig::new(channels[1]).init(); - - Self { - conv, - norm, - activation: nn::GELU::new(), - } + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) + .with_padding(PaddingConfig2d::Valid) + .init(); + let norm = nn::BatchNormConfig::new(channels[1]).init(); + + Self { + conv, + norm, + activation: nn::GELU::new(), } + } - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.conv.forward(input); - let x = self.norm.forward(x); + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.conv.forward(input); + let x = self.norm.forward(x); - self.activation.forward(x) - } + self.activation.forward(x) + } } impl TrainStep, ClassificationOutput> for Model { - fn step(&self, item: MNISTBatch) -> TrainOutput> { - let item = self.forward_classification(item); + fn step(&self, item: MNISTBatch) -> TrainOutput> { + let item = self.forward_classification(item); - TrainOutput::new(self, item.loss.backward(), item) - } + TrainOutput::new(self, item.loss.backward(), item) + } } impl ValidStep, ClassificationOutput> for Model { - fn step(&self, item: MNISTBatch) -> ClassificationOutput { - self.forward_classification(item) - } + fn step(&self, item: MNISTBatch) -> ClassificationOutput { + self.forward_classification(item) + } } diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index 2f7a5bed94..a1f2687401 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -9,88 +9,88 @@ use burn::train::metric::store::{Aggregate, Direction, Split}; use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse}; use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition}; use burn::{ - config::Config, - data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset}, - tensor::backend::AutodiffBackend, - train::{ - metric::{AccuracyMetric, LossMetric}, - LearnerBuilder, - }, + config::Config, + data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset}, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, LossMetric}, + LearnerBuilder, + }, }; static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist"; #[derive(Config)] pub struct MnistTrainingConfig { - #[config(default = 10)] - pub num_epochs: usize, + #[config(default = 10)] + pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, + #[config(default = 64)] + pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, + #[config(default = 4)] + pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, + #[config(default = 42)] + pub seed: u64, - pub optimizer: AdamConfig, + pub optimizer: AdamConfig, } pub fn run(device: B::Device) { - // Config - let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))); - let config = MnistTrainingConfig::new(config_optimizer); - B::seed(config.seed); + // Config + let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))); + let config = MnistTrainingConfig::new(config_optimizer); + B::seed(config.seed); - // Data - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); + // Data + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); - // Model - let learner = LearnerBuilder::new(ARTIFACT_DIR) - .metric_train_numeric(AccuracyMetric::new()) - .metric_valid_numeric(AccuracyMetric::new()) - .metric_train_numeric(CpuUse::new()) - .metric_valid_numeric(CpuUse::new()) - .metric_train_numeric(CpuMemory::new()) - .metric_valid_numeric(CpuMemory::new()) - .metric_train_numeric(CpuTemperature::new()) - .metric_valid_numeric(CpuTemperature::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .early_stopping(MetricEarlyStoppingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Valid, - StoppingCondition::NoImprovementSince { n_epochs: 1 }, - )) - .devices(vec![device]) - .num_epochs(config.num_epochs) - .build(Model::new(), config.optimizer.init(), 1e-4); + // Model + let learner = LearnerBuilder::new(ARTIFACT_DIR) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(CpuUse::new()) + .metric_valid_numeric(CpuUse::new()) + .metric_train_numeric(CpuMemory::new()) + .metric_valid_numeric(CpuMemory::new()) + .metric_train_numeric(CpuTemperature::new()) + .metric_valid_numeric(CpuTemperature::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .early_stopping(MetricEarlyStoppingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + StoppingCondition::NoImprovementSince { n_epochs: 1 }, + )) + .devices(vec![device]) + .num_epochs(config.num_epochs) + .build(Model::new(), config.optimizer.init(), 1e-4); - let model_trained = learner.fit(dataloader_train, dataloader_test); + let model_trained = learner.fit(dataloader_train, dataloader_test); - config - .save(format!("{ARTIFACT_DIR}/config.json").as_str()) - .unwrap(); + config + .save(format!("{ARTIFACT_DIR}/config.json").as_str()) + .unwrap(); - model_trained - .save_file( - format!("{ARTIFACT_DIR}/model"), - &NoStdTrainingRecorder::new(), - ) - .expect("Failed to save trained model"); + model_trained + .save_file( + format!("{ARTIFACT_DIR}/model"), + &NoStdTrainingRecorder::new(), + ) + .expect("Failed to save trained model"); } diff --git a/examples/named-tensor/examples/named-tensor.rs b/examples/named-tensor/examples/named-tensor.rs index 7ea0dd159d..967f75712f 100644 --- a/examples/named-tensor/examples/named-tensor.rs +++ b/examples/named-tensor/examples/named-tensor.rs @@ -1,3 +1,3 @@ fn main() { - named_tensor::run::>(); + named_tensor::run::>(); } diff --git a/examples/named-tensor/src/lib.rs b/examples/named-tensor/src/lib.rs index 3aa637a587..f2f7ee0e16 100644 --- a/examples/named-tensor/src/lib.rs +++ b/examples/named-tensor/src/lib.rs @@ -6,42 +6,40 @@ NamedDim!(SeqLength); NamedDim!(DModel); pub fn run() { - let batch_size = 32; - let seq_length = 48; - let d_model = 24; - - let weights = NamedTensor::::random( - [1, d_model, d_model], - Distribution::Default, - ); - - let input = NamedTensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - ); - - // Doesn't compile - // - // mismatched types - // expected reference `&NamedTensor` - // found reference `&NamedTensor` - // let output = weights.matmul(&input); - - let output = input.clone().matmul(weights.clone()); - - // Doesn't compile - // - // mismatched types - // expected reference `&NamedTensor` - // found reference `&NamedTensor` - // let output = output.mul(&weights); - - let output = output.mul(input.clone()); - - let permut = output.clone().swap_dims::<_, 1, 2>(); - - println!("Weights => {weights}"); - println!("Input => {input}"); - println!("Output => {output}"); - println!("Permut => {permut}"); + let batch_size = 32; + let seq_length = 48; + let d_model = 24; + + let weights = + NamedTensor::::random([1, d_model, d_model], Distribution::Default); + + let input = NamedTensor::::random( + [batch_size, seq_length, d_model], + Distribution::Default, + ); + + // Doesn't compile + // + // mismatched types + // expected reference `&NamedTensor` + // found reference `&NamedTensor` + // let output = weights.matmul(&input); + + let output = input.clone().matmul(weights.clone()); + + // Doesn't compile + // + // mismatched types + // expected reference `&NamedTensor` + // found reference `&NamedTensor` + // let output = output.mul(&weights); + + let output = output.mul(input.clone()); + + let permut = output.clone().swap_dims::<_, 1, 2>(); + + println!("Weights => {weights}"); + println!("Input => {input}"); + println!("Output => {output}"); + println!("Permut => {permut}"); } diff --git a/examples/onnx-inference/build.rs b/examples/onnx-inference/build.rs index 174d3e517b..5425853542 100644 --- a/examples/onnx-inference/build.rs +++ b/examples/onnx-inference/build.rs @@ -1,21 +1,21 @@ use burn_import::onnx::{ModelGen, RecordType}; fn main() { - // Generate the model code from the ONNX file. + // Generate the model code from the ONNX file. - if cfg!(feature = "embedded-model") { - // If the embedded-model, then model is bundled into the binary. - ModelGen::new() - .input("src/model/mnist.onnx") - .out_dir("model/") - .record_type(RecordType::Bincode) - .embed_states(true) - .run_from_script(); - } else { - // If not embedded-model, then model is loaded from the file system (default). - ModelGen::new() - .input("src/model/mnist.onnx") - .out_dir("model/") - .run_from_script(); - } + if cfg!(feature = "embedded-model") { + // If the embedded-model, then model is bundled into the binary. + ModelGen::new() + .input("src/model/mnist.onnx") + .out_dir("model/") + .record_type(RecordType::Bincode) + .embed_states(true) + .run_from_script(); + } else { + // If not embedded-model, then model is loaded from the file system (default). + ModelGen::new() + .input("src/model/mnist.onnx") + .out_dir("model/") + .run_from_script(); + } } diff --git a/examples/onnx-inference/src/bin/mnist_inference.rs b/examples/onnx-inference/src/bin/mnist_inference.rs index 82913f4693..ca74984b8c 100644 --- a/examples/onnx-inference/src/bin/mnist_inference.rs +++ b/examples/onnx-inference/src/bin/mnist_inference.rs @@ -11,53 +11,53 @@ use onnx_inference::mnist::Model; const IMAGE_INX: usize = 42; // <- Change this to test a different image fn main() { - // Get image index argument (first) from command line + // Get image index argument (first) from command line - let image_index = if let Some(image_index) = args().nth(1) { - println!("Image index: {}", image_index); - image_index - .parse::() - .expect("Failed to parse image index") - } else { - println!("No image index provided; Using default image index: {IMAGE_INX}"); - IMAGE_INX - }; + let image_index = if let Some(image_index) = args().nth(1) { + println!("Image index: {}", image_index); + image_index + .parse::() + .expect("Failed to parse image index") + } else { + println!("No image index provided; Using default image index: {IMAGE_INX}"); + IMAGE_INX + }; - assert!(image_index < 10000, "Image index must be less than 10000"); + assert!(image_index < 10000, "Image index must be less than 10000"); - type Backend = NdArray; + type Backend = NdArray; - // Create a new model and load the state - let model: Model = Model::default(); + // Create a new model and load the state + let model: Model = Model::default(); - // Load the MNIST dataset and get an item - let dataset = MNISTDataset::test(); - let item = dataset.get(image_index).unwrap(); + // Load the MNIST dataset and get an item + let dataset = MNISTDataset::test(); + let item = dataset.get(image_index).unwrap(); - // Create a tensor from the image data - let image_data = item.image.iter().copied().flatten().collect::>(); - let mut input: Tensor = - Tensor::from_floats(image_data.as_slice()).reshape([1, 1, 28, 28]); + // Create a tensor from the image data + let image_data = item.image.iter().copied().flatten().collect::>(); + let mut input: Tensor = + Tensor::from_floats(image_data.as_slice()).reshape([1, 1, 28, 28]); - // Normalize the input - input = ((input / 255) - 0.1307) / 0.3081; + // Normalize the input + input = ((input / 255) - 0.1307) / 0.3081; - // Run the model on the input - let output = model.forward(input); + // Run the model on the input + let output = model.forward(input); - // Get the index of the maximum value - let arg_max = output.argmax(1).into_scalar() as usize; + // Get the index of the maximum value + let arg_max = output.argmax(1).into_scalar() as usize; - // Check if the index matches the label - assert!(arg_max == item.label); + // Check if the index matches the label + assert!(arg_max == item.label); - println!("Success!"); - println!("Predicted: {}", arg_max); - println!("Actual: {}", item.label); + println!("Success!"); + println!("Predicted: {}", arg_max); + println!("Actual: {}", item.label); - // Print the image URL if the image index is less than 100 (the online dataset only has 100 images) - if image_index < 100 { - println!("See the image online, click the link below:"); - println!("https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/{image_index}/image/image.jpg"); - } + // Print the image URL if the image index is less than 100 (the online dataset only has 100 images) + if image_index < 100 { + println!("See the image online, click the link below:"); + println!("https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/{image_index}/image/image.jpg"); + } } diff --git a/examples/onnx-inference/src/model/mod.rs b/examples/onnx-inference/src/model/mod.rs index 4c821cafd4..adf789cdd1 100644 --- a/examples/onnx-inference/src/model/mod.rs +++ b/examples/onnx-inference/src/model/mod.rs @@ -1,3 +1,3 @@ pub mod mnist { - include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); + include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); } diff --git a/examples/text-classification/examples/ag-news-infer.rs b/examples/text-classification/examples/ag-news-infer.rs index a2bfa0ce1a..abde7a1750 100644 --- a/examples/text-classification/examples/ag-news-infer.rs +++ b/examples/text-classification/examples/ag-news-infer.rs @@ -9,7 +9,7 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - text_classification::inference::infer::( + text_classification::inference::infer::( device, "/tmp/text-classification-ag-news", // Samples from the test dataset, but you are free to test with your own text. @@ -22,75 +22,75 @@ pub fn launch(device: B::Device) { } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Autodiff; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 4b336c2700..f9c8ec0685 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -12,89 +12,89 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - let config = ExperimentConfig::new( - TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), - AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), - ); - - text_classification::training::train::( - device, - AgNewsDataset::train(), - AgNewsDataset::test(), - config, - "/tmp/text-classification-ag-news", - ); + let config = ExperimentConfig::new( + TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), + AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), + ); + + text_classification::training::train::( + device, + AgNewsDataset::train(), + AgNewsDataset::test(), + config, + "/tmp/text-classification-ag-news", + ); } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use crate::{launch, ElemType}; - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::{Autodiff, Fusion}; + use crate::{launch, ElemType}; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::{Autodiff, Fusion}; - pub fn run() { - launch::>>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/examples/db-pedia-infer.rs b/examples/text-classification/examples/db-pedia-infer.rs index be8a2d5dd0..f178e5b6e4 100644 --- a/examples/text-classification/examples/db-pedia-infer.rs +++ b/examples/text-classification/examples/db-pedia-infer.rs @@ -9,7 +9,7 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - text_classification::inference::infer::( + text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", // Samples from the test dataset, but you are free to test with your own text. @@ -22,75 +22,75 @@ pub fn launch(device: B::Device) { } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::tch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::tch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Autodiff; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/examples/db-pedia-train.rs b/examples/text-classification/examples/db-pedia-train.rs index 81319c32cb..e5478088e7 100644 --- a/examples/text-classification/examples/db-pedia-train.rs +++ b/examples/text-classification/examples/db-pedia-train.rs @@ -12,89 +12,89 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - let config = ExperimentConfig::new( - TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), - AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), - ); - - text_classification::training::train::( - device, - DbPediaDataset::train(), - DbPediaDataset::test(), - config, - "/tmp/text-classification-db-pedia", - ); + let config = ExperimentConfig::new( + TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), + AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), + ); + + text_classification::training::train::( + device, + DbPediaDataset::train(), + DbPediaDataset::test(), + config, + "/tmp/text-classification-db-pedia", + ); } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use crate::{launch, ElemType}; - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use crate::{launch, ElemType}; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Autodiff; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/src/data/batcher.rs b/examples/text-classification/src/data/batcher.rs index 75ef1080b0..8b8edc9f0a 100644 --- a/examples/text-classification/src/data/batcher.rs +++ b/examples/text-classification/src/data/batcher.rs @@ -12,92 +12,92 @@ use super::{dataset::TextClassificationItem, tokenizer::Tokenizer}; use burn::{ - data::dataloader::batcher::Batcher, - nn::attention::generate_padding_mask, - tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Tensor}, + data::dataloader::batcher::Batcher, + nn::attention::generate_padding_mask, + tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Tensor}, }; use std::sync::Arc; /// Struct for batching text classification items #[derive(new)] pub struct TextClassificationBatcher { - tokenizer: Arc, // Tokenizer for converting text to token IDs - device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) - max_seq_length: usize, // Maximum sequence length for tokenized text + tokenizer: Arc, // Tokenizer for converting text to token IDs + device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) + max_seq_length: usize, // Maximum sequence length for tokenized text } /// Struct for training batch in text classification task #[derive(Debug, Clone, new)] pub struct TextClassificationTrainingBatch { - pub tokens: Tensor, // Tokenized text - pub labels: Tensor, // Labels of the text - pub mask_pad: Tensor, // Padding mask for the tokenized text + pub tokens: Tensor, // Tokenized text + pub labels: Tensor, // Labels of the text + pub mask_pad: Tensor, // Padding mask for the tokenized text } /// Struct for inference batch in text classification task #[derive(Debug, Clone, new)] pub struct TextClassificationInferenceBatch { - pub tokens: Tensor, // Tokenized text - pub mask_pad: Tensor, // Padding mask for the tokenized text + pub tokens: Tensor, // Tokenized text + pub mask_pad: Tensor, // Padding mask for the tokenized text } /// Implement Batcher trait for TextClassificationBatcher struct for training impl Batcher> - for TextClassificationBatcher + for TextClassificationBatcher { - /// Batches a vector of text classification items into a training batch - fn batch(&self, items: Vec) -> TextClassificationTrainingBatch { - let mut tokens_list = Vec::with_capacity(items.len()); - let mut labels_list = Vec::with_capacity(items.len()); + /// Batches a vector of text classification items into a training batch + fn batch(&self, items: Vec) -> TextClassificationTrainingBatch { + let mut tokens_list = Vec::with_capacity(items.len()); + let mut labels_list = Vec::with_capacity(items.len()); - // Tokenize text and create label tensor for each item - for item in items { - tokens_list.push(self.tokenizer.encode(&item.text)); - labels_list.push(Tensor::from_data(Data::from([(item.label as i64).elem()]))); - } + // Tokenize text and create label tensor for each item + for item in items { + tokens_list.push(self.tokenizer.encode(&item.text)); + labels_list.push(Tensor::from_data(Data::from([(item.label as i64).elem()]))); + } - // Generate padding mask for tokenized text - let mask = generate_padding_mask( - self.tokenizer.pad_token(), - tokens_list, - Some(self.max_seq_length), - &B::Device::default(), - ); + // Generate padding mask for tokenized text + let mask = generate_padding_mask( + self.tokenizer.pad_token(), + tokens_list, + Some(self.max_seq_length), + &B::Device::default(), + ); - // Create and return training batch - TextClassificationTrainingBatch { - tokens: mask.tensor.to_device(&self.device), - labels: Tensor::cat(labels_list, 0).to_device(&self.device), - mask_pad: mask.mask.to_device(&self.device), - } + // Create and return training batch + TextClassificationTrainingBatch { + tokens: mask.tensor.to_device(&self.device), + labels: Tensor::cat(labels_list, 0).to_device(&self.device), + mask_pad: mask.mask.to_device(&self.device), } + } } /// Implement Batcher trait for TextClassificationBatcher struct for inference impl Batcher> - for TextClassificationBatcher + for TextClassificationBatcher { - /// Batches a vector of strings into an inference batch - fn batch(&self, items: Vec) -> TextClassificationInferenceBatch { - let mut tokens_list = Vec::with_capacity(items.len()); + /// Batches a vector of strings into an inference batch + fn batch(&self, items: Vec) -> TextClassificationInferenceBatch { + let mut tokens_list = Vec::with_capacity(items.len()); - // Tokenize each string - for item in items { - tokens_list.push(self.tokenizer.encode(&item)); - } + // Tokenize each string + for item in items { + tokens_list.push(self.tokenizer.encode(&item)); + } - // Generate padding mask for tokenized text - let mask = generate_padding_mask( - self.tokenizer.pad_token(), - tokens_list, - Some(self.max_seq_length), - &B::Device::default(), - ); + // Generate padding mask for tokenized text + let mask = generate_padding_mask( + self.tokenizer.pad_token(), + tokens_list, + Some(self.max_seq_length), + &B::Device::default(), + ); - // Create and return inference batch - TextClassificationInferenceBatch { - tokens: mask.tensor.to_device(&self.device), - mask_pad: mask.mask.to_device(&self.device), - } + // Create and return inference batch + TextClassificationInferenceBatch { + tokens: mask.tensor.to_device(&self.device), + mask_pad: mask.mask.to_device(&self.device), } + } } diff --git a/examples/text-classification/src/data/dataset.rs b/examples/text-classification/src/data/dataset.rs index 28ae0f2240..43e868879f 100644 --- a/examples/text-classification/src/data/dataset.rs +++ b/examples/text-classification/src/data/dataset.rs @@ -10,162 +10,163 @@ use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset // Define a struct for text classification items #[derive(new, Clone, Debug)] pub struct TextClassificationItem { - pub text: String, // The text for classification - pub label: usize, // The label of the text (classification category) + pub text: String, // The text for classification + pub label: usize, // The label of the text (classification category) } // Trait for text classification datasets pub trait TextClassificationDataset: Dataset { - fn num_classes() -> usize; // Returns the number of unique classes in the dataset - fn class_name(label: usize) -> String; // Returns the name of the class given its label + fn num_classes() -> usize; // Returns the number of unique classes in the dataset + fn class_name(label: usize) -> String; // Returns the name of the class given its label } // Struct for items in the AG News dataset #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct AgNewsItem { - pub text: String, // The text for classification - pub label: usize, // The label of the text (classification category) + pub text: String, // The text for classification + pub label: usize, // The label of the text (classification category) } // Struct for the AG News dataset pub struct AgNewsDataset { - dataset: SqliteDataset, // Underlying SQLite dataset + dataset: SqliteDataset, // Underlying SQLite dataset } // Implement the Dataset trait for the AG News dataset impl Dataset for AgNewsDataset { - /// Returns a specific item from the dataset - fn get(&self, index: usize) -> Option { - self.dataset - .get(index) - .map(|item| TextClassificationItem::new(item.text, item.label)) // Map AgNewsItems to TextClassificationItems - } - - /// Returns the length of the dataset - fn len(&self) -> usize { - self.dataset.len() - } + /// Returns a specific item from the dataset + fn get(&self, index: usize) -> Option { + self + .dataset + .get(index) + .map(|item| TextClassificationItem::new(item.text, item.label)) // Map AgNewsItems to TextClassificationItems + } + + /// Returns the length of the dataset + fn len(&self) -> usize { + self.dataset.len() + } } // Implement methods for constructing the AG News dataset impl AgNewsDataset { - /// Returns the training portion of the dataset - pub fn train() -> Self { - Self::new("train") - } - - /// Returns the testing portion of the dataset - pub fn test() -> Self { - Self::new("test") - } - - /// Constructs the dataset from a split (either "train" or "test") - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("ag_news") - .dataset(split) - .unwrap(); - Self { dataset } - } + /// Returns the training portion of the dataset + pub fn train() -> Self { + Self::new("train") + } + + /// Returns the testing portion of the dataset + pub fn test() -> Self { + Self::new("test") + } + + /// Constructs the dataset from a split (either "train" or "test") + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("ag_news") + .dataset(split) + .unwrap(); + Self { dataset } + } } /// Implements the TextClassificationDataset trait for the AG News dataset impl TextClassificationDataset for AgNewsDataset { - /// Returns the number of unique classes in the dataset - fn num_classes() -> usize { - 4 - } - - /// Returns the name of a class given its label - fn class_name(label: usize) -> String { - match label { - 0 => "World", - 1 => "Sports", - 2 => "Business", - 3 => "Technology", - _ => panic!("invalid class"), - } - .to_string() + /// Returns the number of unique classes in the dataset + fn num_classes() -> usize { + 4 + } + + /// Returns the name of a class given its label + fn class_name(label: usize) -> String { + match label { + 0 => "World", + 1 => "Sports", + 2 => "Business", + 3 => "Technology", + _ => panic!("invalid class"), } + .to_string() + } } /// Struct for items in the DbPedia dataset #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { - pub title: String, // The title of the item - pub content: String, // The content of the item - pub label: usize, // The label of the item (classification category) + pub title: String, // The title of the item + pub content: String, // The content of the item + pub label: usize, // The label of the item (classification category) } /// Struct for the DbPedia dataset pub struct DbPediaDataset { - dataset: SqliteDataset, // Underlying SQLite dataset + dataset: SqliteDataset, // Underlying SQLite dataset } /// Implements the Dataset trait for the DbPedia dataset impl Dataset for DbPediaDataset { - /// Returns a specific item from the dataset - fn get(&self, index: usize) -> Option { - self.dataset.get(index).map(|item| { - TextClassificationItem::new( - format!("Title: {} - Content: {}", item.title, item.content), - item.label, - ) - }) - } - - /// Returns the length of the dataset - fn len(&self) -> usize { - self.dataset.len() - } + /// Returns a specific item from the dataset + fn get(&self, index: usize) -> Option { + self.dataset.get(index).map(|item| { + TextClassificationItem::new( + format!("Title: {} - Content: {}", item.title, item.content), + item.label, + ) + }) + } + + /// Returns the length of the dataset + fn len(&self) -> usize { + self.dataset.len() + } } /// Implement methods for constructing the DbPedia dataset impl DbPediaDataset { - /// Returns the training portion of the dataset - pub fn train() -> Self { - Self::new("train") - } - - /// Returns the testing portion of the dataset - pub fn test() -> Self { - Self::new("test") - } - - /// Constructs the dataset from a split (either "train" or "test") - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") - .dataset(split) - .unwrap(); - Self { dataset } - } + /// Returns the training portion of the dataset + pub fn train() -> Self { + Self::new("train") + } + + /// Returns the testing portion of the dataset + pub fn test() -> Self { + Self::new("test") + } + + /// Constructs the dataset from a split (either "train" or "test") + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") + .dataset(split) + .unwrap(); + Self { dataset } + } } /// Implement the TextClassificationDataset trait for the DbPedia dataset impl TextClassificationDataset for DbPediaDataset { - /// Returns the number of unique classes in the dataset - fn num_classes() -> usize { - 14 - } - - /// Returns the name of a class given its label - fn class_name(label: usize) -> String { - match label { - 0 => "Company", - 1 => "EducationalInstitution", - 2 => "Artist", - 3 => "Athlete", - 4 => "OfficeHolder", - 5 => "MeanOfTransportation", - 6 => "Building", - 7 => "NaturalPlace", - 8 => "Village", - 9 => "Animal", - 10 => "Plant", - 11 => "Album", - 12 => "Film", - 13 => "WrittenWork", - _ => panic!("invalid class"), - } - .to_string() + /// Returns the number of unique classes in the dataset + fn num_classes() -> usize { + 14 + } + + /// Returns the name of a class given its label + fn class_name(label: usize) -> String { + match label { + 0 => "Company", + 1 => "EducationalInstitution", + 2 => "Artist", + 3 => "Athlete", + 4 => "OfficeHolder", + 5 => "MeanOfTransportation", + 6 => "Building", + 7 => "NaturalPlace", + 8 => "Village", + 9 => "Animal", + 10 => "Plant", + 11 => "Album", + 12 => "Film", + 13 => "WrittenWork", + _ => panic!("invalid class"), } + .to_string() + } } diff --git a/examples/text-classification/src/data/tokenizer.rs b/examples/text-classification/src/data/tokenizer.rs index 3d1044b365..1fbe07d3f6 100644 --- a/examples/text-classification/src/data/tokenizer.rs +++ b/examples/text-classification/src/data/tokenizer.rs @@ -6,62 +6,62 @@ // The `Send + Sync` bounds are necessary for allowing these operations // to work across thread boundaries. pub trait Tokenizer: Send + Sync { - /// Converts a text string into a sequence of tokens. - fn encode(&self, value: &str) -> Vec; + /// Converts a text string into a sequence of tokens. + fn encode(&self, value: &str) -> Vec; - /// Converts a sequence of tokens back into a text string. - fn decode(&self, tokens: &[usize]) -> String; + /// Converts a sequence of tokens back into a text string. + fn decode(&self, tokens: &[usize]) -> String; - /// Gets the size of the tokenizer's vocabulary. - fn vocab_size(&self) -> usize; + /// Gets the size of the tokenizer's vocabulary. + fn vocab_size(&self) -> usize; - /// Gets the token used for padding sequences to a consistent length. - fn pad_token(&self) -> usize; + /// Gets the token used for padding sequences to a consistent length. + fn pad_token(&self) -> usize; - /// Gets the string representation of the padding token. - /// The default implementation uses `decode` on the padding token. - fn pad_token_value(&self) -> String { - self.decode(&[self.pad_token()]) - } + /// Gets the string representation of the padding token. + /// The default implementation uses `decode` on the padding token. + fn pad_token_value(&self) -> String { + self.decode(&[self.pad_token()]) + } } /// Struct represents a specific tokenizer using the BERT cased tokenization strategy. pub struct BertCasedTokenizer { - // The underlying tokenizer from the `tokenizers` library. - tokenizer: tokenizers::Tokenizer, + // The underlying tokenizer from the `tokenizers` library. + tokenizer: tokenizers::Tokenizer, } // Default implementation for creating a new BertCasedTokenizer. // This uses a pretrained BERT cased tokenizer model. impl Default for BertCasedTokenizer { - fn default() -> Self { - Self { - tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), - } + fn default() -> Self { + Self { + tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), } + } } // Implementation of the Tokenizer trait for BertCasedTokenizer. impl Tokenizer for BertCasedTokenizer { - // Convert a text string into a sequence of tokens using the BERT cased tokenization strategy. - fn encode(&self, value: &str) -> Vec { - let tokens = self.tokenizer.encode(value, true).unwrap(); - tokens.get_ids().iter().map(|t| *t as usize).collect() - } + // Convert a text string into a sequence of tokens using the BERT cased tokenization strategy. + fn encode(&self, value: &str) -> Vec { + let tokens = self.tokenizer.encode(value, true).unwrap(); + tokens.get_ids().iter().map(|t| *t as usize).collect() + } - /// Converts a sequence of tokens back into a text string. - fn decode(&self, tokens: &[usize]) -> String { - let tokens = tokens.iter().map(|t| *t as u32).collect::>(); - self.tokenizer.decode(&tokens, false).unwrap() - } + /// Converts a sequence of tokens back into a text string. + fn decode(&self, tokens: &[usize]) -> String { + let tokens = tokens.iter().map(|t| *t as u32).collect::>(); + self.tokenizer.decode(&tokens, false).unwrap() + } - /// Gets the size of the BERT cased tokenizer's vocabulary. - fn vocab_size(&self) -> usize { - self.tokenizer.get_vocab_size(true) - } + /// Gets the size of the BERT cased tokenizer's vocabulary. + fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } - /// Gets the token used for padding sequences to a consistent length. - fn pad_token(&self) -> usize { - self.tokenizer.token_to_id("[PAD]").unwrap() as usize - } + /// Gets the token used for padding sequences to a consistent length. + fn pad_token(&self) -> usize { + self.tokenizer.token_to_id("[PAD]").unwrap() as usize + } } diff --git a/examples/text-classification/src/inference.rs b/examples/text-classification/src/inference.rs index 02ed961809..8360a5a06d 100644 --- a/examples/text-classification/src/inference.rs +++ b/examples/text-classification/src/inference.rs @@ -5,73 +5,73 @@ // Import required modules and types use crate::{ - data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, - model::TextClassificationModelConfig, - training::ExperimentConfig, + data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, + model::TextClassificationModelConfig, + training::ExperimentConfig, }; use burn::{ - config::Config, - data::dataloader::batcher::Batcher, - module::Module, - record::{CompactRecorder, Recorder}, - tensor::backend::Backend, + config::Config, + data::dataloader::batcher::Batcher, + module::Module, + record::{CompactRecorder, Recorder}, + tensor::backend::Backend, }; use std::sync::Arc; // Define inference function pub fn infer( - device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) - artifact_dir: &str, // Directory containing model and config files - samples: Vec, // Text samples for inference + device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) + artifact_dir: &str, // Directory containing model and config files + samples: Vec, // Text samples for inference ) { - // Load experiment configuration - let config = ExperimentConfig::load(format!("{artifact_dir}/config.json").as_str()) - .expect("Config file present"); + // Load experiment configuration + let config = ExperimentConfig::load(format!("{artifact_dir}/config.json").as_str()) + .expect("Config file present"); - // Initialize tokenizer - let tokenizer = Arc::new(BertCasedTokenizer::default()); + // Initialize tokenizer + let tokenizer = Arc::new(BertCasedTokenizer::default()); - // Get number of classes from dataset - let n_classes = D::num_classes(); + // Get number of classes from dataset + let n_classes = D::num_classes(); - // Initialize batcher for batching samples - let batcher = Arc::new(TextClassificationBatcher::::new( - tokenizer.clone(), - device.clone(), - config.max_seq_length, - )); + // Initialize batcher for batching samples + let batcher = Arc::new(TextClassificationBatcher::::new( + tokenizer.clone(), + device.clone(), + config.max_seq_length, + )); - // Load pre-trained model weights - println!("Loading weights ..."); - let record = CompactRecorder::new() - .load(format!("{artifact_dir}/model").into()) - .expect("Trained model weights"); + // Load pre-trained model weights + println!("Loading weights ..."); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into()) + .expect("Trained model weights"); - // Create model using loaded weights - println!("Creating model ..."); - let model = TextClassificationModelConfig::new( - config.transformer, - n_classes, - tokenizer.vocab_size(), - config.max_seq_length, - ) - .init_with::(record) // Initialize model with loaded weights - .to_device(&device); // Move model to computation device + // Create model using loaded weights + println!("Creating model ..."); + let model = TextClassificationModelConfig::new( + config.transformer, + n_classes, + tokenizer.vocab_size(), + config.max_seq_length, + ) + .init_with::(record) // Initialize model with loaded weights + .to_device(&device); // Move model to computation device - // Run inference on the given text samples - println!("Running inference ..."); - let item = batcher.batch(samples.clone()); // Batch samples using the batcher - let predictions = model.infer(item); // Get model predictions + // Run inference on the given text samples + println!("Running inference ..."); + let item = batcher.batch(samples.clone()); // Batch samples using the batcher + let predictions = model.infer(item); // Get model predictions - // Print out predictions for each sample - for (i, text) in samples.into_iter().enumerate() { - #[allow(clippy::single_range_in_vec_init)] - let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample - let logits = prediction.to_data(); // Convert prediction tensor to data - let class_index = prediction.argmax(1).into_data().convert::().value[0]; // Get class index with the highest value - let class = D::class_name(class_index as usize); // Get class name + // Print out predictions for each sample + for (i, text) in samples.into_iter().enumerate() { + #[allow(clippy::single_range_in_vec_init)] + let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample + let logits = prediction.to_data(); // Convert prediction tensor to data + let class_index = prediction.argmax(1).into_data().convert::().value[0]; // Get class index with the highest value + let class = D::class_name(class_index as usize); // Get class name - // Print sample text, predicted logits and predicted class - println!("\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: {class}\n================"); - } + // Print sample text, predicted logits and predicted class + println!("\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: {class}\n================"); + } } diff --git a/examples/text-classification/src/model.rs b/examples/text-classification/src/model.rs index 914b14576a..96fd825367 100644 --- a/examples/text-classification/src/model.rs +++ b/examples/text-classification/src/model.rs @@ -5,178 +5,173 @@ use crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch}; use burn::{ - config::Config, - module::Module, - nn::{ - loss::CrossEntropyLoss, - transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, - Embedding, EmbeddingConfig, Linear, LinearConfig, - }, - tensor::backend::{AutodiffBackend, Backend}, - tensor::{activation::softmax, Tensor}, - train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, + config::Config, + module::Module, + nn::{ + loss::CrossEntropyLoss, + transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, + Embedding, EmbeddingConfig, Linear, LinearConfig, + }, + tensor::backend::{AutodiffBackend, Backend}, + tensor::{activation::softmax, Tensor}, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, }; // Define the model configuration #[derive(Config)] pub struct TextClassificationModelConfig { - transformer: TransformerEncoderConfig, - n_classes: usize, - vocab_size: usize, - max_seq_length: usize, + transformer: TransformerEncoderConfig, + n_classes: usize, + vocab_size: usize, + max_seq_length: usize, } // Define the model structure #[derive(Module, Debug)] pub struct TextClassificationModel { - transformer: TransformerEncoder, - embedding_token: Embedding, - embedding_pos: Embedding, - output: Linear, - n_classes: usize, - max_seq_length: usize, + transformer: TransformerEncoder, + embedding_token: Embedding, + embedding_pos: Embedding, + output: Linear, + n_classes: usize, + max_seq_length: usize, } // Define functions for model initialization impl TextClassificationModelConfig { - /// Initializes a model with default weights - pub fn init(&self) -> TextClassificationModel { - let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(); - let transformer = self.transformer.init(); - let embedding_token = - EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); - let embedding_pos = - EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); - - TextClassificationModel { - transformer, - embedding_token, - embedding_pos, - output, - n_classes: self.n_classes, - max_seq_length: self.max_seq_length, - } + /// Initializes a model with default weights + pub fn init(&self) -> TextClassificationModel { + let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(); + let transformer = self.transformer.init(); + let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); + let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); + + TextClassificationModel { + transformer, + embedding_token, + embedding_pos, + output, + n_classes: self.n_classes, + max_seq_length: self.max_seq_length, } - - /// Initializes a model with provided weights - pub fn init_with( - &self, - record: TextClassificationModelRecord, - ) -> TextClassificationModel { - let output = - LinearConfig::new(self.transformer.d_model, self.n_classes).init_with(record.output); - let transformer = self.transformer.init_with(record.transformer); - let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model) - .init_with(record.embedding_token); - let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model) - .init_with(record.embedding_pos); - - TextClassificationModel { - transformer, - embedding_token, - embedding_pos, - output, - n_classes: self.n_classes, - max_seq_length: self.max_seq_length, - } + } + + /// Initializes a model with provided weights + pub fn init_with( + &self, + record: TextClassificationModelRecord, + ) -> TextClassificationModel { + let output = + LinearConfig::new(self.transformer.d_model, self.n_classes).init_with(record.output); + let transformer = self.transformer.init_with(record.transformer); + let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model) + .init_with(record.embedding_token); + let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model) + .init_with(record.embedding_pos); + + TextClassificationModel { + transformer, + embedding_token, + embedding_pos, + output, + n_classes: self.n_classes, + max_seq_length: self.max_seq_length, } + } } /// Define model behavior impl TextClassificationModel { - // Defines forward pass for training - pub fn forward(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { - // Get batch and sequence length, and the device - let [batch_size, seq_length] = item.tokens.dims(); - let device = &self.embedding_token.devices()[0]; - - // Move tensors to the correct device - let tokens = item.tokens.to_device(device); - let labels = item.labels.to_device(device); - let mask_pad = item.mask_pad.to_device(device); - - // Calculate token and position embeddings, and combine them - let index_positions = Tensor::arange_device(0..seq_length, device) - .reshape([1, seq_length]) - .repeat(0, batch_size); - let embedding_positions = self.embedding_pos.forward(index_positions); - let embedding_tokens = self.embedding_token.forward(tokens); - let embedding = (embedding_positions + embedding_tokens) / 2; - - // Perform transformer encoding, calculate output and loss - let encoded = self - .transformer - .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); - let output = self.output.forward(encoded); - - let output_classification = output - .slice([0..batch_size, 0..1]) - .reshape([batch_size, self.n_classes]); - - let loss = CrossEntropyLoss::default(); - let loss = loss.forward(output_classification.clone(), labels.clone()); - - // Return the output and loss - ClassificationOutput { - loss, - output: output_classification, - targets: labels, - } - } - - /// Defines forward pass for inference - pub fn infer(&self, item: TextClassificationInferenceBatch) -> Tensor { - // Get batch and sequence length, and the device - let [batch_size, seq_length] = item.tokens.dims(); - let device = &self.embedding_token.devices()[0]; - - // Move tensors to the correct device - let tokens = item.tokens.to_device(device); - let mask_pad = item.mask_pad.to_device(device); - - // Calculate token and position embeddings, and combine them - let index_positions = Tensor::arange_device(0..seq_length, device) - .reshape([1, seq_length]) - .repeat(0, batch_size); - let embedding_positions = self.embedding_pos.forward(index_positions); - let embedding_tokens = self.embedding_token.forward(tokens); - let embedding = (embedding_positions + embedding_tokens) / 2; - - // Perform transformer encoding, calculate output and apply softmax for prediction - let encoded = self - .transformer - .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); - let output = self.output.forward(encoded); - let output = output - .slice([0..batch_size, 0..1]) - .reshape([batch_size, self.n_classes]); - - softmax(output, 1) + // Defines forward pass for training + pub fn forward(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { + // Get batch and sequence length, and the device + let [batch_size, seq_length] = item.tokens.dims(); + let device = &self.embedding_token.devices()[0]; + + // Move tensors to the correct device + let tokens = item.tokens.to_device(device); + let labels = item.labels.to_device(device); + let mask_pad = item.mask_pad.to_device(device); + + // Calculate token and position embeddings, and combine them + let index_positions = Tensor::arange_device(0..seq_length, device) + .reshape([1, seq_length]) + .repeat(0, batch_size); + let embedding_positions = self.embedding_pos.forward(index_positions); + let embedding_tokens = self.embedding_token.forward(tokens); + let embedding = (embedding_positions + embedding_tokens) / 2; + + // Perform transformer encoding, calculate output and loss + let encoded = self + .transformer + .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); + let output = self.output.forward(encoded); + + let output_classification = output + .slice([0..batch_size, 0..1]) + .reshape([batch_size, self.n_classes]); + + let loss = CrossEntropyLoss::default(); + let loss = loss.forward(output_classification.clone(), labels.clone()); + + // Return the output and loss + ClassificationOutput { + loss, + output: output_classification, + targets: labels, } + } + + /// Defines forward pass for inference + pub fn infer(&self, item: TextClassificationInferenceBatch) -> Tensor { + // Get batch and sequence length, and the device + let [batch_size, seq_length] = item.tokens.dims(); + let device = &self.embedding_token.devices()[0]; + + // Move tensors to the correct device + let tokens = item.tokens.to_device(device); + let mask_pad = item.mask_pad.to_device(device); + + // Calculate token and position embeddings, and combine them + let index_positions = Tensor::arange_device(0..seq_length, device) + .reshape([1, seq_length]) + .repeat(0, batch_size); + let embedding_positions = self.embedding_pos.forward(index_positions); + let embedding_tokens = self.embedding_token.forward(tokens); + let embedding = (embedding_positions + embedding_tokens) / 2; + + // Perform transformer encoding, calculate output and apply softmax for prediction + let encoded = self + .transformer + .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); + let output = self.output.forward(encoded); + let output = output + .slice([0..batch_size, 0..1]) + .reshape([batch_size, self.n_classes]); + + softmax(output, 1) + } } /// Define training step impl TrainStep, ClassificationOutput> - for TextClassificationModel + for TextClassificationModel { - fn step( - &self, - item: TextClassificationTrainingBatch, - ) -> TrainOutput> { - // Run forward pass, calculate gradients and return them along with the output - let item = self.forward(item); - let grads = item.loss.backward(); - - TrainOutput::new(self, grads, item) - } + fn step(&self, item: TextClassificationTrainingBatch) -> TrainOutput> { + // Run forward pass, calculate gradients and return them along with the output + let item = self.forward(item); + let grads = item.loss.backward(); + + TrainOutput::new(self, grads, item) + } } /// Define validation step impl ValidStep, ClassificationOutput> - for TextClassificationModel + for TextClassificationModel { - fn step(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { - // Run forward pass and return the output - self.forward(item) - } + fn step(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { + // Run forward pass and return the output + self.forward(item) + } } diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index eed1ddc4d0..241db1abda 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -6,112 +6,109 @@ // then saved to the specified directory. use crate::{ - data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, - model::TextClassificationModelConfig, + data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, + model::TextClassificationModelConfig, }; use burn::{ - config::Config, - data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset}, - lr_scheduler::noam::NoamLrSchedulerConfig, - module::Module, - nn::transformer::TransformerEncoderConfig, - optim::AdamConfig, - record::{CompactRecorder, Recorder}, - tensor::backend::AutodiffBackend, - train::{ - metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, - LearnerBuilder, - }, + config::Config, + data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset}, + lr_scheduler::noam::NoamLrSchedulerConfig, + module::Module, + nn::transformer::TransformerEncoderConfig, + optim::AdamConfig, + record::{CompactRecorder, Recorder}, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, + LearnerBuilder, + }, }; use std::sync::Arc; // Define configuration struct for the experiment #[derive(Config)] pub struct ExperimentConfig { - pub transformer: TransformerEncoderConfig, - pub optimizer: AdamConfig, - #[config(default = 256)] - pub max_seq_length: usize, - #[config(default = 32)] - pub batch_size: usize, - #[config(default = 5)] - pub num_epochs: usize, + pub transformer: TransformerEncoderConfig, + pub optimizer: AdamConfig, + #[config(default = 256)] + pub max_seq_length: usize, + #[config(default = 32)] + pub batch_size: usize, + #[config(default = 5)] + pub num_epochs: usize, } // Define train function pub fn train( - device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) - dataset_train: D, // Training dataset - dataset_test: D, // Testing dataset - config: ExperimentConfig, // Experiment configuration - artifact_dir: &str, // Directory to save model and config files + device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) + dataset_train: D, // Training dataset + dataset_test: D, // Testing dataset + config: ExperimentConfig, // Experiment configuration + artifact_dir: &str, // Directory to save model and config files ) { - // Initialize tokenizer - let tokenizer = Arc::new(BertCasedTokenizer::default()); + // Initialize tokenizer + let tokenizer = Arc::new(BertCasedTokenizer::default()); - // Initialize batchers for training and testing data - let batcher_train = TextClassificationBatcher::::new( - tokenizer.clone(), - device.clone(), - config.max_seq_length, - ); - let batcher_test = TextClassificationBatcher::::new( - tokenizer.clone(), - device.clone(), - config.max_seq_length, - ); + // Initialize batchers for training and testing data + let batcher_train = + TextClassificationBatcher::::new(tokenizer.clone(), device.clone(), config.max_seq_length); + let batcher_test = TextClassificationBatcher::::new( + tokenizer.clone(), + device.clone(), + config.max_seq_length, + ); - // Initialize model - let model = TextClassificationModelConfig::new( - config.transformer.clone(), - D::num_classes(), - tokenizer.vocab_size(), - config.max_seq_length, - ) - .init(); + // Initialize model + let model = TextClassificationModelConfig::new( + config.transformer.clone(), + D::num_classes(), + tokenizer.vocab_size(), + config.max_seq_length, + ) + .init(); - // Initialize data loaders for training and testing data - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_train, 50_000)); - let dataloader_test = DataLoaderBuilder::new(batcher_test) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_test, 5_000)); + // Initialize data loaders for training and testing data + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_train, 50_000)); + let dataloader_test = DataLoaderBuilder::new(batcher_test) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_test, 5_000)); - // Initialize optimizer - let optim = config.optimizer.init(); + // Initialize optimizer + let optim = config.optimizer.init(); - // Initialize learning rate scheduler - let lr_scheduler = NoamLrSchedulerConfig::new(0.25) - .with_warmup_steps(1000) - .with_model_size(config.transformer.d_model) - .init(); + // Initialize learning rate scheduler + let lr_scheduler = NoamLrSchedulerConfig::new(0.25) + .with_warmup_steps(1000) + .with_model_size(config.transformer.d_model) + .init(); - // Initialize learner - let learner = LearnerBuilder::new(artifact_dir) - .metric_train(CUDAMetric::new()) - .metric_valid(CUDAMetric::new()) - .metric_train(AccuracyMetric::new()) - .metric_valid(AccuracyMetric::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .metric_train_numeric(LearningRateMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device]) - .num_epochs(config.num_epochs) - .build(model, optim, lr_scheduler); + // Initialize learner + let learner = LearnerBuilder::new(artifact_dir) + .metric_train(CUDAMetric::new()) + .metric_valid(CUDAMetric::new()) + .metric_train(AccuracyMetric::new()) + .metric_valid(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .metric_train_numeric(LearningRateMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device]) + .num_epochs(config.num_epochs) + .build(model, optim, lr_scheduler); - // Train the model - let model_trained = learner.fit(dataloader_train, dataloader_test); + // Train the model + let model_trained = learner.fit(dataloader_train, dataloader_test); - // Save the configuration and the trained model - config.save(format!("{artifact_dir}/config.json")).unwrap(); - CompactRecorder::new() - .record( - model_trained.into_record(), - format!("{artifact_dir}/model").into(), - ) - .unwrap(); + // Save the configuration and the trained model + config.save(format!("{artifact_dir}/config.json")).unwrap(); + CompactRecorder::new() + .record( + model_trained.into_record(), + format!("{artifact_dir}/model").into(), + ) + .unwrap(); } diff --git a/examples/text-generation/examples/text-generation.rs b/examples/text-generation/examples/text-generation.rs index 6b6823f5ff..3b58b793a3 100644 --- a/examples/text-generation/examples/text-generation.rs +++ b/examples/text-generation/examples/text-generation.rs @@ -9,21 +9,20 @@ type Elem = f32; type Backend = burn::backend::Autodiff>; fn main() { - let config = ExperimentConfig::new( - burn::nn::transformer::TransformerEncoderConfig::new(384, 1536, 12, 6) - .with_norm_first(true), - burn::optim::AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1.0e-6))), - ); + let config = ExperimentConfig::new( + burn::nn::transformer::TransformerEncoderConfig::new(384, 1536, 12, 6).with_norm_first(true), + burn::optim::AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1.0e-6))), + ); - text_generation::training::train::( - if cfg!(target_os = "macos") { - burn::tensor::Device::::Mps - } else { - burn::tensor::Device::::Cuda(0) - }, - DbPediaDataset::train(), - DbPediaDataset::test(), - config, - "/tmp/text-generation", - ); + text_generation::training::train::( + if cfg!(target_os = "macos") { + burn::tensor::Device::::Mps + } else { + burn::tensor::Device::::Cuda(0) + }, + DbPediaDataset::train(), + DbPediaDataset::test(), + config, + "/tmp/text-generation", + ); } diff --git a/examples/text-generation/src/data/batcher.rs b/examples/text-generation/src/data/batcher.rs index 598676eafa..acaff5e33a 100644 --- a/examples/text-generation/src/data/batcher.rs +++ b/examples/text-generation/src/data/batcher.rs @@ -1,66 +1,66 @@ use super::{dataset::TextGenerationItem, tokenizer::Tokenizer}; use burn::{ - data::dataloader::batcher::Batcher, - nn::attention::generate_padding_mask, - tensor::{backend::Backend, Bool, Int, Tensor}, + data::dataloader::batcher::Batcher, + nn::attention::generate_padding_mask, + tensor::{backend::Backend, Bool, Int, Tensor}, }; use std::sync::Arc; #[derive(new)] pub struct TextGenerationBatcher { - tokenizer: Arc, - max_seq_length: usize, + tokenizer: Arc, + max_seq_length: usize, } #[derive(Debug, Clone, new)] pub struct TextGenerationBatch { - pub tokens: Tensor, - pub mask_pad: Tensor, + pub tokens: Tensor, + pub mask_pad: Tensor, } #[derive(Debug, Clone, new)] pub struct TrainingTextGenerationBatch { - pub tokens_inputs: Tensor, - pub targets: Tensor, - pub mask_pad: Tensor, + pub tokens_inputs: Tensor, + pub targets: Tensor, + pub mask_pad: Tensor, } impl Batcher> for TextGenerationBatcher { - fn batch(&self, items: Vec) -> TextGenerationBatch { - let mut tokens_list = Vec::with_capacity(items.len()); + fn batch(&self, items: Vec) -> TextGenerationBatch { + let mut tokens_list = Vec::with_capacity(items.len()); - for item in items { - tokens_list.push(self.tokenizer.encode(&item.text, true)); - } + for item in items { + tokens_list.push(self.tokenizer.encode(&item.text, true)); + } - let mask = generate_padding_mask( - self.tokenizer.pad_token(), - tokens_list, - Some(self.max_seq_length), - &B::Device::default(), - ); + let mask = generate_padding_mask( + self.tokenizer.pad_token(), + tokens_list, + Some(self.max_seq_length), + &B::Device::default(), + ); - TextGenerationBatch { - tokens: mask.tensor, - mask_pad: mask.mask, - } + TextGenerationBatch { + tokens: mask.tensor, + mask_pad: mask.mask, } + } } impl Batcher> - for TextGenerationBatcher + for TextGenerationBatcher { - fn batch(&self, items: Vec) -> TrainingTextGenerationBatch { - let item: TextGenerationBatch = self.batch(items); - let [batch_size, seq_length] = item.tokens.dims(); + fn batch(&self, items: Vec) -> TrainingTextGenerationBatch { + let item: TextGenerationBatch = self.batch(items); + let [batch_size, seq_length] = item.tokens.dims(); - let inputs = item - .tokens - .clone() - .slice([0..batch_size, 0..seq_length - 1]); - let targets = item.tokens.slice([0..batch_size, 1..seq_length]); - let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]); + let inputs = item + .tokens + .clone() + .slice([0..batch_size, 0..seq_length - 1]); + let targets = item.tokens.slice([0..batch_size, 1..seq_length]); + let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]); - TrainingTextGenerationBatch::new(inputs, targets, mask_pad) - } + TrainingTextGenerationBatch::new(inputs, targets, mask_pad) + } } diff --git a/examples/text-generation/src/data/dataset.rs b/examples/text-generation/src/data/dataset.rs index f198143582..b22b3f6598 100644 --- a/examples/text-generation/src/data/dataset.rs +++ b/examples/text-generation/src/data/dataset.rs @@ -2,42 +2,43 @@ use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset #[derive(new, Clone, Debug)] pub struct TextGenerationItem { - pub text: String, + pub text: String, } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { - pub content: String, + pub content: String, } pub struct DbPediaDataset { - dataset: SqliteDataset, + dataset: SqliteDataset, } impl Dataset for DbPediaDataset { - fn get(&self, index: usize) -> Option { - self.dataset - .get(index) - .map(|item| TextGenerationItem::new(item.content)) - } + fn get(&self, index: usize) -> Option { + self + .dataset + .get(index) + .map(|item| TextGenerationItem::new(item.content)) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } impl DbPediaDataset { - pub fn train() -> Self { - Self::new("train") - } + pub fn train() -> Self { + Self::new("train") + } - pub fn test() -> Self { - Self::new("test") - } - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") - .dataset(split) - .unwrap(); - Self { dataset } - } + pub fn test() -> Self { + Self::new("test") + } + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") + .dataset(split) + .unwrap(); + Self { dataset } + } } diff --git a/examples/text-generation/src/data/tokenizer.rs b/examples/text-generation/src/data/tokenizer.rs index cf6fc81bae..25296cb893 100644 --- a/examples/text-generation/src/data/tokenizer.rs +++ b/examples/text-generation/src/data/tokenizer.rs @@ -1,93 +1,93 @@ pub trait Tokenizer: Send + Sync { - fn encode(&self, value: &str, special_tokens: bool) -> Vec; - fn decode(&self, tokens: &[usize]) -> String; - fn vocab_size(&self) -> usize; - fn pad_token(&self) -> usize; - fn start_token(&self) -> usize; - fn end_token(&self) -> usize; - fn pad_token_value(&self) -> String { - self.decode(&[self.pad_token()]) - } - fn start_token_value(&self) -> String { - self.decode(&[self.start_token()]) - } - fn end_token_value(&self) -> String { - self.decode(&[self.end_token()]) - } + fn encode(&self, value: &str, special_tokens: bool) -> Vec; + fn decode(&self, tokens: &[usize]) -> String; + fn vocab_size(&self) -> usize; + fn pad_token(&self) -> usize; + fn start_token(&self) -> usize; + fn end_token(&self) -> usize; + fn pad_token_value(&self) -> String { + self.decode(&[self.pad_token()]) + } + fn start_token_value(&self) -> String { + self.decode(&[self.start_token()]) + } + fn end_token_value(&self) -> String { + self.decode(&[self.end_token()]) + } } pub struct Gpt2Tokenizer { - tokenizer: tokenizers::Tokenizer, + tokenizer: tokenizers::Tokenizer, } impl Default for Gpt2Tokenizer { - fn default() -> Self { - let mut tokenizer = tokenizers::Tokenizer::from_pretrained("gpt2", None).unwrap(); - tokenizer.add_special_tokens(&[ - tokenizers::AddedToken::from("[START]", true), - tokenizers::AddedToken::from("[END]", true), - tokenizers::AddedToken::from("[PAD]", true), - ]); - - Self { tokenizer } - } + fn default() -> Self { + let mut tokenizer = tokenizers::Tokenizer::from_pretrained("gpt2", None).unwrap(); + tokenizer.add_special_tokens(&[ + tokenizers::AddedToken::from("[START]", true), + tokenizers::AddedToken::from("[END]", true), + tokenizers::AddedToken::from("[PAD]", true), + ]); + + Self { tokenizer } + } } impl Tokenizer for Gpt2Tokenizer { - fn encode(&self, value: &str, special_tokens: bool) -> Vec { - let text = match special_tokens { - true => "[START]".to_owned() + value + "[END]", - false => value.to_string(), - }; - let tokens = self.tokenizer.encode(text, true).unwrap(); - tokens.get_ids().iter().map(|t| *t as usize).collect() - } - - fn decode(&self, tokens: &[usize]) -> String { - let tokens = tokens.iter().map(|t| *t as u32).collect::>(); - self.tokenizer.decode(&tokens, false).unwrap() - } - - fn vocab_size(&self) -> usize { - self.tokenizer.get_vocab_size(true) - } - - fn pad_token(&self) -> usize { - self.tokenizer.token_to_id("[PAD]").unwrap() as usize - } - - fn start_token(&self) -> usize { - self.tokenizer.token_to_id("[START]").unwrap() as usize - } - - fn end_token(&self) -> usize { - self.tokenizer.token_to_id("[END]").unwrap() as usize - } + fn encode(&self, value: &str, special_tokens: bool) -> Vec { + let text = match special_tokens { + true => "[START]".to_owned() + value + "[END]", + false => value.to_string(), + }; + let tokens = self.tokenizer.encode(text, true).unwrap(); + tokens.get_ids().iter().map(|t| *t as usize).collect() + } + + fn decode(&self, tokens: &[usize]) -> String { + let tokens = tokens.iter().map(|t| *t as u32).collect::>(); + self.tokenizer.decode(&tokens, false).unwrap() + } + + fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } + + fn pad_token(&self) -> usize { + self.tokenizer.token_to_id("[PAD]").unwrap() as usize + } + + fn start_token(&self) -> usize { + self.tokenizer.token_to_id("[START]").unwrap() as usize + } + + fn end_token(&self) -> usize { + self.tokenizer.token_to_id("[END]").unwrap() as usize + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_encode_decode() { - let tokenizer = Gpt2Tokenizer::default(); - let text = "A sentence"; + #[test] + fn test_encode_decode() { + let tokenizer = Gpt2Tokenizer::default(); + let text = "A sentence"; - let tokens = tokenizer.encode(text, false); - let decoded = tokenizer.decode(&tokens); + let tokens = tokenizer.encode(text, false); + let decoded = tokenizer.decode(&tokens); - assert_eq!(decoded, text); - } + assert_eq!(decoded, text); + } - #[test] - fn test_add_start_end_token() { - let tokenizer = Gpt2Tokenizer::default(); - let text = "A sentence"; + #[test] + fn test_add_start_end_token() { + let tokenizer = Gpt2Tokenizer::default(); + let text = "A sentence"; - let tokens_without = tokenizer.encode(text, false); - let tokens_with = tokenizer.encode(text, true); + let tokens_without = tokenizer.encode(text, false); + let tokens_with = tokenizer.encode(text, true); - assert_eq!(tokens_with.len() - 2, tokens_without.len()); - } + assert_eq!(tokens_with.len() - 2, tokens_without.len()); + } } diff --git a/examples/text-generation/src/model.rs b/examples/text-generation/src/model.rs index 6e23121424..099811a2e4 100644 --- a/examples/text-generation/src/model.rs +++ b/examples/text-generation/src/model.rs @@ -1,116 +1,111 @@ use crate::data::TrainingTextGenerationBatch; use burn::{ - config::Config, - module::Module, - nn::{ - attention::generate_autoregressive_mask, - loss::CrossEntropyLossConfig, - transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, - Embedding, EmbeddingConfig, Linear, LinearConfig, - }, - tensor::backend::{AutodiffBackend, Backend}, - tensor::Tensor, - train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, + config::Config, + module::Module, + nn::{ + attention::generate_autoregressive_mask, + loss::CrossEntropyLossConfig, + transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, + Embedding, EmbeddingConfig, Linear, LinearConfig, + }, + tensor::backend::{AutodiffBackend, Backend}, + tensor::Tensor, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, }; #[derive(Config)] pub struct TextGenerationModelConfig { - transformer: TransformerEncoderConfig, - vocab_size: usize, - pad_token: usize, - max_seq_length: usize, + transformer: TransformerEncoderConfig, + vocab_size: usize, + pad_token: usize, + max_seq_length: usize, } #[derive(Module, Debug)] pub struct TextGenerationModel { - transformer: TransformerEncoder, - embedding_token: Embedding, - embedding_pos: Embedding, - output: Linear, - vocab_size: usize, - pad_token: usize, - max_seq_length: usize, + transformer: TransformerEncoder, + embedding_token: Embedding, + embedding_pos: Embedding, + output: Linear, + vocab_size: usize, + pad_token: usize, + max_seq_length: usize, } impl TextGenerationModelConfig { - pub fn init(&self) -> TextGenerationModel { - let output = LinearConfig::new(self.transformer.d_model, self.vocab_size).init(); - let transformer = self.transformer.init(); - let embedding_token = - EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); - let embedding_pos = - EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); + pub fn init(&self) -> TextGenerationModel { + let output = LinearConfig::new(self.transformer.d_model, self.vocab_size).init(); + let transformer = self.transformer.init(); + let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); + let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); - TextGenerationModel { - transformer, - embedding_token, - embedding_pos, - output, - vocab_size: self.vocab_size, - pad_token: self.pad_token, - max_seq_length: self.max_seq_length, - } + TextGenerationModel { + transformer, + embedding_token, + embedding_pos, + output, + vocab_size: self.vocab_size, + pad_token: self.pad_token, + max_seq_length: self.max_seq_length, } + } } impl TextGenerationModel { - pub fn forward_training( - &self, - item: TrainingTextGenerationBatch, - ) -> ClassificationOutput { - let [batch_size, seq_length] = item.tokens_inputs.dims(); - let device = &self.devices()[0]; + pub fn forward_training(&self, item: TrainingTextGenerationBatch) -> ClassificationOutput { + let [batch_size, seq_length] = item.tokens_inputs.dims(); + let device = &self.devices()[0]; - let inputs = item.tokens_inputs.to_device(device); - let targets = item.targets.to_device(device); - let mask_pad = item.mask_pad.to_device(device); + let inputs = item.tokens_inputs.to_device(device); + let targets = item.targets.to_device(device); + let mask_pad = item.mask_pad.to_device(device); - let index_positions = Tensor::arange_device(0..seq_length, device) - .reshape([1, seq_length]) - .repeat(0, batch_size); + let index_positions = Tensor::arange_device(0..seq_length, device) + .reshape([1, seq_length]) + .repeat(0, batch_size); - let embedding_positions = self.embedding_pos.forward(index_positions); - let embedding_tokens = self.embedding_token.forward(inputs); - let embedding = (embedding_positions + embedding_tokens) / 2; + let embedding_positions = self.embedding_pos.forward(index_positions); + let embedding_tokens = self.embedding_token.forward(inputs); + let embedding = (embedding_positions + embedding_tokens) / 2; - let mask_attn = generate_autoregressive_mask::(batch_size, seq_length, device); - let encoded = self.transformer.forward( - TransformerEncoderInput::new(embedding) - .mask_pad(mask_pad) - .mask_attn(mask_attn), - ); + let mask_attn = generate_autoregressive_mask::(batch_size, seq_length, device); + let encoded = self.transformer.forward( + TransformerEncoderInput::new(embedding) + .mask_pad(mask_pad) + .mask_attn(mask_attn), + ); - let output = self.output.forward(encoded); - let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]); - let targets_flatten = targets.reshape([batch_size * seq_length]); + let output = self.output.forward(encoded); + let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]); + let targets_flatten = targets.reshape([batch_size * seq_length]); - let loss = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![self.pad_token])) - .init(); - let loss = loss.forward(output_flatten.clone(), targets_flatten.clone()); + let loss = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![self.pad_token])) + .init(); + let loss = loss.forward(output_flatten.clone(), targets_flatten.clone()); - ClassificationOutput { - loss, - output: output_flatten, - targets: targets_flatten, - } + ClassificationOutput { + loss, + output: output_flatten, + targets: targets_flatten, } + } } impl TrainStep, ClassificationOutput> - for TextGenerationModel + for TextGenerationModel { - fn step(&self, item: TrainingTextGenerationBatch) -> TrainOutput> { - let item = self.forward_training(item); - let grads = item.loss.backward(); + fn step(&self, item: TrainingTextGenerationBatch) -> TrainOutput> { + let item = self.forward_training(item); + let grads = item.loss.backward(); - TrainOutput::new(self, grads, item) - } + TrainOutput::new(self, grads, item) + } } impl ValidStep, ClassificationOutput> - for TextGenerationModel + for TextGenerationModel { - fn step(&self, item: TrainingTextGenerationBatch) -> ClassificationOutput { - self.forward_training(item) - } + fn step(&self, item: TrainingTextGenerationBatch) -> ClassificationOutput { + self.forward_training(item) + } } diff --git a/examples/text-generation/src/training.rs b/examples/text-generation/src/training.rs index 782012b8ba..e59475e952 100644 --- a/examples/text-generation/src/training.rs +++ b/examples/text-generation/src/training.rs @@ -1,94 +1,94 @@ use crate::{ - data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer}, - model::TextGenerationModelConfig, + data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer}, + model::TextGenerationModelConfig, }; use burn::data::dataset::transform::SamplerDataset; use burn::{ - config::Config, - data::{dataloader::DataLoaderBuilder, dataset::Dataset}, - lr_scheduler::noam::NoamLrSchedulerConfig, - module::Module, - nn::transformer::TransformerEncoderConfig, - optim::AdamConfig, - record::{CompactRecorder, DefaultRecorder, Recorder}, - tensor::backend::AutodiffBackend, - train::{ - metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, - LearnerBuilder, - }, + config::Config, + data::{dataloader::DataLoaderBuilder, dataset::Dataset}, + lr_scheduler::noam::NoamLrSchedulerConfig, + module::Module, + nn::transformer::TransformerEncoderConfig, + optim::AdamConfig, + record::{CompactRecorder, DefaultRecorder, Recorder}, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, + LearnerBuilder, + }, }; use std::sync::Arc; #[derive(Config)] pub struct ExperimentConfig { - transformer: TransformerEncoderConfig, - optimizer: AdamConfig, - #[config(default = 512)] - max_seq_length: usize, - #[config(default = 6)] - batch_size: usize, - #[config(default = 50)] - num_epochs: usize, + transformer: TransformerEncoderConfig, + optimizer: AdamConfig, + #[config(default = 512)] + max_seq_length: usize, + #[config(default = 6)] + batch_size: usize, + #[config(default = 50)] + num_epochs: usize, } pub fn train + 'static>( - device: B::Device, - dataset_train: D, - dataset_test: D, - config: ExperimentConfig, - artifact_dir: &str, + device: B::Device, + dataset_train: D, + dataset_test: D, + config: ExperimentConfig, + artifact_dir: &str, ) { - let tokenizer = Arc::new(Gpt2Tokenizer::default()); - let batcher_train = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); - let batcher_test = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); + let tokenizer = Arc::new(Gpt2Tokenizer::default()); + let batcher_train = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); + let batcher_test = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); - let model = TextGenerationModelConfig::new( - config.transformer.clone(), - tokenizer.vocab_size(), - tokenizer.pad_token(), - config.max_seq_length, - ) - .init::(); + let model = TextGenerationModelConfig::new( + config.transformer.clone(), + tokenizer.vocab_size(), + tokenizer.pad_token(), + config.max_seq_length, + ) + .init::(); - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_train, 10_000)); + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_train, 10_000)); - let dataloader_test = DataLoaderBuilder::new(batcher_test) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_test, 1000)); + let dataloader_test = DataLoaderBuilder::new(batcher_test) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_test, 1000)); - let accum = 6; // Effective batch size = 6 * 6 = 32. - let optim = config.optimizer.init(); - let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64) - .with_warmup_steps(6000) - .with_model_size(config.transformer.d_model) - .init(); + let accum = 6; // Effective batch size = 6 * 6 = 32. + let optim = config.optimizer.init(); + let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64) + .with_warmup_steps(6000) + .with_model_size(config.transformer.d_model) + .init(); - let learner = LearnerBuilder::new(artifact_dir) - .metric_train(CUDAMetric::new()) - .metric_valid(CUDAMetric::new()) - .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) - .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) - .metric_train(LossMetric::new()) - .metric_valid(LossMetric::new()) - .metric_train_numeric(LearningRateMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device]) - .grads_accumulation(accum) - .num_epochs(config.num_epochs) - .build(model, optim, lr_scheduler); + let learner = LearnerBuilder::new(artifact_dir) + .metric_train(CUDAMetric::new()) + .metric_valid(CUDAMetric::new()) + .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) + .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) + .metric_train(LossMetric::new()) + .metric_valid(LossMetric::new()) + .metric_train_numeric(LearningRateMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device]) + .grads_accumulation(accum) + .num_epochs(config.num_epochs) + .build(model, optim, lr_scheduler); - let model_trained = learner.fit(dataloader_train, dataloader_test); + let model_trained = learner.fit(dataloader_train, dataloader_test); - config.save(format!("{artifact_dir}/config.json")).unwrap(); + config.save(format!("{artifact_dir}/config.json")).unwrap(); - DefaultRecorder::new() - .record( - model_trained.into_record(), - format!("{artifact_dir}/model").into(), - ) - .unwrap(); + DefaultRecorder::new() + .record( + model_trained.into_record(), + format!("{artifact_dir}/model").into(), + ) + .unwrap(); } diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 9658a66bb9..b33db64842 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -6,29 +6,29 @@ mod runchecks; #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Args { - #[command(subcommand)] - command: Command, + #[command(subcommand)] + command: Command, } #[derive(Subcommand)] enum Command { - /// Publish a crate to crates.io - Publish { - /// The name of the crate to publish on crates.io - name: String, - }, - /// Run the specified `burn` tests and checks locally. - RunChecks { - /// The environment to run checks against - env: runchecks::CheckType, - }, + /// Publish a crate to crates.io + Publish { + /// The name of the crate to publish on crates.io + name: String, + }, + /// Run the specified `burn` tests and checks locally. + RunChecks { + /// The environment to run checks against + env: runchecks::CheckType, + }, } fn main() -> anyhow::Result<()> { - let args = Args::parse(); + let args = Args::parse(); - match args.command { - Command::RunChecks { env } => runchecks::run(env), - Command::Publish { name } => publish::run(name), - } + match args.command { + Command::RunChecks { env } => runchecks::run(env), + Command::Publish { name } => publish::run(name), + } } diff --git a/xtask/src/publish.rs b/xtask/src/publish.rs index bff5672166..0d13a3bd29 100644 --- a/xtask/src/publish.rs +++ b/xtask/src/publish.rs @@ -13,116 +13,116 @@ const CRATES_IO_API_TOKEN: &str = "CRATES_IO_API_TOKEN"; // Obtain local crate version fn local_version(crate_name: &str) -> String { - // Obtain local crate version contained in cargo pkgid data - let cargo_pkgid_output = Command::new("cargo") - .args(["pkgid", "-p", crate_name]) - .output() - .expect("Failed to run cargo pkgid"); - - // Convert cargo pkgid output into a str - let cargo_pkgid_str = str::from_utf8(&cargo_pkgid_output.stdout) - .expect("Failed to convert pkgid output into a str"); - - // Extract only the local crate version from str - let (_, local_version) = cargo_pkgid_str - .split_once('#') - .expect("Failed to get local crate version"); - - local_version.trim_end().to_string() + // Obtain local crate version contained in cargo pkgid data + let cargo_pkgid_output = Command::new("cargo") + .args(["pkgid", "-p", crate_name]) + .output() + .expect("Failed to run cargo pkgid"); + + // Convert cargo pkgid output into a str + let cargo_pkgid_str = + str::from_utf8(&cargo_pkgid_output.stdout).expect("Failed to convert pkgid output into a str"); + + // Extract only the local crate version from str + let (_, local_version) = cargo_pkgid_str + .split_once('#') + .expect("Failed to get local crate version"); + + local_version.trim_end().to_string() } // Obtain remote crate version fn remote_version(crate_name: &str) -> Option { - // Obtain remote crate version contained in cargo search data - let cargo_search_output = Command::new("cargo") - .args(["search", crate_name, "--limit", "1"]) - .output() - .expect("Failed to run cargo search"); - - // Cargo search returns an empty string in case of a crate not present on - // crates.io - if cargo_search_output.stdout.is_empty() { - None - } else { - // Convert cargo search output into a str - let remote_version_str = str::from_utf8(&cargo_search_output.stdout) - .expect("Failed to convert cargo search output into a str"); - - // Extract only the remote crate version from str - remote_version_str - .split_once('=') - .and_then(|(_, second)| second.trim_start().split_once(' ')) - .map(|(s, _)| s.trim_matches('"').to_string()) - } + // Obtain remote crate version contained in cargo search data + let cargo_search_output = Command::new("cargo") + .args(["search", crate_name, "--limit", "1"]) + .output() + .expect("Failed to run cargo search"); + + // Cargo search returns an empty string in case of a crate not present on + // crates.io + if cargo_search_output.stdout.is_empty() { + None + } else { + // Convert cargo search output into a str + let remote_version_str = str::from_utf8(&cargo_search_output.stdout) + .expect("Failed to convert cargo search output into a str"); + + // Extract only the remote crate version from str + remote_version_str + .split_once('=') + .and_then(|(_, second)| second.trim_start().split_once(' ')) + .map(|(s, _)| s.trim_matches('"').to_string()) + } } // Run cargo publish fn cargo_publish(params: &[&str]) { - // Run cargo publish - let mut cargo_publish = Command::new("cargo") - .arg("publish") - .arg("--color=always") - .args(params) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to run cargo publish"); - - // Wait for cargo publish command to finish - let status = cargo_publish - .wait() - .expect("Failed to wait for cargo publish child process"); - - // If exit status is not a success, terminate the process with an error - if !status.success() { - // Use the exit code associated to a command to terminate the process, - // if any exit code had been found, use the default value 1 - std::process::exit(status.code().unwrap_or(1)); - } + // Run cargo publish + let mut cargo_publish = Command::new("cargo") + .arg("publish") + .arg("--color=always") + .args(params) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect("Failed to run cargo publish"); + + // Wait for cargo publish command to finish + let status = cargo_publish + .wait() + .expect("Failed to wait for cargo publish child process"); + + // If exit status is not a success, terminate the process with an error + if !status.success() { + // Use the exit code associated to a command to terminate the process, + // if any exit code had been found, use the default value 1 + std::process::exit(status.code().unwrap_or(1)); + } } // Publishes a crate fn publish(crate_name: String) { - // Run cargo publish --dry-run - cargo_publish(&["-p", &crate_name, "--dry-run"]); + // Run cargo publish --dry-run + cargo_publish(&["-p", &crate_name, "--dry-run"]); - let crates_io_token = - env::var(CRATES_IO_API_TOKEN).expect("Failed to retrieve the crates.io API token"); + let crates_io_token = + env::var(CRATES_IO_API_TOKEN).expect("Failed to retrieve the crates.io API token"); - // Publish crate - cargo_publish(&["-p", &crate_name, "--token", &crates_io_token]); + // Publish crate + cargo_publish(&["-p", &crate_name, "--token", &crates_io_token]); } pub fn run(crate_name: String) -> anyhow::Result<()> { - println!("Publishing {crate_name}...\n"); + println!("Publishing {crate_name}...\n"); - // Retrieve local version for crate - let local_version = local_version(&crate_name); + // Retrieve local version for crate + let local_version = local_version(&crate_name); + // Print local version for crate + println!("{crate_name} local version: {local_version}"); + + // Retrieve remote version for crate + // + // If remote version is None, the crate will be published for the first time + // on crates.io + if let Some(remote_version) = remote_version(&crate_name) { // Print local version for crate - println!("{crate_name} local version: {local_version}"); - - // Retrieve remote version for crate - // - // If remote version is None, the crate will be published for the first time - // on crates.io - if let Some(remote_version) = remote_version(&crate_name) { - // Print local version for crate - println!("{crate_name} remote version: {remote_version}\n"); - - // If local and remote versions are equal, do not publish - if local_version == remote_version { - println!("Remote version {remote_version} is up to date, skipping deployment"); - } else { - // Publish crate - publish(crate_name); - } + println!("{crate_name} remote version: {remote_version}\n"); + + // If local and remote versions are equal, do not publish + if local_version == remote_version { + println!("Remote version {remote_version} is up to date, skipping deployment"); } else { - // Print crate publishing message - println!("\nFirst time publishing {crate_name} on crates.io!\n"); - // Publish crate - publish(crate_name); + // Publish crate + publish(crate_name); } + } else { + // Print crate publishing message + println!("\nFirst time publishing {crate_name} on crates.io!\n"); + // Publish crate + publish(crate_name); + } - Ok(()) + Ok(()) } diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 4d790aa3f7..062b7178dc 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -15,169 +15,169 @@ const ARM_TARGET: &str = "thumbv7m-none-eabi"; // Handle child process fn handle_child_process(mut child: Child, error: &str) { - // Wait for the child process to finish - let status = child.wait().expect(error); - - // If exit status is not a success, terminate the process with an error - if !status.success() { - // Use the exit code associated to a command to terminate the process, - // if any exit code had been found, use the default value 1 - std::process::exit(status.code().unwrap_or(1)); - } + // Wait for the child process to finish + let status = child.wait().expect(error); + + // If exit status is not a success, terminate the process with an error + if !status.success() { + // Use the exit code associated to a command to terminate the process, + // if any exit code had been found, use the default value 1 + std::process::exit(status.code().unwrap_or(1)); + } } // Run a command fn run_command(command: &str, args: &[&str], command_error: &str, child_error: &str) { - // Format command - println!("{command} {}\n\n", args.join(" ")); - - // Run command as child process - let command = Command::new(command) - .args(args) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect(command_error); - - // Handle command child process - handle_child_process(command, child_error); + // Format command + println!("{command} {}\n\n", args.join(" ")); + + // Run command as child process + let command = Command::new(command) + .args(args) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect(command_error); + + // Handle command child process + handle_child_process(command, child_error); } // Define and run rustup command fn rustup(command: &str, target: &str) { - run_command( - "rustup", - &[command, "add", target], - "Failed to run rustup", - "Failed to wait for rustup child process", - ) + run_command( + "rustup", + &[command, "add", target], + "Failed to run rustup", + "Failed to wait for rustup child process", + ) } // Define and run a cargo command fn run_cargo(command: &str, params: Params, error: &str) { - // Print cargo command - println!("\ncargo {} {}\n", command, params); - - // Run cargo - let cargo = Command::new("cargo") - .env("CARGO_INCREMENTAL", "0") - .arg(command) - .args(params.params) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect(error); - - // Handle cargo child process - handle_child_process(cargo, "Failed to wait for cargo child process"); + // Print cargo command + println!("\ncargo {} {}\n", command, params); + + // Run cargo + let cargo = Command::new("cargo") + .env("CARGO_INCREMENTAL", "0") + .arg(command) + .args(params.params) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect(error); + + // Handle cargo child process + handle_child_process(cargo, "Failed to wait for cargo child process"); } // Run cargo build command fn cargo_build(params: Params) { - // Run cargo build - run_cargo( - "build", - params + "--color=always", - "Failed to run cargo build", - ); + // Run cargo build + run_cargo( + "build", + params + "--color=always", + "Failed to run cargo build", + ); } // Run cargo install command fn cargo_install(params: Params) { - // Run cargo install - run_cargo( - "install", - params + "--color=always", - "Failed to run cargo install", - ); + // Run cargo install + run_cargo( + "install", + params + "--color=always", + "Failed to run cargo install", + ); } // Run cargo test command fn cargo_test(params: Params) { - // Run cargo test - run_cargo( - "test", - params + "--color=always" + "--" + "--color=always", - "Failed to run cargo test", - ); + // Run cargo test + run_cargo( + "test", + params + "--color=always" + "--" + "--color=always", + "Failed to run cargo test", + ); } // Run cargo fmt command fn cargo_fmt() { - // Run cargo fmt - run_cargo( - "fmt", - ["--check", "--all", "--", "--color=always"].into(), - "Failed to run cargo fmt", - ); + // Run cargo fmt + run_cargo( + "fmt", + ["--check", "--all", "--", "--color=always"].into(), + "Failed to run cargo fmt", + ); } // Run cargo clippy command fn cargo_clippy() { - if std::env::var("CI_RUN").is_ok() { - return; - } - // Run cargo clippy - run_cargo( - "clippy", - ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), - "Failed to run cargo clippy", - ); + if std::env::var("CI_RUN").is_ok() { + return; + } + // Run cargo clippy + run_cargo( + "clippy", + ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), + "Failed to run cargo clippy", + ); } // Run cargo doc command fn cargo_doc(params: Params) { - // Run cargo doc - run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); + // Run cargo doc + run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); } // Build and test a crate in a no_std environment fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N]) { - println!("\nRun checks for `{}` crate", crate_name); - - // Run cargo build --no-default-features - cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); - - // Run cargo test --no-default-features - cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); - - // Run cargo build --no-default-features --target wasm32-unknown-unknowns - cargo_build( - Params::from([ - "-p", - crate_name, - "--no-default-features", - "--target", - WASM32_TARGET, - ]) + extra_args, - ); - - // Run cargo build --no-default-features --target thumbv7m-none-eabi - cargo_build( - Params::from([ - "-p", - crate_name, - "--no-default-features", - "--target", - ARM_TARGET, - ]) + extra_args, - ); + println!("\nRun checks for `{}` crate", crate_name); + + // Run cargo build --no-default-features + cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); + + // Run cargo test --no-default-features + cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); + + // Run cargo build --no-default-features --target wasm32-unknown-unknowns + cargo_build( + Params::from([ + "-p", + crate_name, + "--no-default-features", + "--target", + WASM32_TARGET, + ]) + extra_args, + ); + + // Run cargo build --no-default-features --target thumbv7m-none-eabi + cargo_build( + Params::from([ + "-p", + crate_name, + "--no-default-features", + "--target", + ARM_TARGET, + ]) + extra_args, + ); } // Setup code coverage fn setup_coverage() { - // Install llvm-tools-preview - rustup("component", "llvm-tools-preview"); + // Install llvm-tools-preview + rustup("component", "llvm-tools-preview"); - // Set coverage environment variables - env::set_var("RUSTFLAGS", "-Cinstrument-coverage"); - env::set_var("LLVM_PROFILE_FILE", "burn-%p-%m.profraw"); + // Set coverage environment variables + env::set_var("RUSTFLAGS", "-Cinstrument-coverage"); + env::set_var("LLVM_PROFILE_FILE", "burn-%p-%m.profraw"); } // Run grcov to produce lcov.info fn run_grcov() { - // grcov arguments - #[rustfmt::skip] + // grcov arguments + #[rustfmt::skip] let args = [ ".", "--binary-path", "./target/debug/", @@ -191,245 +191,245 @@ fn run_grcov() { "-o", "lcov.info", ]; - run_command( - "grcov", - &args, - "Failed to run grcov", - "Failed to wait for grcov child process", - ); + run_command( + "grcov", + &args, + "Failed to run grcov", + "Failed to wait for grcov child process", + ); } // Run no_std checks fn no_std_checks() { - println!("Checks for no_std environment...\n\n"); - - // Install wasm32 target - rustup("target", WASM32_TARGET); - - // Install ARM target - rustup("target", ARM_TARGET); - - // Run checks for the following crates - build_and_test_no_std("burn", []); - build_and_test_no_std("burn-core", []); - build_and_test_no_std( - "burn-compute", - ["--features", "channel-mutex storage-bytes"], - ); - build_and_test_no_std("burn-common", []); - build_and_test_no_std("burn-tensor", []); - build_and_test_no_std("burn-ndarray", []); - build_and_test_no_std("burn-no-std-tests", []); + println!("Checks for no_std environment...\n\n"); + + // Install wasm32 target + rustup("target", WASM32_TARGET); + + // Install ARM target + rustup("target", ARM_TARGET); + + // Run checks for the following crates + build_and_test_no_std("burn", []); + build_and_test_no_std("burn-core", []); + build_and_test_no_std( + "burn-compute", + ["--features", "channel-mutex storage-bytes"], + ); + build_and_test_no_std("burn-common", []); + build_and_test_no_std("burn-tensor", []); + build_and_test_no_std("burn-ndarray", []); + build_and_test_no_std("burn-no-std-tests", []); } // Test burn-core with tch and wgpu backend fn burn_core_std() { - println!("\n\nRun checks for burn-core crate with tch and wgpu backend"); + println!("\n\nRun checks for burn-core crate with tch and wgpu backend"); - // Run cargo test --features test-tch - cargo_test(["-p", "burn-core", "--features", "test-tch"].into()); + // Run cargo test --features test-tch + cargo_test(["-p", "burn-core", "--features", "test-tch"].into()); - // Run cargo test --features test-wgpu - cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into()); + // Run cargo test --features test-wgpu + cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into()); } // Test burn-dataset features fn burn_dataset_features_std() { - println!("\n\nRun checks for burn-dataset features"); + println!("\n\nRun checks for burn-dataset features"); - // Run cargo build --all-features - cargo_build(["-p", "burn-dataset", "--all-features"].into()); + // Run cargo build --all-features + cargo_build(["-p", "burn-dataset", "--all-features"].into()); - // Run cargo test --all-features - cargo_test(["-p", "burn-dataset", "--all-features"].into()); + // Run cargo test --all-features + cargo_test(["-p", "burn-dataset", "--all-features"].into()); - // Run cargo doc --all-features - cargo_doc(["-p", "burn-dataset", "--all-features"].into()); + // Run cargo doc --all-features + cargo_doc(["-p", "burn-dataset", "--all-features"].into()); } fn std_checks() { - // Set RUSTDOCFLAGS environment variable to treat warnings as errors - // for the documentation build - env::set_var("RUSTDOCFLAGS", "-D warnings"); + // Set RUSTDOCFLAGS environment variable to treat warnings as errors + // for the documentation build + env::set_var("RUSTDOCFLAGS", "-D warnings"); - // Check if COVERAGE environment variable is set - let is_coverage = std::env::var("COVERAGE").is_ok(); + // Check if COVERAGE environment variable is set + let is_coverage = std::env::var("COVERAGE").is_ok(); - println!("Running std checks"); + println!("Running std checks"); - // Check format - cargo_fmt(); + // Check format + cargo_fmt(); - // Check clippy lints - cargo_clippy(); + // Check clippy lints + cargo_clippy(); - // Build each workspace - cargo_build(["--workspace", "--exclude=xtask"].into()); + // Build each workspace + cargo_build(["--workspace", "--exclude=xtask"].into()); - // Produce documentation for each workspace - cargo_doc(["--workspace"].into()); + // Produce documentation for each workspace + cargo_doc(["--workspace"].into()); - // Setup code coverage - if is_coverage { - setup_coverage(); - } + // Setup code coverage + if is_coverage { + setup_coverage(); + } - // Test each workspace - cargo_test(["--workspace"].into()); + // Test each workspace + cargo_test(["--workspace"].into()); - // Test burn-dataset features - burn_dataset_features_std(); + // Test burn-dataset features + burn_dataset_features_std(); - // Test burn-core with tch and wgpu backend - burn_core_std(); + // Test burn-core with tch and wgpu backend + burn_core_std(); - // Run grcov and produce lcov.info - if is_coverage { - run_grcov(); - } + // Run grcov and produce lcov.info + if is_coverage { + run_grcov(); + } } fn check_typos() { - // This path defines where typos-cl is installed on different - // operating systems. - let typos_cli_path = std::env::var("CARGO_HOME") - .map(|v| std::path::Path::new(&v).join("bin/typos-cli")) - .unwrap(); - - // Do not run cargo install on CI to speed up the computation. - // Check whether the file has been installed on - if std::env::var("CI_RUN").is_err() && !typos_cli_path.exists() { - // Install typos-cli - cargo_install(["typos-cli", "--version", "1.16.5"].into()); - } - - println!("Running typos check \n\n"); - - // Run typos command as child process - let typos = Command::new("typos") - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to run typos"); - - // Handle typos child process - handle_child_process(typos, "Failed to wait for typos child process"); + // This path defines where typos-cl is installed on different + // operating systems. + let typos_cli_path = std::env::var("CARGO_HOME") + .map(|v| std::path::Path::new(&v).join("bin/typos-cli")) + .unwrap(); + + // Do not run cargo install on CI to speed up the computation. + // Check whether the file has been installed on + if std::env::var("CI_RUN").is_err() && !typos_cli_path.exists() { + // Install typos-cli + cargo_install(["typos-cli", "--version", "1.16.5"].into()); + } + + println!("Running typos check \n\n"); + + // Run typos command as child process + let typos = Command::new("typos") + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect("Failed to run typos"); + + // Handle typos child process + handle_child_process(typos, "Failed to wait for typos child process"); } fn check_examples() { - println!("Checking examples compile \n\n"); - - std::fs::read_dir("examples").unwrap().for_each(|dir| { - let dir = dir.unwrap(); - let path = dir.path(); - // Skip if not a directory - if !path.is_dir() { - return; - } - if path.file_name().unwrap().to_str().unwrap() == "notebook" { - // not a crate - return; - } - let path = path.to_str().unwrap(); - println!("Checking {path} \n\n"); - - let child = Command::new("cargo") - .arg("check") - .arg("--examples") - .current_dir(dir.path()) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to check examples"); - - // Handle typos child process - handle_child_process(child, "Failed to wait for examples child process"); - }); + println!("Checking examples compile \n\n"); + + std::fs::read_dir("examples").unwrap().for_each(|dir| { + let dir = dir.unwrap(); + let path = dir.path(); + // Skip if not a directory + if !path.is_dir() { + return; + } + if path.file_name().unwrap().to_str().unwrap() == "notebook" { + // not a crate + return; + } + let path = path.to_str().unwrap(); + println!("Checking {path} \n\n"); + + let child = Command::new("cargo") + .arg("check") + .arg("--examples") + .current_dir(dir.path()) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect("Failed to check examples"); + + // Handle typos child process + handle_child_process(child, "Failed to wait for examples child process"); + }); } #[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)] pub enum CheckType { - /// Run all checks. - #[default] - All, - /// Run `std` environment checks - Std, - /// Run `no-std` environment checks - NoStd, - /// Check for typos - Typos, - /// Test the examples - Examples, + /// Run all checks. + #[default] + All, + /// Run `std` environment checks + Std, + /// Run `no-std` environment checks + NoStd, + /// Check for typos + Typos, + /// Test the examples + Examples, } pub fn run(env: CheckType) -> anyhow::Result<()> { - // Start time measurement - let start = Instant::now(); - - // The environment can assume ONLY "std", "no_std", "typos", "examples" - // as values. - // - // Depending on the input argument, the respective environment checks - // are run. - // - // If no environment has been passed, run all checks. - match env { - CheckType::Std => std_checks(), - CheckType::NoStd => no_std_checks(), - CheckType::Typos => check_typos(), - CheckType::Examples => check_examples(), - CheckType::All => { - /* Run all checks */ - check_typos(); - std_checks(); - no_std_checks(); - check_examples(); - } + // Start time measurement + let start = Instant::now(); + + // The environment can assume ONLY "std", "no_std", "typos", "examples" + // as values. + // + // Depending on the input argument, the respective environment checks + // are run. + // + // If no environment has been passed, run all checks. + match env { + CheckType::Std => std_checks(), + CheckType::NoStd => no_std_checks(), + CheckType::Typos => check_typos(), + CheckType::Examples => check_examples(), + CheckType::All => { + /* Run all checks */ + check_typos(); + std_checks(); + no_std_checks(); + check_examples(); } + } - // Stop time measurement - // - // Compute runtime duration - let duration = start.elapsed(); + // Stop time measurement + // + // Compute runtime duration + let duration = start.elapsed(); - // Print duration - println!("Time elapsed for the current execution: {:?}", duration); + // Print duration + println!("Time elapsed for the current execution: {:?}", duration); - Ok(()) + Ok(()) } struct Params { - params: Vec, + params: Vec, } impl From<[&str; N]> for Params { - fn from(value: [&str; N]) -> Self { - Self { - params: value.iter().map(|v| v.to_string()).collect(), - } + fn from(value: [&str; N]) -> Self { + Self { + params: value.iter().map(|v| v.to_string()).collect(), } + } } impl From<&str> for Params { - fn from(value: &str) -> Self { - Self { - params: vec![value.to_string()], - } + fn from(value: &str) -> Self { + Self { + params: vec![value.to_string()], } + } } impl std::fmt::Display for Params { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.params.join(" ").as_str()) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.params.join(" ").as_str()) + } } impl> std::ops::Add for Params { - type Output = Params; + type Output = Params; - fn add(mut self, rhs: Rhs) -> Self::Output { - let rhs: Params = rhs.into(); - self.params.extend(rhs.params); - self - } + fn add(mut self, rhs: Rhs) -> Self::Output { + let rhs: Params = rhs.into(); + self.params.extend(rhs.params); + self + } } From 395e9e32988978468d47736168b9724901df2cfe Mon Sep 17 00:00:00 2001 From: Will Brickner Date: Sat, 18 Nov 2023 20:33:52 -0600 Subject: [PATCH 2/2] Undid bad formatting --- backend-comparison/benches/binary.rs | 52 +- backend-comparison/benches/custom_gelu.rs | 142 +- backend-comparison/benches/data.rs | 104 +- backend-comparison/benches/matmul.rs | 86 +- backend-comparison/benches/unary.rs | 50 +- backend-comparison/src/lib.rs | 108 +- burn-autodiff/src/backend.rs | 114 +- burn-autodiff/src/grads.rs | 140 +- burn-autodiff/src/graph/backward.rs | 39 +- burn-autodiff/src/graph/base.rs | 130 +- burn-autodiff/src/graph/node.rs | 44 +- burn-autodiff/src/graph/requirement.rs | 46 +- burn-autodiff/src/graph/traversal.rs | 62 +- burn-autodiff/src/ops/activation.rs | 94 +- burn-autodiff/src/ops/backward.rs | 118 +- burn-autodiff/src/ops/base.rs | 218 +- burn-autodiff/src/ops/bool_tensor.rs | 173 +- burn-autodiff/src/ops/int_tensor.rs | 624 +-- burn-autodiff/src/ops/maxmin.rs | 18 +- burn-autodiff/src/ops/module.rs | 1687 +++---- burn-autodiff/src/ops/tensor.rs | 2782 +++++------ burn-autodiff/src/tensor.rs | 150 +- burn-autodiff/src/tests/abs.rs | 40 +- burn-autodiff/src/tests/adaptive_avgpool1d.rs | 78 +- burn-autodiff/src/tests/adaptive_avgpool2d.rs | 110 +- burn-autodiff/src/tests/add.rs | 90 +- burn-autodiff/src/tests/aggregation.rs | 236 +- burn-autodiff/src/tests/avgpool1d.rs | 164 +- burn-autodiff/src/tests/avgpool2d.rs | 214 +- burn-autodiff/src/tests/backward.rs | 44 +- burn-autodiff/src/tests/broadcast.rs | 106 +- burn-autodiff/src/tests/cat.rs | 116 +- burn-autodiff/src/tests/complex.rs | 156 +- burn-autodiff/src/tests/conv1d.rs | 448 +- burn-autodiff/src/tests/conv2d.rs | 1458 +++--- burn-autodiff/src/tests/conv_transpose1d.rs | 476 +- burn-autodiff/src/tests/conv_transpose2d.rs | 1250 ++--- burn-autodiff/src/tests/cos.rs | 44 +- burn-autodiff/src/tests/cross_entropy.rs | 45 +- burn-autodiff/src/tests/div.rs | 172 +- burn-autodiff/src/tests/erf.rs | 40 +- burn-autodiff/src/tests/exp.rs | 40 +- burn-autodiff/src/tests/gather_scatter.rs | 108 +- burn-autodiff/src/tests/gelu.rs | 36 +- burn-autodiff/src/tests/gradients.rs | 31 +- burn-autodiff/src/tests/log.rs | 40 +- burn-autodiff/src/tests/log1p.rs | 42 +- burn-autodiff/src/tests/mask.rs | 101 +- burn-autodiff/src/tests/matmul.rs | 150 +- burn-autodiff/src/tests/maxmin.rs | 84 +- burn-autodiff/src/tests/maxpool1d.rs | 215 +- burn-autodiff/src/tests/maxpool2d.rs | 336 +- burn-autodiff/src/tests/mod.rs | 106 +- burn-autodiff/src/tests/mul.rs | 92 +- burn-autodiff/src/tests/multithread.rs | 122 +- burn-autodiff/src/tests/neg.rs | 32 +- burn-autodiff/src/tests/pow.rs | 40 +- burn-autodiff/src/tests/recip.rs | 27 +- burn-autodiff/src/tests/relu.rs | 34 +- burn-autodiff/src/tests/reshape.rs | 32 +- burn-autodiff/src/tests/select.rs | 100 +- burn-autodiff/src/tests/sin.rs | 42 +- burn-autodiff/src/tests/slice.rs | 148 +- burn-autodiff/src/tests/softmax.rs | 112 +- burn-autodiff/src/tests/sqrt.rs | 40 +- burn-autodiff/src/tests/sub.rs | 86 +- burn-autodiff/src/tests/tanh.rs | 40 +- burn-autodiff/src/tests/transpose.rs | 94 +- burn-autodiff/src/utils.rs | 22 +- burn-candle/src/backend.rs | 84 +- burn-candle/src/lib.rs | 234 +- burn-candle/src/ops/activation.rs | 16 +- burn-candle/src/ops/base.rs | 82 +- burn-candle/src/ops/bool_tensor.rs | 209 +- burn-candle/src/ops/candle_utils.rs | 28 +- burn-candle/src/ops/int_tensor.rs | 709 ++- burn-candle/src/ops/module.rs | 401 +- burn-candle/src/ops/tensor.rs | 881 ++-- burn-candle/src/tensor.rs | 56 +- burn-common/src/benchmark.rs | 204 +- burn-common/src/id.rs | 90 +- burn-common/src/rand.rs | 26 +- burn-common/src/reader.rs | 144 +- burn-common/src/stub.rs | 64 +- burn-compute/src/channel/base.rs | 20 +- burn-compute/src/channel/cell.rs | 59 +- burn-compute/src/channel/mpsc.rs | 228 +- burn-compute/src/channel/mutex.rs | 54 +- burn-compute/src/client.rs | 105 +- burn-compute/src/compute.rs | 116 +- burn-compute/src/id.rs | 76 +- burn-compute/src/memory_management/base.rs | 62 +- burn-compute/src/memory_management/simple.rs | 736 +-- burn-compute/src/server.rs | 70 +- burn-compute/src/storage/base.rs | 46 +- burn-compute/src/storage/bytes_cpu.rs | 160 +- burn-compute/src/tune/operation.rs | 32 +- burn-compute/src/tune/tune_benchmark.rs | 36 +- burn-compute/src/tune/tune_cache.rs | 44 +- burn-compute/src/tune/tuner.rs | 148 +- burn-compute/tests/dummy/compute.rs | 18 +- burn-compute/tests/dummy/kernel.rs | 22 +- burn-compute/tests/dummy/server.rs | 70 +- .../tests/dummy/tune/autotune_operations.rs | 34 +- burn-compute/tests/dummy/tune/kernels.rs | 130 +- .../tests/dummy/tune/operation_sets.rs | 260 +- burn-compute/tests/integration_test.rs | 182 +- burn-core/src/config.rs | 121 +- burn-core/src/data/dataloader/base.rs | 16 +- burn-core/src/data/dataloader/batch.rs | 407 +- burn-core/src/data/dataloader/batcher.rs | 26 +- burn-core/src/data/dataloader/builder.rs | 194 +- burn-core/src/data/dataloader/multithread.rs | 192 +- burn-core/src/data/dataloader/strategy.rs | 110 +- burn-core/src/data/mod.rs | 2 +- burn-core/src/grad_clipping/base.rs | 212 +- burn-core/src/lr_scheduler/base.rs | 18 +- burn-core/src/lr_scheduler/constant.rs | 40 +- burn-core/src/lr_scheduler/noam.rs | 116 +- burn-core/src/module/base.rs | 390 +- burn-core/src/module/param/base.rs | 44 +- burn-core/src/module/param/constant.rs | 300 +- burn-core/src/module/param/id.rs | 46 +- burn-core/src/module/param/primitive.rs | 233 +- burn-core/src/module/param/running.rs | 252 +- burn-core/src/module/param/tensor.rs | 160 +- burn-core/src/module/param/visitor.rs | 28 +- burn-core/src/nn/attention/mask.rs | 218 +- burn-core/src/nn/attention/mha.rs | 774 ++-- burn-core/src/nn/cache/autoregressive.rs | 86 +- burn-core/src/nn/cache/base.rs | 24 +- burn-core/src/nn/conv/checks.rs | 10 +- burn-core/src/nn/conv/conv1d.rs | 240 +- burn-core/src/nn/conv/conv2d.rs | 237 +- burn-core/src/nn/conv/conv_transpose1d.rs | 250 +- burn-core/src/nn/conv/conv_transpose2d.rs | 253 +- burn-core/src/nn/dropout.rs | 88 +- burn-core/src/nn/embedding.rs | 128 +- burn-core/src/nn/gelu.rs | 26 +- burn-core/src/nn/initializer.rs | 626 +-- burn-core/src/nn/linear.rs | 246 +- burn-core/src/nn/loss/binary_cross_entropy.rs | 278 +- burn-core/src/nn/loss/cross_entropy.rs | 677 +-- burn-core/src/nn/loss/mse.rs | 104 +- burn-core/src/nn/loss/reduction.rs | 12 +- burn-core/src/nn/norm/batch.rs | 678 ++- burn-core/src/nn/norm/layer.rs | 202 +- burn-core/src/nn/padding.rs | 86 +- burn-core/src/nn/pool/adaptive_avg_pool1d.rs | 34 +- burn-core/src/nn/pool/adaptive_avg_pool2d.rs | 34 +- burn-core/src/nn/pool/avg_pool1d.rs | 84 +- burn-core/src/nn/pool/avg_pool2d.rs | 85 +- burn-core/src/nn/pool/max_pool1d.rs | 72 +- burn-core/src/nn/pool/max_pool2d.rs | 73 +- burn-core/src/nn/pos_encoding.rs | 326 +- burn-core/src/nn/relu.rs | 26 +- burn-core/src/nn/rnn/gate_controller.rs | 120 +- burn-core/src/nn/rnn/gru.rs | 424 +- burn-core/src/nn/rnn/lstm.rs | 585 +-- burn-core/src/nn/transformer/decoder.rs | 748 +-- burn-core/src/nn/transformer/encoder.rs | 610 +-- burn-core/src/nn/transformer/pwff.rs | 110 +- burn-core/src/nn/unfold.rs | 68 +- burn-core/src/optim/adagrad.rs | 446 +- burn-core/src/optim/adam.rs | 595 +-- burn-core/src/optim/adamw.rs | 637 +-- burn-core/src/optim/base.rs | 22 +- burn-core/src/optim/decay.rs | 78 +- burn-core/src/optim/grad_accum.rs | 179 +- burn-core/src/optim/grads.rs | 214 +- burn-core/src/optim/momentum.rs | 127 +- burn-core/src/optim/rmsprop.rs | 899 ++-- burn-core/src/optim/sgd.rs | 262 +- burn-core/src/optim/simple/adaptor.rs | 228 +- burn-core/src/optim/simple/base.rs | 38 +- burn-core/src/optim/simple/record/base.rs | 92 +- burn-core/src/optim/simple/record/v1.rs | 278 +- burn-core/src/optim/visitor.rs | 42 +- burn-core/src/record/base.rs | 12 +- burn-core/src/record/file.rs | 569 +-- burn-core/src/record/memory.rs | 128 +- burn-core/src/record/primitive.rs | 165 +- burn-core/src/record/recorder.rs | 376 +- burn-core/src/record/settings.rs | 22 +- burn-core/src/record/tensor.rs | 148 +- burn-core/tests/derive_config.rs | 88 +- burn-core/tests/derive_module.rs | 202 +- burn-core/tests/derive_record.rs | 4 +- burn-core/tests/record_resilience.rs | 580 +-- burn-dataset/examples/speech_commands.rs | 24 +- burn-dataset/src/audio/speech_commands.rs | 269 +- burn-dataset/src/dataset/base.rs | 82 +- burn-dataset/src/dataset/fake.rs | 44 +- burn-dataset/src/dataset/in_memory.rs | 270 +- burn-dataset/src/dataset/iterator.rs | 34 +- burn-dataset/src/dataset/sqlite.rs | 1387 +++--- burn-dataset/src/lib.rs | 16 +- .../src/source/huggingface/downloader.rs | 406 +- burn-dataset/src/source/huggingface/mnist.rs | 96 +- burn-dataset/src/transform/composed.rs | 36 +- burn-dataset/src/transform/mapper.rs | 74 +- burn-dataset/src/transform/partial.rs | 204 +- burn-dataset/src/transform/random.rs | 66 +- burn-dataset/src/transform/sampler.rs | 198 +- burn-derive/src/config/analyzer.rs | 110 +- burn-derive/src/config/analyzer_enum.rs | 240 +- burn-derive/src/config/analyzer_struct.rs | 450 +- burn-derive/src/config/base.rs | 34 +- burn-derive/src/lib.rs | 12 +- burn-derive/src/module/base.rs | 260 +- burn-derive/src/module/codegen.rs | 14 +- burn-derive/src/module/codegen_struct.rs | 234 +- burn-derive/src/module/display.rs | 10 +- burn-derive/src/module/record.rs | 4 +- burn-derive/src/module/record_struct.rs | 34 +- burn-derive/src/record/base.rs | 120 +- burn-derive/src/record/codegen.rs | 12 +- burn-derive/src/record/codegen_struct.rs | 94 +- burn-derive/src/shared/attribute.rs | 80 +- burn-derive/src/shared/field.rs | 158 +- burn-fusion/src/backend.rs | 182 +- burn-fusion/src/client/base.rs | 112 +- burn-fusion/src/client/mutex.rs | 267 +- burn-fusion/src/fusion.rs | 96 +- burn-fusion/src/graph/base.rs | 146 +- burn-fusion/src/graph/execution.rs | 106 +- burn-fusion/src/graph/ops.rs | 2442 +++++----- burn-fusion/src/handle.rs | 232 +- burn-fusion/src/ops/binary.rs | 106 +- burn-fusion/src/ops/boolean.rs | 731 +-- burn-fusion/src/ops/float.rs | 3265 +++++++------ burn-fusion/src/ops/int.rs | 2689 ++++++----- burn-fusion/src/ops/module.rs | 1695 +++---- burn-fusion/src/ops/unary.rs | 166 +- burn-fusion/src/server.rs | 289 +- burn-fusion/src/tensor.rs | 222 +- burn-import/build.rs | 18 +- burn-import/onnx-tests/build.rs | 194 +- burn-import/onnx-tests/tests/onnx_tests.rs | 1076 ++--- .../onnx-tests/tests/record_type_tests.rs | 100 +- burn-import/src/burn/codegen.rs | 94 +- burn-import/src/burn/graph.rs | 1156 +++-- burn-import/src/burn/imports.rs | 49 +- burn-import/src/burn/node/avg_pool2d.rs | 246 +- burn-import/src/burn/node/base.rs | 671 +-- burn-import/src/burn/node/batch_norm.rs | 306 +- burn-import/src/burn/node/binary.rs | 506 +- burn-import/src/burn/node/clip.rs | 274 +- burn-import/src/burn/node/concat.rs | 159 +- burn-import/src/burn/node/constant.rs | 280 +- burn-import/src/burn/node/conv1d.rs | 338 +- burn-import/src/burn/node/conv2d.rs | 336 +- burn-import/src/burn/node/dropout.rs | 230 +- burn-import/src/burn/node/gather.rs | 160 +- burn-import/src/burn/node/global_avg_pool.rs | 334 +- burn-import/src/burn/node/linear.rs | 294 +- burn-import/src/burn/node/matmul.rs | 144 +- burn-import/src/burn/node/max_pool2d.rs | 254 +- burn-import/src/burn/node/reshape.rs | 126 +- burn-import/src/burn/node/test.rs | 6 +- burn-import/src/burn/node/unary.rs | 716 +-- burn-import/src/burn/scope.rs | 114 +- burn-import/src/burn/ty.rs | 202 +- burn-import/src/formatter.rs | 12 +- burn-import/src/logger.rs | 26 +- burn-import/src/main.rs | 22 +- burn-import/src/onnx/coalesce.rs | 262 +- burn-import/src/onnx/dim_inference.rs | 637 ++- burn-import/src/onnx/from_onnx.rs | 667 +-- burn-import/src/onnx/ir.rs | 1246 ++--- burn-import/src/onnx/node_remap.rs | 48 +- burn-import/src/onnx/op_configuration.rs | 920 ++-- burn-import/src/onnx/proto_conversion.rs | 396 +- burn-import/src/onnx/protos/mod.rs | 2 +- burn-import/src/onnx/to_burn.rs | 1262 ++--- burn-ndarray/build.rs | 8 +- burn-ndarray/src/backend.rs | 50 +- burn-ndarray/src/element.rs | 238 +- burn-ndarray/src/lib.rs | 22 +- burn-ndarray/src/ops/activations.rs | 22 +- burn-ndarray/src/ops/adaptive_avgpool.rs | 158 +- burn-ndarray/src/ops/avgpool.rs | 213 +- burn-ndarray/src/ops/base.rs | 793 ++-- burn-ndarray/src/ops/bool_tensor.rs | 224 +- burn-ndarray/src/ops/conv.rs | 445 +- burn-ndarray/src/ops/int_tensor.rs | 704 +-- burn-ndarray/src/ops/macros.rs | 40 +- burn-ndarray/src/ops/matmul.rs | 156 +- burn-ndarray/src/ops/maxpool.rs | 280 +- burn-ndarray/src/ops/module.rs | 172 +- burn-ndarray/src/ops/padding.rs | 46 +- burn-ndarray/src/ops/tensor.rs | 829 ++-- burn-ndarray/src/parallel.rs | 48 +- burn-ndarray/src/sharing.rs | 16 +- burn-ndarray/src/tensor.rs | 198 +- burn-no-std-tests/src/conv.rs | 58 +- burn-no-std-tests/src/mlp.rs | 84 +- burn-no-std-tests/src/model.rs | 78 +- burn-no-std-tests/tests/integration_test.rs | 28 +- burn-tch/src/backend.rs | 100 +- burn-tch/src/lib.rs | 12 +- burn-tch/src/ops/activation.rs | 34 +- burn-tch/src/ops/base.rs | 799 ++-- burn-tch/src/ops/bool_tensor.rs | 217 +- burn-tch/src/ops/int_tensor.rs | 742 +-- burn-tch/src/ops/module.rs | 560 +-- burn-tch/src/ops/tensor.rs | 867 ++-- burn-tch/src/tensor.rs | 428 +- burn-tensor-testgen/src/lib.rs | 30 +- burn-tensor/src/tensor/activation/base.rs | 52 +- burn-tensor/src/tensor/api/base.rs | 2456 +++++----- burn-tensor/src/tensor/api/bool.rs | 42 +- burn-tensor/src/tensor/api/check.rs | 1001 ++-- burn-tensor/src/tensor/api/float.rs | 616 +-- burn-tensor/src/tensor/api/int.rs | 140 +- burn-tensor/src/tensor/api/kind.rs | 32 +- burn-tensor/src/tensor/api/numeric.rs | 4084 +++++++++-------- burn-tensor/src/tensor/backend/base.rs | 255 +- burn-tensor/src/tensor/container.rs | 111 +- burn-tensor/src/tensor/data.rs | 734 +-- burn-tensor/src/tensor/element.rs | 146 +- burn-tensor/src/tensor/loss/mod.rs | 12 +- burn-tensor/src/tensor/module.rs | 246 +- burn-tensor/src/tensor/named/base.rs | 100 +- burn-tensor/src/tensor/named/dims.rs | 112 +- burn-tensor/src/tensor/named/matmul.rs | 74 +- burn-tensor/src/tensor/named/swap_dims.rs | 74 +- burn-tensor/src/tensor/ops/activation.rs | 190 +- burn-tensor/src/tensor/ops/bool_tensor.rs | 461 +- burn-tensor/src/tensor/ops/int_tensor.rs | 1677 +++---- burn-tensor/src/tensor/ops/modules/base.rs | 741 +-- burn-tensor/src/tensor/ops/modules/conv.rs | 1316 +++--- burn-tensor/src/tensor/ops/modules/pool.rs | 246 +- burn-tensor/src/tensor/ops/modules/unfold.rs | 105 +- burn-tensor/src/tensor/ops/tensor.rs | 2123 ++++----- burn-tensor/src/tensor/shape.rs | 88 +- burn-tensor/src/tensor/stats/mod.rs | 38 +- burn-tensor/src/tests/activation/gelu.rs | 30 +- burn-tensor/src/tests/activation/relu.rs | 20 +- burn-tensor/src/tests/activation/sigmoid.rs | 36 +- burn-tensor/src/tests/activation/silu.rs | 20 +- burn-tensor/src/tests/activation/softmax.rs | 20 +- .../src/tests/activation/tanh_activation.rs | 20 +- burn-tensor/src/tests/clone_invariance.rs | 1394 +++--- burn-tensor/src/tests/mod.rs | 142 +- .../src/tests/module/adaptive_avgpool1d.rs | 120 +- .../src/tests/module/adaptive_avgpool2d.rs | 180 +- burn-tensor/src/tests/module/avgpool1d.rs | 150 +- burn-tensor/src/tests/module/avgpool2d.rs | 200 +- burn-tensor/src/tests/module/conv1d.rs | 240 +- burn-tensor/src/tests/module/conv2d.rs | 304 +- .../src/tests/module/conv_transpose1d.rs | 262 +- .../src/tests/module/conv_transpose2d.rs | 632 +-- burn-tensor/src/tests/module/forward.rs | 30 +- burn-tensor/src/tests/module/maxpool1d.rs | 226 +- burn-tensor/src/tests/module/maxpool2d.rs | 594 +-- burn-tensor/src/tests/module/unfold4d.rs | 230 +- burn-tensor/src/tests/ops/abs.rs | 36 +- burn-tensor/src/tests/ops/add.rs | 160 +- burn-tensor/src/tests/ops/aggregation.rs | 172 +- burn-tensor/src/tests/ops/arange.rs | 30 +- burn-tensor/src/tests/ops/arange_step.rs | 86 +- burn-tensor/src/tests/ops/arg.rs | 100 +- burn-tensor/src/tests/ops/cast.rs | 80 +- burn-tensor/src/tests/ops/cat.rs | 164 +- burn-tensor/src/tests/ops/clamp.rs | 114 +- burn-tensor/src/tests/ops/cos.rs | 20 +- burn-tensor/src/tests/ops/create_like.rs | 71 +- burn-tensor/src/tests/ops/div.rs | 164 +- burn-tensor/src/tests/ops/erf.rs | 42 +- burn-tensor/src/tests/ops/exp.rs | 20 +- burn-tensor/src/tests/ops/flatten.rs | 98 +- burn-tensor/src/tests/ops/full.rs | 38 +- burn-tensor/src/tests/ops/gather_scatter.rs | 340 +- burn-tensor/src/tests/ops/init.rs | 96 +- burn-tensor/src/tests/ops/iter_dim.rs | 78 +- burn-tensor/src/tests/ops/log.rs | 26 +- burn-tensor/src/tests/ops/log1p.rs | 26 +- burn-tensor/src/tests/ops/map_comparison.rs | 610 +-- burn-tensor/src/tests/ops/mask.rs | 104 +- burn-tensor/src/tests/ops/matmul.rs | 207 +- burn-tensor/src/tests/ops/maxmin.rs | 68 +- burn-tensor/src/tests/ops/mul.rs | 164 +- burn-tensor/src/tests/ops/neg.rs | 20 +- burn-tensor/src/tests/ops/one_hot.rs | 46 +- burn-tensor/src/tests/ops/powf.rs | 68 +- burn-tensor/src/tests/ops/random.rs | 36 +- burn-tensor/src/tests/ops/recip.rs | 20 +- burn-tensor/src/tests/ops/repeat.rs | 30 +- burn-tensor/src/tests/ops/reshape.rs | 140 +- burn-tensor/src/tests/ops/select.rs | 248 +- burn-tensor/src/tests/ops/sin.rs | 20 +- burn-tensor/src/tests/ops/slice.rs | 218 +- burn-tensor/src/tests/ops/sqrt.rs | 22 +- burn-tensor/src/tests/ops/squeeze.rs | 66 +- burn-tensor/src/tests/ops/sub.rs | 160 +- burn-tensor/src/tests/ops/tanh.rs | 20 +- burn-tensor/src/tests/ops/transpose.rs | 184 +- burn-tensor/src/tests/stats/cov.rs | 116 +- burn-tensor/src/tests/stats/diagonal.rs | 24 +- burn-tensor/src/tests/stats/display.rs | 236 +- burn-tensor/src/tests/stats/var.rs | 74 +- burn-train/src/checkpoint/async_checkpoint.rs | 166 +- burn-train/src/checkpoint/base.rs | 50 +- burn-train/src/checkpoint/file.rs | 95 +- burn-train/src/checkpoint/strategy/base.rs | 34 +- .../src/checkpoint/strategy/composed.rs | 227 +- burn-train/src/checkpoint/strategy/lastn.rs | 72 +- burn-train/src/checkpoint/strategy/metric.rs | 224 +- burn-train/src/components.rs | 106 +- burn-train/src/learner/base.rs | 180 +- burn-train/src/learner/builder.rs | 570 +-- burn-train/src/learner/classification.rs | 24 +- burn-train/src/learner/early_stopping.rs | 352 +- burn-train/src/learner/epoch.rs | 426 +- burn-train/src/learner/log.rs | 59 +- burn-train/src/learner/regression.rs | 18 +- burn-train/src/learner/step/train.rs | 206 +- burn-train/src/learner/train_val.rs | 302 +- burn-train/src/logger/async_logger.rs | 115 +- burn-train/src/logger/base.rs | 36 +- burn-train/src/logger/file.rs | 48 +- burn-train/src/logger/in_memory.rs | 10 +- burn-train/src/logger/metric.rs | 268 +- burn-train/src/metric/acc.rs | 202 +- burn-train/src/metric/base.rs | 94 +- burn-train/src/metric/cpu_temp.rs | 58 +- burn-train/src/metric/cpu_use.rs | 78 +- burn-train/src/metric/cuda.rs | 154 +- burn-train/src/metric/learning_rate.rs | 49 +- burn-train/src/metric/loss.rs | 41 +- burn-train/src/metric/memory_use.rs | 94 +- burn-train/src/metric/processor/base.rs | 50 +- burn-train/src/metric/processor/full.rs | 144 +- burn-train/src/metric/processor/metrics.rs | 282 +- burn-train/src/metric/processor/minimal.rs | 80 +- burn-train/src/metric/processor/mod.rs | 66 +- burn-train/src/metric/state.rs | 130 +- burn-train/src/metric/store/aggregate.rs | 248 +- burn-train/src/metric/store/base.rs | 72 +- burn-train/src/metric/store/client.rs | 241 +- burn-train/src/metric/store/log.rs | 170 +- burn-train/src/renderer/base.rs | 94 +- burn-train/src/renderer/cli.rs | 24 +- burn-train/src/renderer/mod.rs | 12 +- burn-train/src/renderer/tui/base.rs | 64 +- burn-train/src/renderer/tui/controls.rs | 72 +- burn-train/src/renderer/tui/full_history.rs | 340 +- burn-train/src/renderer/tui/metric_numeric.rs | 361 +- burn-train/src/renderer/tui/metric_text.rs | 148 +- burn-train/src/renderer/tui/plot_utils.rs | 64 +- burn-train/src/renderer/tui/popup.rs | 200 +- burn-train/src/renderer/tui/progress.rs | 422 +- burn-train/src/renderer/tui/recent_history.rs | 356 +- burn-train/src/renderer/tui/renderer.rs | 258 +- burn-train/src/renderer/tui/status.rs | 126 +- burn-wgpu/benches/fused_elemwise.rs | 88 +- burn-wgpu/benches/matmul.rs | 168 +- burn-wgpu/benches/reduction.rs | 130 +- burn-wgpu/src/backend.rs | 66 +- burn-wgpu/src/compute/base.rs | 358 +- burn-wgpu/src/compute/kernel.rs | 113 +- burn-wgpu/src/compute/server.rs | 564 +-- burn-wgpu/src/compute/storage.rs | 170 +- burn-wgpu/src/compute/tune_key.rs | 24 +- burn-wgpu/src/device.rs | 56 +- burn-wgpu/src/element.rs | 64 +- burn-wgpu/src/fusion/base.rs | 204 +- burn-wgpu/src/fusion/codegen/body.rs | 22 +- burn-wgpu/src/fusion/codegen/function.rs | 22 +- burn-wgpu/src/fusion/codegen/operator.rs | 238 +- burn-wgpu/src/fusion/codegen/shader.rs | 252 +- burn-wgpu/src/fusion/codegen/variable.rs | 22 +- burn-wgpu/src/fusion/elemwise/ops.rs | 817 ++-- burn-wgpu/src/fusion/kernel.rs | 468 +- burn-wgpu/src/graphics.rs | 52 +- burn-wgpu/src/kernel/base.rs | 350 +- burn-wgpu/src/kernel/binary_elemwise.rs | 254 +- burn-wgpu/src/kernel/cast.rs | 99 +- burn-wgpu/src/kernel/cat.rs | 154 +- burn-wgpu/src/kernel/clamp.rs | 153 +- burn-wgpu/src/kernel/comparison/base.rs | 168 +- burn-wgpu/src/kernel/comparison/binary.rs | 269 +- burn-wgpu/src/kernel/comparison/elem.rs | 205 +- burn-wgpu/src/kernel/conv/conv2d.rs | 162 +- burn-wgpu/src/kernel/conv/conv_transpose2d.rs | 213 +- burn-wgpu/src/kernel/index/gather.rs | 133 +- burn-wgpu/src/kernel/index/scatter.rs | 251 +- burn-wgpu/src/kernel/index/select.rs | 313 +- burn-wgpu/src/kernel/index/slice.rs | 204 +- burn-wgpu/src/kernel/mask/base.rs | 34 +- burn-wgpu/src/kernel/mask/mask_fill.rs | 210 +- burn-wgpu/src/kernel/mask/mask_where.rs | 266 +- burn-wgpu/src/kernel/matmul/mem_coalescing.rs | 326 +- burn-wgpu/src/kernel/matmul/naive.rs | 298 +- burn-wgpu/src/kernel/matmul/tiling2d/base.rs | 120 +- .../src/kernel/matmul/tiling2d/padding.rs | 531 +-- .../src/kernel/matmul/tiling2d/unpadded.rs | 332 +- burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs | 242 +- .../src/kernel/matmul/tiling2d/vec4_lhs.rs | 244 +- burn-wgpu/src/kernel/matmul/tune/base.rs | 246 +- burn-wgpu/src/kernel/matmul/tune/key.rs | 172 +- burn-wgpu/src/kernel/matmul/utils.rs | 123 +- .../src/kernel/pool/adaptive_avg_pool2d.rs | 132 +- burn-wgpu/src/kernel/pool/avg_pool2d.rs | 269 +- burn-wgpu/src/kernel/pool/base.rs | 107 +- burn-wgpu/src/kernel/pool/max_pool2d.rs | 335 +- burn-wgpu/src/kernel/prng/base.rs | 136 +- burn-wgpu/src/kernel/prng/bernoulli.rs | 225 +- burn-wgpu/src/kernel/prng/normal.rs | 189 +- burn-wgpu/src/kernel/prng/uniform.rs | 265 +- burn-wgpu/src/kernel/reduce/base.rs | 30 +- burn-wgpu/src/kernel/reduce/reduction.rs | 296 +- .../kernel/reduce/reduction_shared_memory.rs | 248 +- burn-wgpu/src/kernel/reduce/tune/base.rs | 42 +- burn-wgpu/src/kernel/reduce/tune/key.rs | 62 +- burn-wgpu/src/kernel/reduce/tune/mean_dim.rs | 144 +- burn-wgpu/src/kernel/reduce/tune/sum_dim.rs | 144 +- burn-wgpu/src/kernel/source.rs | 98 +- burn-wgpu/src/kernel/unary.rs | 274 +- burn-wgpu/src/kernel/unary_scalar.rs | 272 +- burn-wgpu/src/lib.rs | 16 +- burn-wgpu/src/ops/activation_ops.rs | 28 +- burn-wgpu/src/ops/base.rs | 84 +- burn-wgpu/src/ops/bool_ops.rs | 227 +- burn-wgpu/src/ops/float_ops.rs | 882 ++-- burn-wgpu/src/ops/int_ops.rs | 625 +-- burn-wgpu/src/ops/module_ops.rs | 181 +- burn-wgpu/src/ops/numeric.rs | 214 +- burn-wgpu/src/tensor/base.rs | 232 +- burn/src/lib.rs | 2 +- .../examples/custom-renderer.rs | 2 +- examples/custom-renderer/src/lib.rs | 108 +- .../examples/custom-training-loop.rs | 2 +- examples/custom-training-loop/src/lib.rs | 244 +- .../examples/custom-wgpu-kernel.rs | 94 +- examples/custom-wgpu-kernel/src/backward.rs | 185 +- examples/custom-wgpu-kernel/src/forward.rs | 176 +- examples/custom-wgpu-kernel/src/lib.rs | 38 +- examples/guide/examples/guide.rs | 32 +- examples/guide/src/data.rs | 54 +- examples/guide/src/inference.rs | 34 +- examples/guide/src/model.rs | 121 +- examples/guide/src/training.rs | 154 +- examples/image-classification-web/build.rs | 76 +- .../src/model/normalizer.rs | 42 +- .../src/model/squeezenet.rs | 2 +- examples/image-classification-web/src/web.rs | 256 +- examples/mnist-inference-web/src/model.rs | 130 +- examples/mnist-inference-web/src/state.rs | 14 +- examples/mnist-inference-web/src/web.rs | 104 +- examples/mnist/examples/mnist.rs | 96 +- examples/mnist/src/data.rs | 54 +- examples/mnist/src/model.rs | 180 +- examples/mnist/src/training.rs | 130 +- .../named-tensor/examples/named-tensor.rs | 2 +- examples/named-tensor/src/lib.rs | 74 +- examples/onnx-inference/build.rs | 32 +- .../onnx-inference/src/bin/mnist_inference.rs | 74 +- examples/onnx-inference/src/model/mod.rs | 2 +- .../examples/ag-news-infer.rs | 92 +- .../examples/ag-news-train.rs | 114 +- .../examples/db-pedia-infer.rs | 92 +- .../examples/db-pedia-train.rs | 114 +- .../text-classification/src/data/batcher.rs | 108 +- .../text-classification/src/data/dataset.rs | 221 +- .../text-classification/src/data/tokenizer.rs | 74 +- examples/text-classification/src/inference.rs | 106 +- examples/text-classification/src/model.rs | 281 +- examples/text-classification/src/training.rs | 171 +- .../examples/text-generation.rs | 31 +- examples/text-generation/src/data/batcher.rs | 74 +- examples/text-generation/src/data/dataset.rs | 47 +- .../text-generation/src/data/tokenizer.rs | 144 +- examples/text-generation/src/model.rs | 157 +- examples/text-generation/src/training.rs | 144 +- xtask/src/main.rs | 34 +- xtask/src/publish.rs | 178 +- xtask/src/runchecks.rs | 578 +-- 579 files changed, 67861 insertions(+), 67442 deletions(-) diff --git a/backend-comparison/benches/binary.rs b/backend-comparison/benches/binary.rs index 43ae124d91..cb5b3264f5 100644 --- a/backend-comparison/benches/binary.rs +++ b/backend-comparison/benches/binary.rs @@ -2,46 +2,46 @@ use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; pub struct BinaryBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for BinaryBenchmark { - type Args = (Tensor, Tensor); + type Args = (Tensor, Tensor); - fn name(&self) -> String { - "Binary Ops".into() - } + fn name(&self) -> String { + "Binary Ops".into() + } - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.num_repeats { - // Choice of add is arbitrary - B::add(lhs.clone().into_primitive(), rhs.clone().into_primitive()); + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.num_repeats { + // Choice of add is arbitrary + B::add(lhs.clone().into_primitive(), rhs.clone().into_primitive()); + } } - } - fn prepare(&self) -> Self::Args { - let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] fn bench(device: &B::Device) { - run_benchmark(BinaryBenchmark:: { - shape: [32, 512, 1024].into(), - num_repeats: 10, - device: device.clone(), - }) + run_benchmark(BinaryBenchmark:: { + shape: [32, 512, 1024].into(), + num_repeats: 10, + device: device.clone(), + }) } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/custom_gelu.rs b/backend-comparison/benches/custom_gelu.rs index 63c52eefc0..71db646a97 100644 --- a/backend-comparison/benches/custom_gelu.rs +++ b/backend-comparison/benches/custom_gelu.rs @@ -5,62 +5,62 @@ use derive_new::new; #[derive(Debug)] enum GeluKind { - Reference, - WithReferenceErf, - WithCustomErf, + Reference, + WithReferenceErf, + WithCustomErf, } /// Benchmark how well a backend executes a custom activation function with a lot of basic tensor /// operations. #[derive(new)] struct CustomGeluBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, - kind: GeluKind, + shape: Shape, + num_repeats: usize, + device: B::Device, + kind: GeluKind, } impl Benchmark for CustomGeluBenchmark { - type Args = Tensor; - - fn name(&self) -> String { - format!("Gelu {:?}", self.kind) - } - - fn execute(&self, args: Self::Args) { - for _ in 0..self.num_repeats { - match self.kind { - GeluKind::Reference => burn::tensor::activation::gelu(args.clone()), - GeluKind::WithReferenceErf => gelu_custom(args.clone(), Tensor::erf), - GeluKind::WithCustomErf => gelu_custom(args.clone(), erf_custom), - }; + type Args = Tensor; + + fn name(&self) -> String { + format!("Gelu {:?}", self.kind) } - } - fn prepare(&self) -> Self::Args { - Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + match self.kind { + GeluKind::Reference => burn::tensor::activation::gelu(args.clone()), + GeluKind::WithReferenceErf => gelu_custom(args.clone(), Tensor::erf), + GeluKind::WithCustomErf => gelu_custom(args.clone(), erf_custom), + }; + } + } - fn sync(&self) { - B::sync(&self.device) - } + fn prepare(&self) -> Self::Args { + Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } + + fn sync(&self) { + B::sync(&self.device) + } } fn gelu_custom(x: Tensor, erf: Erf) -> Tensor where - B: Backend, - Erf: Fn(Tensor) -> Tensor, + B: Backend, + Erf: Fn(Tensor) -> Tensor, { - let x = x.clone() * (erf(x / SQRT_2) + 1); - x / 2 + let x = x.clone() * (erf(x / SQRT_2) + 1); + x / 2 } fn erf_custom(x: Tensor) -> Tensor { - let x1 = -erf_positive(-x.clone()); - let x2 = erf_positive(x.clone()); - let mask = x.greater_elem(0); + let x1 = -erf_positive(-x.clone()); + let x2 = erf_positive(x.clone()); + let mask = x.greater_elem(0); - x1.mask_where(mask, x2) + x1.mask_where(mask, x2) } /// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations @@ -68,47 +68,47 @@ fn erf_custom(x: Tensor) -> Tensor { /// > (maximum error: 1.5×10−7) /// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x). fn erf_positive(x: Tensor) -> Tensor { - let p = 0.3275911; - let a1 = 0.254829592; - let a2 = -0.284496736; - let a3 = 1.421413741; - let a4 = -1.453152027; - let a5 = 1.061405429; - - let x1 = x.clone().abs() * p + 1; - let t = x1.recip(); - let tmp = (((((t.clone() * a5) + a4) * t.clone()) + a3) * t.clone() + a2) * t.clone() + a1; - - -(tmp * t * (-x.clone() * x).exp()) + 1.0 + let p = 0.3275911; + let a1 = 0.254829592; + let a2 = -0.284496736; + let a3 = 1.421413741; + let a4 = -1.453152027; + let a5 = 1.061405429; + + let x1 = x.clone().abs() * p + 1; + let t = x1.recip(); + let tmp = (((((t.clone() * a5) + a4) * t.clone()) + a3) * t.clone() + a2) * t.clone() + a1; + + -(tmp * t * (-x.clone() * x).exp()) + 1.0 } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let shape: Shape = [32, 512, 2048].into(); - let num_repeats = 1; - - println!("Backend {}", B::name()); - run_benchmark(CustomGeluBenchmark::::new( - shape.clone(), - num_repeats, - device.clone(), - GeluKind::Reference, - )); - run_benchmark(CustomGeluBenchmark::::new( - shape.clone(), - num_repeats, - device.clone(), - GeluKind::WithReferenceErf, - )); - run_benchmark(CustomGeluBenchmark::::new( - shape, - num_repeats, - device.clone(), - GeluKind::WithCustomErf, - )); + const D: usize = 3; + let shape: Shape = [32, 512, 2048].into(); + let num_repeats = 1; + + println!("Backend {}", B::name()); + run_benchmark(CustomGeluBenchmark::::new( + shape.clone(), + num_repeats, + device.clone(), + GeluKind::Reference, + )); + run_benchmark(CustomGeluBenchmark::::new( + shape.clone(), + num_repeats, + device.clone(), + GeluKind::WithReferenceErf, + )); + run_benchmark(CustomGeluBenchmark::::new( + shape, + num_repeats, + device.clone(), + GeluKind::WithCustomErf, + )); } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/data.rs b/backend-comparison/benches/data.rs index df1f439e1e..e9379b3933 100644 --- a/backend-comparison/benches/data.rs +++ b/backend-comparison/benches/data.rs @@ -4,83 +4,83 @@ use derive_new::new; #[derive(new)] struct ToDataBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for ToDataBenchmark { - type Args = Tensor; + type Args = Tensor; - fn name(&self) -> String { - format!("to-data-{:?}-{}", self.shape.dims, self.num_repeats) - } + fn name(&self) -> String { + format!("to-data-{:?}-{}", self.shape.dims, self.num_repeats) + } - fn execute(&self, args: Self::Args) { - for _ in 0..self.num_repeats { - let _data = args.to_data(); + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + let _data = args.to_data(); + } } - } - fn prepare(&self) -> Self::Args { - Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } + fn prepare(&self) -> Self::Args { + Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[derive(new)] struct FromDataBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for FromDataBenchmark { - type Args = (Data, B::Device); + type Args = (Data, B::Device); + + fn name(&self) -> String { + format!("from-data-{:?}-{}", self.shape.dims, self.num_repeats) + } + + fn execute(&self, (data, device): Self::Args) { + for _ in 0..self.num_repeats { + let _data = Tensor::::from_data_device(data.clone(), &device); + } + } - fn name(&self) -> String { - format!("from-data-{:?}-{}", self.shape.dims, self.num_repeats) - } + fn prepare(&self) -> Self::Args { + ( + Data::random( + self.shape.clone(), + Distribution::Default, + &mut rand::thread_rng(), + ), + self.device.clone(), + ) + } - fn execute(&self, (data, device): Self::Args) { - for _ in 0..self.num_repeats { - let _data = Tensor::::from_data_device(data.clone(), &device); + fn sync(&self) { + B::sync(&self.device) } - } - - fn prepare(&self) -> Self::Args { - ( - Data::random( - self.shape.clone(), - Distribution::Default, - &mut rand::thread_rng(), - ), - self.device.clone(), - ) - } - - fn sync(&self) { - B::sync(&self.device) - } } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let shape: Shape = [32, 512, 1024].into(); - let num_repeats = 10; + const D: usize = 3; + let shape: Shape = [32, 512, 1024].into(); + let num_repeats = 10; - let to_benchmark = ToDataBenchmark::::new(shape.clone(), num_repeats, device.clone()); - let from_benchmark = FromDataBenchmark::::new(shape, num_repeats, device.clone()); + let to_benchmark = ToDataBenchmark::::new(shape.clone(), num_repeats, device.clone()); + let from_benchmark = FromDataBenchmark::::new(shape, num_repeats, device.clone()); - println!("Backend {}", B::name()); - run_benchmark(to_benchmark); - run_benchmark(from_benchmark) + println!("Backend {}", B::name()); + run_benchmark(to_benchmark); + run_benchmark(from_benchmark) } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 5114300afa..7574e21970 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -4,60 +4,62 @@ use derive_new::new; #[derive(new)] struct MatmulBenchmark { - shape_lhs: Shape, - shape_rhs: Shape, - num_repeats: usize, - device: B::Device, + shape_lhs: Shape, + shape_rhs: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for MatmulBenchmark { - type Args = (Tensor, Tensor); - - fn name(&self) -> String { - format!( - "Matmul {:?} x {:?}", - self.shape_lhs.dims, self.shape_rhs.dims - ) - } - - fn num_samples(&self) -> usize { - 10 - } - - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.num_repeats { - lhs.clone().matmul(rhs.clone()); + type Args = (Tensor, Tensor); + + fn name(&self) -> String { + format!( + "Matmul {:?} x {:?}", + self.shape_lhs.dims, self.shape_rhs.dims + ) + } + + fn num_samples(&self) -> usize { + 10 + } + + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.num_repeats { + lhs.clone().matmul(rhs.clone()); + } } - } - fn prepare(&self) -> Self::Args { - let lhs = Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); - let rhs = Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + let lhs = + Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); + let rhs = + Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let num_repeats = 3; - let batch_size = 3; - let m = 1024; - let k = 2048; - let n = 1024; - let shape_lhs = [batch_size, m, k].into(); - let shape_rhs = [batch_size, k, n].into(); - - let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, num_repeats, device.clone()); - println!("Backend {}", B::name()); - run_benchmark(benchmark); + const D: usize = 3; + let num_repeats = 3; + let batch_size = 3; + let m = 1024; + let k = 2048; + let n = 1024; + let shape_lhs = [batch_size, m, k].into(); + let shape_rhs = [batch_size, k, n].into(); + + let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, num_repeats, device.clone()); + println!("Backend {}", B::name()); + run_benchmark(benchmark); } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/benches/unary.rs b/backend-comparison/benches/unary.rs index b836f0a845..4befcdd2ad 100644 --- a/backend-comparison/benches/unary.rs +++ b/backend-comparison/benches/unary.rs @@ -4,46 +4,46 @@ use derive_new::new; #[derive(new)] struct UnaryBenchmark { - shape: Shape, - num_repeats: usize, - device: B::Device, + shape: Shape, + num_repeats: usize, + device: B::Device, } impl Benchmark for UnaryBenchmark { - type Args = Tensor; + type Args = Tensor; - fn name(&self) -> String { - "Unary Ops".into() - } + fn name(&self) -> String { + "Unary Ops".into() + } - fn execute(&self, args: Self::Args) { - for _ in 0..self.num_repeats { - // Choice of tanh is arbitrary - B::tanh(args.clone().into_primitive()); + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + // Choice of tanh is arbitrary + B::tanh(args.clone().into_primitive()); + } } - } - fn prepare(&self) -> Self::Args { - Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } + fn prepare(&self) -> Self::Args { + Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] fn bench(device: &B::Device) { - const D: usize = 3; - let shape: Shape = [32, 512, 1024].into(); - let num_repeats = 10; + const D: usize = 3; + let shape: Shape = [32, 512, 1024].into(); + let num_repeats = 10; - let benchmark = UnaryBenchmark::::new(shape, num_repeats, device.clone()); + let benchmark = UnaryBenchmark::::new(shape, num_repeats, device.clone()); - println!("Backend {}", B::name()); - run_benchmark(benchmark) + println!("Backend {}", B::name()); + run_benchmark(benchmark) } fn main() { - backend_comparison::bench_on_backend!(); + backend_comparison::bench_on_backend!(); } diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index c82e5e351f..065b50f418 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -1,70 +1,70 @@ #[macro_export] macro_rules! bench_on_backend { - () => { - #[cfg(feature = "wgpu-fusion")] - { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Fusion; + () => { + #[cfg(feature = "wgpu-fusion")] + { + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Fusion; - bench::>>(&WgpuDevice::default()); - } + bench::>>(&WgpuDevice::default()); + } - #[cfg(feature = "wgpu")] - { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + #[cfg(feature = "wgpu")] + { + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - bench::>(&WgpuDevice::default()); - } + bench::>(&WgpuDevice::default()); + } - #[cfg(feature = "tch-gpu")] - { - use burn::backend::{libtorch::LibTorchDevice, LibTorch}; + #[cfg(feature = "tch-gpu")] + { + use burn::backend::{libtorch::LibTorchDevice, LibTorch}; - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; - bench::(&device); - } + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + bench::(&device); + } - #[cfg(feature = "tch-cpu")] - { - use burn::backend::{libtorch::LibTorchDevice, LibTorch}; + #[cfg(feature = "tch-cpu")] + { + use burn::backend::{libtorch::LibTorchDevice, LibTorch}; - let device = LibTorchDevice::Cpu; - bench::(&device); - } + let device = LibTorchDevice::Cpu; + bench::(&device); + } - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - { - use burn::backend::ndarray::NdArrayDevice; - use burn::backend::NdArray; + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + { + use burn::backend::ndarray::NdArrayDevice; + use burn::backend::NdArray; - let device = NdArrayDevice::Cpu; - bench::(&device); - } + let device = NdArrayDevice::Cpu; + bench::(&device); + } - #[cfg(feature = "candle-cpu")] - { - use burn::backend::candle::CandleDevice; - use burn::backend::Candle; + #[cfg(feature = "candle-cpu")] + { + use burn::backend::candle::CandleDevice; + use burn::backend::Candle; - let device = CandleDevice::Cpu; - bench::(&device); - } + let device = CandleDevice::Cpu; + bench::(&device); + } - #[cfg(feature = "candle-cuda")] - { - use burn::backend::candle::CandleDevice; - use burn::backend::Candle; + #[cfg(feature = "candle-cuda")] + { + use burn::backend::candle::CandleDevice; + use burn::backend::Candle; - let device = CandleDevice::Cuda(0); - bench::(&device); - } - }; + let device = CandleDevice::Cuda(0); + bench::(&device); + } + }; } diff --git a/burn-autodiff/src/backend.rs b/burn-autodiff/src/backend.rs index 302fa2d7d2..e0039ae290 100644 --- a/burn-autodiff/src/backend.rs +++ b/burn-autodiff/src/backend.rs @@ -8,75 +8,75 @@ use core::marker::PhantomData; /// backpropagation. #[derive(Clone, Copy, Debug, Default)] pub struct Autodiff { - _b: PhantomData, + _b: PhantomData, } impl Backend for Autodiff { - type Device = B::Device; + type Device = B::Device; - type FullPrecisionElem = B::FullPrecisionElem; - type FullPrecisionBackend = Autodiff; + type FullPrecisionElem = B::FullPrecisionElem; + type FullPrecisionBackend = Autodiff; - type TensorPrimitive = AutodiffTensor; - type FloatElem = B::FloatElem; + type TensorPrimitive = AutodiffTensor; + type FloatElem = B::FloatElem; - type IntTensorPrimitive = B::IntTensorPrimitive; - type IntElem = B::IntElem; + type IntTensorPrimitive = B::IntTensorPrimitive; + type IntElem = B::IntElem; - type BoolTensorPrimitive = B::BoolTensorPrimitive; + type BoolTensorPrimitive = B::BoolTensorPrimitive; - fn ad_enabled() -> bool { - true - } + fn ad_enabled() -> bool { + true + } - fn name() -> String { - format!("autodiff<{}>", B::name()) - } + fn name() -> String { + format!("autodiff<{}>", B::name()) + } - fn seed(seed: u64) { - B::seed(seed) - } + fn seed(seed: u64) { + B::seed(seed) + } - fn sync(device: &B::Device) { - B::sync(device); - } + fn sync(device: &B::Device) { + B::sync(device); + } } impl AutodiffBackend for Autodiff { - type InnerBackend = B; - type Gradients = Gradients; - - fn backward(tensor: AutodiffTensor) -> Gradients { - backward(tensor) - } - - fn grad( - tensor: &AutodiffTensor, - grads: &Gradients, - ) -> Option> { - grads.get(tensor) - } - - fn grad_remove( - tensor: &AutodiffTensor, - grads: &mut Gradients, - ) -> Option> { - grads.remove(tensor) - } - fn inner(tensor: AutodiffTensor) -> B::TensorPrimitive { - tensor.primitive - } - - fn from_inner(tensor: B::TensorPrimitive) -> AutodiffTensor { - AutodiffTensor::new(tensor) - } - - fn grad_replace( - tensor: &AutodiffTensor, - grads: &mut Self::Gradients, - grad: B::TensorPrimitive, - ) { - grads.remove(tensor); - grads.register::(tensor.node.clone(), grad); - } + type InnerBackend = B; + type Gradients = Gradients; + + fn backward(tensor: AutodiffTensor) -> Gradients { + backward(tensor) + } + + fn grad( + tensor: &AutodiffTensor, + grads: &Gradients, + ) -> Option> { + grads.get(tensor) + } + + fn grad_remove( + tensor: &AutodiffTensor, + grads: &mut Gradients, + ) -> Option> { + grads.remove(tensor) + } + fn inner(tensor: AutodiffTensor) -> B::TensorPrimitive { + tensor.primitive + } + + fn from_inner(tensor: B::TensorPrimitive) -> AutodiffTensor { + AutodiffTensor::new(tensor) + } + + fn grad_replace( + tensor: &AutodiffTensor, + grads: &mut Self::Gradients, + grad: B::TensorPrimitive, + ) { + grads.remove(tensor); + grads.register::(tensor.node.clone(), grad); + } } diff --git a/burn-autodiff/src/grads.rs b/burn-autodiff/src/grads.rs index 1e090e6ea4..9e2c5b3e8c 100644 --- a/burn-autodiff/src/grads.rs +++ b/burn-autodiff/src/grads.rs @@ -1,8 +1,8 @@ use burn_tensor::{backend::Backend, container::TensorContainer, Tensor}; use crate::{ - graph::{NodeRef, Requirement}, - tensor::AutodiffTensor, + graph::{NodeRef, Requirement}, + tensor::AutodiffTensor, }; /// Gradient identifier. @@ -10,85 +10,81 @@ pub type GradID = u64; /// Gradients container used during the backward pass. pub struct Gradients { - container: TensorContainer, + container: TensorContainer, } type TensorPrimitive = ::TensorPrimitive; impl Gradients { - /// Creates a new gradients container. - pub fn new( - root_node: NodeRef, - root_tensor: TensorPrimitive, - ) -> Self { - let mut gradients = Self { - container: TensorContainer::new(), - }; - gradients.register::( - root_node, - B::ones(B::shape(&root_tensor), &B::device(&root_tensor)), - ); - gradients - } + /// Creates a new gradients container. + pub fn new( + root_node: NodeRef, + root_tensor: TensorPrimitive, + ) -> Self { + let mut gradients = Self { + container: TensorContainer::new(), + }; + gradients.register::( + root_node, + B::ones(B::shape(&root_tensor), &B::device(&root_tensor)), + ); + gradients + } - /// Consumes the gradients for a given tensor. - /// - /// Each tensor should be consumed exactly 1 time if its gradients are only required during the - /// backward pass, otherwise, it may be consume multiple times. - pub fn consume(&mut self, node: &NodeRef) -> TensorPrimitive { - match node.requirement { - Requirement::Grad => self - .container - .get::(&node.id.value) - .map(|tensor| tensor.into_primitive()) - .expect("Can't consume the gradients before they are registered at least once."), - Requirement::GradInBackward => self - .container - .remove::(&node.id.value) - .map(|tensor| tensor.into_primitive()) - .expect("Can't consume the gradients before they are registered at least once."), - Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"), + /// Consumes the gradients for a given tensor. + /// + /// Each tensor should be consumed exactly 1 time if its gradients are only required during the + /// backward pass, otherwise, it may be consume multiple times. + pub fn consume(&mut self, node: &NodeRef) -> TensorPrimitive { + match node.requirement { + Requirement::Grad => self + .container + .get::(&node.id.value) + .map(|tensor| tensor.into_primitive()) + .expect("Can't consume the gradients before they are registered at least once."), + Requirement::GradInBackward => self + .container + .remove::(&node.id.value) + .map(|tensor| tensor.into_primitive()) + .expect("Can't consume the gradients before they are registered at least once."), + Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"), + } } - } - /// Removes a grad tensor from the container. - pub fn remove( - &mut self, - tensor: &AutodiffTensor, - ) -> Option> { - self - .container - .remove::(&tensor.node.id.value) - .map(|tensor| tensor.into_primitive()) - } + /// Removes a grad tensor from the container. + pub fn remove( + &mut self, + tensor: &AutodiffTensor, + ) -> Option> { + self.container + .remove::(&tensor.node.id.value) + .map(|tensor| tensor.into_primitive()) + } - /// Gets a grad tensor from the container. - pub fn get( - &self, - tensor: &AutodiffTensor, - ) -> Option> { - self - .container - .get::(&tensor.node.id.value) - .map(|tensor| tensor.into_primitive()) - } + /// Gets a grad tensor from the container. + pub fn get( + &self, + tensor: &AutodiffTensor, + ) -> Option> { + self.container + .get::(&tensor.node.id.value) + .map(|tensor| tensor.into_primitive()) + } - /// Register a grad tensor in the container. - /// - /// If the tensor already exists, add both tensors together before saving the result. - pub fn register( - &mut self, - node: NodeRef, - value: TensorPrimitive, - ) { - if let Some(tensor_old) = self.container.remove::(&node.id.value) { - self - .container - .register(node.id.value, Tensor::from_primitive(value).add(tensor_old)); - } else { - self - .container - .register::(node.id.value, Tensor::from_primitive(value)); + /// Register a grad tensor in the container. + /// + /// If the tensor already exists, add both tensors together before saving the result. + pub fn register( + &mut self, + node: NodeRef, + value: TensorPrimitive, + ) { + if let Some(tensor_old) = self.container.remove::(&node.id.value) { + self.container + .register(node.id.value, Tensor::from_primitive(value).add(tensor_old)); + } else { + self.container + .register::(node.id.value, Tensor::from_primitive(value)); + } } - } } diff --git a/burn-autodiff/src/graph/backward.rs b/burn-autodiff/src/graph/backward.rs index ea1517c81c..ea1edf7cc0 100644 --- a/burn-autodiff/src/graph/backward.rs +++ b/burn-autodiff/src/graph/backward.rs @@ -5,34 +5,33 @@ use crate::{grads::Gradients, tensor::AutodiffTensor}; use super::{traversal::BreadthFirstSearch, Graph, NodeRef, StepBoxed}; pub fn backward(root: AutodiffTensor) -> Gradients { - let grads = Gradients::new::(root.node.clone(), root.primitive); - let tape = build_tape(root.node, root.graph); + let grads = Gradients::new::(root.node.clone(), root.primitive); + let tape = build_tape(root.node, root.graph); - execute_steps(tape, grads) + execute_steps(tape, grads) } fn build_tape(root: NodeRef, graph: Graph) -> Vec> { - let mut tape = (0..root.order) - .map(|_| Vec::with_capacity(1)) - .collect::>(); + let mut tape = (0..root.order) + .map(|_| Vec::with_capacity(1)) + .collect::>(); - BreadthFirstSearch.traverse(root, graph, |node, step| { - if node.order == 0 { - return; - } + BreadthFirstSearch.traverse(root, graph, |node, step| { + if node.order == 0 { + return; + } - if let Some(steps) = tape.get_mut(node.order - 1) { - steps.push(step) - }; - }); + if let Some(steps) = tape.get_mut(node.order - 1) { + steps.push(step) + }; + }); - tape + tape } fn execute_steps(tape: Vec>, mut grads: Gradients) -> Gradients { - tape - .into_iter() - .rev() - .for_each(|steps| steps.into_iter().for_each(|step| step.step(&mut grads))); - grads + tape.into_iter() + .rev() + .for_each(|steps| steps.into_iter().for_each(|step| step.step(&mut grads))); + grads } diff --git a/burn-autodiff/src/graph/base.rs b/burn-autodiff/src/graph/base.rs index 5c3830a57e..1e8e989f32 100644 --- a/burn-autodiff/src/graph/base.rs +++ b/burn-autodiff/src/graph/base.rs @@ -7,10 +7,10 @@ use super::{NodeID, NodeRef}; /// Backward step for reverse mode autodiff. pub trait Step: Send + Sync + std::fmt::Debug { - /// Executes the step and consumes it. - fn step(self: Box, grads: &mut Gradients); - /// The node associated to the step. - fn node(&self) -> NodeRef; + /// Executes the step and consumes it. + fn step(self: Box, grads: &mut Gradients); + /// The node associated to the step. + fn node(&self) -> NodeRef; } pub type StepBoxed = Box; @@ -21,76 +21,76 @@ pub type NodeSteps = HashMap; /// The graph contains the [node steps](Step), which can be access by [node id](NodeID). #[derive(Default, Clone, Debug)] pub struct Graph { - steps: Arc>, + steps: Arc>, } impl Graph { - /// Create a new graph. - pub fn new() -> Self { - Self::default() - } - - /// Get all the steps for the graph. - /// - /// # Notes - /// - /// This is a owned method, so the current graph will be freed. However, the steps can - /// be shared with other graphs, therefore they are going to be cleared. - /// - /// This is useful, since the graph is supposed to be consumed only once for backprop, and - /// keeping all the tensors alive for multiple backward call is a heavy waste of resources. - pub fn steps(self) -> NodeSteps { - let mut map_drain = HashMap::new(); - self.execute_mut(|map| { - std::mem::swap(&mut *map, &mut map_drain); - }); - map_drain - } + /// Create a new graph. + pub fn new() -> Self { + Self::default() + } - /// Register a new step into the graph. - pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self { - self.execute_mut(|map| { - map.insert(id.clone(), ops); - }) - } + /// Get all the steps for the graph. + /// + /// # Notes + /// + /// This is a owned method, so the current graph will be freed. However, the steps can + /// be shared with other graphs, therefore they are going to be cleared. + /// + /// This is useful, since the graph is supposed to be consumed only once for backprop, and + /// keeping all the tensors alive for multiple backward call is a heavy waste of resources. + pub fn steps(self) -> NodeSteps { + let mut map_drain = HashMap::new(); + self.execute_mut(|map| { + std::mem::swap(&mut *map, &mut map_drain); + }); + map_drain + } - /// Merge two graphs. - pub fn merge(self, other: Self) -> Self { - if Arc::ptr_eq(&self.steps, &other.steps) { - return self; + /// Register a new step into the graph. + pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self { + self.execute_mut(|map| { + map.insert(id.clone(), ops); + }) } - self.merge_different(other) - } + /// Merge two graphs. + pub fn merge(self, other: Self) -> Self { + if Arc::ptr_eq(&self.steps, &other.steps) { + return self; + } - fn execute_mut(mut self, func: F) -> Self { - match Arc::get_mut(&mut self.steps) { - Some(mutex) => { - let map = mutex.get_mut(); - func(map); - } - None => { - // Only lock when there are multiple references to the graph. - let mut map = self.steps.lock(); - func(&mut map); - } - }; + self.merge_different(other) + } - self - } + fn execute_mut(mut self, func: F) -> Self { + match Arc::get_mut(&mut self.steps) { + Some(mutex) => { + let map = mutex.get_mut(); + func(map); + } + None => { + // Only lock when there are multiple references to the graph. + let mut map = self.steps.lock(); + func(&mut map); + } + }; - fn merge_different(self, other: Self) -> Self { - let mut map2 = other.steps(); + self + } - self.execute_mut(|map1| { - if map1.len() > map2.len() { - map1.extend(map2); - } else { - let mut map_drain = HashMap::new(); - std::mem::swap(map1, &mut map_drain); - map2.extend(map_drain); - std::mem::swap(map1, &mut map2); - } - }) - } + fn merge_different(self, other: Self) -> Self { + let mut map2 = other.steps(); + + self.execute_mut(|map1| { + if map1.len() > map2.len() { + map1.extend(map2); + } else { + let mut map_drain = HashMap::new(); + std::mem::swap(map1, &mut map_drain); + map2.extend(map_drain); + std::mem::swap(map1, &mut map2); + } + }) + } } diff --git a/burn-autodiff/src/graph/node.rs b/burn-autodiff/src/graph/node.rs index 7c448742b3..38665408ec 100644 --- a/burn-autodiff/src/graph/node.rs +++ b/burn-autodiff/src/graph/node.rs @@ -6,43 +6,43 @@ use super::Requirement; /// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning. #[derive(new, Debug)] pub struct Node { - pub parents: Vec, - pub order: usize, - pub id: NodeID, - pub requirement: Requirement, + pub parents: Vec, + pub order: usize, + pub id: NodeID, + pub requirement: Requirement, } pub type NodeRef = Arc; impl Node { - /// Returns the [node](Node) only if gradients are required. - pub fn clone_if_require_grad(self: &Arc) -> Option { - match self.requirement.is_none() { - true => None, - false => Some(self.clone()), + /// Returns the [node](Node) only if gradients are required. + pub fn clone_if_require_grad(self: &Arc) -> Option { + match self.requirement.is_none() { + true => None, + false => Some(self.clone()), + } } - } } /// Unique identifier generated for each [node](Node). #[derive(Clone, Hash, PartialEq, Eq, Debug)] pub struct NodeID { - pub value: u64, + pub value: u64, } impl NodeID { - /// Create a unique [node id](NodeID). - pub fn new() -> Self { - static COUNTER: AtomicU64 = AtomicU64::new(0); - let value = COUNTER.fetch_add(1, Ordering::Relaxed); - if value == u64::MAX { - panic!("NodeID overflowed"); + /// Create a unique [node id](NodeID). + pub fn new() -> Self { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let value = COUNTER.fetch_add(1, Ordering::Relaxed); + if value == u64::MAX { + panic!("NodeID overflowed"); + } + Self { value } } - Self { value } - } } impl Default for NodeID { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } diff --git a/burn-autodiff/src/graph/requirement.rs b/burn-autodiff/src/graph/requirement.rs index 9d405b9562..c2825ff131 100644 --- a/burn-autodiff/src/graph/requirement.rs +++ b/burn-autodiff/src/graph/requirement.rs @@ -3,32 +3,32 @@ use super::NodeRef; /// Requirement for each tensor in the graph. #[derive(Debug, Clone, Copy)] pub enum Requirement { - /// Operations that require gradients. - Grad, - /// Operations that require gradients only for backprop. - GradInBackward, - /// Operations that don't need gradients, therefore not to be included in the graph. - None, + /// Operations that require gradients. + Grad, + /// Operations that require gradients only for backprop. + GradInBackward, + /// Operations that don't need gradients, therefore not to be included in the graph. + None, } impl Requirement { - /// Returns true if gradients are not required. - pub fn is_none(&self) -> bool { - matches!(self, Self::None) - } - /// Returns the right requirement from a list of nodes. - pub fn from_nodes(nodes: &[NodeRef]) -> Self { - nodes - .iter() - .map(|node| node.requirement) - .reduce(|acc, requirement| requirement.infer(&acc)) - .unwrap_or(Requirement::None) - } + /// Returns true if gradients are not required. + pub fn is_none(&self) -> bool { + matches!(self, Self::None) + } + /// Returns the right requirement from a list of nodes. + pub fn from_nodes(nodes: &[NodeRef]) -> Self { + nodes + .iter() + .map(|node| node.requirement) + .reduce(|acc, requirement| requirement.infer(&acc)) + .unwrap_or(Requirement::None) + } - fn infer(&self, other: &Self) -> Self { - match self.is_none() && other.is_none() { - true => Self::None, - false => Self::GradInBackward, + fn infer(&self, other: &Self) -> Self { + match self.is_none() && other.is_none() { + true => Self::None, + false => Self::GradInBackward, + } } - } } diff --git a/burn-autodiff/src/graph/traversal.rs b/burn-autodiff/src/graph/traversal.rs index de9e10ed3b..eefebe2b78 100644 --- a/burn-autodiff/src/graph/traversal.rs +++ b/burn-autodiff/src/graph/traversal.rs @@ -6,45 +6,45 @@ use super::{Graph, NodeRef, StepBoxed}; pub struct BreadthFirstSearch; impl BreadthFirstSearch { - /// Traverse the graph of backward steps from a root node. - pub fn traverse( - &self, - root: NodeRef, - graph: Graph, - mut callback: F, - ) { - let mut visited = HashSet::with_capacity(root.order); - let mut parents = Vec::with_capacity(root.order); - let mut steps = graph.steps(); - let root_step = steps + /// Traverse the graph of backward steps from a root node. + pub fn traverse( + &self, + root: NodeRef, + graph: Graph, + mut callback: F, + ) { + let mut visited = HashSet::with_capacity(root.order); + let mut parents = Vec::with_capacity(root.order); + let mut steps = graph.steps(); + let root_step = steps .remove(&root.id) .expect("Root node should have a step registered, did you forget to call `Tensor::register_grad` on the tensor where you need gradients?"); - visited.insert(root.id.clone()); - parents.append(&mut root.parents.clone()); - callback(root, root_step); + visited.insert(root.id.clone()); + parents.append(&mut root.parents.clone()); + callback(root, root_step); - while let Some(id) = parents.pop() { - let step = match steps.remove(&id) { - Some(step) => step, - None => continue, - }; + while let Some(id) = parents.pop() { + let step = match steps.remove(&id) { + Some(step) => step, + None => continue, + }; - let node = step.node(); + let node = step.node(); - if visited.contains(&node.id) { - continue; - } + if visited.contains(&node.id) { + continue; + } - visited.insert(node.id.clone()); + visited.insert(node.id.clone()); - for id in node.parents.iter() { - if !visited.contains(id) { - parents.push(id.clone()); - } - } + for id in node.parents.iter() { + if !visited.contains(id) { + parents.push(id.clone()); + } + } - callback(node, step); + callback(node, step); + } } - } } diff --git a/burn-autodiff/src/ops/activation.rs b/burn-autodiff/src/ops/activation.rs index 965f13a64e..0a34499f28 100644 --- a/burn-autodiff/src/ops/activation.rs +++ b/burn-autodiff/src/ops/activation.rs @@ -1,57 +1,57 @@ use crate::{ - grads::Gradients, - ops::{unary, Backward, Ops, OpsKind}, - Autodiff, + grads::Gradients, + ops::{unary, Backward, Ops, OpsKind}, + Autodiff, }; use burn_tensor::{ - backend::Backend, - ops::{ActivationOps, FloatTensor}, + backend::Backend, + ops::{ActivationOps, FloatTensor}, }; impl ActivationOps> for Autodiff { - fn gelu(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Gelu; - - impl Backward for Gelu { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let input = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - B::gelu_backward(input, grad) - }); - } - } - - match Gelu::.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::gelu(tensor.primitive.clone()); - prep.finish(tensor.primitive, output) - } - OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)), - } - } - - fn relu(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Relu; - - impl Backward for Relu { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::relu_backward(ops.state, grad) - }); - } + fn gelu(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Gelu; + + impl Backward for Gelu { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let input = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + B::gelu_backward(input, grad) + }); + } + } + + match Gelu::.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::gelu(tensor.primitive.clone()); + prep.finish(tensor.primitive, output) + } + OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)), + } } - let output = B::relu(tensor.primitive); - match Relu.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(output.clone(), output), - OpsKind::UnTracked(prep) => prep.finish(output), + fn relu(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Relu; + + impl Backward for Relu { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::relu_backward(ops.state, grad) + }); + } + } + let output = B::relu(tensor.primitive); + + match Relu.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(output.clone(), output), + OpsKind::UnTracked(prep) => prep.finish(output), + } } - } } diff --git a/burn-autodiff/src/ops/backward.rs b/burn-autodiff/src/ops/backward.rs index dd2a71954f..53e251c4f7 100644 --- a/burn-autodiff/src/ops/backward.rs +++ b/burn-autodiff/src/ops/backward.rs @@ -1,8 +1,8 @@ use super::{Ops, OpsPrep}; use crate::{ - grads::Gradients, - graph::{Graph, NodeRef, Requirement}, - utils::duplicate, + grads::Gradients, + graph::{Graph, NodeRef, Requirement}, + utils::duplicate, }; use burn_tensor::backend::Backend; @@ -15,84 +15,88 @@ use burn_tensor::backend::Backend; /// they should be declared with the associated type 'State'. pub trait Backward: Send + Sync + std::fmt::Debug where - Self: Sized + 'static, - B: Backend, + Self: Sized + 'static, + B: Backend, { - /// Associated type to compute the backward pass. - type State: Clone + Send + Sync + std::fmt::Debug + 'static; + /// Associated type to compute the backward pass. + type State: Clone + Send + Sync + std::fmt::Debug + 'static; - /// The backward pass. - fn backward(self, ops: Ops, grads: &mut Gradients); + /// The backward pass. + fn backward(self, ops: Ops, grads: &mut Gradients); - /// Prepare the backward ops. - fn prepare(self, nodes: [NodeRef; N], graphs: [Graph; N]) -> OpsPrep { - let requirement = Requirement::from_nodes(&nodes); - OpsPrep::new(nodes, graphs, requirement, self) - } + /// Prepare the backward ops. + fn prepare( + self, + nodes: [NodeRef; N], + graphs: [Graph; N], + ) -> OpsPrep { + let requirement = Requirement::from_nodes(&nodes); + OpsPrep::new(nodes, graphs, requirement, self) + } } /// Execute a binary operation during the backward step. pub fn binary( - parents: [Option; 2], - node: NodeRef, - grads: &mut Gradients, - func_lhs: FLhs, - func_rhs: FRhs, + parents: [Option; 2], + node: NodeRef, + grads: &mut Gradients, + func_lhs: FLhs, + func_rhs: FRhs, ) where - B: Backend, - FLhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, - FRhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, + B: Backend, + FLhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, + FRhs: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, { - let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::(&node))); - let [node_lhs, node_rhs] = parents; + let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::(&node))); + let [node_lhs, node_rhs] = parents; - if let Some(node) = node_lhs { - let grad = func_lhs(grad_4lhs.unwrap()); - grads.register::(node, grad) - } + if let Some(node) = node_lhs { + let grad = func_lhs(grad_4lhs.unwrap()); + grads.register::(node, grad) + } - if let Some(node) = node_rhs { - let grad = func_rhs(grad_4rhs.unwrap()); - grads.register::(node, grad) - } + if let Some(node) = node_rhs { + let grad = func_rhs(grad_4rhs.unwrap()); + grads.register::(node, grad) + } } /// Execute a unary operation during the backward step. pub fn unary( - parents: [Option; 1], - node: NodeRef, - grads: &mut Gradients, - func: F, + parents: [Option; 1], + node: NodeRef, + grads: &mut Gradients, + func: F, ) where - B: Backend, - F: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, + B: Backend, + F: FnOnce(B::TensorPrimitive) -> B::TensorPrimitive, { - let [parent_node] = parents; - let grad = grads.consume::(&node); + let [parent_node] = parents; + let grad = grads.consume::(&node); - if let Some(node) = parent_node { - let grad = func(grad); - grads.register::(node, grad) - } + if let Some(node) = parent_node { + let grad = func(grad); + grads.register::(node, grad) + } } /// Execute a unary operation during the backward step where the input backend /// is different from the output backend. pub fn unary_different_backend( - parents: [Option; 1], - node: NodeRef, - grads: &mut Gradients, - func: F, + parents: [Option; 1], + node: NodeRef, + grads: &mut Gradients, + func: F, ) where - BIn: Backend, - BOut: Backend, - F: FnOnce(BOut::TensorPrimitive) -> BIn::TensorPrimitive, + BIn: Backend, + BOut: Backend, + F: FnOnce(BOut::TensorPrimitive) -> BIn::TensorPrimitive, { - let [parent_node] = parents; - let grad = grads.consume::(&node); + let [parent_node] = parents; + let grad = grads.consume::(&node); - if let Some(node) = parent_node { - let grad = func(grad); - grads.register::(node, grad) - } + if let Some(node) = parent_node { + let grad = func(grad); + grads.register::(node, grad) + } } diff --git a/burn-autodiff/src/ops/base.rs b/burn-autodiff/src/ops/base.rs index 2081d49473..d212525943 100644 --- a/burn-autodiff/src/ops/base.rs +++ b/burn-autodiff/src/ops/base.rs @@ -1,10 +1,10 @@ use super::Backward; use crate::{ - grads::Gradients, - graph::{ - NodeRef, Requirement, {Graph, Step}, - }, - tensor::AutodiffTensor, + grads::Gradients, + graph::{ + NodeRef, Requirement, {Graph, Step}, + }, + tensor::AutodiffTensor, }; use burn_tensor::{backend::Backend, Shape}; use std::marker::PhantomData; @@ -15,13 +15,13 @@ use std::marker::PhantomData; /// Each mode has its own set of functions to minimize cloning for unused backward states. #[derive(new)] pub struct OpsPrep { - nodes: [NodeRef; N], - graphs: [Graph; N], - requirement: Requirement, - backward: Backward, - phantom_backend: PhantomData, - phantom_state: PhantomData, - marker: PhantomData, + nodes: [NodeRef; N], + graphs: [Graph; N], + requirement: Requirement, + backward: Backward, + phantom_backend: PhantomData, + phantom_state: PhantomData, + marker: PhantomData, } /// Init operation tag. @@ -33,130 +33,130 @@ pub struct UnTracked; impl OpsPrep where - B: Backend, - BO: Backward, + B: Backend, + BO: Backward, { - /// Prepare a stateless operation. - pub fn stateless(self, output: ::TensorPrimitive) -> AutodiffTensor { - match self.stateful() { - OpsKind::Tracked(prep) => prep.finish((), output), - OpsKind::UnTracked(prep) => prep.finish(output), + /// Prepare a stateless operation. + pub fn stateless(self, output: ::TensorPrimitive) -> AutodiffTensor { + match self.stateful() { + OpsKind::Tracked(prep) => prep.finish((), output), + OpsKind::UnTracked(prep) => prep.finish(output), + } } - } } impl OpsPrep where - B: Backend, - S: Clone + Send + Sync + std::fmt::Debug + 'static, - BO: Backward, + B: Backend, + S: Clone + Send + Sync + std::fmt::Debug + 'static, + BO: Backward, { - /// Prepare an operation that requires a state during the backward pass. - pub fn stateful(self) -> OpsKind { - match self.requirement.is_none() { - false => OpsKind::Tracked(OpsPrep::new( - self.nodes, - self.graphs, - self.requirement, - self.backward, - )), - true => OpsKind::UnTracked(OpsPrep::new( - self.nodes, - self.graphs, - self.requirement, - self.backward, - )), + /// Prepare an operation that requires a state during the backward pass. + pub fn stateful(self) -> OpsKind { + match self.requirement.is_none() { + false => OpsKind::Tracked(OpsPrep::new( + self.nodes, + self.graphs, + self.requirement, + self.backward, + )), + true => OpsKind::UnTracked(OpsPrep::new( + self.nodes, + self.graphs, + self.requirement, + self.backward, + )), + } } - } } impl OpsPrep where - B: Backend, - S: Clone + Send + Sync + std::fmt::Debug + 'static, - BO: Backward, + B: Backend, + S: Clone + Send + Sync + std::fmt::Debug + 'static, + BO: Backward, { - /// Finish the preparation of an untracked operation and returns the output tensor. - pub fn finish(self, output: ::TensorPrimitive) -> AutodiffTensor { - AutodiffTensor::from_parents( - output, - &self.nodes, - self.graphs.into_iter(), - self.requirement, - ) - } + /// Finish the preparation of an untracked operation and returns the output tensor. + pub fn finish(self, output: ::TensorPrimitive) -> AutodiffTensor { + AutodiffTensor::from_parents( + output, + &self.nodes, + self.graphs.into_iter(), + self.requirement, + ) + } } impl OpsPrep where - B: Backend, - S: Clone + Send + Sync + std::fmt::Debug + 'static, - BO: Backward, + B: Backend, + S: Clone + Send + Sync + std::fmt::Debug + 'static, + BO: Backward, { - /// Finish the preparation of a tracked operation and returns the output tensor. - pub fn finish( - self, - state: S, - output: ::TensorPrimitive, - ) -> AutodiffTensor { - let output = AutodiffTensor::from_parents( - output, - &self.nodes, - self.graphs.into_iter(), - self.requirement, - ); - let parents = self.nodes.map(|node| node.clone_if_require_grad()); - let ops = Ops::new(parents, output.node.clone(), state); + /// Finish the preparation of a tracked operation and returns the output tensor. + pub fn finish( + self, + state: S, + output: ::TensorPrimitive, + ) -> AutodiffTensor { + let output = AutodiffTensor::from_parents( + output, + &self.nodes, + self.graphs.into_iter(), + self.requirement, + ); + let parents = self.nodes.map(|node| node.clone_if_require_grad()); + let ops = Ops::new(parents, output.node.clone(), state); - output.register_step(OpsStep::new(ops, self.backward)) - } + output.register_step(OpsStep::new(ops, self.backward)) + } } /// Enum used before finishing tracked and untracked operations. pub enum OpsKind { - /// Tracked operation preparation. - Tracked(OpsPrep), - /// Untracked operation preparation. - UnTracked(OpsPrep), + /// Tracked operation preparation. + Tracked(OpsPrep), + /// Untracked operation preparation. + UnTracked(OpsPrep), } /// Operation containing its parent nodes, its own node and the backward step state. #[derive(new, Debug)] pub struct Ops { - /// Parents nodes. - pub parents: [Option; N], - /// The node. - pub node: NodeRef, - /// The state. - pub state: S, + /// Parents nodes. + pub parents: [Option; N], + /// The node. + pub node: NodeRef, + /// The state. + pub state: S, } /// Operation implementing backward [step](Step) with type erasing. #[derive(new, Debug)] struct OpsStep where - B: Backend, - T: Backward, - SB: Clone + Send + Sync + std::fmt::Debug + 'static, + B: Backend, + T: Backward, + SB: Clone + Send + Sync + std::fmt::Debug + 'static, { - ops: Ops, - backward: T, - phantom: PhantomData, + ops: Ops, + backward: T, + phantom: PhantomData, } impl Step for OpsStep where - B: Backend, - T: Backward, - SB: Clone + Send + Sync + std::fmt::Debug + 'static, + B: Backend, + T: Backward, + SB: Clone + Send + Sync + std::fmt::Debug + 'static, { - fn step(self: Box, grads: &mut Gradients) { - self.backward.backward(self.ops, grads); - } + fn step(self: Box, grads: &mut Gradients) { + self.backward.backward(self.ops, grads); + } - fn node(&self) -> NodeRef { - self.ops.node.clone() - } + fn node(&self) -> NodeRef { + self.ops.node.clone() + } } /// Make sure the grad tensor has the given shape. @@ -164,22 +164,22 @@ where /// If broadcasting happened during the forward pass, the gradients will be sum along the /// broadcasted dimension. pub fn broadcast_shape( - mut grad: B::TensorPrimitive, - shape: &Shape, + mut grad: B::TensorPrimitive, + shape: &Shape, ) -> B::TensorPrimitive { - let shape_grad = B::shape(&grad); + let shape_grad = B::shape(&grad); - for i in 0..D { - if shape_grad.dims[i] != shape.dims[i] { - if shape.dims[i] != 1 { - panic!( - "Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}", - shape.dims, shape_grad.dims, "Expected the shape of the next grad to be 1." - ); - } - grad = B::sum_dim(grad, i); + for i in 0..D { + if shape_grad.dims[i] != shape.dims[i] { + if shape.dims[i] != 1 { + panic!( + "Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}", + shape.dims, shape_grad.dims, "Expected the shape of the next grad to be 1." + ); + } + grad = B::sum_dim(grad, i); + } } - } - grad + grad } diff --git a/burn-autodiff/src/ops/bool_tensor.rs b/burn-autodiff/src/ops/bool_tensor.rs index bc5685d642..59b342f840 100644 --- a/burn-autodiff/src/ops/bool_tensor.rs +++ b/burn-autodiff/src/ops/bool_tensor.rs @@ -1,92 +1,95 @@ use crate::{tensor::AutodiffTensor, Autodiff}; use burn_tensor::{ - backend::Backend, - ops::{BoolTensor, BoolTensorOps, IntTensor}, - Data, Device, Reader, Shape, + backend::Backend, + ops::{BoolTensor, BoolTensorOps, IntTensor}, + Data, Device, Reader, Shape, }; impl BoolTensorOps for Autodiff { - fn bool_from_data(data: Data, device: &Device) -> BoolTensor { - B::bool_from_data(data, device) - } - - fn bool_shape(tensor: &BoolTensor) -> Shape { - B::bool_shape(tensor) - } - - fn bool_to_data(tensor: &BoolTensor) -> Reader> { - B::bool_to_data(tensor) - } - - fn bool_into_data(tensor: BoolTensor) -> Reader> { - B::bool_into_data(tensor) - } - - fn bool_into_int(tensor: BoolTensor) -> IntTensor { - B::bool_into_int(tensor) - } - - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - B::bool_to_device(tensor, device) - } - - fn bool_device(tensor: &BoolTensor) -> Device { - B::bool_device(tensor) - } - - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - B::bool_reshape(tensor, shape) - } - - fn bool_slice( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - ) -> BoolTensor { - B::bool_slice(tensor, ranges) - } - - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - B::bool_empty(shape, device) - } - - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - value: BoolTensor, - ) -> BoolTensor { - B::bool_slice_assign(tensor, ranges, value) - } - - fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { - B::bool_cat(tensors, dim) - } - - fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { - B::bool_equal(lhs, rhs) - } - - fn bool_not(tensor: BoolTensor) -> BoolTensor { - B::bool_not(tensor) - } - - fn bool_into_float( - tensor: BoolTensor, - ) -> as Backend>::TensorPrimitive { - AutodiffTensor::new(B::bool_into_float(tensor)) - } - - fn bool_swap_dims( - tensor: as Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::BoolTensorPrimitive { - B::bool_swap_dims(tensor, dim1, dim2) - } + fn bool_from_data(data: Data, device: &Device) -> BoolTensor { + B::bool_from_data(data, device) + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + B::bool_shape(tensor) + } + + fn bool_to_data(tensor: &BoolTensor) -> Reader> { + B::bool_to_data(tensor) + } + + fn bool_into_data(tensor: BoolTensor) -> Reader> { + B::bool_into_data(tensor) + } + + fn bool_into_int(tensor: BoolTensor) -> IntTensor { + B::bool_into_int(tensor) + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + B::bool_to_device(tensor, device) + } + + fn bool_device(tensor: &BoolTensor) -> Device { + B::bool_device(tensor) + } + + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + B::bool_reshape(tensor, shape) + } + + fn bool_slice( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + ) -> BoolTensor { + B::bool_slice(tensor, ranges) + } + + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + B::bool_empty(shape, device) + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + value: BoolTensor, + ) -> BoolTensor { + B::bool_slice_assign(tensor, ranges, value) + } + + fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { + B::bool_cat(tensors, dim) + } + + fn bool_equal( + lhs: BoolTensor, + rhs: BoolTensor, + ) -> BoolTensor { + B::bool_equal(lhs, rhs) + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + B::bool_not(tensor) + } + + fn bool_into_float( + tensor: BoolTensor, + ) -> as Backend>::TensorPrimitive { + AutodiffTensor::new(B::bool_into_float(tensor)) + } + + fn bool_swap_dims( + tensor: as Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::BoolTensorPrimitive { + B::bool_swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-autodiff/src/ops/int_tensor.rs b/burn-autodiff/src/ops/int_tensor.rs index 48d6ab460c..a7af0420e1 100644 --- a/burn-autodiff/src/ops/int_tensor.rs +++ b/burn-autodiff/src/ops/int_tensor.rs @@ -1,319 +1,319 @@ use crate::{tensor::AutodiffTensor, Autodiff}; use burn_tensor::{ - backend::Backend, - ops::{BoolTensor, IntTensor, IntTensorOps}, - Data, Device, Reader, Shape, + backend::Backend, + ops::{BoolTensor, IntTensor, IntTensorOps}, + Data, Device, Reader, Shape, }; impl IntTensorOps> for Autodiff { - fn int_from_data( - data: Data, - device: &Device, - ) -> IntTensor { - B::int_from_data(data, device) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - B::int_shape(tensor) - } - - fn int_to_data(tensor: &IntTensor) -> Reader> { - B::int_to_data(tensor) - } - - fn int_into_data(tensor: IntTensor) -> Reader> { - B::int_into_data(tensor) - } - - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - B::int_to_device(tensor, device) - } - - fn int_device(tensor: &IntTensor) -> Device { - B::int_device(tensor) - } - - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - B::int_reshape(tensor, shape) - } - - fn int_slice( - tensor: IntTensor, - ranges: [std::ops::Range; D2], - ) -> IntTensor { - B::int_slice(tensor, ranges) - } - - fn int_empty( - shape: Shape, - device: & as Backend>::Device, - ) -> IntTensor { - B::int_empty(shape, device) - } - - fn int_slice_assign( - tensor: IntTensor, - ranges: [std::ops::Range; D2], - value: IntTensor, - ) -> IntTensor { - B::int_slice_assign(tensor, ranges, value) - } - - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - B::int_cat(tensors, dim) - } - - fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - B::int_equal(lhs, rhs) - } - - fn int_equal_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { - B::int_equal_elem(lhs, rhs) - } - - fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_add(lhs, rhs) - } - - fn int_add_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_add_scalar(lhs, rhs) - } - - fn int_clamp_min(tensor: IntTensor, min: B::IntElem) -> IntTensor { - B::int_clamp_min(tensor, min) - } - - fn int_clamp_max(tensor: IntTensor, max: B::IntElem) -> IntTensor { - B::int_clamp_max(tensor, max) - } - - fn int_clamp( - tensor: IntTensor, - min: B::IntElem, - max: B::IntElem, - ) -> IntTensor { - B::int_clamp(tensor, min, max) - } - - fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_sub(lhs, rhs) - } - - fn int_sub_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_sub_scalar(lhs, rhs) - } - - fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_mul(lhs, rhs) - } - - fn int_mul_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_mul_scalar(lhs, rhs) - } - - fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - B::int_div(lhs, rhs) - } - - fn int_div_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { - B::int_div_scalar(lhs, rhs) - } - - fn int_neg(tensor: IntTensor) -> IntTensor { - B::int_neg(tensor) - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - B::int_zeros(shape, device) - } - - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - B::int_ones(shape, device) - } - - fn int_full( - shape: Shape, - fill_value: B::IntElem, - device: &Device, - ) -> IntTensor { - B::int_full(shape, fill_value, device) - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - B::int_sum(tensor) - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_sum_dim(tensor, dim) - } - - fn int_mean(tensor: IntTensor) -> IntTensor { - B::int_mean(tensor) - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_mean_dim(tensor, dim) - } - - fn int_repeat( - tensor: IntTensor, - dim: usize, - times: usize, - ) -> IntTensor { - B::int_repeat(tensor, dim, times) - } - - fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - B::int_greater(lhs, rhs) - } - - fn int_greater_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { - B::int_greater_elem(lhs, rhs) - } - - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - B::int_greater_equal(lhs, rhs) - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: B::IntElem, - ) -> BoolTensor { - B::int_greater_equal_elem(lhs, rhs) - } - - fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - B::int_lower(lhs, rhs) - } - - fn int_lower_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { - B::int_lower_elem(lhs, rhs) - } - - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - B::int_lower_equal(lhs, rhs) - } - - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: B::IntElem, - ) -> BoolTensor { - B::int_lower_equal_elem(lhs, rhs) - } - - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - B::int_gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - B::int_scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - B::int_select(tensor, dim, indices) - } - - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - B::int_select_assign(tensor, dim, indices, value) - } - - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - value: IntTensor, - ) -> as Backend>::IntTensorPrimitive { - B::int_mask_where(tensor, mask, value) - } - - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: B::IntElem, - ) -> as Backend>::IntTensorPrimitive { - B::int_mask_fill(tensor, mask, value) - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_argmax(tensor, dim) - } - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - B::int_argmin(tensor, dim) - } - fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { - B::int_max(tensor) - } - fn int_max_dim( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> B::IntTensorPrimitive { - B::int_max_dim(tensor, dim) - } - fn int_max_dim_with_indices( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { - B::int_max_dim_with_indices(tensor, dim) - } - fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { - B::int_min(tensor) - } - fn int_min_dim( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> B::IntTensorPrimitive { - B::int_min_dim(tensor, dim) - } - fn int_min_dim_with_indices( - tensor: B::IntTensorPrimitive, - dim: usize, - ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { - B::int_min_dim_with_indices(tensor, dim) - } - fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive { - B::int_abs(tensor) - } - fn int_into_float( - tensor: as Backend>::IntTensorPrimitive, - ) -> as Backend>::TensorPrimitive { - AutodiffTensor::new(B::int_into_float(tensor)) - } - - fn int_swap_dims( - tensor: as Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::IntTensorPrimitive { - B::int_swap_dims(tensor, dim1, dim2) - } + fn int_from_data( + data: Data, + device: &Device, + ) -> IntTensor { + B::int_from_data(data, device) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + B::int_shape(tensor) + } + + fn int_to_data(tensor: &IntTensor) -> Reader> { + B::int_to_data(tensor) + } + + fn int_into_data(tensor: IntTensor) -> Reader> { + B::int_into_data(tensor) + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + B::int_to_device(tensor, device) + } + + fn int_device(tensor: &IntTensor) -> Device { + B::int_device(tensor) + } + + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + B::int_reshape(tensor, shape) + } + + fn int_slice( + tensor: IntTensor, + ranges: [std::ops::Range; D2], + ) -> IntTensor { + B::int_slice(tensor, ranges) + } + + fn int_empty( + shape: Shape, + device: & as Backend>::Device, + ) -> IntTensor { + B::int_empty(shape, device) + } + + fn int_slice_assign( + tensor: IntTensor, + ranges: [std::ops::Range; D2], + value: IntTensor, + ) -> IntTensor { + B::int_slice_assign(tensor, ranges, value) + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + B::int_cat(tensors, dim) + } + + fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + B::int_equal(lhs, rhs) + } + + fn int_equal_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { + B::int_equal_elem(lhs, rhs) + } + + fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_add(lhs, rhs) + } + + fn int_add_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_add_scalar(lhs, rhs) + } + + fn int_clamp_min(tensor: IntTensor, min: B::IntElem) -> IntTensor { + B::int_clamp_min(tensor, min) + } + + fn int_clamp_max(tensor: IntTensor, max: B::IntElem) -> IntTensor { + B::int_clamp_max(tensor, max) + } + + fn int_clamp( + tensor: IntTensor, + min: B::IntElem, + max: B::IntElem, + ) -> IntTensor { + B::int_clamp(tensor, min, max) + } + + fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_sub(lhs, rhs) + } + + fn int_sub_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_sub_scalar(lhs, rhs) + } + + fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_mul(lhs, rhs) + } + + fn int_mul_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_mul_scalar(lhs, rhs) + } + + fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_div(lhs, rhs) + } + + fn int_div_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::int_div_scalar(lhs, rhs) + } + + fn int_neg(tensor: IntTensor) -> IntTensor { + B::int_neg(tensor) + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + B::int_zeros(shape, device) + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + B::int_ones(shape, device) + } + + fn int_full( + shape: Shape, + fill_value: B::IntElem, + device: &Device, + ) -> IntTensor { + B::int_full(shape, fill_value, device) + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + B::int_sum(tensor) + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_sum_dim(tensor, dim) + } + + fn int_mean(tensor: IntTensor) -> IntTensor { + B::int_mean(tensor) + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_mean_dim(tensor, dim) + } + + fn int_repeat( + tensor: IntTensor, + dim: usize, + times: usize, + ) -> IntTensor { + B::int_repeat(tensor, dim, times) + } + + fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + B::int_greater(lhs, rhs) + } + + fn int_greater_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { + B::int_greater_elem(lhs, rhs) + } + + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + B::int_greater_equal(lhs, rhs) + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: B::IntElem, + ) -> BoolTensor { + B::int_greater_equal_elem(lhs, rhs) + } + + fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + B::int_lower(lhs, rhs) + } + + fn int_lower_elem(lhs: IntTensor, rhs: B::IntElem) -> BoolTensor { + B::int_lower_elem(lhs, rhs) + } + + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + B::int_lower_equal(lhs, rhs) + } + + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: B::IntElem, + ) -> BoolTensor { + B::int_lower_equal_elem(lhs, rhs) + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + B::int_gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + B::int_scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + B::int_select(tensor, dim, indices) + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + B::int_select_assign(tensor, dim, indices, value) + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + value: IntTensor, + ) -> as Backend>::IntTensorPrimitive { + B::int_mask_where(tensor, mask, value) + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: B::IntElem, + ) -> as Backend>::IntTensorPrimitive { + B::int_mask_fill(tensor, mask, value) + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_argmax(tensor, dim) + } + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_argmin(tensor, dim) + } + fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { + B::int_max(tensor) + } + fn int_max_dim( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> B::IntTensorPrimitive { + B::int_max_dim(tensor, dim) + } + fn int_max_dim_with_indices( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { + B::int_max_dim_with_indices(tensor, dim) + } + fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { + B::int_min(tensor) + } + fn int_min_dim( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> B::IntTensorPrimitive { + B::int_min_dim(tensor, dim) + } + fn int_min_dim_with_indices( + tensor: B::IntTensorPrimitive, + dim: usize, + ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { + B::int_min_dim_with_indices(tensor, dim) + } + fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive { + B::int_abs(tensor) + } + fn int_into_float( + tensor: as Backend>::IntTensorPrimitive, + ) -> as Backend>::TensorPrimitive { + AutodiffTensor::new(B::int_into_float(tensor)) + } + + fn int_swap_dims( + tensor: as Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::IntTensorPrimitive { + B::int_swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-autodiff/src/ops/maxmin.rs b/burn-autodiff/src/ops/maxmin.rs index 4e788e3681..3371c03eb2 100644 --- a/burn-autodiff/src/ops/maxmin.rs +++ b/burn-autodiff/src/ops/maxmin.rs @@ -6,15 +6,15 @@ use burn_tensor::{backend::Backend, Shape}; pub(crate) struct MaxMinDim; impl Backward for MaxMinDim { - type State = (B::IntTensorPrimitive, Shape); + type State = (B::IntTensorPrimitive, Shape); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let (indices, shape) = ops.state; - let device = B::device(&grad); - let zeros = B::zeros(shape, &device); + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let (indices, shape) = ops.state; + let device = B::device(&grad); + let zeros = B::zeros(shape, &device); - B::scatter(D - 1, zeros, indices, grad) - }); - } + B::scatter(D - 1, zeros, indices, grad) + }); + } } diff --git a/burn-autodiff/src/ops/module.rs b/burn-autodiff/src/ops/module.rs index 7a4d3bd56e..9da1440f3f 100644 --- a/burn-autodiff/src/ops/module.rs +++ b/burn-autodiff/src/ops/module.rs @@ -9,899 +9,948 @@ use burn_tensor::ops::*; use super::OpsKind; impl ModuleOps> for Autodiff { - fn embedding(weights: AutodiffTensor, indices: IntTensor) -> AutodiffTensor { - #[derive(Debug)] - struct Embedding; + fn embedding(weights: AutodiffTensor, indices: IntTensor) -> AutodiffTensor { + #[derive(Debug)] + struct Embedding; - impl Backward for Embedding { - type State = (B::TensorPrimitive<2>, IntTensor); + impl Backward for Embedding { + type State = (B::TensorPrimitive<2>, IntTensor); - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (weights, indices) = ops.state; + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (weights, indices) = ops.state; - unary::(ops.parents, ops.node, grads, |grad| { - B::embedding_backward(weights, grad, indices) - }); - } - } - - match Embedding - .prepare([weights.node], [weights.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - (weights.primitive.clone(), indices.clone()), - B::embedding(weights.primitive, indices), - ), - OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)), - } - } - - fn embedding_backward( - _weights: AutodiffTensor, - _output: AutodiffTensor, - _indices: IntTensor, - ) -> AutodiffTensor { - panic!("Can't differentiate embedding backward."); - } - - fn conv2d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct Conv2DWithBias; - #[derive(Debug)] - struct Conv2DNoBias; - - impl Backward for Conv2DWithBias { - type State = ( - B::TensorPrimitive<4>, - B::TensorPrimitive<4>, - B::TensorPrimitive<1>, - ConvOptions<2>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, bias, options) = ops.state; - let backward = B::conv2d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) + unary::(ops.parents, ops.node, grads, |grad| { + B::embedding_backward(weights, grad, indices) + }); + } } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) - } - } - } - - impl Backward for Conv2DNoBias { - type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, ConvOptions<2>); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv2d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - } - } - match bias { - Some(bias) => { - match Conv2DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() + match Embedding + .prepare([weights.node], [weights.graph]) + .stateful() { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), + OpsKind::Tracked(prep) => prep.finish( + (weights.primitive.clone(), indices.clone()), + B::embedding(weights.primitive, indices), ), - B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv2d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), + OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)), } - } - None => { - match Conv2DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv2d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::conv2d(x.primitive, weight.primitive, None, options)) - } - } - } } - } - - fn conv_transpose2d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct ConvTranspose2DWithBias; - #[derive(Debug)] - struct ConvTranspose2DNoBias; - - impl Backward for ConvTranspose2DWithBias { - type State = ( - B::TensorPrimitive<4>, - B::TensorPrimitive<4>, - B::TensorPrimitive<1>, - ConvTransposeOptions<2>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, weight, bias, options) = ops.state; - let backward = B::conv_transpose2d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) - } - } + fn embedding_backward( + _weights: AutodiffTensor, + _output: AutodiffTensor, + _indices: IntTensor, + ) -> AutodiffTensor { + panic!("Can't differentiate embedding backward."); } - impl Backward for ConvTranspose2DNoBias { - type State = ( - B::TensorPrimitive<4>, - B::TensorPrimitive<4>, - ConvTransposeOptions<2>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv_transpose2d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) + fn conv2d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct Conv2DWithBias; + #[derive(Debug)] + struct Conv2DNoBias; + + impl Backward for Conv2DWithBias { + type State = ( + B::TensorPrimitive<4>, + B::TensorPrimitive<4>, + B::TensorPrimitive<1>, + ConvOptions<2>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, bias, options) = ops.state; + let backward = B::conv2d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) + } + } + } + + impl Backward for Conv2DNoBias { + type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, ConvOptions<2>); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv2d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + } + } + + match bias { + Some(bias) => { + match Conv2DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), + ), + B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv2d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), + } + } + None => { + match Conv2DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv2d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::conv2d(x.primitive, weight.primitive, None, options)) + } + } + } } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - } } - match bias { - Some(bias) => { - match ConvTranspose2DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), - ), - B::conv_transpose2d(x.primitive, weight.primitive, Some(bias.primitive), options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), - } - } - None => { - match ConvTranspose2DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv_transpose2d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( - x.primitive, - weight.primitive, - None, - options, - )), + fn conv_transpose2d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct ConvTranspose2DWithBias; + #[derive(Debug)] + struct ConvTranspose2DNoBias; + + impl Backward for ConvTranspose2DWithBias { + type State = ( + B::TensorPrimitive<4>, + B::TensorPrimitive<4>, + B::TensorPrimitive<1>, + ConvTransposeOptions<2>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, bias, options) = ops.state; + let backward = B::conv_transpose2d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) + } + } + } + + impl Backward for ConvTranspose2DNoBias { + type State = ( + B::TensorPrimitive<4>, + B::TensorPrimitive<4>, + ConvTransposeOptions<2>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv_transpose2d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + } + } + + match bias { + Some(bias) => { + match ConvTranspose2DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), + ), + B::conv_transpose2d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + ), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), + } + } + None => { + match ConvTranspose2DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv_transpose2d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( + x.primitive, + weight.primitive, + None, + options, + )), + } + } } - } } - } - - fn conv1d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct Conv1DWithBias; - #[derive(Debug)] - struct Conv1DNoBias; - - impl Backward for Conv1DWithBias { - type State = ( - B::TensorPrimitive<3>, - B::TensorPrimitive<3>, - B::TensorPrimitive<1>, - ConvOptions<1>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, weight, bias, options) = ops.state; - let backward = B::conv1d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) + fn conv1d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct Conv1DWithBias; + #[derive(Debug)] + struct Conv1DNoBias; + + impl Backward for Conv1DWithBias { + type State = ( + B::TensorPrimitive<3>, + B::TensorPrimitive<3>, + B::TensorPrimitive<1>, + ConvOptions<1>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, bias, options) = ops.state; + let backward = B::conv1d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) + } + } + } + + impl Backward for Conv1DNoBias { + type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, ConvOptions<1>); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv1d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + } + } + match bias { + Some(bias) => { + match Conv1DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), + ), + B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv1d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), + } + } + None => { + match Conv1DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv1d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::conv1d(x.primitive, weight.primitive, None, options)) + } + } + } } - } } - impl Backward for Conv1DNoBias { - type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, ConvOptions<1>); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv1d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - } - } - match bias { - Some(bias) => { - match Conv1DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), - ), - B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv1d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), - } - } - None => { - match Conv1DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv1d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::conv1d(x.primitive, weight.primitive, None, options)) - } + fn conv_transpose1d( + x: AutodiffTensor, + weight: AutodiffTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct ConvTranspose1DWithBias; + #[derive(Debug)] + struct ConvTranspose1DNoBias; + + impl Backward for ConvTranspose1DWithBias { + type State = ( + B::TensorPrimitive<3>, + B::TensorPrimitive<3>, + B::TensorPrimitive<1>, + ConvTransposeOptions<1>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, bias, options) = ops.state; + let backward = B::conv_transpose1d_backward(x, weight, Some(bias), grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + if let Some(node) = node_bias { + grads.register::(node, backward.bias_grad.unwrap()) + } + } + } + + impl Backward for ConvTranspose1DNoBias { + type State = ( + B::TensorPrimitive<3>, + B::TensorPrimitive<3>, + ConvTransposeOptions<1>, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_x, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x, weight, options) = ops.state; + let backward = B::conv_transpose1d_backward(x, weight, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node, backward.x_grad) + } + if let Some(node) = node_weight { + grads.register::(node, backward.weights_grad) + } + } + } + + match bias { + Some(bias) => { + match ConvTranspose1DWithBias + .prepare( + [x.node, weight.node, bias.node], + [x.graph, weight.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + bias.primitive.clone(), + options.clone(), + ), + B::conv_transpose1d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + ), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( + x.primitive, + weight.primitive, + Some(bias.primitive), + options, + )), + } + } + None => { + match ConvTranspose1DNoBias + .prepare([x.node, weight.node], [x.graph, weight.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + x.primitive.clone(), + weight.primitive.clone(), + options.clone(), + ), + B::conv_transpose1d(x.primitive, weight.primitive, None, options), + ), + OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( + x.primitive, + weight.primitive, + None, + options, + )), + } + } } - } } - } - - fn conv_transpose1d( - x: AutodiffTensor, - weight: AutodiffTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> AutodiffTensor { - #[derive(Debug)] - struct ConvTranspose1DWithBias; - #[derive(Debug)] - struct ConvTranspose1DNoBias; - - impl Backward for ConvTranspose1DWithBias { - type State = ( - B::TensorPrimitive<3>, - B::TensorPrimitive<3>, - B::TensorPrimitive<1>, - ConvTransposeOptions<1>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight, node_bias] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, weight, bias, options) = ops.state; - let backward = B::conv_transpose1d_backward(x, weight, Some(bias), grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) - } - if let Some(node) = node_bias { - grads.register::(node, backward.bias_grad.unwrap()) + // TODO: Support a custom unfold4d operation by overriding the default implementation. + // + // We don't override it now because the fold operation isn't available for the backward pass. + // This implies that when autodiff is enabled, custom unfold operations defined by backends + // won't be used. Instead, the conv2d operation with custom weights matrix will be used. + // Therefore, the conv2d backward pass will be used for the unfold4d backward pass. + // + // fn unfold4d( + // x: AutodiffTensor, + // kernel_size: [usize; 2], + // options: UnfoldOptions, + // ) -> AutodiffTensor { + // todo!() + // } + + fn avg_pool1d( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> AutodiffTensor { + #[derive(Debug)] + struct AvgPool1D; + + impl Backward for AvgPool1D { + type State = (B::TensorPrimitive<3>, usize, usize, usize, bool); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, kernel_size, stride, padding, count_include_pad) = ops.state; + + if let Some(node) = node_parent { + let grad = B::avg_pool1d_backward( + x, + grad, + kernel_size, + stride, + padding, + count_include_pad, + ); + grads.register::(node, grad); + } + } + } + + match AvgPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::avg_pool1d( + x.primitive.clone(), + kernel_size, + stride, + padding, + count_include_pad, + ); + prep.finish( + (x.primitive, kernel_size, stride, padding, count_include_pad), + output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )), } - } } - impl Backward for ConvTranspose1DNoBias { - type State = ( - B::TensorPrimitive<3>, - B::TensorPrimitive<3>, - ConvTransposeOptions<1>, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_x, node_weight] = ops.parents; - let grad = grads.consume::(&ops.node); - - let (x, weight, options) = ops.state; - let backward = B::conv_transpose1d_backward(x, weight, None, grad, options); - - if let Some(node) = node_x { - grads.register::(node, backward.x_grad) - } - if let Some(node) = node_weight { - grads.register::(node, backward.weights_grad) + fn avg_pool2d( + x: AutodiffTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> AutodiffTensor { + #[derive(Debug)] + struct AvgPool2D; + + impl Backward for AvgPool2D { + type State = ( + B::TensorPrimitive<4>, + [usize; 2], + [usize; 2], + [usize; 2], + bool, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, kernel_size, stride, padding, count_include_pad) = ops.state; + + if let Some(node) = node_parent { + let grad = B::avg_pool2d_backward( + x, + grad, + kernel_size, + stride, + padding, + count_include_pad, + ); + grads.register::(node, grad); + } + } + } + + match AvgPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::avg_pool2d( + x.primitive.clone(), + kernel_size, + stride, + padding, + count_include_pad, + ); + prep.finish( + (x.primitive, kernel_size, stride, padding, count_include_pad), + output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )), } - } } - match bias { - Some(bias) => { - match ConvTranspose1DWithBias - .prepare( - [x.node, weight.node, bias.node], - [x.graph, weight.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - bias.primitive.clone(), - options.clone(), - ), - B::conv_transpose1d(x.primitive, weight.primitive, Some(bias.primitive), options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( - x.primitive, - weight.primitive, - Some(bias.primitive), - options, - )), - } - } - None => { - match ConvTranspose1DNoBias - .prepare([x.node, weight.node], [x.graph, weight.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - x.primitive.clone(), - weight.primitive.clone(), - options.clone(), - ), - B::conv_transpose1d(x.primitive, weight.primitive, None, options), - ), - OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( - x.primitive, - weight.primitive, - None, - options, - )), - } - } + fn avg_pool2d_backward( + _x: AutodiffTensor, + _grad: AutodiffTensor, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _count_include_pad: bool, + ) -> AutodiffTensor { + panic!("Can't differentiate avg pool 2d backward."); } - } - - // TODO: Support a custom unfold4d operation by overriding the default implementation. - // - // We don't override it now because the fold operation isn't available for the backward pass. - // This implies that when autodiff is enabled, custom unfold operations defined by backends - // won't be used. Instead, the conv2d operation with custom weights matrix will be used. - // Therefore, the conv2d backward pass will be used for the unfold4d backward pass. - // - // fn unfold4d( - // x: AutodiffTensor, - // kernel_size: [usize; 2], - // options: UnfoldOptions, - // ) -> AutodiffTensor { - // todo!() - // } - - fn avg_pool1d( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> AutodiffTensor { - #[derive(Debug)] - struct AvgPool1D; - - impl Backward for AvgPool1D { - type State = (B::TensorPrimitive<3>, usize, usize, usize, bool); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, kernel_size, stride, padding, count_include_pad) = ops.state; - if let Some(node) = node_parent { - let grad = - B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad); - grads.register::(node, grad); + fn max_pool1d( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> AutodiffTensor { + match MaxPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::max_pool1d_with_indices( + x.primitive.clone(), + kernel_size, + stride, + padding, + dilation, + ); + prep.finish( + ( + x.primitive, + output.indices, + kernel_size, + stride, + padding, + dilation, + ), + output.output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )), } - } } - match AvgPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::avg_pool1d( - x.primitive.clone(), - kernel_size, - stride, - padding, - count_include_pad, - ); - prep.finish( - (x.primitive, kernel_size, stride, padding, count_include_pad), - output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )), - } - } - - fn avg_pool2d( - x: AutodiffTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> AutodiffTensor { - #[derive(Debug)] - struct AvgPool2D; - - impl Backward for AvgPool2D { - type State = ( - B::TensorPrimitive<4>, - [usize; 2], - [usize; 2], - [usize; 2], - bool, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, kernel_size, stride, padding, count_include_pad) = ops.state; - - if let Some(node) = node_parent { - let grad = - B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad); - grads.register::(node, grad); + fn max_pool1d_with_indices( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices> { + match MaxPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::max_pool1d_with_indices( + x.primitive.clone(), + kernel_size, + stride, + padding, + dilation, + ); + + let output_tensor = prep.finish( + ( + x.primitive, + output.indices.clone(), + kernel_size, + stride, + padding, + dilation, + ), + output.output, + ); + + MaxPool1dWithIndices::new(output_tensor, output.indices) + } + OpsKind::UnTracked(prep) => { + let output = + B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output_tensor = prep.finish(output.output); + + MaxPool1dWithIndices::new(output_tensor, output.indices) + } } - } } - match AvgPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::avg_pool2d( - x.primitive.clone(), - kernel_size, - stride, - padding, - count_include_pad, - ); - prep.finish( - (x.primitive, kernel_size, stride, padding, count_include_pad), - output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )), - } - } - - fn avg_pool2d_backward( - _x: AutodiffTensor, - _grad: AutodiffTensor, - _kernel_size: [usize; 2], - _stride: [usize; 2], - _padding: [usize; 2], - _count_include_pad: bool, - ) -> AutodiffTensor { - panic!("Can't differentiate avg pool 2d backward."); - } - - fn max_pool1d( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> AutodiffTensor { - match MaxPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = - B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); - prep.finish( - ( - x.primitive, - output.indices, - kernel_size, - stride, - padding, - dilation, - ), - output.output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )), - } - } - - fn max_pool1d_with_indices( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices> { - match MaxPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = - B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); - - let output_tensor = prep.finish( - ( + fn max_pool1d_with_indices_backward( + x: AutodiffTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: AutodiffTensor, + indices: IntTensor, + ) -> MaxPool1dBackward> { + let output = B::max_pool1d_with_indices_backward( x.primitive, - output.indices.clone(), kernel_size, stride, padding, dilation, - ), - output.output, + output_grad.primitive, + indices, ); + MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad)) + } - MaxPool1dWithIndices::new(output_tensor, output.indices) - } - OpsKind::UnTracked(prep) => { - let output = - B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - let output_tensor = prep.finish(output.output); + fn max_pool2d( + x: AutodiffTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> AutodiffTensor { + match MaxPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::max_pool2d_with_indices( + x.primitive.clone(), + kernel_size, + stride, + padding, + dilation, + ); + prep.finish( + ( + x.primitive, + output.indices, + kernel_size, + stride, + padding, + dilation, + ), + output.output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )), + } + } - MaxPool1dWithIndices::new(output_tensor, output.indices) - } + fn max_pool2d_with_indices( + x: AutodiffTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + match MaxPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::max_pool2d_with_indices( + x.primitive.clone(), + kernel_size, + stride, + padding, + dilation, + ); + + let output_tensor = prep.finish( + ( + x.primitive, + output.indices.clone(), + kernel_size, + stride, + padding, + dilation, + ), + output.output, + ); + + MaxPool2dWithIndices::new(output_tensor, output.indices) + } + OpsKind::UnTracked(prep) => { + let output = + B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output_tensor = prep.finish(output.output); + + MaxPool2dWithIndices::new(output_tensor, output.indices) + } + } } - } - - fn max_pool1d_with_indices_backward( - x: AutodiffTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: AutodiffTensor, - indices: IntTensor, - ) -> MaxPool1dBackward> { - let output = B::max_pool1d_with_indices_backward( - x.primitive, - kernel_size, - stride, - padding, - dilation, - output_grad.primitive, - indices, - ); - MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad)) - } - - fn max_pool2d( - x: AutodiffTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> AutodiffTensor { - match MaxPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = - B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); - prep.finish( - ( - x.primitive, - output.indices, - kernel_size, - stride, - padding, - dilation, - ), - output.output, - ) - } - OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )), + + fn max_pool2d_with_indices_backward( + _x: AutodiffTensor, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _dilation: [usize; 2], + _output_grad: AutodiffTensor, + _indices: IntTensor, + ) -> MaxPool2dBackward> { + panic!("Can't differentiate max pool2d with indices backward."); } - } - - fn max_pool2d_with_indices( - x: AutodiffTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - match MaxPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = - B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding, dilation); - - let output_tensor = prep.finish( - ( - x.primitive, - output.indices.clone(), - kernel_size, - stride, - padding, - dilation, - ), - output.output, - ); + fn adaptive_avg_pool1d(x: AutodiffTensor, output_size: usize) -> AutodiffTensor { + #[derive(Debug)] + struct AdaptiveAvgPool1D; - MaxPool2dWithIndices::new(output_tensor, output.indices) - } - OpsKind::UnTracked(prep) => { - let output = - B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - let output_tensor = prep.finish(output.output); + impl Backward for AdaptiveAvgPool1D { + type State = B::TensorPrimitive<3>; - MaxPool2dWithIndices::new(output_tensor, output.indices) - } - } - } - - fn max_pool2d_with_indices_backward( - _x: AutodiffTensor, - _kernel_size: [usize; 2], - _stride: [usize; 2], - _padding: [usize; 2], - _dilation: [usize; 2], - _output_grad: AutodiffTensor, - _indices: IntTensor, - ) -> MaxPool2dBackward> { - panic!("Can't differentiate max pool2d with indices backward."); - } - fn adaptive_avg_pool1d(x: AutodiffTensor, output_size: usize) -> AutodiffTensor { - #[derive(Debug)] - struct AdaptiveAvgPool1D; - - impl Backward for AdaptiveAvgPool1D { - type State = B::TensorPrimitive<3>; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); - if let Some(node) = node_parent { - let grad = B::adaptive_avg_pool1d_backward(ops.state, grad); - grads.register::(node, grad); + if let Some(node) = node_parent { + let grad = B::adaptive_avg_pool1d_backward(ops.state, grad); + grads.register::(node, grad); + } + } } - } - } - match AdaptiveAvgPool1D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - x.primitive.clone(), - B::adaptive_avg_pool1d(x.primitive, output_size), - ), - OpsKind::UnTracked(prep) => prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size)), + match AdaptiveAvgPool1D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + x.primitive.clone(), + B::adaptive_avg_pool1d(x.primitive, output_size), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size)) + } + } } - } - fn adaptive_avg_pool2d(x: AutodiffTensor, output_size: [usize; 2]) -> AutodiffTensor { - #[derive(Debug)] - struct AdaptiveAvgPool2D; + fn adaptive_avg_pool2d( + x: AutodiffTensor, + output_size: [usize; 2], + ) -> AutodiffTensor { + #[derive(Debug)] + struct AdaptiveAvgPool2D; - impl Backward for AdaptiveAvgPool2D { - type State = B::TensorPrimitive<4>; + impl Backward for AdaptiveAvgPool2D { + type State = B::TensorPrimitive<4>; - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); - if let Some(node) = node_parent { - let grad = B::adaptive_avg_pool2d_backward(ops.state, grad); - grads.register::(node, grad); + if let Some(node) = node_parent { + let grad = B::adaptive_avg_pool2d_backward(ops.state, grad); + grads.register::(node, grad); + } + } + } + + match AdaptiveAvgPool2D.prepare([x.node], [x.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + x.primitive.clone(), + B::adaptive_avg_pool2d(x.primitive, output_size), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size)) + } } - } } - match AdaptiveAvgPool2D.prepare([x.node], [x.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - x.primitive.clone(), - B::adaptive_avg_pool2d(x.primitive, output_size), - ), - OpsKind::UnTracked(prep) => prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size)), + fn adaptive_avg_pool2d_backward( + _x: AutodiffTensor, + _grad: AutodiffTensor, + ) -> as Backend>::TensorPrimitive<4> { + panic!("Can't differentiate adaptive avg pool2d backward."); } - } - - fn adaptive_avg_pool2d_backward( - _x: AutodiffTensor, - _grad: AutodiffTensor, - ) -> as Backend>::TensorPrimitive<4> { - panic!("Can't differentiate adaptive avg pool2d backward."); - } } #[derive(Debug)] struct MaxPool1D; impl Backward for MaxPool1D { - type State = ( - B::TensorPrimitive<3>, - IntTensor, - usize, - usize, - usize, - usize, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, indices, kernel_size, stride, padding, dilation) = ops.state; - - if let Some(node) = node_parent { - let grad = B::max_pool1d_with_indices_backward( - x, - kernel_size, - stride, - padding, - dilation, - grad, - indices, - ); - - grads.register::(node, grad.x_grad); + type State = ( + B::TensorPrimitive<3>, + IntTensor, + usize, + usize, + usize, + usize, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, indices, kernel_size, stride, padding, dilation) = ops.state; + + if let Some(node) = node_parent { + let grad = B::max_pool1d_with_indices_backward( + x, + kernel_size, + stride, + padding, + dilation, + grad, + indices, + ); + + grads.register::(node, grad.x_grad); + } } - } } #[derive(Debug)] struct MaxPool2D; impl Backward for MaxPool2D { - type State = ( - B::TensorPrimitive<4>, - IntTensor, - [usize; 2], - [usize; 2], - [usize; 2], - [usize; 2], - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let [node_parent] = ops.parents; - let grad = grads.consume::(&ops.node); - let (x, indices, kernel_size, stride, padding, dilation) = ops.state; - - if let Some(node) = node_parent { - let grad = B::max_pool2d_with_indices_backward( - x, - kernel_size, - stride, - padding, - dilation, - grad, - indices, - ); - - grads.register::(node, grad.x_grad); + type State = ( + B::TensorPrimitive<4>, + IntTensor, + [usize; 2], + [usize; 2], + [usize; 2], + [usize; 2], + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, indices, kernel_size, stride, padding, dilation) = ops.state; + + if let Some(node) = node_parent { + let grad = B::max_pool2d_with_indices_backward( + x, + kernel_size, + stride, + padding, + dilation, + grad, + indices, + ); + + grads.register::(node, grad.x_grad); + } } - } } diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index d5ae02ecba..e4cd4755e7 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -1,1534 +1,1554 @@ use std::marker::PhantomData; use crate::{ - grads::Gradients, - graph::{NodeRef, Requirement, Step}, - ops::{binary, broadcast_shape, unary, unary_different_backend, Backward, Ops, OpsKind}, - tensor::AutodiffTensor, - utils::duplicate, - Autodiff, + grads::Gradients, + graph::{NodeRef, Requirement, Step}, + ops::{binary, broadcast_shape, unary, unary_different_backend, Backward, Ops, OpsKind}, + tensor::AutodiffTensor, + utils::duplicate, + Autodiff, }; use burn_tensor::{ - backend::Backend, - ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, - Data, Device, ElementConversion, Reader, Shape, Tensor, + backend::Backend, + ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, + Data, Device, ElementConversion, Reader, Shape, Tensor, }; use super::maxmin::MaxMinDim; impl TensorOps for Autodiff { - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor { - AutodiffTensor::new(B::from_data(data, device)) - } - - fn random( - shape: Shape, - distribution: burn_tensor::Distribution>, - device: &Device, - ) -> FloatTensor { - AutodiffTensor::new(B::random(shape, distribution, device)) - } - - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::zeros(shape), device) - } - - fn ones(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::ones(shape), device) - } - - fn shape(tensor: &FloatTensor) -> Shape { - B::shape(&tensor.primitive) - } - - fn to_data(tensor: &FloatTensor) -> Reader, D>> { - B::to_data(&tensor.primitive) - } - - fn into_data(tensor: FloatTensor) -> Reader, D>> { - B::into_data(tensor.primitive) - } - - fn device(tensor: &FloatTensor) -> Device { - B::device(&tensor.primitive) - } - - fn to_device( - tensor: FloatTensor, - device: &Device, - ) -> FloatTensor { - #[derive(Debug)] - struct ToDevice; - - impl Backward for ToDevice { - type State = B::Device; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::to_device(grad, &ops.state) - }); - } - } - - match ToDevice.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let device_old = B::device(&tensor.primitive); - prep.finish(device_old, B::to_device(tensor.primitive, device)) - } - OpsKind::UnTracked(prep) => prep.finish(B::to_device(tensor.primitive, device)), - } - } - - fn arange(range: std::ops::Range, device: &Device) -> IntTensor { - B::arange(range, device) - } - - fn empty(shape: Shape, device: &Device) -> FloatTensor { - AutodiffTensor::new(B::empty(shape, device)) - } - - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Add; - - impl Backward for Add { - type State = (Shape, Shape); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape_lhs, shape_rhs) = ops.state; - - binary::( - ops.parents, - ops.node, - grads, - |grad| broadcast_shape::(grad, &shape_lhs), - |grad| broadcast_shape::(grad, &shape_rhs), - ); - } - } - - match Add - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(preps) => preps.finish( - (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), - B::add(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(preps) => preps.finish(B::add(lhs.primitive, rhs.primitive)), - } - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct AddScalar; - - impl Backward for AddScalar { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| grad); - } - } - - AddScalar - .prepare([lhs.node], [lhs.graph]) - .stateless(B::add_scalar(lhs.primitive, rhs)) - } - - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Sub; - - impl Backward for Sub { - type State = (Shape, Shape); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape_lhs, shape_rhs) = ops.state; - - binary::( - ops.parents, - ops.node, - grads, - |grad| broadcast_shape::(grad, &shape_lhs), - |grad| broadcast_shape::(B::neg(grad), &shape_rhs), - ); - } - } - - match Sub - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(preps) => preps.finish( - (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), - B::sub(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(preps) => preps.finish(B::sub(lhs.primitive, rhs.primitive)), - } - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct SubScalar; - - impl Backward for SubScalar { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| grad); - } - } - - SubScalar - .prepare([lhs.node], [lhs.graph]) - .stateless(B::sub_scalar(lhs.primitive, rhs)) - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Mul; - - impl Backward for Mul { - type State = ( - Option>, - Option>, - BinaryOpsBroadcast, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (lhs, rhs, broadcast) = ops.state; - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let grad = B::mul(grad, rhs.unwrap()); - broadcast.backward_lhs::(grad) - }, - |grad| { - let grad = B::mul(grad, lhs.unwrap()); - broadcast.backward_rhs::(grad) - }, - ); - } - } - - let lhs_tracked = lhs.is_tracked(); - let rhs_tracked = rhs.is_tracked(); - let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); - - match Mul - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - rhs_tracked.then(|| lhs.primitive.clone()), - lhs_tracked.then(|| rhs.primitive.clone()), - broadcast, - ), - B::mul(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::mul(lhs.primitive, rhs.primitive)), - } - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct MulScalar; - - impl Backward for MulScalar { - type State = FloatElem; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::mul_scalar(grad, ops.state) - }); - } - } - - match MulScalar.prepare([lhs.node], [lhs.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(rhs, B::mul_scalar(lhs.primitive, rhs)), - OpsKind::UnTracked(prep) => prep.finish(B::mul_scalar(lhs.primitive, rhs)), - } - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Div; - - impl Backward for Div { - type State = ( - Option>, - Option>, - BinaryOpsBroadcast, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (lhs, rhs, broadcast) = ops.state; - let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, rhs); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let rhs = rhs_4lhs.unwrap(); - let value = B::powf(rhs, -1.0); - let grad = B::mul(grad, value); - - broadcast.backward_lhs::(grad) - }, - |grad| { - let rhs = rhs_4rhs.unwrap(); - let lhs = lhs.unwrap(); - let value = B::div(B::neg(lhs), B::powf(rhs, 2.0)); - let grad = B::mul(grad, value); - - broadcast.backward_rhs::(grad) - }, - ); - } - } - - let lhs_tracked = lhs.is_tracked(); - let rhs_tracked = rhs.is_tracked(); - let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); - - match Div - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - rhs_tracked.then(|| lhs.primitive.clone()), - (lhs_tracked || rhs_tracked).then(|| rhs.primitive.clone()), - broadcast, - ), - B::div(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::div(lhs.primitive, rhs.primitive)), - } - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct DivScalar; - - impl Backward for DivScalar { - type State = FloatElem; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let tmp = 1.0 / ops.state.elem::(); - B::mul_scalar(grad, tmp.elem()) - }); - } - } - - match DivScalar.prepare([lhs.node], [lhs.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(rhs, B::div_scalar(lhs.primitive, rhs)), - OpsKind::UnTracked(prep) => prep.finish(B::div_scalar(lhs.primitive, rhs)), - } - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Matmul; - - impl Backward for Matmul { - type State = ( - Option>, - Option>, - BinaryOpsBroadcast, - ); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (lhs, rhs, broadcast) = ops.state; - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let rhs = B::transpose(rhs.unwrap()); - let grad = B::matmul(grad, rhs); - - broadcast.backward_lhs::(grad) - }, - |grad| { - let lhs = B::transpose(lhs.unwrap()); - let grad = B::matmul(lhs, grad); - - broadcast.backward_rhs::(grad) - }, - ); - } - } - - let lhs_tracked = lhs.is_tracked(); - let rhs_tracked = rhs.is_tracked(); - let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); - - match Matmul - .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - rhs_tracked.then(|| lhs.primitive.clone()), - lhs_tracked.then(|| rhs.primitive.clone()), - broadcast, - ), - B::matmul(lhs.primitive, rhs.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::matmul(lhs.primitive, rhs.primitive)), - } - } - - fn neg(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Neg; - - impl Backward for Neg { - type State = (); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| B::neg(grad)); - } - } - - Neg - .prepare([tensor.node], [tensor.graph]) - .stateless(B::neg(tensor.primitive)) - } - - fn recip(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Recip; - - impl Backward for Recip { - type State = B::TensorPrimitive; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let tensor = ops.state; - unary::(ops.parents, ops.node, grads, |grad| { - let tmp = B::powf(tensor, -2.0); - let value = B::neg(tmp); - - B::mul(grad, value) - }); - } + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor { + AutodiffTensor::new(B::from_data(data, device)) } - match Recip.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)), + fn random( + shape: Shape, + distribution: burn_tensor::Distribution>, + device: &Device, + ) -> FloatTensor { + AutodiffTensor::new(B::random(shape, distribution, device)) } - } - - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - #[derive(Debug)] - struct SwapDim; - impl Backward for SwapDim { - type State = (usize, usize); + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::zeros(shape), device) + } - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim1, dim2) = ops.state; + fn ones(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::ones(shape), device) + } - unary::(ops.parents, ops.node, grads, |grad| { - B::swap_dims(grad, dim2, dim1) - }); - } + fn shape(tensor: &FloatTensor) -> Shape { + B::shape(&tensor.primitive) } - let output = B::swap_dims(tensor.primitive, dim1, dim2); + fn to_data(tensor: &FloatTensor) -> Reader, D>> { + B::to_data(&tensor.primitive) + } - match SwapDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish((dim1, dim2), output), - OpsKind::UnTracked(prep) => prep.finish(output), + fn into_data(tensor: FloatTensor) -> Reader, D>> { + B::into_data(tensor.primitive) } - } - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - #[derive(Debug)] - struct ReshapeDim; + fn device(tensor: &FloatTensor) -> Device { + B::device(&tensor.primitive) + } - impl Backward for ReshapeDim { - type State = (Shape, Shape); + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor { + #[derive(Debug)] + struct ToDevice; - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape_original, shape) = ops.state; + impl Backward for ToDevice { + type State = B::Device; - unary::(ops.parents, ops.node, grads, |grad| { - let shape_grad = B::shape(&grad); - let mut grad = grad; + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::to_device(grad, &ops.state) + }); + } + } - for i in 0..D2 { - if shape.dims[i] == 1 && shape_grad.dims[i] != 1 { - grad = B::sum_dim(grad, i); + match ToDevice.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let device_old = B::device(&tensor.primitive); + prep.finish(device_old, B::to_device(tensor.primitive, device)) } - } + OpsKind::UnTracked(prep) => prep.finish(B::to_device(tensor.primitive, device)), + } + } - B::reshape(grad, shape_original) - }); - } + fn arange(range: std::ops::Range, device: &Device) -> IntTensor { + B::arange(range, device) } - match ReshapeDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (B::shape(&tensor.primitive), shape.clone()), - B::reshape(tensor.primitive, shape), - ), - OpsKind::UnTracked(prep) => prep.finish(B::reshape(tensor.primitive, shape)), + fn empty(shape: Shape, device: &Device) -> FloatTensor { + AutodiffTensor::new(B::empty(shape, device)) } - } - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Gather; + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Add; - impl Backward for Gather { - type State = (usize, IntTensor, Shape, B::Device); + impl Backward for Add { + type State = (Shape, Shape); - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape, device) = ops.state; + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape_lhs, shape_rhs) = ops.state; - unary::(ops.parents, ops.node, grads, |grad| { - let zeros = B::zeros(shape, &device); - B::scatter(dim, zeros, indices, grad) - }); - } - } - - match Gather.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::device(&tensor.primitive), - ), - B::gather(dim, tensor.primitive, indices), - ), - OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indices)), - } - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct Scatter; - - impl Backward for Scatter { - type State = (usize, IntTensor, Shape, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; - let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_lhs, &device); - B::scatter(dim, grad, indices_4lhs.unwrap(), zeros) - }, - |grad| { - let zeros = B::zeros(shape_rhs, &device); - B::scatter(dim, zeros, indices_4rhs.unwrap(), grad) - }, - ); - } - } - - match Scatter - .prepare([tensor.node, value.node], [tensor.graph, value.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::shape(&value.primitive), - B::device(&value.primitive), - ), - B::scatter(dim, tensor.primitive, indices, value.primitive), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::scatter(dim, tensor.primitive, indices, value.primitive)) - } - } - } - - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct IndexSelectDim; - - impl Backward for IndexSelectDim { - type State = (usize, IntTensor, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape, device) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - let zeros = B::zeros(shape, &device); - B::select_assign(zeros, dim, indices, grad) - }); - } - } - - match IndexSelectDim - .prepare([tensor.node], [tensor.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::device(&tensor.primitive), - ), - B::select(tensor.primitive, dim, indices), - ), - OpsKind::UnTracked(prep) => prep.finish(B::select(tensor.primitive, dim, indices)), - } - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct IndexSelectDimAssign; - - impl Backward for IndexSelectDimAssign { - type State = (usize, IntTensor, Shape, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; - let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_lhs, &device); - B::select_assign(grad, dim, indices_4lhs.unwrap(), zeros) - }, - |grad| { - let zeros = B::zeros(shape_rhs, &device); - B::select_assign(zeros, dim, indices_4rhs.unwrap(), grad) - }, - ); - } - } - - match IndexSelectDimAssign:: - .prepare([tensor.node, value.node], [tensor.graph, value.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - dim, - indices.clone(), - B::shape(&tensor.primitive), - B::shape(&value.primitive), - B::device(&value.primitive), - ), - B::select_assign(tensor.primitive, dim, indices, value.primitive), - ), - OpsKind::UnTracked(prep) => prep.finish(B::select_assign( - tensor.primitive, - dim, - indices, - value.primitive, - )), - } - } - - fn slice( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - ) -> FloatTensor { - #[derive(Debug)] - struct Index; - - impl Backward for Index { - type State = ([std::ops::Range; D2], Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (ranges, shape, device) = ops.state; - - unary::(ops.parents, ops.node, grads, |grad| { - let zeros = B::zeros(shape, &device); - B::slice_assign(zeros, ranges, grad) - }); - } - } - - match Index.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - ( - ranges.clone(), - B::shape(&tensor.primitive), - B::device(&tensor.primitive), - ), - B::slice(tensor.primitive, ranges), - ), - OpsKind::UnTracked(prep) => prep.finish(B::slice(tensor.primitive, ranges)), - } - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - value: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct IndexAssign; - - impl Backward for IndexAssign { - type State = ([std::ops::Range; D2], Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (ranges, shape_rhs, device) = ops.state; - let [ranges_4lhs, ranges_4rhs] = duplicate(&ops.parents, Some(ranges)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_rhs, &device); - B::slice_assign(grad, ranges_4lhs.unwrap(), zeros) - }, - |grad| B::slice(grad, ranges_4rhs.unwrap()), - ); - } - } - - match IndexAssign - .prepare([tensor.node, value.node], [tensor.graph, value.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - ranges.clone(), - B::shape(&value.primitive), - B::device(&value.primitive), - ), - B::slice_assign(tensor.primitive, ranges, value.primitive), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::slice_assign(tensor.primitive, ranges, value.primitive)) - } - } - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - source: FloatTensor, - ) -> FloatTensor { - #[derive(Debug)] - struct MaskWhere; - - impl Backward for MaskWhere { - type State = (BoolTensor, Shape, Shape, B::Device); - - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (mask, shape_lhs, shape_rhs, device) = ops.state; - let [mask_4lhs, mask_4rhs] = duplicate(&ops.parents, Some(mask)); - - binary::( - ops.parents, - ops.node, - grads, - |grad| { - let zeros = B::zeros(shape_lhs.clone(), &device); - let grad = B::mask_where(grad, mask_4lhs.unwrap(), zeros); - - broadcast_shape::(grad, &shape_lhs) - }, - |grad| { - let zeros = B::zeros(shape_rhs.clone(), &device); - let grad = B::mask_where(zeros, mask_4rhs.unwrap(), grad); - - broadcast_shape::(grad, &shape_rhs) - }, - ); - } - } - - match MaskWhere - .prepare([tensor.node, source.node], [tensor.graph, source.graph]) - .stateful() - { - OpsKind::Tracked(prep) => prep.finish( - ( - mask.clone(), - B::shape(&tensor.primitive), - B::shape(&source.primitive), - B::device(&source.primitive), - ), - B::mask_where(tensor.primitive, mask, source.primitive), - ), - OpsKind::UnTracked(prep) => { - prep.finish(B::mask_where(tensor.primitive, mask, source.primitive)) - } - } - } - - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - #[derive(Debug)] - struct MaskFill; - - impl Backward for MaskFill { - type State = BoolTensor; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - B::mask_fill(grad, ops.state, 0.elem()) - }); - } - } - - match MaskFill.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - prep.finish(mask.clone(), B::mask_fill(tensor.primitive, mask, value)) - } - OpsKind::UnTracked(prep) => prep.finish(B::mask_fill(tensor.primitive, mask, value)), - } - } - - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::equal(lhs.primitive, rhs.primitive) - } - - fn equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { - B::equal_elem(lhs.primitive, rhs) - } - - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::greater(lhs.primitive, rhs.primitive) - } - - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::greater_elem(lhs.primitive, rhs) - } - - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::greater_equal(lhs.primitive, rhs.primitive) - } - - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::greater_equal_elem(lhs.primitive, rhs) - } - - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::lower(lhs.primitive, rhs.primitive) - } - - fn lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { - B::lower_elem(lhs.primitive, rhs) - } - - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - B::lower_equal(lhs.primitive, rhs.primitive) - } - - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - B::lower_equal_elem(lhs.primitive, rhs) - } - - fn detach(tensor: FloatTensor) -> FloatTensor { - // When we detach a tensor, we remove it from the graph, but we still want to keep the - // `require_grad` setting. - let is_require_grad = Self::is_require_grad(&tensor); - let tensor = AutodiffTensor::new(tensor.primitive); - - match is_require_grad { - true => tensor.require_grad(), - false => tensor, - } - } - - fn set_require_grad( - tensor: FloatTensor, - require_grad: bool, - ) -> FloatTensor { - if require_grad { - return tensor.require_grad(); - } - - AutodiffTensor::new(tensor.primitive) - } - - fn is_require_grad(tensor: &FloatTensor) -> bool { - matches!(tensor.node.requirement, Requirement::Grad) - } - - fn mean(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Mean; - - impl Backward for Mean { - type State = Shape; - - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let shape = ops.state; - let val = 1_f64 / shape.num_elements() as f64; - let ones = B::ones(shape, &B::device(&grad)); - let val = B::mul_scalar(ones, val.elem()); - - let grad: Tensor = Tensor::from_primitive(grad); - let val: Tensor = Tensor::from_primitive(val); - - val.mul(grad.unsqueeze()).into_primitive() - }); - } - } + binary::( + ops.parents, + ops.node, + grads, + |grad| broadcast_shape::(grad, &shape_lhs), + |grad| broadcast_shape::(grad, &shape_rhs), + ); + } + } + + match Add + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(preps) => preps.finish( + (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), + B::add(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(preps) => preps.finish(B::add(lhs.primitive, rhs.primitive)), + } + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct AddScalar; + + impl Backward for AddScalar { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| grad); + } + } - match Mean.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(B::shape(&tensor.primitive), B::mean(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::mean(tensor.primitive)), + AddScalar + .prepare([lhs.node], [lhs.graph]) + .stateless(B::add_scalar(lhs.primitive, rhs)) } - } - fn sum(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Sum; + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Sub; - impl Backward for Sum { - type State = Shape; + impl Backward for Sub { + type State = (Shape, Shape); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let val = B::ones(ops.state, &B::device(&grad)); + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape_lhs, shape_rhs) = ops.state; - let grad: Tensor = Tensor::from_primitive(grad); - let val: Tensor = Tensor::from_primitive(val); + binary::( + ops.parents, + ops.node, + grads, + |grad| broadcast_shape::(grad, &shape_lhs), + |grad| broadcast_shape::(B::neg(grad), &shape_rhs), + ); + } + } + + match Sub + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(preps) => preps.finish( + (B::shape(&lhs.primitive), B::shape(&rhs.primitive)), + B::sub(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(preps) => preps.finish(B::sub(lhs.primitive, rhs.primitive)), + } + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct SubScalar; + + impl Backward for SubScalar { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| grad); + } + } + + SubScalar + .prepare([lhs.node], [lhs.graph]) + .stateless(B::sub_scalar(lhs.primitive, rhs)) + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Mul; + + impl Backward for Mul { + type State = ( + Option>, + Option>, + BinaryOpsBroadcast, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (lhs, rhs, broadcast) = ops.state; + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let grad = B::mul(grad, rhs.unwrap()); + broadcast.backward_lhs::(grad) + }, + |grad| { + let grad = B::mul(grad, lhs.unwrap()); + broadcast.backward_rhs::(grad) + }, + ); + } + } + + let lhs_tracked = lhs.is_tracked(); + let rhs_tracked = rhs.is_tracked(); + let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); + + match Mul + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + rhs_tracked.then(|| lhs.primitive.clone()), + lhs_tracked.then(|| rhs.primitive.clone()), + broadcast, + ), + B::mul(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::mul(lhs.primitive, rhs.primitive)), + } + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct MulScalar; + + impl Backward for MulScalar { + type State = FloatElem; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::mul_scalar(grad, ops.state) + }); + } + } + + match MulScalar.prepare([lhs.node], [lhs.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(rhs, B::mul_scalar(lhs.primitive, rhs)), + OpsKind::UnTracked(prep) => prep.finish(B::mul_scalar(lhs.primitive, rhs)), + } + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Div; + + impl Backward for Div { + type State = ( + Option>, + Option>, + BinaryOpsBroadcast, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (lhs, rhs, broadcast) = ops.state; + let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, rhs); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let rhs = rhs_4lhs.unwrap(); + let value = B::powf(rhs, -1.0); + let grad = B::mul(grad, value); + + broadcast.backward_lhs::(grad) + }, + |grad| { + let rhs = rhs_4rhs.unwrap(); + let lhs = lhs.unwrap(); + let value = B::div(B::neg(lhs), B::powf(rhs, 2.0)); + let grad = B::mul(grad, value); + + broadcast.backward_rhs::(grad) + }, + ); + } + } + + let lhs_tracked = lhs.is_tracked(); + let rhs_tracked = rhs.is_tracked(); + let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); + + match Div + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + rhs_tracked.then(|| lhs.primitive.clone()), + (lhs_tracked || rhs_tracked).then(|| rhs.primitive.clone()), + broadcast, + ), + B::div(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::div(lhs.primitive, rhs.primitive)), + } + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct DivScalar; + + impl Backward for DivScalar { + type State = FloatElem; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let tmp = 1.0 / ops.state.elem::(); + B::mul_scalar(grad, tmp.elem()) + }); + } + } + + match DivScalar.prepare([lhs.node], [lhs.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(rhs, B::div_scalar(lhs.primitive, rhs)), + OpsKind::UnTracked(prep) => prep.finish(B::div_scalar(lhs.primitive, rhs)), + } + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Matmul; + + impl Backward for Matmul { + type State = ( + Option>, + Option>, + BinaryOpsBroadcast, + ); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (lhs, rhs, broadcast) = ops.state; + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let rhs = B::transpose(rhs.unwrap()); + let grad = B::matmul(grad, rhs); + + broadcast.backward_lhs::(grad) + }, + |grad| { + let lhs = B::transpose(lhs.unwrap()); + let grad = B::matmul(lhs, grad); + + broadcast.backward_rhs::(grad) + }, + ); + } + } + + let lhs_tracked = lhs.is_tracked(); + let rhs_tracked = rhs.is_tracked(); + let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); + + match Matmul + .prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + rhs_tracked.then(|| lhs.primitive.clone()), + lhs_tracked.then(|| rhs.primitive.clone()), + broadcast, + ), + B::matmul(lhs.primitive, rhs.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::matmul(lhs.primitive, rhs.primitive)), + } + } + + fn neg(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Neg; + + impl Backward for Neg { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| B::neg(grad)); + } + } - val.mul(grad.unsqueeze()).into_primitive() - }); - } + Neg.prepare([tensor.node], [tensor.graph]) + .stateless(B::neg(tensor.primitive)) } - match Sum.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(B::shape(&tensor.primitive), B::sum(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::sum(tensor.primitive)), + fn recip(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Recip; + + impl Backward for Recip { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let tensor = ops.state; + unary::(ops.parents, ops.node, grads, |grad| { + let tmp = B::powf(tensor, -2.0); + let value = B::neg(tmp); + + B::mul(grad, value) + }); + } + } + + match Recip.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)), + } } - } - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[derive(Debug)] - struct MeamDim; + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + #[derive(Debug)] + struct SwapDim; - impl Backward for MeamDim { - type State = (Shape, usize); + impl Backward for SwapDim { + type State = (usize, usize); - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape, dim) = ops.state; + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim1, dim2) = ops.state; - unary::(ops.parents, ops.node, grads, |grad| { - let val = 1_f64 / shape.dims[dim] as f64; - let ones = B::ones(shape, &B::device(&grad)); - let val = B::mul_scalar(ones, B::FloatElem::from_elem(val)); + unary::(ops.parents, ops.node, grads, |grad| { + B::swap_dims(grad, dim2, dim1) + }); + } + } - let grad = B::sum_dim(grad, dim); - B::mul(val, grad) - }); - } + let output = B::swap_dims(tensor.primitive, dim1, dim2); + + match SwapDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish((dim1, dim2), output), + OpsKind::UnTracked(prep) => prep.finish(output), + } } - match MeamDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (B::shape(&tensor.primitive), dim), - B::mean_dim(tensor.primitive, dim), - ), - OpsKind::UnTracked(prep) => prep.finish(B::mean_dim(tensor.primitive, dim)), + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + #[derive(Debug)] + struct ReshapeDim; + + impl Backward for ReshapeDim { + type State = (Shape, Shape); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape_original, shape) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let shape_grad = B::shape(&grad); + let mut grad = grad; + + for i in 0..D2 { + if shape.dims[i] == 1 && shape_grad.dims[i] != 1 { + grad = B::sum_dim(grad, i); + } + } + + B::reshape(grad, shape_original) + }); + } + } + + match ReshapeDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (B::shape(&tensor.primitive), shape.clone()), + B::reshape(tensor.primitive, shape), + ), + OpsKind::UnTracked(prep) => prep.finish(B::reshape(tensor.primitive, shape)), + } + } + + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Gather; + + impl Backward for Gather { + type State = (usize, IntTensor, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape, device) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let zeros = B::zeros(shape, &device); + B::scatter(dim, zeros, indices, grad) + }); + } + } + + match Gather.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::device(&tensor.primitive), + ), + B::gather(dim, tensor.primitive, indices), + ), + OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indices)), + } + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct Scatter; + + impl Backward for Scatter { + type State = (usize, IntTensor, Shape, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; + let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_lhs, &device); + B::scatter(dim, grad, indices_4lhs.unwrap(), zeros) + }, + |grad| { + let zeros = B::zeros(shape_rhs, &device); + B::scatter(dim, zeros, indices_4rhs.unwrap(), grad) + }, + ); + } + } + + match Scatter + .prepare([tensor.node, value.node], [tensor.graph, value.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::shape(&value.primitive), + B::device(&value.primitive), + ), + B::scatter(dim, tensor.primitive, indices, value.primitive), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::scatter(dim, tensor.primitive, indices, value.primitive)) + } + } } - } - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[derive(Debug)] - struct SumDim; + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct IndexSelectDim; + + impl Backward for IndexSelectDim { + type State = (usize, IntTensor, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape, device) = ops.state; - impl Backward for SumDim { - type State = (Shape, usize); + unary::(ops.parents, ops.node, grads, |grad| { + let zeros = B::zeros(shape, &device); + B::select_assign(zeros, dim, indices, grad) + }); + } + } + + match IndexSelectDim + .prepare([tensor.node], [tensor.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::device(&tensor.primitive), + ), + B::select(tensor.primitive, dim, indices), + ), + OpsKind::UnTracked(prep) => prep.finish(B::select(tensor.primitive, dim, indices)), + } + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct IndexSelectDimAssign; + + impl Backward for IndexSelectDimAssign { + type State = (usize, IntTensor, Shape, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (dim, indices, shape_lhs, shape_rhs, device) = ops.state; + let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_lhs, &device); + B::select_assign(grad, dim, indices_4lhs.unwrap(), zeros) + }, + |grad| { + let zeros = B::zeros(shape_rhs, &device); + B::select_assign(zeros, dim, indices_4rhs.unwrap(), grad) + }, + ); + } + } + + match IndexSelectDimAssign:: + .prepare([tensor.node, value.node], [tensor.graph, value.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + dim, + indices.clone(), + B::shape(&tensor.primitive), + B::shape(&value.primitive), + B::device(&value.primitive), + ), + B::select_assign(tensor.primitive, dim, indices, value.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::select_assign( + tensor.primitive, + dim, + indices, + value.primitive, + )), + } + } + + fn slice( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + ) -> FloatTensor { + #[derive(Debug)] + struct Index; + + impl Backward for Index { + type State = ([std::ops::Range; D2], Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (ranges, shape, device) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let zeros = B::zeros(shape, &device); + B::slice_assign(zeros, ranges, grad) + }); + } + } + + match Index.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + ( + ranges.clone(), + B::shape(&tensor.primitive), + B::device(&tensor.primitive), + ), + B::slice(tensor.primitive, ranges), + ), + OpsKind::UnTracked(prep) => prep.finish(B::slice(tensor.primitive, ranges)), + } + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + value: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct IndexAssign; + + impl Backward for IndexAssign { + type State = ([std::ops::Range; D2], Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (ranges, shape_rhs, device) = ops.state; + let [ranges_4lhs, ranges_4rhs] = duplicate(&ops.parents, Some(ranges)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_rhs, &device); + B::slice_assign(grad, ranges_4lhs.unwrap(), zeros) + }, + |grad| B::slice(grad, ranges_4rhs.unwrap()), + ); + } + } + + match IndexAssign + .prepare([tensor.node, value.node], [tensor.graph, value.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + ranges.clone(), + B::shape(&value.primitive), + B::device(&value.primitive), + ), + B::slice_assign(tensor.primitive, ranges, value.primitive), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::slice_assign(tensor.primitive, ranges, value.primitive)) + } + } + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + source: FloatTensor, + ) -> FloatTensor { + #[derive(Debug)] + struct MaskWhere; + + impl Backward for MaskWhere { + type State = (BoolTensor, Shape, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (mask, shape_lhs, shape_rhs, device) = ops.state; + let [mask_4lhs, mask_4rhs] = duplicate(&ops.parents, Some(mask)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_lhs.clone(), &device); + let grad = B::mask_where(grad, mask_4lhs.unwrap(), zeros); + + broadcast_shape::(grad, &shape_lhs) + }, + |grad| { + let zeros = B::zeros(shape_rhs.clone(), &device); + let grad = B::mask_where(zeros, mask_4rhs.unwrap(), grad); + + broadcast_shape::(grad, &shape_rhs) + }, + ); + } + } + + match MaskWhere + .prepare([tensor.node, source.node], [tensor.graph, source.graph]) + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + ( + mask.clone(), + B::shape(&tensor.primitive), + B::shape(&source.primitive), + B::device(&source.primitive), + ), + B::mask_where(tensor.primitive, mask, source.primitive), + ), + OpsKind::UnTracked(prep) => { + prep.finish(B::mask_where(tensor.primitive, mask, source.primitive)) + } + } + } - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (shape, dim) = ops.state; + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + #[derive(Debug)] + struct MaskFill; - unary::(ops.parents, ops.node, grads, |grad| { - let ones = B::ones(shape, &B::device(&grad)); - let grad = B::sum_dim(grad, dim); + impl Backward for MaskFill { + type State = BoolTensor; - B::mul(ones, grad) - }); - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + B::mask_fill(grad, ops.state, 0.elem()) + }); + } + } + + match MaskFill.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(mask.clone(), B::mask_fill(tensor.primitive, mask, value)) + } + OpsKind::UnTracked(prep) => prep.finish(B::mask_fill(tensor.primitive, mask, value)), + } } - match SumDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (B::shape(&tensor.primitive), dim), - B::sum_dim(tensor.primitive, dim), - ), - OpsKind::UnTracked(prep) => prep.finish(B::sum_dim(tensor.primitive, dim)), + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::equal(lhs.primitive, rhs.primitive) } - } - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - #[derive(Debug)] - struct ToFullPrecision { - phantom: PhantomData, + fn equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::equal_elem(lhs.primitive, rhs) } - impl Backward for ToFullPrecision { - type State = (); + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::greater(lhs.primitive, rhs.primitive) + } - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary_different_backend::( - ops.parents, - ops.node, - grads, - |grad| B::from_full_precision(grad), - ); - } + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::greater_elem(lhs.primitive, rhs) } - let ops = ToFullPrecision:: { - phantom: PhantomData, - }; - ops - .prepare([tensor.node.clone()], [tensor.graph.clone()]) - .stateless(B::to_full_precision(&tensor.primitive)) - } + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::greater_equal(lhs.primitive, rhs.primitive) + } - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - #[derive(Debug)] - struct FromFullPrecision { - phantom: PhantomData, + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::greater_equal_elem(lhs.primitive, rhs) } - impl Backward for FromFullPrecision { - type State = (); + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::lower(lhs.primitive, rhs.primitive) + } - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary_different_backend::( - ops.parents, - ops.node, - grads, - |grad| B::to_full_precision(&grad), - ); - } + fn lower_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::lower_elem(lhs.primitive, rhs) } - let ops = FromFullPrecision:: { - phantom: PhantomData, - }; + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + B::lower_equal(lhs.primitive, rhs.primitive) + } - ops - .prepare([tensor.node.clone()], [tensor.graph]) - .stateless(B::from_full_precision(tensor.primitive)) - } + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + B::lower_equal_elem(lhs.primitive, rhs) + } - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - B::argmax(tensor.primitive, dim) - } + fn detach(tensor: FloatTensor) -> FloatTensor { + // When we detach a tensor, we remove it from the graph, but we still want to keep the + // `require_grad` setting. + let is_require_grad = Self::is_require_grad(&tensor); + let tensor = AutodiffTensor::new(tensor.primitive); - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - B::argmin(tensor.primitive, dim) - } + match is_require_grad { + true => tensor.require_grad(), + false => tensor, + } + } - fn exp(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Exp; + fn set_require_grad( + tensor: FloatTensor, + require_grad: bool, + ) -> FloatTensor { + if require_grad { + return tensor.require_grad(); + } - impl Backward for Exp { - type State = B::TensorPrimitive; + AutodiffTensor::new(tensor.primitive) + } - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); - } + fn is_require_grad(tensor: &FloatTensor) -> bool { + matches!(tensor.node.requirement, Requirement::Grad) } - let output = B::exp(tensor.primitive); + fn mean(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Mean; + + impl Backward for Mean { + type State = Shape; - match Exp.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(output.clone(), output), - OpsKind::UnTracked(prep) => prep.finish(output), + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let shape = ops.state; + let val = 1_f64 / shape.num_elements() as f64; + let ones = B::ones(shape, &B::device(&grad)); + let val = B::mul_scalar(ones, val.elem()); + + let grad: Tensor = Tensor::from_primitive(grad); + let val: Tensor = Tensor::from_primitive(val); + + val.mul(grad.unsqueeze()).into_primitive() + }); + } + } + + match Mean.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(B::shape(&tensor.primitive), B::mean(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::mean(tensor.primitive)), + } } - } - fn log(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Log; + fn sum(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Sum; - impl Backward for Log { - type State = B::TensorPrimitive; + impl Backward for Sum { + type State = Shape; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::powf(ops.state, -1.0); - B::mul(grad, value) - }); - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let val = B::ones(ops.state, &B::device(&grad)); + + let grad: Tensor = Tensor::from_primitive(grad); + let val: Tensor = Tensor::from_primitive(val); + + val.mul(grad.unsqueeze()).into_primitive() + }); + } + } + + match Sum.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(B::shape(&tensor.primitive), B::sum(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::sum(tensor.primitive)), + } } - match Log.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::log(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::log(tensor.primitive)), + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[derive(Debug)] + struct MeamDim; + + impl Backward for MeamDim { + type State = (Shape, usize); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape, dim) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let val = 1_f64 / shape.dims[dim] as f64; + let ones = B::ones(shape, &B::device(&grad)); + let val = B::mul_scalar(ones, B::FloatElem::from_elem(val)); + + let grad = B::sum_dim(grad, dim); + B::mul(val, grad) + }); + } + } + + match MeamDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (B::shape(&tensor.primitive), dim), + B::mean_dim(tensor.primitive, dim), + ), + OpsKind::UnTracked(prep) => prep.finish(B::mean_dim(tensor.primitive, dim)), + } } - } - fn log1p(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Log1P; + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[derive(Debug)] + struct SumDim; - impl Backward for Log1P { - type State = B::TensorPrimitive; + impl Backward for SumDim { + type State = (Shape, usize); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::add_scalar(ops.state, 1.elem()); - let value = B::powf(value, -1.0); + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (shape, dim) = ops.state; - B::mul(grad, value) - }); - } + unary::(ops.parents, ops.node, grads, |grad| { + let ones = B::ones(shape, &B::device(&grad)); + let grad = B::sum_dim(grad, dim); + + B::mul(ones, grad) + }); + } + } + + match SumDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (B::shape(&tensor.primitive), dim), + B::sum_dim(tensor.primitive, dim), + ), + OpsKind::UnTracked(prep) => prep.finish(B::sum_dim(tensor.primitive, dim)), + } + } + + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + #[derive(Debug)] + struct ToFullPrecision { + phantom: PhantomData, + } + + impl Backward for ToFullPrecision { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary_different_backend::( + ops.parents, + ops.node, + grads, + |grad| B::from_full_precision(grad), + ); + } + } + + let ops = ToFullPrecision:: { + phantom: PhantomData, + }; + ops.prepare([tensor.node.clone()], [tensor.graph.clone()]) + .stateless(B::to_full_precision(&tensor.primitive)) + } + + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + #[derive(Debug)] + struct FromFullPrecision { + phantom: PhantomData, + } + + impl Backward for FromFullPrecision { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary_different_backend::( + ops.parents, + ops.node, + grads, + |grad| B::to_full_precision(&grad), + ); + } + } + + let ops = FromFullPrecision:: { + phantom: PhantomData, + }; + + ops.prepare([tensor.node.clone()], [tensor.graph]) + .stateless(B::from_full_precision(tensor.primitive)) } - match Log1P.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::log1p(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::log1p(tensor.primitive)), + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + B::argmax(tensor.primitive, dim) } - } - fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { - #[derive(Debug)] - struct PowF; + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + B::argmin(tensor.primitive, dim) + } - impl Backward for PowF { - type State = (B::TensorPrimitive, f32); + fn exp(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Exp; - fn backward(self, ops: Ops, grads: &mut Gradients) { - let (tensor, value) = ops.state; + impl Backward for Exp { + type State = B::TensorPrimitive; - unary::(ops.parents, ops.node, grads, |grad| { - let tmp = B::powf(tensor, value - 1.0); - let value = B::mul_scalar(tmp, value.elem()); + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); + } + } - B::mul(grad, value) - }); - } - } + let output = B::exp(tensor.primitive); - match PowF.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish( - (tensor.primitive.clone(), value), - B::powf(tensor.primitive, value), - ), - OpsKind::UnTracked(prep) => prep.finish(B::powf(tensor.primitive, value)), + match Exp.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish(output.clone(), output), + OpsKind::UnTracked(prep) => prep.finish(output), + } } - } - fn sqrt(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Sqrt; + fn log(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Log; - impl Backward for Sqrt { - type State = B::TensorPrimitive; + impl Backward for Log { + type State = B::TensorPrimitive; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let input = ops.state; - let value = B::div_scalar(B::powf(input, -0.5), 2.elem()); + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::powf(ops.state, -1.0); + B::mul(grad, value) + }); + } + } - B::mul(grad, value) - }); - } + match Log.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::log(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::log(tensor.primitive)), + } } - match Sqrt.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::sqrt(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::sqrt(tensor.primitive)), + fn log1p(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Log1P; + + impl Backward for Log1P { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::add_scalar(ops.state, 1.elem()); + let value = B::powf(value, -1.0); + + B::mul(grad, value) + }); + } + } + + match Log1P.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::log1p(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::log1p(tensor.primitive)), + } } - } - fn abs(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Abs; + fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { + #[derive(Debug)] + struct PowF; + + impl Backward for PowF { + type State = (B::TensorPrimitive, f32); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (tensor, value) = ops.state; - impl Backward for Abs { - type State = B::TensorPrimitive; + unary::(ops.parents, ops.node, grads, |grad| { + let tmp = B::powf(tensor, value - 1.0); + let value = B::mul_scalar(tmp, value.elem()); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); - } + B::mul(grad, value) + }); + } + } + + match PowF.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => prep.finish( + (tensor.primitive.clone(), value), + B::powf(tensor.primitive, value), + ), + OpsKind::UnTracked(prep) => prep.finish(B::powf(tensor.primitive, value)), + } } - match Abs.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::abs(tensor.primitive.clone()); - let state = B::div(tensor.primitive, output.clone()); - prep.finish(state, output) - } - OpsKind::UnTracked(prep) => prep.finish(B::abs(tensor.primitive)), + fn sqrt(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Sqrt; + + impl Backward for Sqrt { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let input = ops.state; + let value = B::div_scalar(B::powf(input, -0.5), 2.elem()); + + B::mul(grad, value) + }); + } + } + + match Sqrt.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::sqrt(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::sqrt(tensor.primitive)), + } } - } - fn cos(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Cos; + fn abs(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Abs; - impl Backward for Cos { - type State = B::TensorPrimitive; + impl Backward for Abs { + type State = B::TensorPrimitive; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let input = ops.state; - let value = B::neg(B::sin(input)); + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| B::mul(grad, ops.state)); + } + } - B::mul(grad, value) - }); - } + match Abs.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::abs(tensor.primitive.clone()); + let state = B::div(tensor.primitive, output.clone()); + prep.finish(state, output) + } + OpsKind::UnTracked(prep) => prep.finish(B::abs(tensor.primitive)), + } } - match Cos.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::cos(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::cos(tensor.primitive)), + fn cos(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Cos; + + impl Backward for Cos { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let input = ops.state; + let value = B::neg(B::sin(input)); + + B::mul(grad, value) + }); + } + } + + match Cos.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::cos(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::cos(tensor.primitive)), + } } - } - fn sin(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Sin; + fn sin(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Sin; - impl Backward for Sin { - type State = B::TensorPrimitive; + impl Backward for Sin { + type State = B::TensorPrimitive; - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::cos(ops.state); - B::mul(grad, value) - }); - } + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::cos(ops.state); + B::mul(grad, value) + }); + } + } + + match Sin.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::sin(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::sin(tensor.primitive)), + } } - match Sin.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::sin(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::sin(tensor.primitive)), + fn tanh(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Tanh; + + impl Backward for Tanh { + type State = B::TensorPrimitive; + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let value = B::add_scalar(B::neg(B::powf(ops.state, 2.0)), 1.elem()); + B::mul(grad, value) + }); + } + } + + match Tanh.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let output = B::tanh(tensor.primitive); + prep.finish(output.clone(), output) + } + OpsKind::UnTracked(prep) => prep.finish(B::tanh(tensor.primitive)), + } } - } - fn tanh(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Tanh; + fn erf(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Erf; + + impl Backward for Erf { + type State = B::TensorPrimitive; - impl Backward for Tanh { - type State = B::TensorPrimitive; + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| { + let exponent = B::neg(B::powf(ops.state, 2.0)); + let numerator = B::mul_scalar(B::exp(exponent), 2.0.elem()); + let denominator = std::f64::consts::PI.sqrt().elem(); + let value = B::div_scalar(numerator, denominator); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let value = B::add_scalar(B::neg(B::powf(ops.state, 2.0)), 1.elem()); - B::mul(grad, value) + B::mul(grad, value) + }); + } + } + + match Erf.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::erf(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::erf(tensor.primitive)), + } + } + + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + #[derive(new, Debug)] + struct CatStep { + nodes: Vec>, + // The dimension of each tensor along the dim dimension. + // This indicates the number of dimension concatenated for each tensor. + dim_sizes: Vec, + output: NodeRef, + phantom: PhantomData, + dim: usize, + } + + impl Step for CatStep { + fn step(self: Box, grads: &mut Gradients) { + let grad = grads.consume::(&self.output); + let ranges: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect(); + let ranges: [std::ops::Range; D] = ranges.try_into().unwrap(); + + let mut current_index = 0; + + self.nodes + .into_iter() + .zip(self.dim_sizes) + .filter_map(|(node, dim_size)| node.map(|node| (node, dim_size))) + .for_each(|(node, dim_size)| { + let mut ranges = ranges.clone(); + ranges[self.dim] = current_index..dim_size + current_index; + current_index += dim_size; + grads.register::(node, B::slice(grad.clone(), ranges)); + }); + } + + fn node(&self) -> NodeRef { + self.output.clone() + } + } + + let mut nodes = Vec::with_capacity(tensors.len()); + let mut graphs = Vec::with_capacity(tensors.len()); + let mut primitives = Vec::with_capacity(tensors.len()); + let mut dim_sizes = Vec::with_capacity(tensors.len()); + + tensors.into_iter().for_each(|tensor| { + dim_sizes.push(B::shape(&tensor.primitive).dims[dim]); + nodes.push(tensor.node); + primitives.push(tensor.primitive); + graphs.push(tensor.graph); }); - } - } - match Tanh.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let output = B::tanh(tensor.primitive); - prep.finish(output.clone(), output) - } - OpsKind::UnTracked(prep) => prep.finish(B::tanh(tensor.primitive)), + let requirement = Requirement::from_nodes(&nodes); + + let output = B::cat(primitives, dim); + if requirement.is_none() { + return AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); + } + + let output = AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); + let nodes = nodes + .into_iter() + .map(|node| node.clone_if_require_grad()) + .collect::>(); + + let ops = CatStep::::new(nodes, dim_sizes, output.node.clone(), dim); + output.register_step(ops) } - } - fn erf(tensor: FloatTensor) -> FloatTensor { - #[derive(Debug)] - struct Erf; + fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); + prep.finish((index, shape), tensor) + } + OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)), + } + } + fn max_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish((index.clone(), shape), tensor); + + (tensor, index) + } + OpsKind::UnTracked(prep) => { + let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish(tensor); - impl Backward for Erf { - type State = B::TensorPrimitive; + (tensor, index) + } + } + } + fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); + prep.finish((index, shape), tensor) + } + OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)), + } + } + fn min_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + let shape = B::shape(&tensor.primitive); + let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish((index.clone(), shape), tensor); + + (tensor, index) + } + OpsKind::UnTracked(prep) => { + let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); + let tensor = prep.finish(tensor); - fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| { - let exponent = B::neg(B::powf(ops.state, 2.0)); - let numerator = B::mul_scalar(B::exp(exponent), 2.0.elem()); - let denominator = std::f64::consts::PI.sqrt().elem(); - let value = B::div_scalar(numerator, denominator); + (tensor, index) + } + } + } - B::mul(grad, value) - }); - } - } - - match Erf.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => prep.finish(tensor.primitive.clone(), B::erf(tensor.primitive)), - OpsKind::UnTracked(prep) => prep.finish(B::erf(tensor.primitive)), - } - } - - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - #[derive(new, Debug)] - struct CatStep { - nodes: Vec>, - // The dimension of each tensor along the dim dimension. - // This indicates the number of dimension concatenated for each tensor. - dim_sizes: Vec, - output: NodeRef, - phantom: PhantomData, - dim: usize, - } - - impl Step for CatStep { - fn step(self: Box, grads: &mut Gradients) { - let grad = grads.consume::(&self.output); - let ranges: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect(); - let ranges: [std::ops::Range; D] = ranges.try_into().unwrap(); - - let mut current_index = 0; - - self - .nodes - .into_iter() - .zip(self.dim_sizes) - .filter_map(|(node, dim_size)| node.map(|node| (node, dim_size))) - .for_each(|(node, dim_size)| { - let mut ranges = ranges.clone(); - ranges[self.dim] = current_index..dim_size + current_index; - current_index += dim_size; - grads.register::(node, B::slice(grad.clone(), ranges)); - }); - } - - fn node(&self) -> NodeRef { - self.output.clone() - } - } - - let mut nodes = Vec::with_capacity(tensors.len()); - let mut graphs = Vec::with_capacity(tensors.len()); - let mut primitives = Vec::with_capacity(tensors.len()); - let mut dim_sizes = Vec::with_capacity(tensors.len()); - - tensors.into_iter().for_each(|tensor| { - dim_sizes.push(B::shape(&tensor.primitive).dims[dim]); - nodes.push(tensor.node); - primitives.push(tensor.primitive); - graphs.push(tensor.graph); - }); - - let requirement = Requirement::from_nodes(&nodes); - - let output = B::cat(primitives, dim); - if requirement.is_none() { - return AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); - } - - let output = AutodiffTensor::from_parents(output, &nodes, graphs.into_iter(), requirement); - let nodes = nodes - .into_iter() - .map(|node| node.clone_if_require_grad()) - .collect::>(); - - let ops = CatStep::::new(nodes, dim_sizes, output.node.clone(), dim); - output.register_step(ops) - } - - fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); - prep.finish((index, shape), tensor) - } - OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)), - } - } - fn max_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish((index.clone(), shape), tensor); - - (tensor, index) - } - OpsKind::UnTracked(prep) => { - let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish(tensor); - - (tensor, index) - } - } - } - fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); - prep.finish((index, shape), tensor) - } - OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)), - } - } - fn min_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - match MaxMinDim.prepare([tensor.node], [tensor.graph]).stateful() { - OpsKind::Tracked(prep) => { - let shape = B::shape(&tensor.primitive); - let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish((index.clone(), shape), tensor); - - (tensor, index) - } - OpsKind::UnTracked(prep) => { - let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim); - let tensor = prep.finish(tensor); - - (tensor, index) - } - } - } - - fn into_int( - tensor: FloatTensor, - ) -> as Backend>::IntTensorPrimitive { - B::into_int(tensor.primitive) - } + fn into_int( + tensor: FloatTensor, + ) -> as Backend>::IntTensorPrimitive { + B::into_int(tensor.primitive) + } } #[derive(Debug, Clone)] enum BinaryOpsBroadcast { - Broadcasted(Shape, Shape), - None, + Broadcasted(Shape, Shape), + None, } impl BinaryOpsBroadcast { - fn new(lhs: &B::TensorPrimitive, rhs: &B::TensorPrimitive) -> Self { - let shape_lhs = B::shape(lhs); - let shape_rhs = B::shape(rhs); + fn new(lhs: &B::TensorPrimitive, rhs: &B::TensorPrimitive) -> Self { + let shape_lhs = B::shape(lhs); + let shape_rhs = B::shape(rhs); - for i in 0..D { - if shape_rhs.dims[i] != shape_lhs.dims[i] { - return Self::Broadcasted(shape_lhs, shape_rhs); - } - } + for i in 0..D { + if shape_rhs.dims[i] != shape_lhs.dims[i] { + return Self::Broadcasted(shape_lhs, shape_rhs); + } + } - Self::None - } + Self::None + } - fn backward_lhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { - match self { - BinaryOpsBroadcast::Broadcasted(lhs, _rhs) => broadcast_shape::(grad, lhs), - BinaryOpsBroadcast::None => grad, + fn backward_lhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { + match self { + BinaryOpsBroadcast::Broadcasted(lhs, _rhs) => broadcast_shape::(grad, lhs), + BinaryOpsBroadcast::None => grad, + } } - } - fn backward_rhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { - match self { - BinaryOpsBroadcast::Broadcasted(_lhs, rhs) => broadcast_shape::(grad, rhs), - BinaryOpsBroadcast::None => grad, + fn backward_rhs(&self, grad: B::TensorPrimitive) -> B::TensorPrimitive { + match self { + BinaryOpsBroadcast::Broadcasted(_lhs, rhs) => broadcast_shape::(grad, rhs), + BinaryOpsBroadcast::None => grad, + } } - } } diff --git a/burn-autodiff/src/tensor.rs b/burn-autodiff/src/tensor.rs index ab5465a5fd..84c6c80b73 100644 --- a/burn-autodiff/src/tensor.rs +++ b/burn-autodiff/src/tensor.rs @@ -1,106 +1,106 @@ use burn_tensor::backend::Backend; use crate::{ - grads::Gradients, - graph::{ - Node, NodeID, NodeRef, Requirement, {Graph, Step}, - }, + grads::Gradients, + graph::{ + Node, NodeID, NodeRef, Requirement, {Graph, Step}, + }, }; #[derive(Debug, Clone)] pub struct AutodiffTensor { - pub primitive: B::TensorPrimitive, - pub node: NodeRef, - pub graph: Graph, + pub primitive: B::TensorPrimitive, + pub node: NodeRef, + pub graph: Graph, } #[derive(new, Debug)] struct RootStep { - node: NodeRef, + node: NodeRef, } impl Step for RootStep { - fn step(self: Box, _grads: &mut Gradients) { - // Nothing to do - } + fn step(self: Box, _grads: &mut Gradients) { + // Nothing to do + } - fn node(&self) -> NodeRef { - self.node.clone() - } + fn node(&self) -> NodeRef { + self.node.clone() + } } impl AutodiffTensor { - /// Create a new leaf tensor. - pub fn new(primitive: B::TensorPrimitive) -> Self { - let id = NodeID::new(); - let node = Node::new(vec![], 0, id, Requirement::None); + /// Create a new leaf tensor. + pub fn new(primitive: B::TensorPrimitive) -> Self { + let id = NodeID::new(); + let node = Node::new(vec![], 0, id, Requirement::None); - Self { - primitive, - node: node.into(), - graph: Graph::new(), + Self { + primitive, + node: node.into(), + graph: Graph::new(), + } } - } - pub fn is_tracked(&self) -> bool { - !self.node.requirement.is_none() - } + pub fn is_tracked(&self) -> bool { + !self.node.requirement.is_none() + } - /// Mark the tensor as requirering gradients. - /// - /// # Panics - /// - /// It panics if the tensor is non a leaf. - pub fn require_grad(mut self) -> Self { - match self.node.requirement { - Requirement::Grad => self, - Requirement::GradInBackward => { - panic!("Can't convert a non leaf tensor into a tracked tensor") - } - Requirement::None => { - self.node = Node::new(vec![], 0, self.node.id.clone(), Requirement::Grad).into(); - let ops = RootStep::new(self.node.clone()); + /// Mark the tensor as requirering gradients. + /// + /// # Panics + /// + /// It panics if the tensor is non a leaf. + pub fn require_grad(mut self) -> Self { + match self.node.requirement { + Requirement::Grad => self, + Requirement::GradInBackward => { + panic!("Can't convert a non leaf tensor into a tracked tensor") + } + Requirement::None => { + self.node = Node::new(vec![], 0, self.node.id.clone(), Requirement::Grad).into(); + let ops = RootStep::new(self.node.clone()); - self.register_step(ops) - } + self.register_step(ops) + } + } } - } - /// Create a tensor from parent infos. - pub fn from_parents>( - output: B::TensorPrimitive, - parent_nodes: &[NodeRef], - parent_graphs: I, - requirement: Requirement, - ) -> Self { - let graph = parent_graphs - .reduce(|acc, graph| acc.merge(graph)) - .unwrap_or_else(Graph::new); + /// Create a tensor from parent infos. + pub fn from_parents>( + output: B::TensorPrimitive, + parent_nodes: &[NodeRef], + parent_graphs: I, + requirement: Requirement, + ) -> Self { + let graph = parent_graphs + .reduce(|acc, graph| acc.merge(graph)) + .unwrap_or_else(Graph::new); - let order = parent_nodes - .iter() - .map(|node| node.order) - .reduce(usize::max) - .unwrap_or(0) - + 1; + let order = parent_nodes + .iter() + .map(|node| node.order) + .reduce(usize::max) + .unwrap_or(0) + + 1; - let node = Node::new( - parent_nodes.iter().map(|node| node.id.clone()).collect(), - order, - NodeID::new(), - requirement, - ); + let node = Node::new( + parent_nodes.iter().map(|node| node.id.clone()).collect(), + order, + NodeID::new(), + requirement, + ); - Self { - primitive: output, - node: node.into(), - graph, + Self { + primitive: output, + node: node.into(), + graph, + } } - } - /// Register a step into a graph for that tensor. - pub fn register_step(mut self, ops: O) -> Self { - self.graph = self.graph.register(&self.node.id, Box::new(ops)); - self - } + /// Register a step into a graph for that tensor. + pub fn register_step(mut self, ops: O) -> Self { + self.graph = self.graph.register(&self.node.id, Box::new(ops)); + self + } } diff --git a/burn-autodiff/src/tests/abs.rs b/burn-autodiff/src/tests/abs.rs index 275761a7e1..02c40d135b 100644 --- a/burn-autodiff/src/tests/abs.rs +++ b/burn-autodiff/src/tests/abs.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_abs)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_abs() { - let data_1 = Data::::from([[0.0, -1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, -10.0]]); + #[test] + fn should_diff_abs() { + let data_1 = Data::::from([[0.0, -1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, -10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[71.0, 107.0], [71.0, 107.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[84.0, 42.0], [90.0, 54.0]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[71.0, 107.0], [71.0, 107.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[84.0, 42.0], [90.0, 54.0]]), 3); + } } diff --git a/burn-autodiff/src/tests/adaptive_avgpool1d.rs b/burn-autodiff/src/tests/adaptive_avgpool1d.rs index aaec2a2c1b..60caee893b 100644 --- a/burn-autodiff/src/tests/adaptive_avgpool1d.rs +++ b/burn-autodiff/src/tests/adaptive_avgpool1d.rs @@ -1,48 +1,48 @@ #[burn_tensor_testgen::testgen(ad_adaptive_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool1d_simple() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 5, - output_size: 3, - }; + #[test] + fn test_avg_pool1d_simple() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 5, + output_size: 3, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], - [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], + [0.5000, 0.8333, 0.3333, 0.8333, 0.5000], + ]])); + } - struct AdaptiveAvgPool1dTestCase { - batch_size: usize, - channels: usize, - length: usize, - output_size: usize, - } + struct AdaptiveAvgPool1dTestCase { + batch_size: usize, + channels: usize, + length: usize, + output_size: usize, + } - impl AdaptiveAvgPool1dTestCase { - fn assert_output(self, x_grad: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = adaptive_avg_pool1d(x.clone(), self.output_size); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AdaptiveAvgPool1dTestCase { + fn assert_output(self, x_grad: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = adaptive_avg_pool1d(x.clone(), self.output_size); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/adaptive_avgpool2d.rs b/burn-autodiff/src/tests/adaptive_avgpool2d.rs index a77974fe2f..4e09a63891 100644 --- a/burn-autodiff/src/tests/adaptive_avgpool2d.rs +++ b/burn-autodiff/src/tests/adaptive_avgpool2d.rs @@ -1,64 +1,64 @@ #[burn_tensor_testgen::testgen(ad_adaptive_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool2d_simple() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 5, - width: 3, - output_size_1: 3, - output_size_2: 2, - }; + #[test] + fn test_avg_pool2d_simple() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 5, + width: 3, + output_size_1: 3, + output_size_2: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [0.2500, 0.5000, 0.2500], - [0.4167, 0.8333, 0.4167], - [0.1667, 0.3333, 0.1667], - [0.4167, 0.8333, 0.4167], - [0.2500, 0.5000, 0.2500], - ], - [ - [0.2500, 0.5000, 0.2500], - [0.4167, 0.8333, 0.4167], - [0.1667, 0.3333, 0.1667], - [0.4167, 0.8333, 0.4167], - [0.2500, 0.5000, 0.2500], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [0.2500, 0.5000, 0.2500], + [0.4167, 0.8333, 0.4167], + [0.1667, 0.3333, 0.1667], + [0.4167, 0.8333, 0.4167], + [0.2500, 0.5000, 0.2500], + ], + [ + [0.2500, 0.5000, 0.2500], + [0.4167, 0.8333, 0.4167], + [0.1667, 0.3333, 0.1667], + [0.4167, 0.8333, 0.4167], + [0.2500, 0.5000, 0.2500], + ], + ]])); + } - struct AdaptiveAvgPool2dTestCase { - batch_size: usize, - channels: usize, - height: usize, - width: usize, - output_size_1: usize, - output_size_2: usize, - } + struct AdaptiveAvgPool2dTestCase { + batch_size: usize, + channels: usize, + height: usize, + width: usize, + output_size_1: usize, + output_size_2: usize, + } - impl AdaptiveAvgPool2dTestCase { - fn assert_output(self, x_grad: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AdaptiveAvgPool2dTestCase { + fn assert_output(self, x_grad: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/add.rs b/burn-autodiff/src/tests/add.rs index b21b6e5fd1..884ced38af 100644 --- a/burn-autodiff/src/tests/add.rs +++ b/burn-autodiff/src/tests/add.rs @@ -1,62 +1,62 @@ #[burn_tensor_testgen::testgen(ad_add)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_diff_add() { - let tensor_1 = TestAutodiffTensor::from_floats([2.0, 5.0]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0]).require_grad(); + #[test] + fn should_diff_add() { + let tensor_1 = TestAutodiffTensor::from_floats([2.0, 5.0]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0]).require_grad(); - let tensor_3 = tensor_1.clone() + tensor_2.clone(); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone() + tensor_2.clone(); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); - assert_eq!(grad_2.to_data(), Data::from([1.0, 1.0])); - assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0])); - } + assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.to_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0])); + } - #[test] - fn should_diff_add_scalar() { - let data = Data::from([2.0, 10.0]); + #[test] + fn should_diff_add_scalar() { + let data = Data::from([2.0, 10.0]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().add_scalar(5.0); - let grads = tensor_out.backward(); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().add_scalar(5.0); + let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); - assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0])); - } + assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0])); + } - #[test] - fn test_add_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + #[test] + fn test_add_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - let tensor_4 = tensor_1.clone().add(tensor_2.clone()); - let tensor_5 = tensor_4 - .add(tensor_3) - .add_scalar(5.0) - .add(tensor_1.clone()) - .add(tensor_2.clone()); - let tensor_6 = tensor_1.clone().add(tensor_5); + let tensor_4 = tensor_1.clone().add(tensor_2.clone()); + let tensor_5 = tensor_4 + .add(tensor_3) + .add_scalar(5.0) + .add(tensor_1.clone()) + .add(tensor_2.clone()); + let tensor_6 = tensor_1.clone().add(tensor_5); - let grads = tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[3.0, 3.0], [3.0, 3.0]])); - assert_eq!(grad_2.to_data(), Data::from([[2.0, 2.0], [2.0, 2.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[3.0, 3.0], [3.0, 3.0]])); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 2.0], [2.0, 2.0]])); + } } diff --git a/burn-autodiff/src/tests/aggregation.rs b/burn-autodiff/src/tests/aggregation.rs index d57b182051..a546a01469 100644 --- a/burn-autodiff/src/tests/aggregation.rs +++ b/burn-autodiff/src/tests/aggregation.rs @@ -1,121 +1,121 @@ #[burn_tensor_testgen::testgen(ad_aggregation)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_mean() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[3.5, 9.5], [3.5, 9.5]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-0.75, -0.75], [3.0, 3.0]]), 5); - } - - #[test] - fn should_diff_sum_1() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[14.0, 38.0], [14.0, 38.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-3.0, -3.0], [12.0, 12.0]]), 5); - } - - #[test] - fn should_diff_sum_2() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.clone().sum_dim(1); - let tensor_5 = tensor_4.mul(tensor_3); - - let grads = tensor_5.sum().backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[494.0, 722.0], [2990.0, 4370.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[690.0, 690.0], [958.0, 958.0]]), 3); - } - - #[test] - fn should_diff_mean_dim() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[4.0, 36.0], [3.0, -17.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[9.0, 9.0], [35.5, 35.5]]), 5); - } - - #[test] - fn should_diff_sum_dim() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[8.0, 72.0], [6.0, -34.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[18.0, 18.0], [71.0, 71.0]]), 5); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_mean() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[3.5, 9.5], [3.5, 9.5]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-0.75, -0.75], [3.0, 3.0]]), 5); + } + + #[test] + fn should_diff_sum_1() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[14.0, 38.0], [14.0, 38.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-3.0, -3.0], [12.0, 12.0]]), 5); + } + + #[test] + fn should_diff_sum_2() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.clone().sum_dim(1); + let tensor_5 = tensor_4.mul(tensor_3); + + let grads = tensor_5.sum().backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[494.0, 722.0], [2990.0, 4370.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[690.0, 690.0], [958.0, 958.0]]), 3); + } + + #[test] + fn should_diff_mean_dim() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[4.0, 36.0], [3.0, -17.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[9.0, 9.0], [35.5, 35.5]]), 5); + } + + #[test] + fn should_diff_sum_dim() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[8.0, 72.0], [6.0, -34.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[18.0, 18.0], [71.0, 71.0]]), 5); + } } diff --git a/burn-autodiff/src/tests/avgpool1d.rs b/burn-autodiff/src/tests/avgpool1d.rs index feb9891175..a0224cf11f 100644 --- a/burn-autodiff/src/tests/avgpool1d.rs +++ b/burn-autodiff/src/tests/avgpool1d.rs @@ -1,95 +1,95 @@ #[burn_tensor_testgen::testgen(ad_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool1d_simple() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 1, - kernel_size: 3, - padding: 0, - stride: 1, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_simple() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 1, + kernel_size: 3, + padding: 0, + stride: 1, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - 0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333, - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + 0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333, + ]]])); + } - #[test] - fn test_avg_pool1d_complex() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_complex() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[ - [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], - [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], + [0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333], + ]])); + } - #[test] - fn test_avg_pool1d_complex_dont_count_pad() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool1d_complex_dont_count_pad() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], - [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], + [0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333], + ]])); + } - struct AvgPool1dTestCase { - batch_size: usize, - channels: usize, - kernel_size: usize, - padding: usize, - stride: usize, - length: usize, - count_include_pad: bool, - } + struct AvgPool1dTestCase { + batch_size: usize, + channels: usize, + kernel_size: usize, + padding: usize, + stride: usize, + length: usize, + count_include_pad: bool, + } - impl AvgPool1dTestCase { - fn assert_output(self, x_grad: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = avg_pool1d( - x.clone(), - self.kernel_size, - self.stride, - self.padding, - self.count_include_pad, - ); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AvgPool1dTestCase { + fn assert_output(self, x_grad: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = avg_pool1d( + x.clone(), + self.kernel_size, + self.stride, + self.padding, + self.count_include_pad, + ); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/avgpool2d.rs b/burn-autodiff/src/tests/avgpool2d.rs index aba6936a52..5ad2aa50a3 100644 --- a/burn-autodiff/src/tests/avgpool2d.rs +++ b/burn-autodiff/src/tests/avgpool2d.rs @@ -1,120 +1,120 @@ #[burn_tensor_testgen::testgen(ad_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool2d_simple() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - height: 6, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_simple() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + height: 6, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], - [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], - [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], - [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], - [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], - [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], + [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], + [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], + [0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333], + [0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222], + [0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111], + ]]])); + } - #[test] - fn test_avg_pool2d_complex() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_complex() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], - [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], - [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], - [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], + [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], + [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], + [0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333], + ]]])); + } - #[test] - fn test_avg_pool2d_complex_dont_include_pad() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool2d_complex_dont_include_pad() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], - [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], - [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], - [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], + [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], + [0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750], + [0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250], + ]]])); + } - struct AvgPool2dTestCase { - batch_size: usize, - channels: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - height: usize, - width: usize, - count_include_pad: bool, - } + struct AvgPool2dTestCase { + batch_size: usize, + channels: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + height: usize, + width: usize, + count_include_pad: bool, + } - impl AvgPool2dTestCase { - fn assert_output(self, x_grad: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = avg_pool2d( - x.clone(), - [self.kernel_size_1, self.kernel_size_2], - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - self.count_include_pad, - ); - let grads = output.backward(); - let x_grad_actual = x.grad(&grads).unwrap(); + impl AvgPool2dTestCase { + fn assert_output(self, x_grad: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = avg_pool2d( + x.clone(), + [self.kernel_size_1, self.kernel_size_2], + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + self.count_include_pad, + ); + let grads = output.backward(); + let x_grad_actual = x.grad(&grads).unwrap(); - x_grad - .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + x_grad + .to_data() + .assert_approx_eq(&x_grad_actual.into_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/backward.rs b/burn-autodiff/src/tests/backward.rs index e75bc30da6..ca25e71da9 100644 --- a/burn-autodiff/src/tests/backward.rs +++ b/burn-autodiff/src/tests/backward.rs @@ -1,27 +1,29 @@ #[burn_tensor_testgen::testgen(module_backward)] mod tests { - use super::*; - use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; + use super::*; + use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; - #[test] - fn test_embedding_backward() { - let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = Data::from([[0, 1], [1, 1]]); - let x = Data::from([ - [[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]], - [[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]], - ]); - let weights = Tensor::::from_data(weights).require_grad(); - let indices = Tensor::::from_data(indices); - let x = Tensor::::from_data(x).require_grad(); + #[test] + fn test_embedding_backward() { + let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = Data::from([[0, 1], [1, 1]]); + let x = Data::from([ + [[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]], + [[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]], + ]); + let weights = Tensor::::from_data(weights).require_grad(); + let indices = Tensor::::from_data(indices); + let x = Tensor::::from_data(x).require_grad(); - let output = embedding(weights.clone(), indices); - let output = output.matmul(x); - let grads = output.backward(); + let output = embedding(weights.clone(), indices); + let output = output.matmul(x); + let grads = output.backward(); - let grad = weights.grad(&grads).unwrap(); - let expected = - Data::<::FloatElem, 2>::from([[3., 9., 7.], [21., 35., 27.]]); - assert_eq!(grad.to_data(), expected); - } + let grad = weights.grad(&grads).unwrap(); + let expected = Data::<::FloatElem, 2>::from([ + [3., 9., 7.], + [21., 35., 27.], + ]); + assert_eq!(grad.to_data(), expected); + } } diff --git a/burn-autodiff/src/tests/broadcast.rs b/burn-autodiff/src/tests/broadcast.rs index 428df71176..324c538d9c 100644 --- a/burn-autodiff/src/tests/broadcast.rs +++ b/burn-autodiff/src/tests/broadcast.rs @@ -1,56 +1,56 @@ #[burn_tensor_testgen::testgen(ad_broadcast)] mod tests { - use super::*; - use burn_tensor::{Data, Distribution, Int, Shape, Tensor}; - - #[test] - fn mul_broadcast() { - test_ops_broadcast_backward(|x, y| x * y); - } - - #[test] - fn div_broadcast() { - test_ops_broadcast_backward(|x, y| x / y); - } - - #[test] - fn sub_broadcast() { - test_ops_broadcast_backward(|x, y| x - y); - } - - #[test] - fn add_broadcast() { - test_ops_broadcast_backward(|x, y| x + y); - } - - #[test] - fn matmul_broadcast() { - test_ops_broadcast_backward(|x, y| x.matmul(y)); - } - - #[test] - fn mask_where_broadcast() { - test_ops_broadcast_backward(|x, y| x.mask_where(y.clone().equal_elem(4), y)); - } - - fn test_ops_broadcast_backward(func: F) - where - F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>, - { - let w = TestAutodiffTensor::zeros([16, 5, 5]).require_grad(); - let x = TestAutodiffTensor::zeros([4, 5, 5]).require_grad(); - - // Slice isn't a broadcastable operation, so it will fail when the previous backward pass - // of an operation that support broadcast doesn't support it during the backward pass. - let y = func(w.clone().slice([0..1]), x.clone()); - - // Will panic if broadcast isn't supported! - let grads = y.backward(); - - let w_grad = w.grad(&grads).unwrap(); - let x_grad = x.grad(&grads).unwrap(); - - assert_eq!(w_grad.shape(), w.shape()); - assert_eq!(x_grad.shape(), x.shape()); - } + use super::*; + use burn_tensor::{Data, Distribution, Int, Shape, Tensor}; + + #[test] + fn mul_broadcast() { + test_ops_broadcast_backward(|x, y| x * y); + } + + #[test] + fn div_broadcast() { + test_ops_broadcast_backward(|x, y| x / y); + } + + #[test] + fn sub_broadcast() { + test_ops_broadcast_backward(|x, y| x - y); + } + + #[test] + fn add_broadcast() { + test_ops_broadcast_backward(|x, y| x + y); + } + + #[test] + fn matmul_broadcast() { + test_ops_broadcast_backward(|x, y| x.matmul(y)); + } + + #[test] + fn mask_where_broadcast() { + test_ops_broadcast_backward(|x, y| x.mask_where(y.clone().equal_elem(4), y)); + } + + fn test_ops_broadcast_backward(func: F) + where + F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>, + { + let w = TestAutodiffTensor::zeros([16, 5, 5]).require_grad(); + let x = TestAutodiffTensor::zeros([4, 5, 5]).require_grad(); + + // Slice isn't a broadcastable operation, so it will fail when the previous backward pass + // of an operation that support broadcast doesn't support it during the backward pass. + let y = func(w.clone().slice([0..1]), x.clone()); + + // Will panic if broadcast isn't supported! + let grads = y.backward(); + + let w_grad = w.grad(&grads).unwrap(); + let x_grad = x.grad(&grads).unwrap(); + + assert_eq!(w_grad.shape(), w.shape()); + assert_eq!(x_grad.shape(), x.shape()); + } } diff --git a/burn-autodiff/src/tests/cat.rs b/burn-autodiff/src/tests/cat.rs index 3668fcbbd5..3a27c42135 100644 --- a/burn-autodiff/src/tests/cat.rs +++ b/burn-autodiff/src/tests/cat.rs @@ -1,76 +1,76 @@ #[burn_tensor_testgen::testgen(ad_cat)] mod tests { - use super::*; - use burn_tensor::{Data, Float}; + use super::*; + use burn_tensor::{Data, Float}; - #[test] - fn should_diff_cat() { - let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0]]).require_grad(); + #[test] + fn should_diff_cat() { + let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0]]).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - let mut tensor_1_list = Vec::new(); - let mut tensor_2_list = Vec::new(); + let mut tensor_1_list = Vec::new(); + let mut tensor_2_list = Vec::new(); - for i in 0..2 { - tensor_1_list.push(tensor_1.clone().slice([i..i + 1])); - tensor_2_list.push(tensor_2.clone().slice([i..i + 1])); - } + for i in 0..2 { + tensor_1_list.push(tensor_1.clone().slice([i..i + 1])); + tensor_2_list.push(tensor_2.clone().slice([i..i + 1])); + } - let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0); - let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0); + let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0); + let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0); - let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone()); - let grads = tensor_3_cat.backward(); + let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone()); + let grads = tensor_3_cat.backward(); - let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]); - let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]); + let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]); + let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]); - let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]); - let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]); + let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]); + let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]); - grad_1 - .clone() - .slice([0..1]) - .to_data() - .assert_approx_eq(&grad_1_slice_1.to_data(), 3); - grad_1 - .slice([1..2]) - .to_data() - .assert_approx_eq(&grad_1_slice_2.to_data(), 3); + grad_1 + .clone() + .slice([0..1]) + .to_data() + .assert_approx_eq(&grad_1_slice_1.to_data(), 3); + grad_1 + .slice([1..2]) + .to_data() + .assert_approx_eq(&grad_1_slice_2.to_data(), 3); - grad_2 - .clone() - .slice([0..1]) - .to_data() - .assert_approx_eq(&grad_2_slice_1.to_data(), 3); - grad_2 - .slice([1..2]) - .to_data() - .assert_approx_eq(&grad_2_slice_2.to_data(), 3); - } + grad_2 + .clone() + .slice([0..1]) + .to_data() + .assert_approx_eq(&grad_2_slice_1.to_data(), 3); + grad_2 + .slice([1..2]) + .to_data() + .assert_approx_eq(&grad_2_slice_2.to_data(), 3); + } - #[test] - fn should_diff_cat_more_than_1_dim() { - let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); - let tensor_2 = - TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]]).require_grad(); + #[test] + fn should_diff_cat_more_than_1_dim() { + let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad(); + let tensor_2 = + TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]]).require_grad(); - // Concat a tensor [2, 2] with another tensor [3, 2] along dim 0. - // The resulting tensor should be [5, 2] - let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0); - assert_eq!(tensor_3.dims(), [5, 2]); - let grads = tensor_3.backward(); + // Concat a tensor [2, 2] with another tensor [3, 2] along dim 0. + // The resulting tensor should be [5, 2] + let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0); + assert_eq!(tensor_3.dims(), [5, 2]); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(tensor_1.dims(), grad_1.dims()); - assert_eq!(tensor_2.dims(), grad_2.dims()); - } + assert_eq!(tensor_1.dims(), grad_1.dims()); + assert_eq!(tensor_2.dims(), grad_2.dims()); + } } diff --git a/burn-autodiff/src/tests/complex.rs b/burn-autodiff/src/tests/complex.rs index 40f7db9fa5..aa15d53213 100644 --- a/burn-autodiff/src/tests/complex.rs +++ b/burn-autodiff/src/tests/complex.rs @@ -1,81 +1,81 @@ #[burn_tensor_testgen::testgen(ad_complex)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_full_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.matmul(tensor_1.clone()); - let tensor_5 = tensor_4.mul(tensor_2.clone()); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[593., 463.0], [487.0, 539.0]]) - ); - assert_eq!( - grad_2.to_data(), - Data::from([[734.0, 294.0], [1414.0, 242.0]]) - ); - } - - #[test] - fn should_diff_full_complex_2() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.matmul(tensor_1.clone()); - let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone()); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[166.0, 110.0], [212.0, 156.0]]) - ); - assert_eq!(grad_2.to_data(), Data::from([[113.0, 141.0], [33.0, 41.0]])); - } - - #[test] - fn should_diff_full_complex_3() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.matmul(tensor_1.clone()); - let tensor_5 = tensor_4.clone().sub(tensor_2.clone()); - let tensor_6 = tensor_5.add(tensor_4); - - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[332.0, 220.0], [424.0, 312.0]]) - ); - assert_eq!(grad_2.to_data(), Data::from([[223.0, 279.0], [63.0, 79.0]])); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_full_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.matmul(tensor_1.clone()); + let tensor_5 = tensor_4.mul(tensor_2.clone()); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[593., 463.0], [487.0, 539.0]]) + ); + assert_eq!( + grad_2.to_data(), + Data::from([[734.0, 294.0], [1414.0, 242.0]]) + ); + } + + #[test] + fn should_diff_full_complex_2() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.matmul(tensor_1.clone()); + let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone()); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[166.0, 110.0], [212.0, 156.0]]) + ); + assert_eq!(grad_2.to_data(), Data::from([[113.0, 141.0], [33.0, 41.0]])); + } + + #[test] + fn should_diff_full_complex_3() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.matmul(tensor_1.clone()); + let tensor_5 = tensor_4.clone().sub(tensor_2.clone()); + let tensor_6 = tensor_5.add(tensor_4); + + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[332.0, 220.0], [424.0, 312.0]]) + ); + assert_eq!(grad_2.to_data(), Data::from([[223.0, 279.0], [63.0, 79.0]])); + } } diff --git a/burn-autodiff/src/tests/conv1d.rs b/burn-autodiff/src/tests/conv1d.rs index 55a2d473bb..3ff44aa0d0 100644 --- a/burn-autodiff/src/tests/conv1d.rs +++ b/burn-autodiff/src/tests/conv1d.rs @@ -1,240 +1,240 @@ #[burn_tensor_testgen::testgen(ad_conv1d)] mod tests { - use super::*; - use burn_tensor::{module::conv1d, ops::ConvOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv1d, ops::ConvOptions, Data, Shape}; - #[test] - fn test_conv1d_basic() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[14., 24., 24., 18.], [26., 42., 42., 30.]], - [[14., 24., 24., 18.], [26., 42., 42., 30.]], - ]), - weight: TestTensor::from_floats([ - [[30., 44., 36.], [54., 76., 60.]], - [[30., 44., 36.], [54., 76., 60.]], - ]), - bias: TestTensor::from_floats([8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_basic() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[14., 24., 24., 18.], [26., 42., 42., 30.]], + [[14., 24., 24., 18.], [26., 42., 42., 30.]], + ]), + weight: TestTensor::from_floats([ + [[30., 44., 36.], [54., 76., 60.]], + [[30., 44., 36.], [54., 76., 60.]], + ]), + bias: TestTensor::from_floats([8., 8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_different_channels() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 3, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[39., 63., 63., 45.], [57., 90., 90., 63.]], - [[39., 63., 63., 45.], [57., 90., 90., 63.]], - ]), - weight: TestTensor::from_floats([ - [[30., 44., 36.], [54., 76., 60.]], - [[30., 44., 36.], [54., 76., 60.]], - [[30., 44., 36.], [54., 76., 60.]], - ]), - bias: TestTensor::from_floats([8., 8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_different_channels() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 3, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[39., 63., 63., 45.], [57., 90., 90., 63.]], + [[39., 63., 63., 45.], [57., 90., 90., 63.]], + ]), + weight: TestTensor::from_floats([ + [[30., 44., 36.], [54., 76., 60.]], + [[30., 44., 36.], [54., 76., 60.]], + [[30., 44., 36.], [54., 76., 60.]], + ]), + bias: TestTensor::from_floats([8., 8., 8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_with_padding() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 2, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[24., 24., 24., 24.], [42., 42., 42., 42.]], - [[24., 24., 24., 24.], [42., 42., 42., 42.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [76., 76., 76.]], - [[44., 44., 44.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([12., 12.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_with_padding() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 2, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[24., 24., 24., 24.], [42., 42., 42., 42.]], + [[24., 24., 24., 24.], [42., 42., 42., 42.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [76., 76., 76.]], + [[44., 44., 44.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([12., 12.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_with_stride() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 2, - dilation: 1, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[8., 16., 8., 10.], [14., 28., 14., 16.]], - [[8., 16., 8., 10.], [14., 28., 14., 16.]], - ]), - weight: TestTensor::from_floats([ - [[10., 20., 24.], [18., 36., 40.]], - [[10., 20., 24.], [18., 36., 40.]], - ]), - bias: TestTensor::from_floats([4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_with_stride() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[8., 16., 8., 10.], [14., 28., 14., 16.]], + [[8., 16., 8., 10.], [14., 28., 14., 16.]], + ]), + weight: TestTensor::from_floats([ + [[10., 20., 24.], [18., 36., 40.]], + [[10., 20., 24.], [18., 36., 40.]], + ]), + bias: TestTensor::from_floats([4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_dilation() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 2, - groups: 1, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[6., 8., 8., 10.], [12., 14., 14., 16.]], - [[6., 8., 8., 10.], [12., 14., 14., 16.]], - ]), - weight: TestTensor::from_floats([ - [[8., 22., 14.], [16., 38., 22.]], - [[8., 22., 14.], [16., 38., 22.]], - ]), - bias: TestTensor::from_floats([4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_dilation() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 2, + groups: 1, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[6., 8., 8., 10.], [12., 14., 14., 16.]], + [[6., 8., 8., 10.], [12., 14., 14., 16.]], + ]), + weight: TestTensor::from_floats([ + [[8., 22., 14.], [16., 38., 22.]], + [[8., 22., 14.], [16., 38., 22.]], + ]), + bias: TestTensor::from_floats([4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv1d_groups() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 2, - length: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[1., 3., 3., 3.], [7., 12., 12., 9.]], - [[1., 3., 3., 3.], [7., 12., 12., 9.]], - ]), - weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]]), - bias: TestTensor::from_floats([8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv1d_groups() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 2, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[1., 3., 3., 3.], [7., 12., 12., 9.]], + [[1., 3., 3., 3.], [7., 12., 12., 9.]], + ]), + weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]]), + bias: TestTensor::from_floats([8., 8.]), + }; + test.assert_grads(grads); + } - struct Conv1dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size: usize, - padding: usize, - stride: usize, - dilation: usize, - groups: usize, - length: usize, - } + struct Conv1dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size: usize, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + length: usize, + } - struct Grads { - x: TestTensor<3>, - weight: TestTensor<3>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<3>, + weight: TestTensor<3>, + bias: TestTensor<1>, + } - impl Conv1dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size, - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); + impl Conv1dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size, + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); - let output = conv1d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), - ); - let grads = output.backward(); + let output = conv1d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/conv2d.rs b/burn-autodiff/src/tests/conv2d.rs index 9afe7a86f4..7e6d2a801a 100644 --- a/burn-autodiff/src/tests/conv2d.rs +++ b/burn-autodiff/src/tests/conv2d.rs @@ -1,750 +1,750 @@ #[burn_tensor_testgen::testgen(ad_conv2d)] mod tests { - use super::*; - use burn_tensor::{module::conv2d, ops::ConvOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv2d, ops::ConvOptions, Data, Shape}; - #[test] - fn test_conv2d_basic() { - let test = Conv2dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [88., 138., 138., 96.], - [150., 234., 234., 162.], - [150., 234., 234., 162.], - [112., 174., 174., 120.], - ], - [ - [160., 246., 246., 168.], - [258., 396., 396., 270.], - [258., 396., 396., 270.], - [184., 282., 282., 192.], - ], - ], - [ - [ - [88., 138., 138., 96.], - [150., 234., 234., 162.], - [150., 234., 234., 162.], - [112., 174., 174., 120.], - ], - [ - [160., 246., 246., 168.], - [258., 396., 396., 270.], - [258., 396., 396., 270.], - [184., 282., 282., 192.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - ]), - bias: TestTensor::from_floats([32., 32.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_basic() { + let test = Conv2dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [88., 138., 138., 96.], + [150., 234., 234., 162.], + [150., 234., 234., 162.], + [112., 174., 174., 120.], + ], + [ + [160., 246., 246., 168.], + [258., 396., 396., 270.], + [258., 396., 396., 270.], + [184., 282., 282., 192.], + ], + ], + [ + [ + [88., 138., 138., 96.], + [150., 234., 234., 162.], + [150., 234., 234., 162.], + [112., 174., 174., 120.], + ], + [ + [160., 246., 246., 168.], + [258., 396., 396., 270.], + [258., 396., 396., 270.], + [184., 282., 282., 192.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + ]), + bias: TestTensor::from_floats([32., 32.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_channels() { - let test = Conv2dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 3, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [240., 369., 369., 252.], - [387., 594., 594., 405.], - [387., 594., 594., 405.], - [276., 423., 423., 288.], - ], - [ - [348., 531., 531., 360.], - [549., 837., 837., 567.], - [549., 837., 837., 567.], - [384., 585., 585., 396.], - ], - ], - [ - [ - [240., 369., 369., 252.], - [387., 594., 594., 405.], - [387., 594., 594., 405.], - [276., 423., 423., 288.], - ], - [ - [348., 531., 531., 360.], - [549., 837., 837., 567.], - [549., 837., 837., 567.], - [384., 585., 585., 396.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - [ - [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], - [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], - ], - ]), - bias: TestTensor::from_floats([32., 32., 32.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_channels() { + let test = Conv2dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [240., 369., 369., 252.], + [387., 594., 594., 405.], + [387., 594., 594., 405.], + [276., 423., 423., 288.], + ], + [ + [348., 531., 531., 360.], + [549., 837., 837., 567.], + [549., 837., 837., 567.], + [384., 585., 585., 396.], + ], + ], + [ + [ + [240., 369., 369., 252.], + [387., 594., 594., 405.], + [387., 594., 594., 405.], + [276., 423., 423., 288.], + ], + [ + [348., 531., 531., 360.], + [549., 837., 837., 567.], + [549., 837., 837., 567.], + [384., 585., 585., 396.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + [ + [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], + [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], + ], + ]), + bias: TestTensor::from_floats([32., 32., 32.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_kernel_size() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [116., 180., 192., 132.], - [198., 306., 324., 222.], - [198., 306., 324., 222.], - [148., 228., 240., 164.], - ], - [ - [212., 324., 336., 228.], - [342., 522., 540., 366.], - [342., 522., 540., 366.], - [244., 372., 384., 260.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [ - [27., 45., 54., 39.], - [52., 84., 96., 68.], - [51., 81., 90., 63.], - ], - [ - [123., 189., 198., 135.], - [180., 276., 288., 196.], - [147., 225., 234., 159.], - ], - ], - [ - [ - [27., 45., 54., 39.], - [52., 84., 96., 68.], - [51., 81., 90., 63.], - ], - [ - [123., 189., 198., 135.], - [180., 276., 288., 196.], - [147., 225., 234., 159.], - ], - ], - ]), - bias: TestTensor::from_floats([12., 12.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_kernel_size() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [116., 180., 192., 132.], + [198., 306., 324., 222.], + [198., 306., 324., 222.], + [148., 228., 240., 164.], + ], + [ + [212., 324., 336., 228.], + [342., 522., 540., 366.], + [342., 522., 540., 366.], + [244., 372., 384., 260.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [ + [27., 45., 54., 39.], + [52., 84., 96., 68.], + [51., 81., 90., 63.], + ], + [ + [123., 189., 198., 135.], + [180., 276., 288., 196.], + [147., 225., 234., 159.], + ], + ], + [ + [ + [27., 45., 54., 39.], + [52., 84., 96., 68.], + [51., 81., 90., 63.], + ], + [ + [123., 189., 198., 135.], + [180., 276., 288., 196.], + [147., 225., 234., 159.], + ], + ], + ]), + bias: TestTensor::from_floats([12., 12.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_padding() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [138., 138., 138., 138.], - [234., 234., 234., 234.], - [234., 234., 234., 234.], - [174., 174., 174., 174.], - ], - [ - [246., 246., 246., 246.], - [396., 396., 396., 396.], - [396., 396., 396., 396.], - [282., 282., 282., 282.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], - [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], - ], - [ - [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], - [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], - ], - ]), - bias: TestTensor::from_floats([24., 24.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_padding() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [138., 138., 138., 138.], + [234., 234., 234., 234.], + [234., 234., 234., 234.], + [174., 174., 174., 174.], + ], + [ + [246., 246., 246., 246.], + [396., 396., 396., 396.], + [396., 396., 396., 396.], + [282., 282., 282., 282.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], + [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], + ], + [ + [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], + [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], + ], + ]), + bias: TestTensor::from_floats([24., 24.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_width() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 5, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [88., 138., 138., 138., 96.], - [150., 234., 234., 234., 162.], - [150., 234., 234., 234., 162.], - [112., 174., 174., 174., 120.], - ], - [ - [160., 246., 246., 246., 168.], - [258., 396., 396., 396., 270.], - [258., 396., 396., 396., 270.], - [184., 282., 282., 282., 192.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], - [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], - ], - [ - [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], - [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], - ], - ]), - bias: TestTensor::from_floats([20., 20.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_width() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 5, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [88., 138., 138., 138., 96.], + [150., 234., 234., 234., 162.], + [150., 234., 234., 234., 162.], + [112., 174., 174., 174., 120.], + ], + [ + [160., 246., 246., 246., 168.], + [258., 396., 396., 396., 270.], + [258., 396., 396., 396., 270.], + [184., 282., 282., 282., 192.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], + [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], + ], + [ + [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], + [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], + ], + ]), + bias: TestTensor::from_floats([20., 20.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_stride_2() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 2, - stride_2: 2, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 6, - width: 6, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [26., 52., 26., 52., 26., 28.], - [52., 104., 52., 104., 52., 56.], - [26., 52., 26., 52., 26., 28.], - [52., 104., 52., 104., 52., 56.], - [26., 52., 26., 52., 26., 28.], - [32., 64., 32., 64., 32., 34.], - ], - [ - [44., 88., 44., 88., 44., 46.], - [88., 176., 88., 176., 88., 92.], - [44., 88., 44., 88., 44., 46.], - [88., 176., 88., 176., 88., 92.], - [44., 88., 44., 88., 44., 46.], - [50., 100., 50., 100., 50., 52.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], - [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], - ], - [ - [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], - [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], - ], - ]), - bias: TestTensor::from_floats([9., 9.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_stride_2() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 2, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 6, + width: 6, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [26., 52., 26., 52., 26., 28.], + [52., 104., 52., 104., 52., 56.], + [26., 52., 26., 52., 26., 28.], + [52., 104., 52., 104., 52., 56.], + [26., 52., 26., 52., 26., 28.], + [32., 64., 32., 64., 32., 34.], + ], + [ + [44., 88., 44., 88., 44., 46.], + [88., 176., 88., 176., 88., 92.], + [44., 88., 44., 88., 44., 46.], + [88., 176., 88., 176., 88., 92.], + [44., 88., 44., 88., 44., 46.], + [50., 100., 50., 100., 50., 52.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], + [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], + ], + [ + [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], + [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], + ], + ]), + bias: TestTensor::from_floats([9., 9.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_stride() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 3, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 8, - width: 8, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [50., 78., 78., 78., 78., 78., 78., 54.], - [62., 96., 96., 96., 96., 96., 96., 66.], - [38., 60., 60., 60., 60., 60., 60., 42.], - [50., 78., 78., 78., 78., 78., 78., 54.], - [62., 96., 96., 96., 96., 96., 96., 66.], - [38., 60., 60., 60., 60., 60., 60., 42.], - [50., 78., 78., 78., 78., 78., 78., 54.], - [62., 96., 96., 96., 96., 96., 96., 66.], - ], - [ - [86., 132., 132., 132., 132., 132., 132., 90.], - [98., 150., 150., 150., 150., 150., 150., 102.], - [74., 114., 114., 114., 114., 114., 114., 78.], - [86., 132., 132., 132., 132., 132., 132., 90.], - [98., 150., 150., 150., 150., 150., 150., 102.], - [74., 114., 114., 114., 114., 114., 114., 78.], - [86., 132., 132., 132., 132., 132., 132., 90.], - [98., 150., 150., 150., 150., 150., 150., 102.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], - [ - [1330., 1528., 1344.], - [1911., 2196., 1932.], - [2079., 2388., 2100.], - ], - ], - [ - [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], - [ - [1330., 1528., 1344.], - [1911., 2196., 1932.], - [2079., 2388., 2100.], - ], - ], - ]), - bias: TestTensor::from_floats([24., 24.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_stride() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 3, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 8, + width: 8, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [50., 78., 78., 78., 78., 78., 78., 54.], + [62., 96., 96., 96., 96., 96., 96., 66.], + [38., 60., 60., 60., 60., 60., 60., 42.], + [50., 78., 78., 78., 78., 78., 78., 54.], + [62., 96., 96., 96., 96., 96., 96., 66.], + [38., 60., 60., 60., 60., 60., 60., 42.], + [50., 78., 78., 78., 78., 78., 78., 54.], + [62., 96., 96., 96., 96., 96., 96., 66.], + ], + [ + [86., 132., 132., 132., 132., 132., 132., 90.], + [98., 150., 150., 150., 150., 150., 150., 102.], + [74., 114., 114., 114., 114., 114., 114., 78.], + [86., 132., 132., 132., 132., 132., 132., 90.], + [98., 150., 150., 150., 150., 150., 150., 102.], + [74., 114., 114., 114., 114., 114., 114., 78.], + [86., 132., 132., 132., 132., 132., 132., 90.], + [98., 150., 150., 150., 150., 150., 150., 102.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], + [ + [1330., 1528., 1344.], + [1911., 2196., 1932.], + [2079., 2388., 2100.], + ], + ], + [ + [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], + [ + [1330., 1528., 1344.], + [1911., 2196., 1932.], + [2079., 2388., 2100.], + ], + ], + ]), + bias: TestTensor::from_floats([24., 24.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_dilation_2() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 2, - dilation_2: 2, - groups: 1, - height: 6, - width: 6, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [18., 38., 38., 42., 42., 22.], - [42., 88., 88., 96., 96., 50.], - [42., 88., 88., 96., 96., 50.], - [54., 112., 112., 120., 120., 62.], - [54., 112., 112., 120., 120., 62.], - [30., 62., 62., 66., 66., 34.], - ], - [ - [36., 74., 74., 78., 78., 40.], - [78., 160., 160., 168., 168., 86.], - [78., 160., 160., 168., 168., 86.], - [90., 184., 184., 192., 192., 98.], - [90., 184., 184., 192., 192., 98.], - [48., 98., 98., 102., 102., 52.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], - [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], - ], - [ - [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], - [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], - ], - ]), - bias: TestTensor::from_floats([16., 16.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_dilation_2() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 2, + dilation_2: 2, + groups: 1, + height: 6, + width: 6, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [18., 38., 38., 42., 42., 22.], + [42., 88., 88., 96., 96., 50.], + [42., 88., 88., 96., 96., 50.], + [54., 112., 112., 120., 120., 62.], + [54., 112., 112., 120., 120., 62.], + [30., 62., 62., 66., 66., 34.], + ], + [ + [36., 74., 74., 78., 78., 40.], + [78., 160., 160., 168., 168., 86.], + [78., 160., 160., 168., 168., 86.], + [90., 184., 184., 192., 192., 98.], + [90., 184., 184., 192., 192., 98.], + [48., 98., 98., 102., 102., 52.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], + [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], + ], + [ + [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], + [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], + ], + ]), + bias: TestTensor::from_floats([16., 16.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_different_dilation() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 2, - dilation_2: 3, - groups: 1, - height: 6, - width: 6, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [18., 0., 20., 20., 0., 22.], - [42., 0., 46., 46., 0., 50.], - [42., 0., 46., 46., 0., 50.], - [54., 0., 58., 58., 0., 62.], - [54., 0., 58., 58., 0., 62.], - [30., 0., 32., 32., 0., 34.], - ], - [ - [36., 0., 38., 38., 0., 40.], - [78., 0., 82., 82., 0., 86.], - [78., 0., 82., 82., 0., 86.], - [90., 0., 94., 94., 0., 98.], - [90., 0., 94., 94., 0., 98.], - [48., 0., 50., 50., 0., 52.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], - [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], - ], - [ - [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], - [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], - ], - ]), - bias: TestTensor::from_floats([8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_different_dilation() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 2, + dilation_2: 3, + groups: 1, + height: 6, + width: 6, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [18., 0., 20., 20., 0., 22.], + [42., 0., 46., 46., 0., 50.], + [42., 0., 46., 46., 0., 50.], + [54., 0., 58., 58., 0., 62.], + [54., 0., 58., 58., 0., 62.], + [30., 0., 32., 32., 0., 34.], + ], + [ + [36., 0., 38., 38., 0., 40.], + [78., 0., 82., 82., 0., 86.], + [78., 0., 82., 82., 0., 86.], + [90., 0., 94., 94., 0., 98.], + [90., 0., 94., 94., 0., 98.], + [48., 0., 50., 50., 0., 52.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], + [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], + ], + [ + [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], + [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], + ], + ]), + bias: TestTensor::from_floats([8., 8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_groups() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 5, - width: 5, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [0., 1., 3., 3., 2.], - [3., 8., 15., 12., 7.], - [9., 21., 36., 27., 15.], - [9., 20., 33., 24., 13.], - [6., 13., 21., 15., 8.], - ], - [ - [9., 19., 30., 21., 11.], - [21., 44., 69., 48., 25.], - [36., 75., 117., 81., 42.], - [27., 56., 87., 60., 31.], - [15., 31., 48., 33., 17.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]], - [[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]], - ]), - bias: TestTensor::from_floats([9., 9.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_groups() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 5, + width: 5, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [0., 1., 3., 3., 2.], + [3., 8., 15., 12., 7.], + [9., 21., 36., 27., 15.], + [9., 20., 33., 24., 13.], + [6., 13., 21., 15., 8.], + ], + [ + [9., 19., 30., 21., 11.], + [21., 44., 69., 48., 25.], + [36., 75., 117., 81., 42.], + [27., 56., 87., 60., 31.], + [15., 31., 48., 33., 17.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]], + [[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]], + ]), + bias: TestTensor::from_floats([9., 9.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_groups_different_channels() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 3, - channels_out: 6, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 3, - height: 4, - width: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [9., 20., 24., 13.], - [24., 52., 60., 32.], - [36., 76., 84., 44.], - [21., 44., 48., 25.], - ], - [ - [45., 92., 96., 49.], - [96., 196., 204., 104.], - [108., 220., 228., 116.], - [57., 116., 120., 61.], - ], - [ - [81., 164., 168., 85.], - [168., 340., 348., 176.], - [180., 364., 372., 188.], - [93., 188., 192., 97.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], - [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], - [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], - [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], - [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], - [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], - ]), - bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_groups_different_channels() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 6, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 3, + height: 4, + width: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [9., 20., 24., 13.], + [24., 52., 60., 32.], + [36., 76., 84., 44.], + [21., 44., 48., 25.], + ], + [ + [45., 92., 96., 49.], + [96., 196., 204., 104.], + [108., 220., 228., 116.], + [57., 116., 120., 61.], + ], + [ + [81., 164., 168., 85.], + [168., 340., 348., 176.], + [180., 364., 372., 188.], + [93., 188., 192., 97.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], + [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], + [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], + [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], + [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], + [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], + ]), + bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv2d_complex() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 3, - kernel_size_1: 2, - kernel_size_2: 3, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - dilation_1: 2, - dilation_2: 3, - groups: 1, - height: 4, - width: 5, - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [36., 39., 0., 39., 42.], - [81., 87., 0., 87., 93.], - [81., 87., 0., 87., 93.], - [45., 48., 0., 48., 51.], - ], - [ - [54., 57., 0., 57., 60.], - [117., 123., 0., 123., 129.], - [117., 123., 0., 123., 129.], - [63., 66., 0., 66., 69.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[15., 42., 27.], [30., 72., 42.]], - [[75., 162., 87.], [90., 192., 102.]], - ], - [ - [[15., 42., 27.], [30., 72., 42.]], - [[75., 162., 87.], [90., 192., 102.]], - ], - [ - [[15., 42., 27.], [30., 72., 42.]], - [[75., 162., 87.], [90., 192., 102.]], - ], - ]), - bias: TestTensor::from_floats([8., 8., 8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv2d_complex() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 2, + kernel_size_2: 3, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + dilation_1: 2, + dilation_2: 3, + groups: 1, + height: 4, + width: 5, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [36., 39., 0., 39., 42.], + [81., 87., 0., 87., 93.], + [81., 87., 0., 87., 93.], + [45., 48., 0., 48., 51.], + ], + [ + [54., 57., 0., 57., 60.], + [117., 123., 0., 123., 129.], + [117., 123., 0., 123., 129.], + [63., 66., 0., 66., 69.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[15., 42., 27.], [30., 72., 42.]], + [[75., 162., 87.], [90., 192., 102.]], + ], + [ + [[15., 42., 27.], [30., 72., 42.]], + [[75., 162., 87.], [90., 192., 102.]], + ], + [ + [[15., 42., 27.], [30., 72., 42.]], + [[75., 162., 87.], [90., 192., 102.]], + ], + ]), + bias: TestTensor::from_floats([8., 8., 8.]), + }; + test.assert_grads(grads); + } - struct Conv2dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - dilation_1: usize, - dilation_2: usize, - groups: usize, - height: usize, - width: usize, - } + struct Conv2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + groups: usize, + height: usize, + width: usize, + } - struct Grads { - x: TestTensor<4>, - weight: TestTensor<4>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<4>, + weight: TestTensor<4>, + bias: TestTensor<1>, + } - impl Conv2dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size_1, - self.kernel_size_2, - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = conv2d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvOptions::new( - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - [self.dilation_1, self.dilation_2], - self.groups, - ), - ); - let grads = output.backward(); + impl Conv2dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = conv2d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], + self.groups, + ), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/conv_transpose1d.rs b/burn-autodiff/src/tests/conv_transpose1d.rs index 948e51ad15..4f05a447d8 100644 --- a/burn-autodiff/src/tests/conv_transpose1d.rs +++ b/burn-autodiff/src/tests/conv_transpose1d.rs @@ -1,253 +1,253 @@ #[burn_tensor_testgen::testgen(ad_conv_transpose1d)] mod tests { - use super::*; - use burn_tensor::{module::conv_transpose1d, ops::ConvTransposeOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv_transpose1d, ops::ConvTransposeOptions, Data, Shape}; - #[test] - fn test_conv_transpose1d_basic() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], - [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], - ]), - weight: TestTensor::from_floats([ - [[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]], - [[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]], - ]), - bias: TestTensor::from_floats([12., 12.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_basic() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], + [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], + ]), + weight: TestTensor::from_floats([ + [[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]], + [[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]], + ]), + bias: TestTensor::from_floats([12., 12.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_padding() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 2, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[7., 12., 8., 3.], [19., 36., 32., 15.]], - [[7., 12., 8., 3.], [19., 36., 32., 15.]], - ]), - weight: TestTensor::from_floats([ - [[26., 22., 18.], [26., 22., 18.]], - [[42., 38., 34.], [42., 38., 34.]], - ]), - bias: TestTensor::from_floats([4., 4.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_padding() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 2, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[7., 12., 8., 3.], [19., 36., 32., 15.]], + [[7., 12., 8., 3.], [19., 36., 32., 15.]], + ]), + weight: TestTensor::from_floats([ + [[26., 22., 18.], [26., 22., 18.]], + [[42., 38., 34.], [42., 38., 34.]], + ]), + bias: TestTensor::from_floats([4., 4.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_stride() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 0, - stride: 2, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [44., 44., 44.]], - [[76., 76., 76.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([18., 18.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_stride() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 0, + stride: 2, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [44., 44., 44.]], + [[76., 76., 76.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([18., 18.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_stride_padding_out() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 1, - stride: 2, - dilation: 1, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [44., 44., 44.]], - [[76., 76., 76.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([20., 20.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_stride_padding_out() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 1, + stride: 2, + dilation: 1, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [44., 44., 44.]], + [[76., 76., 76.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([20., 20.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_dilation() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: 3, - padding: 0, - padding_out: 0, - stride: 1, - dilation: 2, - groups: 1, - size: 4, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - [[15., 15., 15., 15.], [51., 51., 51., 51.]], - ]), - weight: TestTensor::from_floats([ - [[44., 44., 44.], [44., 44., 44.]], - [[76., 76., 76.], [76., 76., 76.]], - ]), - bias: TestTensor::from_floats([16., 16.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_dilation() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: 3, + padding: 0, + padding_out: 0, + stride: 1, + dilation: 2, + groups: 1, + size: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + [[15., 15., 15., 15.], [51., 51., 51., 51.]], + ]), + weight: TestTensor::from_floats([ + [[44., 44., 44.], [44., 44., 44.]], + [[76., 76., 76.], [76., 76., 76.]], + ]), + bias: TestTensor::from_floats([16., 16.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose1d_complex() { - let test = ConvTranspose1dTestCase { - batch_size: 2, - channels: [2, 4], - kernel_size: 3, - padding: 1, - padding_out: 1, - stride: 2, - dilation: 2, - groups: 2, - size: 8, - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], - [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], - ], - [ - [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], - [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], - ], - ]), - weight: TestTensor::from_floats([ - [[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]], - [[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]], - ]), - bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose1d_complex() { + let test = ConvTranspose1dTestCase { + batch_size: 2, + channels: [2, 4], + kernel_size: 3, + padding: 1, + padding_out: 1, + stride: 2, + dilation: 2, + groups: 2, + size: 8, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], + [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], + ], + [ + [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], + [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], + ], + ]), + weight: TestTensor::from_floats([ + [[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]], + [[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]], + ]), + bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0]), + }; + test.assert_grads(grads); + } - struct ConvTranspose1dTestCase { - batch_size: usize, - channels: [usize; 2], - kernel_size: usize, - padding: usize, - padding_out: usize, - stride: usize, - dilation: usize, - groups: usize, - size: usize, - } + struct ConvTranspose1dTestCase { + batch_size: usize, + channels: [usize; 2], + kernel_size: usize, + padding: usize, + padding_out: usize, + stride: usize, + dilation: usize, + groups: usize, + size: usize, + } - struct Grads { - x: TestTensor<3>, - weight: TestTensor<3>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<3>, + weight: TestTensor<3>, + bias: TestTensor<1>, + } - impl ConvTranspose1dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]); - let shape_weight = Shape::new([ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size, - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels[1]) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = conv_transpose1d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvTransposeOptions::new( - [self.stride], - [self.padding], - [self.padding_out], - [self.dilation], - self.groups, - ), - ); - let grads = output.backward(); + impl ConvTranspose1dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]); + let shape_weight = Shape::new([ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size, + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels[1]) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = conv_transpose1d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvTransposeOptions::new( + [self.stride], + [self.padding], + [self.padding_out], + [self.dilation], + self.groups, + ), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/conv_transpose2d.rs b/burn-autodiff/src/tests/conv_transpose2d.rs index bedc9d617d..138afbf5fe 100644 --- a/burn-autodiff/src/tests/conv_transpose2d.rs +++ b/burn-autodiff/src/tests/conv_transpose2d.rs @@ -1,643 +1,647 @@ #[burn_tensor_testgen::testgen(ad_conv_transpose2d)] mod tests { - use super::*; - use burn_tensor::{module::conv_transpose2d, ops::ConvTransposeOptions, Data, Shape}; + use super::*; + use burn_tensor::{module::conv_transpose2d, ops::ConvTransposeOptions, Data, Shape}; - #[test] - fn test_conv_transpose2d_basic() { - let test = ConvTranspose2dTestCase { - batch_size: 2, - channels: [2, 2], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - ], - [ - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - ], - ], - [ - [ - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - [153., 153., 153., 153.], - ], - [ - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - [477., 477., 477., 477.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], - [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], - ], - [ - [ - [1264., 1264., 1264.], - [1264., 1264., 1264.], - [1264., 1264., 1264.], - ], - [ - [1264., 1264., 1264.], - [1264., 1264., 1264.], - [1264., 1264., 1264.], - ], - ], - ]), - bias: TestTensor::from_floats([72., 72.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_basic() { + let test = ConvTranspose2dTestCase { + batch_size: 2, + channels: [2, 2], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + ], + [ + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + ], + ], + [ + [ + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + [153., 153., 153., 153.], + ], + [ + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + [477., 477., 477., 477.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], + [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], + ], + [ + [ + [1264., 1264., 1264.], + [1264., 1264., 1264.], + [1264., 1264., 1264.], + ], + [ + [1264., 1264., 1264.], + [1264., 1264., 1264.], + [1264., 1264., 1264.], + ], + ], + ]), + bias: TestTensor::from_floats([72., 72.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_padding() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [1, 2], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [13., 24., 20., 9.], - [15., 27., 21., 9.], - [15., 27., 21., 9.], - [7., 12., 8., 3.], - ]]]), - weight: TestTensor::from_floats([[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]]), - bias: TestTensor::from_floats([8.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_padding() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [1, 2], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [13., 24., 20., 9.], + [15., 27., 21., 9.], + [15., 27., 21., 9.], + [7., 12., 8., 3.], + ]]]), + weight: TestTensor::from_floats([[[ + [63., 57., 51.], + [68., 60., 52.], + [39., 33., 27.], + ]]]), + bias: TestTensor::from_floats([8.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_stride() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [2, 3], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ]]]), - weight: TestTensor::from_floats([[[ - [120., 120., 120.], - [120., 120., 120.], - [120., 120., 120.], - ]]]), - bias: TestTensor::from_floats([108.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_stride() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [2, 3], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ]]]), + weight: TestTensor::from_floats([[[ + [120., 120., 120.], + [120., 120., 120.], + [120., 120., 120.], + ]]]), + bias: TestTensor::from_floats([108.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_stride_padding_out() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [1, 2], - stride: [2, 3], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ]]]), - weight: TestTensor::from_floats([[[ - [120., 120., 120.], - [120., 120., 120.], - [120., 120., 120.], - ]]]), - bias: TestTensor::from_floats([140.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_stride_padding_out() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [1, 2], + stride: [2, 3], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ]]]), + weight: TestTensor::from_floats([[[ + [120., 120., 120.], + [120., 120., 120.], + [120., 120., 120.], + ]]]), + bias: TestTensor::from_floats([140.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_dilation() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [2, 3], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ]]]), - weight: TestTensor::from_floats([[[ - [120., 120., 120.], - [120., 120., 120.], - [120., 120., 120.], - ]]]), - bias: TestTensor::from_floats([80.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_dilation() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [2, 3], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ]]]), + weight: TestTensor::from_floats([[[ + [120., 120., 120.], + [120., 120., 120.], + [120., 120., 120.], + ]]]), + bias: TestTensor::from_floats([80.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_channels() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [2, 3], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [351., 351., 351., 351.], - [351., 351., 351., 351.], - [351., 351., 351., 351.], - [351., 351., 351., 351.], - ], - [ - [1080., 1080., 1080., 1080.], - [1080., 1080., 1080., 1080.], - [1080., 1080., 1080., 1080.], - [1080., 1080., 1080., 1080.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], - [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], - [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], - ], - [ - [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], - [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], - [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], - ], - ]), - bias: TestTensor::from_floats([36., 36., 36.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_channels() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [2, 3], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [351., 351., 351., 351.], + [351., 351., 351., 351.], + [351., 351., 351., 351.], + [351., 351., 351., 351.], + ], + [ + [1080., 1080., 1080., 1080.], + [1080., 1080., 1080., 1080.], + [1080., 1080., 1080., 1080.], + [1080., 1080., 1080., 1080.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], + [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], + [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], + ], + [ + [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], + [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], + [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], + ], + ]), + bias: TestTensor::from_floats([36., 36., 36.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_kernel_size() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [1, 1], - kernel_size: [3, 5], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 1, - size: [6, 6], - }; - let grads = Grads { - x: TestTensor::from_floats([[[ - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - [105., 105., 105., 105., 105., 105.], - ]]]), - weight: TestTensor::from_floats([[[ - [630., 630., 630., 630., 630.], - [630., 630., 630., 630., 630.], - [630., 630., 630., 630., 630.], - ]]]), - bias: TestTensor::from_floats([80.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_kernel_size() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [1, 1], + kernel_size: [3, 5], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 1, + size: [6, 6], + }; + let grads = Grads { + x: TestTensor::from_floats([[[ + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + [105., 105., 105., 105., 105., 105.], + ]]]), + weight: TestTensor::from_floats([[[ + [630., 630., 630., 630., 630.], + [630., 630., 630., 630., 630.], + [630., 630., 630., 630., 630.], + ]]]), + bias: TestTensor::from_floats([80.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_groups() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [2, 2], - kernel_size: [3, 3], - padding: [0, 0], - padding_out: [0, 0], - stride: [1, 1], - dilation: [1, 1], - groups: 2, - size: [4, 4], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - [36., 36., 36., 36.], - ], - [ - [117., 117., 117., 117.], - [117., 117., 117., 117.], - [117., 117., 117., 117.], - [117., 117., 117., 117.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]], - [[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]], - ]), - bias: TestTensor::from_floats([36., 36.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_groups() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [2, 2], + kernel_size: [3, 3], + padding: [0, 0], + padding_out: [0, 0], + stride: [1, 1], + dilation: [1, 1], + groups: 2, + size: [4, 4], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + [36., 36., 36., 36.], + ], + [ + [117., 117., 117., 117.], + [117., 117., 117., 117.], + [117., 117., 117., 117.], + [117., 117., 117., 117.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]], + [[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]], + ]), + bias: TestTensor::from_floats([36., 36.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_complex_no_groups() { - let test = ConvTranspose2dTestCase { - batch_size: 2, - channels: [2, 3], - kernel_size: [3, 5], - padding: [1, 2], - padding_out: [1, 2], - stride: [2, 3], - dilation: [2, 3], - groups: 1, - size: [6, 8], - }; - let grads = Grads { - x: TestTensor::from_floats([ - [ - [ - [600., 735., 735., 735., 735., 735., 735., 735.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - ], - [ - [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - ], - ], - [ - [ - [600., 735., 735., 735., 735., 735., 735., 735.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - [810., 990., 990., 990., 990., 990., 990., 990.], - ], - [ - [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], - ], - ], - ]), - weight: TestTensor::from_floats([ - [ - [ - [5320., 6040., 6040., 6040., 6040.], - [6048., 6864., 6864., 6864., 6864.], - [6048., 6864., 6864., 6864., 6864.], - ], - [ - [5320., 6040., 6040., 6040., 6040.], - [6048., 6864., 6864., 6864., 6864.], - [6048., 6864., 6864., 6864., 6864.], - ], - [ - [5320., 6040., 6040., 6040., 6040.], - [6048., 6864., 6864., 6864., 6864.], - [6048., 6864., 6864., 6864., 6864.], - ], - ], - [ - [ - [8680., 9880., 9880., 9880., 9880.], - [10080., 11472., 11472., 11472., 11472.], - [10080., 11472., 11472., 11472., 11472.], - ], - [ - [8680., 9880., 9880., 9880., 9880.], - [10080., 11472., 11472., 11472., 11472.], - [10080., 11472., 11472., 11472., 11472.], - ], - [ - [8680., 9880., 9880., 9880., 9880.], - [10080., 11472., 11472., 11472., 11472.], - [10080., 11472., 11472., 11472., 11472.], - ], - ], - ]), - bias: TestTensor::from_floats([896., 896., 896.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_complex_no_groups() { + let test = ConvTranspose2dTestCase { + batch_size: 2, + channels: [2, 3], + kernel_size: [3, 5], + padding: [1, 2], + padding_out: [1, 2], + stride: [2, 3], + dilation: [2, 3], + groups: 1, + size: [6, 8], + }; + let grads = Grads { + x: TestTensor::from_floats([ + [ + [ + [600., 735., 735., 735., 735., 735., 735., 735.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + ], + [ + [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + ], + ], + [ + [ + [600., 735., 735., 735., 735., 735., 735., 735.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + [810., 990., 990., 990., 990., 990., 990., 990.], + ], + [ + [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], + ], + ], + ]), + weight: TestTensor::from_floats([ + [ + [ + [5320., 6040., 6040., 6040., 6040.], + [6048., 6864., 6864., 6864., 6864.], + [6048., 6864., 6864., 6864., 6864.], + ], + [ + [5320., 6040., 6040., 6040., 6040.], + [6048., 6864., 6864., 6864., 6864.], + [6048., 6864., 6864., 6864., 6864.], + ], + [ + [5320., 6040., 6040., 6040., 6040.], + [6048., 6864., 6864., 6864., 6864.], + [6048., 6864., 6864., 6864., 6864.], + ], + ], + [ + [ + [8680., 9880., 9880., 9880., 9880.], + [10080., 11472., 11472., 11472., 11472.], + [10080., 11472., 11472., 11472., 11472.], + ], + [ + [8680., 9880., 9880., 9880., 9880.], + [10080., 11472., 11472., 11472., 11472.], + [10080., 11472., 11472., 11472., 11472.], + ], + [ + [8680., 9880., 9880., 9880., 9880.], + [10080., 11472., 11472., 11472., 11472.], + [10080., 11472., 11472., 11472., 11472.], + ], + ], + ]), + bias: TestTensor::from_floats([896., 896., 896.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_complex_no_groups_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [4, 2], - kernel_size: [2, 3], - padding: [1, 2], - padding_out: [1, 2], - stride: [2, 3], - dilation: [1, 2], - groups: 1, - size: [10, 10], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [30., 42., 42., 42., 42., 42., 42., 42., 42., 42.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - ], - [ - [78., 114., 114., 114., 114., 114., 114., 114., 114., 114.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], - ], - [ - [126., 186., 186., 186., 186., 186., 186., 186., 186., 186.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], - ], - [ - [174., 258., 258., 258., 258., 258., 258., 258., 258., 258.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], - ], - ]]), - weight: TestTensor::from_floats([ - [ - [[4455., 4905., 4905.], [4500., 4950., 4950.]], - [[4455., 4905., 4905.], [4500., 4950., 4950.]], - ], - [ - [[12555., 13905., 13905.], [13500., 14950., 14950.]], - [[12555., 13905., 13905.], [13500., 14950., 14950.]], - ], - [ - [[20655., 22905., 22905.], [22500., 24950., 24950.]], - [[20655., 22905., 22905.], [22500., 24950., 24950.]], - ], - [ - [[28755., 31905., 31905.], [31500., 34950., 34950.]], - [[28755., 31905., 31905.], [31500., 34950., 34950.]], - ], - ]), - bias: TestTensor::from_floats([570., 570.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_complex_no_groups_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [4, 2], + kernel_size: [2, 3], + padding: [1, 2], + padding_out: [1, 2], + stride: [2, 3], + dilation: [1, 2], + groups: 1, + size: [10, 10], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [30., 42., 42., 42., 42., 42., 42., 42., 42., 42.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + ], + [ + [78., 114., 114., 114., 114., 114., 114., 114., 114., 114.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], + ], + [ + [126., 186., 186., 186., 186., 186., 186., 186., 186., 186.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], + ], + [ + [174., 258., 258., 258., 258., 258., 258., 258., 258., 258.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], + ], + ]]), + weight: TestTensor::from_floats([ + [ + [[4455., 4905., 4905.], [4500., 4950., 4950.]], + [[4455., 4905., 4905.], [4500., 4950., 4950.]], + ], + [ + [[12555., 13905., 13905.], [13500., 14950., 14950.]], + [[12555., 13905., 13905.], [13500., 14950., 14950.]], + ], + [ + [[20655., 22905., 22905.], [22500., 24950., 24950.]], + [[20655., 22905., 22905.], [22500., 24950., 24950.]], + ], + [ + [[28755., 31905., 31905.], [31500., 34950., 34950.]], + [[28755., 31905., 31905.], [31500., 34950., 34950.]], + ], + ]), + bias: TestTensor::from_floats([570., 570.]), + }; + test.assert_grads(grads); + } - #[test] - fn test_conv_transpose2d_complex_groups() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels: [4, 2], - kernel_size: [2, 3], - padding: [1, 2], - padding_out: [1, 2], - stride: [2, 3], - dilation: [1, 2], - groups: 2, - size: [10, 10], - }; - let grads = Grads { - x: TestTensor::from_floats([[ - [ - [9., 12., 12., 12., 12., 12., 12., 12., 12., 12.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], - ], - [ - [21., 30., 30., 30., 30., 30., 30., 30., 30., 30.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], - ], - [ - [33., 48., 48., 48., 48., 48., 48., 48., 48., 48.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], - ], - [ - [45., 66., 66., 66., 66., 66., 66., 66., 66., 66.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], - ], - ]]), - weight: TestTensor::from_floats([ - [[[4455., 4905., 4905.], [4500., 4950., 4950.]]], - [[[12555., 13905., 13905.], [13500., 14950., 14950.]]], - [[[20655., 22905., 22905.], [22500., 24950., 24950.]]], - [[[28755., 31905., 31905.], [31500., 34950., 34950.]]], - ]), - bias: TestTensor::from_floats([570., 570.]), - }; - test.assert_grads(grads); - } + #[test] + fn test_conv_transpose2d_complex_groups() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels: [4, 2], + kernel_size: [2, 3], + padding: [1, 2], + padding_out: [1, 2], + stride: [2, 3], + dilation: [1, 2], + groups: 2, + size: [10, 10], + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [9., 12., 12., 12., 12., 12., 12., 12., 12., 12.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], + ], + [ + [21., 30., 30., 30., 30., 30., 30., 30., 30., 30.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], + ], + [ + [33., 48., 48., 48., 48., 48., 48., 48., 48., 48.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], + ], + [ + [45., 66., 66., 66., 66., 66., 66., 66., 66., 66.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[4455., 4905., 4905.], [4500., 4950., 4950.]]], + [[[12555., 13905., 13905.], [13500., 14950., 14950.]]], + [[[20655., 22905., 22905.], [22500., 24950., 24950.]]], + [[[28755., 31905., 31905.], [31500., 34950., 34950.]]], + ]), + bias: TestTensor::from_floats([570., 570.]), + }; + test.assert_grads(grads); + } - struct ConvTranspose2dTestCase { - batch_size: usize, - channels: [usize; 2], - kernel_size: [usize; 2], - padding: [usize; 2], - padding_out: [usize; 2], - stride: [usize; 2], - dilation: [usize; 2], - groups: usize, - size: [usize; 2], - } + struct ConvTranspose2dTestCase { + batch_size: usize, + channels: [usize; 2], + kernel_size: [usize; 2], + padding: [usize; 2], + padding_out: [usize; 2], + stride: [usize; 2], + dilation: [usize; 2], + groups: usize, + size: [usize; 2], + } - struct Grads { - x: TestTensor<4>, - weight: TestTensor<4>, - bias: TestTensor<1>, - } + struct Grads { + x: TestTensor<4>, + weight: TestTensor<4>, + bias: TestTensor<1>, + } - impl ConvTranspose2dTestCase { - fn assert_grads(self, expected_grads: Grads) { - let shape_x = Shape::new([ - self.batch_size, - self.channels[0], - self.size[0], - self.size[1], - ]); - let shape_weight = Shape::new([ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size[0], - self.kernel_size[1], - ]); - let weight = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ) - .require_grad(); - let bias = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..self.channels[1]) - .into_data() - .convert(), - ) - .require_grad(); - let x = TestAutodiffTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ) - .require_grad(); - let output = conv_transpose2d( - x.clone(), - weight.clone(), - Some(bias.clone()), - ConvTransposeOptions::new( - self.stride, - self.padding, - self.padding_out, - self.dilation, - self.groups, - ), - ); - let grads = output.backward(); + impl ConvTranspose2dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let shape_x = Shape::new([ + self.batch_size, + self.channels[0], + self.size[0], + self.size[1], + ]); + let shape_weight = Shape::new([ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size[0], + self.kernel_size[1], + ]); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels[1]) + .into_data() + .convert(), + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ) + .require_grad(); + let output = conv_transpose2d( + x.clone(), + weight.clone(), + Some(bias.clone()), + ConvTransposeOptions::new( + self.stride, + self.padding, + self.padding_out, + self.dilation, + self.groups, + ), + ); + let grads = output.backward(); - // Assert - let x_grad_actual = x.grad(&grads).unwrap(); - let weight_grad_actual = weight.grad(&grads).unwrap(); - let bias_grad_actual = bias.grad(&grads).unwrap(); + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); - expected_grads - .bias - .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); - expected_grads - .x - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - expected_grads - .weight - .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); + } } - } } diff --git a/burn-autodiff/src/tests/cos.rs b/burn-autodiff/src/tests/cos.rs index a89e08139a..af42104e8e 100644 --- a/burn-autodiff/src/tests/cos.rs +++ b/burn-autodiff/src/tests/cos.rs @@ -1,30 +1,30 @@ #[burn_tensor_testgen::testgen(ad_cos)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_cos() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_cos() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1.to_data().assert_approx_eq_diff( - &Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]), - 2.0e-3, - ); - grad_2.to_data().assert_approx_eq_diff( - &Data::from([[9.222064, -39.123375], [-28.721354, 49.748356]]), - 2.0e-3, - ); - } + grad_1.to_data().assert_approx_eq_diff( + &Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]), + 2.0e-3, + ); + grad_2.to_data().assert_approx_eq_diff( + &Data::from([[9.222064, -39.123375], [-28.721354, 49.748356]]), + 2.0e-3, + ); + } } diff --git a/burn-autodiff/src/tests/cross_entropy.rs b/burn-autodiff/src/tests/cross_entropy.rs index c22a478f4a..f898f6b2f6 100644 --- a/burn-autodiff/src/tests/cross_entropy.rs +++ b/burn-autodiff/src/tests/cross_entropy.rs @@ -1,30 +1,31 @@ #[burn_tensor_testgen::testgen(ad_cross_entropy_loss)] mod tests { - use super::*; - use burn_tensor::{loss, Data, Tensor}; + use super::*; + use burn_tensor::{loss, Data, Tensor}; - #[test] - fn test_cross_entropy_loss_grad() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - let data_targets = Data::from([[0.8, 0.2], [0.9, 0.1]]); + #[test] + fn test_cross_entropy_loss_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + let data_targets = Data::from([[0.8, 0.2], [0.9, 0.1]]); - let tensor_1 = Tensor::::from_data(data_1).require_grad(); - let tensor_2 = Tensor::::from_data(data_2).require_grad(); - let tensor_targets = Tensor::::from_data(data_targets).require_grad(); + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); + let tensor_targets = + Tensor::::from_data(data_targets).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets); - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[0.2655, 0.2655], [0.4496, 0.4496]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-1.3486, 1.3486], [-2.0637, 2.0637]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[0.2655, 0.2655], [0.4496, 0.4496]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-1.3486, 1.3486], [-2.0637, 2.0637]]), 3); + } } diff --git a/burn-autodiff/src/tests/div.rs b/burn-autodiff/src/tests/div.rs index 47402e9d6d..7ab4470921 100644 --- a/burn-autodiff/src/tests/div.rs +++ b/burn-autodiff/src/tests/div.rs @@ -1,89 +1,89 @@ #[burn_tensor_testgen::testgen(ad_div)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_div() { - let data_1 = Data::from([1.0, 7.0]); - let data_2 = Data::from([4.0, 7.0]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().div(tensor_2.clone()); - let grads = tensor_3.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([0.25, 0.1429]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([-0.0625, -0.1429]), 3); - } - - #[test] - fn should_diff_div_scalar() { - let data = Data::from([1.0, 7.0]); - - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().div_scalar(4.0); - - let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); - - assert_eq!(grad.to_data(), Data::from([0.25, 0.25])); - } - - #[test] - fn test_div_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().div(tensor_2.clone()); - let tensor_5 = tensor_4.div(tensor_3); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[0.1250, 0.0714], [0.25, 0.1667]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-0.0312, -0.0714], [-1.6250, 0.1667]]), 3); - } - - #[test] - fn test_div_complex_2() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.div(tensor_2.clone()); - - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[2.00, 2.9286], [1.3667, 2.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[0.0833, 0.0959], [-0.0556, -0.0671]]), 3); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_div() { + let data_1 = Data::from([1.0, 7.0]); + let data_2 = Data::from([4.0, 7.0]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().div(tensor_2.clone()); + let grads = tensor_3.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([0.25, 0.1429]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([-0.0625, -0.1429]), 3); + } + + #[test] + fn should_diff_div_scalar() { + let data = Data::from([1.0, 7.0]); + + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().div_scalar(4.0); + + let grads = tensor_out.backward(); + let grad = tensor.grad(&grads).unwrap(); + + assert_eq!(grad.to_data(), Data::from([0.25, 0.25])); + } + + #[test] + fn test_div_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().div(tensor_2.clone()); + let tensor_5 = tensor_4.div(tensor_3); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[0.1250, 0.0714], [0.25, 0.1667]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-0.0312, -0.0714], [-1.6250, 0.1667]]), 3); + } + + #[test] + fn test_div_complex_2() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.div(tensor_2.clone()); + + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[2.00, 2.9286], [1.3667, 2.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[0.0833, 0.0959], [-0.0556, -0.0671]]), 3); + } } diff --git a/burn-autodiff/src/tests/erf.rs b/burn-autodiff/src/tests/erf.rs index 5db398fe67..bd80c347ad 100644 --- a/burn-autodiff/src/tests/erf.rs +++ b/burn-autodiff/src/tests/erf.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_erf)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_erf() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_erf() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3); + } } diff --git a/burn-autodiff/src/tests/exp.rs b/burn-autodiff/src/tests/exp.rs index bfd9293bee..bba159bb60 100644 --- a/burn-autodiff/src/tests/exp.rs +++ b/burn-autodiff/src/tests/exp.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_exp)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_exp() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + #[test] + fn should_diff_exp() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[54.5991, 27.4746], [54.5991, 27.4746]]), 3); - grad_2.to_data().assert_approx_eq( - &Data::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]), - 3, - ); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[54.5991, 27.4746], [54.5991, 27.4746]]), 3); + grad_2.to_data().assert_approx_eq( + &Data::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]), + 3, + ); + } } diff --git a/burn-autodiff/src/tests/gather_scatter.rs b/burn-autodiff/src/tests/gather_scatter.rs index e1384c2218..3557f11c8a 100644 --- a/burn-autodiff/src/tests/gather_scatter.rs +++ b/burn-autodiff/src/tests/gather_scatter.rs @@ -1,56 +1,58 @@ #[burn_tensor_testgen::testgen(ad_gather_scatter)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_gather_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); - let indices = Tensor::::from_data(Data::from([ - [2, 1, 0, 1, 2], - [1, 0, 2, 1, 0], - ])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().gather(1, indices); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[94., 150., 187.], [242., 305., 304.]]) - ); - } - - #[test] - fn test_scatter_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); - let values = - TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad(); - let indices = - Tensor::::from_data(Data::from([[2, 1, 0], [2, 0, 1]])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().scatter(1, indices, values.clone()); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = values.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[127., 181., 235.], [226., 316., 406.]]) - ); - assert_eq!( - grad_2.into_data(), - Data::from([[19., 19., 19.], [64., 64., 64.]]) - ); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_gather_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) + .require_grad(); + let indices = Tensor::::from_data(Data::from([ + [2, 1, 0, 1, 2], + [1, 0, 2, 1, 0], + ])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().gather(1, indices); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[94., 150., 187.], [242., 305., 304.]]) + ); + } + + #[test] + fn test_scatter_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) + .require_grad(); + let values = TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) + .require_grad(); + let indices = + Tensor::::from_data(Data::from([[2, 1, 0], [2, 0, 1]])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().scatter(1, indices, values.clone()); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = values.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[127., 181., 235.], [226., 316., 406.]]) + ); + assert_eq!( + grad_2.into_data(), + Data::from([[19., 19., 19.], [64., 64., 64.]]) + ); + } } diff --git a/burn-autodiff/src/tests/gelu.rs b/burn-autodiff/src/tests/gelu.rs index c39ff5ddbb..fec6eb3aa0 100644 --- a/burn-autodiff/src/tests/gelu.rs +++ b/burn-autodiff/src/tests/gelu.rs @@ -1,25 +1,25 @@ #[burn_tensor_testgen::testgen(ad_gelu)] mod tests { - use super::*; - use burn_tensor::{activation, Data}; + use super::*; + use burn_tensor::{activation, Data}; - #[test] - fn should_diff_gelu() { - let tensor_1 = TestAutodiffTensor::from_floats([[0.0, 1.0], [-3.0, 4.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]]).require_grad(); + #[test] + fn should_diff_gelu() { + let tensor_1 = TestAutodiffTensor::from_floats([[0.0, 1.0], [-3.0, 4.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]]).require_grad(); - let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone())); - let x = tensor_1.clone().matmul(x); - let grads = x.backward(); + let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone())); + let x = tensor_1.clone().matmul(x); + let grads = x.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[1.4629, 1.4629], [48.2286, 153.4629]]), 2); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[-15.0000, -1.9895], [17.0000, 17.0000]]), 2); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[1.4629, 1.4629], [48.2286, 153.4629]]), 2); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[-15.0000, -1.9895], [17.0000, 17.0000]]), 2); + } } diff --git a/burn-autodiff/src/tests/gradients.rs b/burn-autodiff/src/tests/gradients.rs index 844de4f41c..a1f6eda3c0 100644 --- a/burn-autodiff/src/tests/gradients.rs +++ b/burn-autodiff/src/tests/gradients.rs @@ -1,24 +1,25 @@ #[burn_tensor_testgen::testgen(gradients)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Distribution}; + use super::*; + use burn_tensor::{activation, Data, Distribution}; - #[test] - fn should_update_tensor_when_grad_replace() { - let tensor_1 = TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); - let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default); + #[test] + fn should_update_tensor_when_grad_replace() { + let tensor_1 = TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); + let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default); - let x = tensor_1.clone().matmul(activation::gelu(tensor_2)); - let mut grads = x.backward(); + let x = tensor_1.clone().matmul(activation::gelu(tensor_2)); + let mut grads = x.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_1_updated = TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); - tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner()); + let grad_1_updated = + TestAutodiffTensor::random([32, 32], Distribution::Default).require_grad(); + tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner()); - let grad_1_new = tensor_1.grad(&grads).unwrap(); + let grad_1_new = tensor_1.grad(&grads).unwrap(); - assert_ne!(grad_1_new.to_data(), grad_1.into_data()); - assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data()); - } + assert_ne!(grad_1_new.to_data(), grad_1.into_data()); + assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data()); + } } diff --git a/burn-autodiff/src/tests/log.rs b/burn-autodiff/src/tests/log.rs index 3752aebea5..9c7766da97 100644 --- a/burn-autodiff/src/tests/log.rs +++ b/burn-autodiff/src/tests/log.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_log)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_log() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_log() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[60.2652, 72.3130], [60.2652, 72.3130]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[22.8614, 24.5043], [24.5729, 26.8507]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[60.2652, 72.3130], [60.2652, 72.3130]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[22.8614, 24.5043], [24.5729, 26.8507]]), 3); + } } diff --git a/burn-autodiff/src/tests/log1p.rs b/burn-autodiff/src/tests/log1p.rs index 627fd00aeb..d94f5aa176 100644 --- a/burn-autodiff/src/tests/log1p.rs +++ b/burn-autodiff/src/tests/log1p.rs @@ -1,29 +1,29 @@ #[burn_tensor_testgen::testgen(ad_log1p)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_log1p() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_log1p() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[64.80622, 75.49362], [64.80622, 75.49362]]), 3); - grad_2.to_data().assert_approx_eq( - &Data::from([[22.922085, 24.475657], [24.727802, 26.864166]]), - 3, - ); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[64.80622, 75.49362], [64.80622, 75.49362]]), 3); + grad_2.to_data().assert_approx_eq( + &Data::from([[22.922085, 24.475657], [24.727802, 26.864166]]), + 3, + ); + } } diff --git a/burn-autodiff/src/tests/mask.rs b/burn-autodiff/src/tests/mask.rs index fe7b25fd15..d400149cfb 100644 --- a/burn-autodiff/src/tests/mask.rs +++ b/burn-autodiff/src/tests/mask.rs @@ -1,53 +1,54 @@ #[burn_tensor_testgen::testgen(ad_mask)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Tensor}; - - #[test] - fn should_diff_mask_fill() { - let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); - let mask = Data::::from([[true, false], [false, true]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let mask = Tensor::::from_bool(mask); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.mask_fill(mask, 2.0); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[7.0, 3.0], [4.0, 2.0]])); - assert_eq!(grad_2.to_data(), Data::from([[2.0, 1.0], [3.0, 7.0]])); - } - - #[test] - fn should_diff_mask_where() { - let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]]).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]]).require_grad(); - let mask = Tensor::::from_data([[true, false], [false, true]]); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.clone().matmul(tensor_3.clone()); - let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone()); - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - let grad_3 = tensor_3.grad(&grads).unwrap(); - - grad_1 - .into_data() - .assert_approx_eq(&Data::from([[121.8, 55.0], [110.8, 50.0]]), 3); - grad_2 - .into_data() - .assert_approx_eq(&Data::from([[27.4, 33.4], [95.0, 115.0]]), 3); - grad_3 - .into_data() - .assert_approx_eq(&Data::from([[15., 18.], [23., 29.]]), 3); - } + use super::*; + use burn_tensor::{Bool, Data, Tensor}; + + #[test] + fn should_diff_mask_fill() { + let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); + let mask = Data::::from([[true, false], [false, true]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let mask = Tensor::::from_bool(mask); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.mask_fill(mask, 2.0); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[7.0, 3.0], [4.0, 2.0]])); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 1.0], [3.0, 7.0]])); + } + + #[test] + fn should_diff_mask_where() { + let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]]).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]]).require_grad(); + let mask = + Tensor::::from_data([[true, false], [false, true]]); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.clone().matmul(tensor_3.clone()); + let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone()); + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_3 = tensor_3.grad(&grads).unwrap(); + + grad_1 + .into_data() + .assert_approx_eq(&Data::from([[121.8, 55.0], [110.8, 50.0]]), 3); + grad_2 + .into_data() + .assert_approx_eq(&Data::from([[27.4, 33.4], [95.0, 115.0]]), 3); + grad_3 + .into_data() + .assert_approx_eq(&Data::from([[15., 18.], [23., 29.]]), 3); + } } diff --git a/burn-autodiff/src/tests/matmul.rs b/burn-autodiff/src/tests/matmul.rs index 31c7be72e0..7f5de915e5 100644 --- a/burn-autodiff/src/tests/matmul.rs +++ b/burn-autodiff/src/tests/matmul.rs @@ -1,78 +1,78 @@ #[burn_tensor_testgen::testgen(ad_matmul)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_matmul() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let grads = tensor_3.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); - assert_eq!( - tensor_3.clone().into_data(), - Data::from([[18.0, 28.0], [14.0, 23.0]]) - ); - } - - #[test] - fn test_matmul_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); - assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); - } - - #[test] - fn test_matmul_complex_2() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3.clone()); - let tensor_6 = tensor_1.clone().matmul(tensor_5); - - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[800.0, 792.0], [360.0, 592.0]]) - ); - assert_eq!( - grad_2.to_data(), - Data::from([[264., 264.0], [344.0, 344.0]]) - ); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_matmul() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let grads = tensor_3.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); + assert_eq!( + tensor_3.clone().into_data(), + Data::from([[18.0, 28.0], [14.0, 23.0]]) + ); + } + + #[test] + fn test_matmul_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); + assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); + } + + #[test] + fn test_matmul_complex_2() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3.clone()); + let tensor_6 = tensor_1.clone().matmul(tensor_5); + + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[800.0, 792.0], [360.0, 592.0]]) + ); + assert_eq!( + grad_2.to_data(), + Data::from([[264., 264.0], [344.0, 344.0]]) + ); + } } diff --git a/burn-autodiff/src/tests/maxmin.rs b/burn-autodiff/src/tests/maxmin.rs index 2f10eac2fc..3edab63f6e 100644 --- a/burn-autodiff/src/tests/maxmin.rs +++ b/burn-autodiff/src/tests/maxmin.rs @@ -1,45 +1,45 @@ #[burn_tensor_testgen::testgen(ad_maxmin)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_max_dim() { - let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[50.0, 34.0], [40.0, -10.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[8.0, 10.0], [56.0, 15.0]]), 5); - } - - #[test] - fn should_diff_min_dim() { - let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); - let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze()); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[-42.0, 38.0], [-34.0, -24.0]]), 5); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_max_dim() { + let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[50.0, 34.0], [40.0, -10.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[8.0, 10.0], [56.0, 15.0]]), 5); + } + + #[test] + fn should_diff_min_dim() { + let tensor_1 = TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad(); + let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze()); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[-42.0, 38.0], [-34.0, -24.0]]), 5); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5); + } } diff --git a/burn-autodiff/src/tests/maxpool1d.rs b/burn-autodiff/src/tests/maxpool1d.rs index 05d8c09573..2ccceafd22 100644 --- a/burn-autodiff/src/tests/maxpool1d.rs +++ b/burn-autodiff/src/tests/maxpool1d.rs @@ -1,110 +1,111 @@ #[burn_tensor_testgen::testgen(ad_max_pool1d)] mod tests { - use super::*; - use burn_tensor::{module::max_pool1d, Data}; - - #[test] - fn test_max_pool1d_simple() { - let kernel_size = 4; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = TestAutodiffTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool1d_with_dilation() { - let kernel_size = 4; - let padding = 0; - let stride = 1; - let dilation = 2; - - let x = TestAutodiffTensor::from_floats([[[ - 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, - 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, - 0.4610, 0.5365, 0.6880, - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - 0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0., 0., - 0., 1., - ]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool1d_complex() { - let kernel_size = 4; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = TestAutodiffTensor::from_floats([[[ - 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, - 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, - 0.4610, 0.5365, 0.6880, - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - 0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., 1., - 1., 1., - ]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool1d_complex_with_padding() { - let kernel_size = 4; - let padding = 2; - let stride = 1; - let dilation = 1; - - let x = TestAutodiffTensor::from_floats([[[ - 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, - 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, - 0.4610, 0.5365, 0.6880, - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - 1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., 1., - 1., 3., - ]]]); - - let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } + use super::*; + use burn_tensor::{module::max_pool1d, Data}; + + #[test] + fn test_max_pool1d_simple() { + let kernel_size = 4; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = + TestAutodiffTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_with_dilation() { + let kernel_size = 4; + let padding = 0; + let stride = 1; + let dilation = 2; + + let x = TestAutodiffTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + 0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0., + 0., 0., 1., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_complex() { + let kernel_size = 4; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = TestAutodiffTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + 0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., + 1., 1., 1., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_complex_with_padding() { + let kernel_size = 4; + let padding = 2; + let stride = 1; + let dilation = 1; + + let x = TestAutodiffTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + 1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., + 1., 1., 3., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } } diff --git a/burn-autodiff/src/tests/maxpool2d.rs b/burn-autodiff/src/tests/maxpool2d.rs index a73e11ff0f..49d66212ae 100644 --- a/burn-autodiff/src/tests/maxpool2d.rs +++ b/burn-autodiff/src/tests/maxpool2d.rs @@ -1,171 +1,171 @@ #[burn_tensor_testgen::testgen(ad_max_pool2d)] mod tests { - use super::*; - use burn_tensor::{module::max_pool2d, Data}; - - #[test] - fn test_max_pool2d_simple_1() { - let kernel_size_1 = 3; - let kernel_size_2 = 3; - let padding_1 = 0; - let padding_2 = 0; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; - - let x = TestAutodiffTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 2.0], - [0.0, 2.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool2d_simple_2() { - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; - - let x = TestAutodiffTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [1., 3., 0., 2.], - [3., 0., 0., 4.], - [1., 4., 0., 1.], - [2., 0., 3., 1.], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool2d_with_dilation() { - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 2; - let dilation_2 = 2; - - let x = TestAutodiffTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [0., 0., 0., 0.], - [1., 1., 1., 2.], - [0., 4., 4., 0.], - [0., 1., 2., 0.], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } - - #[test] - fn test_max_pool2d_complex() { - let kernel_size_1 = 4; - let kernel_size_2 = 2; - let padding_1 = 2; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 2; - let dilation_1 = 1; - let dilation_2 = 1; - - let x = TestAutodiffTensor::from_floats([[[ - [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], - [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], - [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], - [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], - [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], - ]]]) - .require_grad(); - let x_grad_expected = TestAutodiffTensor::from_floats([[[ - [0., 0., 0., 3., 0.], - [4., 0., 2., 1., 0.], - [0., 0., 0., 0., 0.], - [2., 4., 0., 0., 0.], - [0., 0., 0., 0., 2.], - ]]]); - - let output = max_pool2d( - x.clone(), - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); - let grads = output.backward(); - - // Asserts - let x_grad_actual = x.grad(&grads).unwrap(); - x_grad_expected - .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); - } + use super::*; + use burn_tensor::{module::max_pool2d, Data}; + + #[test] + fn test_max_pool2d_simple_1() { + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 0; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; + + let x = TestAutodiffTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 2.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_simple_2() { + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; + + let x = TestAutodiffTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [1., 3., 0., 2.], + [3., 0., 0., 4.], + [1., 4., 0., 1.], + [2., 0., 3., 1.], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_with_dilation() { + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 2; + let dilation_2 = 2; + + let x = TestAutodiffTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [0., 0., 0., 0.], + [1., 1., 1., 2.], + [0., 4., 4., 0.], + [0., 1., 2., 0.], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_complex() { + let kernel_size_1 = 4; + let kernel_size_2 = 2; + let padding_1 = 2; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 2; + let dilation_1 = 1; + let dilation_2 = 1; + + let x = TestAutodiffTensor::from_floats([[[ + [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], + [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], + [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], + [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], + [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], + ]]]) + .require_grad(); + let x_grad_expected = TestAutodiffTensor::from_floats([[[ + [0., 0., 0., 3., 0.], + [4., 0., 2., 1., 0.], + [0., 0., 0., 0., 0.], + [2., 4., 0., 0., 0.], + [0., 0., 0., 0., 2.], + ]]]); + + let output = max_pool2d( + x.clone(), + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } } diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index 15328cc5b8..23e84426d0 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -48,61 +48,61 @@ mod transpose; #[macro_export] macro_rules! testgen_all { - () => { - type TestAutodiffBackend = burn_autodiff::Autodiff; - type TestAutodiffTensor = burn_tensor::Tensor; + () => { + type TestAutodiffBackend = burn_autodiff::Autodiff; + type TestAutodiffTensor = burn_tensor::Tensor; - // Behavior - burn_autodiff::testgen_ad_broadcast!(); - burn_autodiff::testgen_gradients!(); + // Behavior + burn_autodiff::testgen_ad_broadcast!(); + burn_autodiff::testgen_gradients!(); - // Activation - burn_autodiff::testgen_ad_relu!(); - burn_autodiff::testgen_ad_gelu!(); + // Activation + burn_autodiff::testgen_ad_relu!(); + burn_autodiff::testgen_ad_gelu!(); - // Modules - burn_autodiff::testgen_ad_conv1d!(); - burn_autodiff::testgen_ad_conv2d!(); - burn_autodiff::testgen_ad_conv_transpose1d!(); - burn_autodiff::testgen_ad_conv_transpose2d!(); - burn_autodiff::testgen_ad_max_pool1d!(); - burn_autodiff::testgen_ad_max_pool2d!(); - burn_autodiff::testgen_ad_avg_pool1d!(); - burn_autodiff::testgen_ad_avg_pool2d!(); - burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); - burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); - burn_autodiff::testgen_module_backward!(); + // Modules + burn_autodiff::testgen_ad_conv1d!(); + burn_autodiff::testgen_ad_conv2d!(); + burn_autodiff::testgen_ad_conv_transpose1d!(); + burn_autodiff::testgen_ad_conv_transpose2d!(); + burn_autodiff::testgen_ad_max_pool1d!(); + burn_autodiff::testgen_ad_max_pool2d!(); + burn_autodiff::testgen_ad_avg_pool1d!(); + burn_autodiff::testgen_ad_avg_pool2d!(); + burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); + burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); + burn_autodiff::testgen_module_backward!(); - // Tensor - burn_autodiff::testgen_ad_complex!(); - burn_autodiff::testgen_ad_multithread!(); - burn_autodiff::testgen_ad_add!(); - burn_autodiff::testgen_ad_aggregation!(); - burn_autodiff::testgen_ad_maxmin!(); - burn_autodiff::testgen_ad_cat!(); - burn_autodiff::testgen_ad_cos!(); - burn_autodiff::testgen_ad_cross_entropy_loss!(); - burn_autodiff::testgen_ad_div!(); - burn_autodiff::testgen_ad_erf!(); - burn_autodiff::testgen_ad_exp!(); - burn_autodiff::testgen_ad_slice!(); - burn_autodiff::testgen_ad_gather_scatter!(); - burn_autodiff::testgen_ad_select!(); - burn_autodiff::testgen_ad_log!(); - burn_autodiff::testgen_ad_log1p!(); - burn_autodiff::testgen_ad_mask!(); - burn_autodiff::testgen_ad_matmul!(); - burn_autodiff::testgen_ad_mul!(); - burn_autodiff::testgen_ad_neg!(); - burn_autodiff::testgen_ad_powf!(); - burn_autodiff::testgen_ad_recip!(); - burn_autodiff::testgen_ad_reshape!(); - burn_autodiff::testgen_ad_sin!(); - burn_autodiff::testgen_ad_softmax!(); - burn_autodiff::testgen_ad_sqrt!(); - burn_autodiff::testgen_ad_abs!(); - burn_autodiff::testgen_ad_sub!(); - burn_autodiff::testgen_ad_tanh!(); - burn_autodiff::testgen_ad_transpose!(); - }; + // Tensor + burn_autodiff::testgen_ad_complex!(); + burn_autodiff::testgen_ad_multithread!(); + burn_autodiff::testgen_ad_add!(); + burn_autodiff::testgen_ad_aggregation!(); + burn_autodiff::testgen_ad_maxmin!(); + burn_autodiff::testgen_ad_cat!(); + burn_autodiff::testgen_ad_cos!(); + burn_autodiff::testgen_ad_cross_entropy_loss!(); + burn_autodiff::testgen_ad_div!(); + burn_autodiff::testgen_ad_erf!(); + burn_autodiff::testgen_ad_exp!(); + burn_autodiff::testgen_ad_slice!(); + burn_autodiff::testgen_ad_gather_scatter!(); + burn_autodiff::testgen_ad_select!(); + burn_autodiff::testgen_ad_log!(); + burn_autodiff::testgen_ad_log1p!(); + burn_autodiff::testgen_ad_mask!(); + burn_autodiff::testgen_ad_matmul!(); + burn_autodiff::testgen_ad_mul!(); + burn_autodiff::testgen_ad_neg!(); + burn_autodiff::testgen_ad_powf!(); + burn_autodiff::testgen_ad_recip!(); + burn_autodiff::testgen_ad_reshape!(); + burn_autodiff::testgen_ad_sin!(); + burn_autodiff::testgen_ad_softmax!(); + burn_autodiff::testgen_ad_sqrt!(); + burn_autodiff::testgen_ad_abs!(); + burn_autodiff::testgen_ad_sub!(); + burn_autodiff::testgen_ad_tanh!(); + burn_autodiff::testgen_ad_transpose!(); + }; } diff --git a/burn-autodiff/src/tests/mul.rs b/burn-autodiff/src/tests/mul.rs index a214abd24b..85eec40498 100644 --- a/burn-autodiff/src/tests/mul.rs +++ b/burn-autodiff/src/tests/mul.rs @@ -1,64 +1,64 @@ #[burn_tensor_testgen::testgen(ad_mul)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_mul() { - let data_1 = Data::from([1.0, 7.0]); - let data_2 = Data::from([4.0, 7.0]); + #[test] + fn should_diff_mul() { + let data_1 = Data::from([1.0, 7.0]); + let data_2 = Data::from([4.0, 7.0]); - let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); - let tensor_3 = tensor_1.clone().mul(tensor_2.clone()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().mul(tensor_2.clone()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), data_2); - assert_eq!(grad_2.to_data(), data_1); - assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0])); - } + assert_eq!(grad_1.to_data(), data_2); + assert_eq!(grad_2.to_data(), data_1); + assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0])); + } - #[test] - fn should_diff_mul_scalar() { - let data = Data::from([2.0, 5.0]); + #[test] + fn should_diff_mul_scalar() { + let data = Data::from([2.0, 5.0]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().mul_scalar(4.0); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().mul_scalar(4.0); - let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grads = tensor_out.backward(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0])); - assert_eq!(grad.to_data(), Data::from([4.0, 4.0])); - } + assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0])); + assert_eq!(grad.to_data(), Data::from([4.0, 4.0])); + } - #[test] - fn test_mul_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + #[test] + fn test_mul_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - let tensor_4 = tensor_1.clone().mul(tensor_2.clone()); - let tensor_5 = tensor_4.mul(tensor_3); - let tensor_6 = tensor_1.clone().mul(tensor_5); + let tensor_4 = tensor_1.clone().mul(tensor_2.clone()); + let tensor_5 = tensor_4.mul(tensor_3); + let tensor_6 = tensor_1.clone().mul(tensor_5); - let grads = tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!( - grad_1.to_data(), - Data::from([[16.0, 196.0], [104.0, -36.0]]) - ); - assert_eq!(grad_2.to_data(), Data::from([[2.0, 98.0], [338.0, 18.0]])); - } + assert_eq!( + grad_1.to_data(), + Data::from([[16.0, 196.0], [104.0, -36.0]]) + ); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 98.0], [338.0, 18.0]])); + } } diff --git a/burn-autodiff/src/tests/multithread.rs b/burn-autodiff/src/tests/multithread.rs index 3b30b52a8a..041572da6e 100644 --- a/burn-autodiff/src/tests/multithread.rs +++ b/burn-autodiff/src/tests/multithread.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(ad_multithread)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_behave_the_same_with_multithread() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + #[test] + fn should_behave_the_same_with_multithread() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let with_move = || { - let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); + let with_move = || { + let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3); - // Task 1 - let tensor_1_cloned = tensor_1.clone(); - let tensor_2_cloned = tensor_2.clone(); - let tensor_5_cloned = tensor_5.clone(); + // Task 1 + let tensor_1_cloned = tensor_1.clone(); + let tensor_2_cloned = tensor_2.clone(); + let tensor_5_cloned = tensor_5.clone(); - let first_call = move || { - let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned); - tensor_6_1.matmul(tensor_1_cloned) - }; + let first_call = move || { + let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned); + tensor_6_1.matmul(tensor_1_cloned) + }; - // Task 2 - let tensor_1_cloned = tensor_1.clone(); - let tensor_2_cloned = tensor_2.clone(); - let tensor_5_cloned = tensor_5; + // Task 2 + let tensor_1_cloned = tensor_1.clone(); + let tensor_2_cloned = tensor_2.clone(); + let tensor_5_cloned = tensor_5; - let second_call = move || { - let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned); - tensor_6_2.matmul(tensor_2_cloned) - }; + let second_call = move || { + let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned); + tensor_6_2.matmul(tensor_2_cloned) + }; - let tensor_7_1_handle = std::thread::spawn(first_call); - let tensor_7_2_handle = std::thread::spawn(second_call); + let tensor_7_1_handle = std::thread::spawn(first_call); + let tensor_7_2_handle = std::thread::spawn(second_call); - let tensor_7_1 = tensor_7_1_handle.join().unwrap(); - let tensor_7_2 = tensor_7_2_handle.join().unwrap(); - let tensor_8 = tensor_7_1.matmul(tensor_7_2); + let tensor_7_1 = tensor_7_1_handle.join().unwrap(); + let tensor_7_2 = tensor_7_2_handle.join().unwrap(); + let tensor_8 = tensor_7_1.matmul(tensor_7_2); - let grads = tensor_8.backward(); + let grads = tensor_8.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - (grad_1, grad_2) - }; - let without_move = || { - let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); + (grad_1, grad_2) + }; + let without_move = || { + let tensor_1 = TestAutodiffTensor::from_data(data_1.clone()).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2.clone()).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_4.matmul(tensor_3); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_4.matmul(tensor_3); - // Task 1 - let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone()); - let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone()); + // Task 1 + let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone()); + let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone()); - // Task 2 - let tensor_6_2 = tensor_5.matmul(tensor_1.clone()); - let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone()); + // Task 2 + let tensor_6_2 = tensor_5.matmul(tensor_1.clone()); + let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone()); - let tensor_8 = tensor_7_1.matmul(tensor_7_2); + let tensor_8 = tensor_7_1.matmul(tensor_7_2); - let grads = tensor_8.backward(); + let grads = tensor_8.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - (grad_1, grad_2) - }; + (grad_1, grad_2) + }; - let (grad_1, grad_2) = without_move(); - let (grad_1_moved, grad_2_moved) = with_move(); + let (grad_1, grad_2) = without_move(); + let (grad_1_moved, grad_2_moved) = with_move(); - assert_eq!(grad_1.to_data(), grad_1_moved.to_data()); - assert_eq!(grad_2.to_data(), grad_2_moved.to_data()); - } + assert_eq!(grad_1.to_data(), grad_1_moved.to_data()); + assert_eq!(grad_2.to_data(), grad_2_moved.to_data()); + } } diff --git a/burn-autodiff/src/tests/neg.rs b/burn-autodiff/src/tests/neg.rs index 51974cb025..83657ea1f5 100644 --- a/burn-autodiff/src/tests/neg.rs +++ b/burn-autodiff/src/tests/neg.rs @@ -1,24 +1,24 @@ #[burn_tensor_testgen::testgen(ad_neg)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_neg() { - let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); + #[test] + fn should_diff_neg() { + let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg()); - let tensor_4 = tensor_3.neg(); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg()); + let tensor_4 = tensor_3.neg(); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); + } } diff --git a/burn-autodiff/src/tests/pow.rs b/burn-autodiff/src/tests/pow.rs index aadbd1d88d..7321951ddc 100644 --- a/burn-autodiff/src/tests/pow.rs +++ b/burn-autodiff/src/tests/pow.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_powf)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_powf() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_powf() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf(0.4)); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf(0.4)); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3); + } } diff --git a/burn-autodiff/src/tests/recip.rs b/burn-autodiff/src/tests/recip.rs index dc6911008e..c77579e273 100644 --- a/burn-autodiff/src/tests/recip.rs +++ b/burn-autodiff/src/tests/recip.rs @@ -1,21 +1,20 @@ #[burn_tensor_testgen::testgen(ad_recip)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_recip() { - let data = Data::from([2.0, 5.0, 0.4]); + #[test] + fn should_diff_recip() { + let data = Data::from([2.0, 5.0, 0.4]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().recip(); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().recip(); - let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grads = tensor_out.backward(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5])); - grad - .to_data() - .assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3); - } + assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5])); + grad.to_data() + .assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3); + } } diff --git a/burn-autodiff/src/tests/relu.rs b/burn-autodiff/src/tests/relu.rs index 13a6de9b24..57cfd51baa 100644 --- a/burn-autodiff/src/tests/relu.rs +++ b/burn-autodiff/src/tests/relu.rs @@ -1,25 +1,25 @@ #[burn_tensor_testgen::testgen(ad_relu)] mod tests { - use super::*; - use burn_tensor::{activation, Data}; + use super::*; + use burn_tensor::{activation, Data}; - #[test] - fn should_diff_relu() { - let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); - let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); + #[test] + fn should_diff_relu() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0, -7.0], [2.0, 3.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = activation::relu(tensor_3); - let tensor_5 = tensor_4.matmul(tensor_2.clone()); - let grads = tensor_5.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::relu(tensor_3); + let tensor_5 = tensor_4.matmul(tensor_2.clone()); + let grads = tensor_5.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[-47.0, 9.0], [-35.0, 15.0]])); - assert_eq!(grad_2.to_data(), Data::from([[15.0, 13.0], [-2.0, 39.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[-47.0, 9.0], [-35.0, 15.0]])); + assert_eq!(grad_2.to_data(), Data::from([[15.0, 13.0], [-2.0, 39.0]])); + } } diff --git a/burn-autodiff/src/tests/reshape.rs b/burn-autodiff/src/tests/reshape.rs index 7d3bc2cdd1..057241aba5 100644 --- a/burn-autodiff/src/tests/reshape.rs +++ b/burn-autodiff/src/tests/reshape.rs @@ -1,24 +1,24 @@ #[burn_tensor_testgen::testgen(ad_reshape)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_reshape() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([4.0, 7.0, 2.0, 3.0]); + #[test] + fn should_diff_reshape() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([4.0, 7.0, 2.0, 3.0]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_2.clone().reshape([2, 2]); - let tensor_4 = tensor_1.clone().matmul(tensor_3); - let grads = tensor_4.backward(); + let tensor_3 = tensor_2.clone().reshape([2, 2]); + let tensor_4 = tensor_1.clone().matmul(tensor_3); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.to_data(), Data::from([3.0, 3.0, 10.0, 10.0])); - } + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([3.0, 3.0, 10.0, 10.0])); + } } diff --git a/burn-autodiff/src/tests/select.rs b/burn-autodiff/src/tests/select.rs index 9c20fa5068..21c49f5242 100644 --- a/burn-autodiff/src/tests/select.rs +++ b/burn-autodiff/src/tests/select.rs @@ -1,52 +1,54 @@ #[burn_tensor_testgen::testgen(ad_select)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_select_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); - let indices = Tensor::::from_data(Data::from([1, 0])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().select(0, indices); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[109., 148., 187.], [37., 58., 79.]]) - ); - } - - #[test] - fn test_select_assign_grad() { - let tensor_1 = - TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); - let values = - TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad(); - let indices = Tensor::::from_data(Data::from([1, 0])); - - let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().select_assign(0, indices, values.clone()); - let tensor_4 = tensor_2.matmul(tensor_3); - - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = values.grad(&grads).unwrap(); - - assert_eq!( - grad_1.into_data(), - Data::from([[127., 199., 271.], [172., 244., 316.]]) - ); - assert_eq!( - grad_2.into_data(), - Data::from([[64., 64., 64.], [19., 19., 19.]]) - ); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_select_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) + .require_grad(); + let indices = Tensor::::from_data(Data::from([1, 0])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().select(0, indices); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[109., 148., 187.], [37., 58., 79.]]) + ); + } + + #[test] + fn test_select_assign_grad() { + let tensor_1 = + TestAutodiffTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])) + .require_grad(); + let values = TestAutodiffTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) + .require_grad(); + let indices = Tensor::::from_data(Data::from([1, 0])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().select_assign(0, indices, values.clone()); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = values.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[127., 199., 271.], [172., 244., 316.]]) + ); + assert_eq!( + grad_2.into_data(), + Data::from([[64., 64., 64.], [19., 19., 19.]]) + ); + } } diff --git a/burn-autodiff/src/tests/sin.rs b/burn-autodiff/src/tests/sin.rs index 2e7b544928..8462893d9a 100644 --- a/burn-autodiff/src/tests/sin.rs +++ b/burn-autodiff/src/tests/sin.rs @@ -1,29 +1,29 @@ #[burn_tensor_testgen::testgen(ad_sin)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_sin() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_sin() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq_diff(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 2.6e-3); - grad_2.to_data().assert_approx_eq_diff( - &Data::from([[38.668987, 44.194775], [-59.97261, -80.46094]]), - 2.6e-3, - ); - } + grad_1 + .to_data() + .assert_approx_eq_diff(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 2.6e-3); + grad_2.to_data().assert_approx_eq_diff( + &Data::from([[38.668987, 44.194775], [-59.97261, -80.46094]]), + 2.6e-3, + ); + } } diff --git a/burn-autodiff/src/tests/slice.rs b/burn-autodiff/src/tests/slice.rs index d6bb7c1505..6b8b46d70b 100644 --- a/burn-autodiff/src/tests/slice.rs +++ b/burn-autodiff/src/tests/slice.rs @@ -1,77 +1,77 @@ #[burn_tensor_testgen::testgen(ad_slice)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_matmul_with_slice() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_2.clone().slice([0..2, 0..2]); - let tensor_4 = tensor_1.clone().matmul(tensor_3); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!( - grad_2.to_data(), - Data::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]) - ); - } - - #[test] - fn should_diff_matmul_with_slice_assign() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_assigned: Data = Data::from([[9.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_assigned = TestAutodiffTensor::from_data(data_assigned).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned); - let tensor_5 = tensor_4.matmul(tensor_1.clone()); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[58.0, 38.0], [118.0, 82.0]])); - assert_eq!(grad_2.to_data(), Data::from([[16.0, 15.0], [24.0, 50.0]])); - } - - #[test] - fn should_diff_matmul_with_slice_assign_complex() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[9.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - - let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_5 = tensor_2.clone().slice([0..1, 0..1]); - let tensor_6 = tensor_5.mul(tensor_3.clone()); - let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6); - let tensor_8 = tensor_7.matmul(tensor_1.clone()); - - let grads = tensor_8.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - let grad_3 = tensor_3.grad(&grads).unwrap(); - - assert_eq!(grad_3.to_data(), Data::from([[32.0]])); - assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]])); - assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]])); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_matmul_with_slice() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_2.clone().slice([0..2, 0..2]); + let tensor_4 = tensor_1.clone().matmul(tensor_3); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!( + grad_2.to_data(), + Data::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]) + ); + } + + #[test] + fn should_diff_matmul_with_slice_assign() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_assigned: Data = Data::from([[9.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_assigned = TestAutodiffTensor::from_data(data_assigned).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned); + let tensor_5 = tensor_4.matmul(tensor_1.clone()); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[58.0, 38.0], [118.0, 82.0]])); + assert_eq!(grad_2.to_data(), Data::from([[16.0, 15.0], [24.0, 50.0]])); + } + + #[test] + fn should_diff_matmul_with_slice_assign_complex() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[9.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + + let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_5 = tensor_2.clone().slice([0..1, 0..1]); + let tensor_6 = tensor_5.mul(tensor_3.clone()); + let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6); + let tensor_8 = tensor_7.matmul(tensor_1.clone()); + + let grads = tensor_8.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_3 = tensor_3.grad(&grads).unwrap(); + + assert_eq!(grad_3.to_data(), Data::from([[32.0]])); + assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]])); + assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]])); + } } diff --git a/burn-autodiff/src/tests/softmax.rs b/burn-autodiff/src/tests/softmax.rs index c825ae3192..d282a4be45 100644 --- a/burn-autodiff/src/tests/softmax.rs +++ b/burn-autodiff/src/tests/softmax.rs @@ -1,72 +1,72 @@ #[burn_tensor_testgen::testgen(ad_softmax)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_softmax_grad() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = Tensor::::from_data(data_1).require_grad(); - let tensor_2 = Tensor::::from_data(data_2).require_grad(); + #[test] + fn test_softmax_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); + } - #[test] - fn test_log_softmax_grad() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = Tensor::::from_data(data_1).require_grad(); - let tensor_2 = Tensor::::from_data(data_2).require_grad(); + #[test] + fn test_log_softmax_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone()); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone()); - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3); + } - #[test] - fn test_quiet_softmax_grad() { - let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn test_quiet_softmax_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = Tensor::::from_data(data_1).require_grad(); - let tensor_2 = Tensor::::from_data(data_2).require_grad(); + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); - let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); - let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); + } } diff --git a/burn-autodiff/src/tests/sqrt.rs b/burn-autodiff/src/tests/sqrt.rs index 94b0d6c860..c9d075fedd 100644 --- a/burn-autodiff/src/tests/sqrt.rs +++ b/burn-autodiff/src/tests/sqrt.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_sqrt)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_sqrt() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_sqrt() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[82.1126, 99.0832], [82.1126, 99.0832]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[30.3093, 33.1204], [34.5819, 38.7694]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[82.1126, 99.0832], [82.1126, 99.0832]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[30.3093, 33.1204], [34.5819, 38.7694]]), 3); + } } diff --git a/burn-autodiff/src/tests/sub.rs b/burn-autodiff/src/tests/sub.rs index b89850f506..50beae42f3 100644 --- a/burn-autodiff/src/tests/sub.rs +++ b/burn-autodiff/src/tests/sub.rs @@ -1,60 +1,60 @@ #[burn_tensor_testgen::testgen(ad_sub)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_sub() { - let data_1 = Data::from([2.0, 5.0]); - let data_2 = Data::from([4.0, 1.0]); + #[test] + fn should_diff_sub() { + let data_1 = Data::from([2.0, 5.0]); + let data_2 = Data::from([4.0, 1.0]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().sub(tensor_2.clone()); - let grads = tensor_3.backward(); + let tensor_3 = tensor_1.clone().sub(tensor_2.clone()); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); - assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0])); - assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); - } + assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0])); + assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); + } - #[test] - fn should_diff_sub_scalar() { - let data = Data::from([2.0, 10.0]); - let tensor = TestAutodiffTensor::from_data(data).require_grad(); - let tensor_out = tensor.clone().sub_scalar(5.0); - let grads = tensor_out.backward(); + #[test] + fn should_diff_sub_scalar() { + let data = Data::from([2.0, 10.0]); + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().sub_scalar(5.0); + let grads = tensor_out.backward(); - let grad = tensor.grad(&grads).unwrap(); + let grad = tensor.grad(&grads).unwrap(); - assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); - assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); - } + assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); + } - #[test] - fn test_sub_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + #[test] + fn test_sub_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad(); - let tensor_4 = tensor_1.clone().sub(tensor_2.clone()); - let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0); - let tensor_6 = tensor_1.clone().sub(tensor_5); + let tensor_4 = tensor_1.clone().sub(tensor_2.clone()); + let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0); + let tensor_6 = tensor_1.clone().sub(tensor_5); - let grads = tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); - assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); - } + assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); + assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); + } } diff --git a/burn-autodiff/src/tests/tanh.rs b/burn-autodiff/src/tests/tanh.rs index 3dc8700451..db1b884baf 100644 --- a/burn-autodiff/src/tests/tanh.rs +++ b/burn-autodiff/src/tests/tanh.rs @@ -1,28 +1,28 @@ #[burn_tensor_testgen::testgen(ad_tanh)] mod tests { - use super::*; - use burn_tensor::Data; + use super::*; + use burn_tensor::Data; - #[test] - fn should_diff_tanh() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + #[test] + fn should_diff_tanh() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh()); - let tensor_4 = tensor_3.matmul(tensor_2.clone()); - let grads = tensor_4.backward(); + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh()); + let tensor_4 = tensor_3.matmul(tensor_2.clone()); + let grads = tensor_4.backward(); - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[8.00092, 8.000153], [8.000003, 7.999995]]), 3); - } + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[8.00092, 8.000153], [8.000003, 7.999995]]), 3); + } } diff --git a/burn-autodiff/src/tests/transpose.rs b/burn-autodiff/src/tests/transpose.rs index aaf4ecd952..bead7b4671 100644 --- a/burn-autodiff/src/tests/transpose.rs +++ b/burn-autodiff/src/tests/transpose.rs @@ -1,50 +1,50 @@ #[burn_tensor_testgen::testgen(ad_transpose)] mod tests { - use super::*; - use burn_tensor::Data; - - #[test] - fn should_diff_transpose() { - let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); - let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose()); - let tensor_4 = tensor_3.transpose(); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]])); - assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]])); - } - - #[test] - fn should_diff_swap_dims() { - let tensor_1 = - TestAutodiffTensor::from_floats([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]) - .require_grad(); - let tensor_2 = - TestAutodiffTensor::from_floats([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]) - .require_grad(); - - let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2)); - let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2)); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]) - ); - assert_eq!( - grad_2.to_data(), - Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]) - ); - } + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_transpose() { + let data_1 = Data::::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2 = Data::::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose()); + let tensor_4 = tensor_3.transpose(); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]])); + } + + #[test] + fn should_diff_swap_dims() { + let tensor_1 = + TestAutodiffTensor::from_floats([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]) + .require_grad(); + let tensor_2 = + TestAutodiffTensor::from_floats([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]) + .require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2)); + let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2)); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]) + ); + assert_eq!( + grad_2.to_data(), + Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]) + ); + } } diff --git a/burn-autodiff/src/utils.rs b/burn-autodiff/src/utils.rs index 56480f4805..617c101e2a 100644 --- a/burn-autodiff/src/utils.rs +++ b/burn-autodiff/src/utils.rs @@ -9,16 +9,16 @@ use crate::graph::NodeRef; /// /// If the object is a tensor and if one reference exists, it can be updated inplace. pub fn duplicate( - nodes: &[Option; N], - obj: Option, + nodes: &[Option; N], + obj: Option, ) -> [Option; N] { - nodes - .iter() - .map(|node| match node { - Some(_) => obj.clone(), - None => None, - }) - .collect::>() - .try_into() - .unwrap() + nodes + .iter() + .map(|node| match node { + Some(_) => obj.clone(), + None => None, + }) + .collect::>() + .try_into() + .unwrap() } diff --git a/burn-candle/src/backend.rs b/burn-candle/src/backend.rs index 9aab0cc306..c2bec24177 100644 --- a/burn-candle/src/backend.rs +++ b/burn-candle/src/backend.rs @@ -4,8 +4,8 @@ use burn_tensor::backend::Backend; use candle_core::DeviceLocation; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + CandleTensor, }; /// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations. @@ -15,11 +15,11 @@ use crate::{ #[derive(Clone, Copy, Default, Debug)] pub struct Candle where - F: FloatCandleElement, - I: IntCandleElement, + F: FloatCandleElement, + I: IntCandleElement, { - _float: PhantomData, - _int: PhantomData, + _float: PhantomData, + _int: PhantomData, } /// The device type for the candle backend. @@ -28,62 +28,62 @@ where /// /// Note that you need to provide the device index when using Cuda. pub enum CandleDevice { - /// CPU device. - Cpu, + /// CPU device. + Cpu, - /// Cuda device with the given index. The index is the index of the Cuda device in the list of - /// all Cuda devices found on the system. - Cuda(usize), + /// Cuda device with the given index. The index is the index of the Cuda device in the list of + /// all Cuda devices found on the system. + Cuda(usize), } impl From for candle_core::Device { - fn from(device: CandleDevice) -> Self { - match device { - CandleDevice::Cpu => candle_core::Device::Cpu, - CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(), + fn from(device: CandleDevice) -> Self { + match device { + CandleDevice::Cpu => candle_core::Device::Cpu, + CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(), + } } - } } impl From for CandleDevice { - fn from(device: candle_core::Device) -> Self { - match device.location() { - DeviceLocation::Cpu => CandleDevice::Cpu, - DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id), + fn from(device: candle_core::Device) -> Self { + match device.location() { + DeviceLocation::Cpu => CandleDevice::Cpu, + DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id), + } } - } } impl Default for CandleDevice { - fn default() -> Self { - Self::Cpu - } + fn default() -> Self { + Self::Cpu + } } impl Backend for Candle { - type Device = CandleDevice; + type Device = CandleDevice; - type FullPrecisionBackend = Candle; - type FullPrecisionElem = f32; + type FullPrecisionBackend = Candle; + type FullPrecisionElem = f32; - type TensorPrimitive = CandleTensor; - type FloatElem = F; + type TensorPrimitive = CandleTensor; + type FloatElem = F; - type IntTensorPrimitive = CandleTensor; - type IntElem = I; + type IntTensorPrimitive = CandleTensor; + type IntElem = I; - type BoolTensorPrimitive = CandleTensor; + type BoolTensorPrimitive = CandleTensor; - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn name() -> String { - "candle".to_string() - } + fn name() -> String { + "candle".to_string() + } - fn seed(seed: u64) { - // TODO submit an issue at Candle - panic!("Manual seed not supported by Candle. ") - } + fn seed(seed: u64) { + // TODO submit an issue at Candle + panic!("Manual seed not supported by Candle. ") + } } diff --git a/burn-candle/src/lib.rs b/burn-candle/src/lib.rs index 34d7b992be..a0ee7f0490 100644 --- a/burn-candle/src/lib.rs +++ b/burn-candle/src/lib.rs @@ -15,132 +15,132 @@ pub use tensor::*; #[cfg(test)] mod tests { - extern crate alloc; - use super::*; + extern crate alloc; + use super::*; - pub type TestBackend = Candle; - pub type ReferenceBackend = burn_tch::LibTorch; + pub type TestBackend = Candle; + pub type ReferenceBackend = burn_tch::LibTorch; - pub type TestTensor = burn_tensor::Tensor; - pub type ReferenceTensor = burn_tensor::Tensor; - pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensor = burn_tensor::Tensor; + pub type ReferenceTensor = burn_tensor::Tensor; + pub type TestTensorInt = burn_tensor::Tensor; - type TestAutodiffBackend = burn_autodiff::Autodiff; - type TestAutodiffTensor = burn_tensor::Tensor; + type TestAutodiffBackend = burn_autodiff::Autodiff; + type TestAutodiffTensor = burn_tensor::Tensor; - // test activation - burn_tensor::testgen_gelu!(); - burn_tensor::testgen_relu!(); - burn_tensor::testgen_softmax!(); - burn_tensor::testgen_sigmoid!(); - burn_tensor::testgen_silu!(); + // test activation + burn_tensor::testgen_gelu!(); + burn_tensor::testgen_relu!(); + burn_tensor::testgen_softmax!(); + burn_tensor::testgen_sigmoid!(); + burn_tensor::testgen_silu!(); - // test module - burn_tensor::testgen_module_forward!(); - burn_tensor::testgen_module_conv1d!(); - // burn_tensor::testgen_module_conv2d!(); - // burn_tensor::testgen_module_conv_transpose1d!(); - // burn_tensor::testgen_module_conv_transpose2d!(); - // burn_tensor::testgen_module_max_pool1d!(); - // burn_tensor::testgen_module_max_pool2d!(); - // burn_tensor::testgen_module_avg_pool1d!(); - // burn_tensor::testgen_module_avg_pool2d!(); - // burn_tensor::testgen_module_adaptive_avg_pool1d!(); - // burn_tensor::testgen_module_adaptive_avg_pool2d!(); + // test module + burn_tensor::testgen_module_forward!(); + burn_tensor::testgen_module_conv1d!(); + // burn_tensor::testgen_module_conv2d!(); + // burn_tensor::testgen_module_conv_transpose1d!(); + // burn_tensor::testgen_module_conv_transpose2d!(); + // burn_tensor::testgen_module_max_pool1d!(); + // burn_tensor::testgen_module_max_pool2d!(); + // burn_tensor::testgen_module_avg_pool1d!(); + // burn_tensor::testgen_module_avg_pool2d!(); + // burn_tensor::testgen_module_adaptive_avg_pool1d!(); + // burn_tensor::testgen_module_adaptive_avg_pool2d!(); - // test ops - burn_tensor::testgen_add!(); - // burn_tensor::testgen_aggregation!(); - burn_tensor::testgen_arange!(); - burn_tensor::testgen_arange_step!(); - burn_tensor::testgen_arg!(); - burn_tensor::testgen_cast!(); - burn_tensor::testgen_cat!(); - burn_tensor::testgen_recip!(); - burn_tensor::testgen_clamp!(); - burn_tensor::testgen_cos!(); - // burn_tensor::testgen_div!(); - burn_tensor::testgen_erf!(); - burn_tensor::testgen_exp!(); - burn_tensor::testgen_flatten!(); - burn_tensor::testgen_full!(); - burn_tensor::testgen_gather_scatter!(); - burn_tensor::testgen_init!(); - burn_tensor::testgen_log!(); - burn_tensor::testgen_log1p!(); - burn_tensor::testgen_map_comparison!(); - burn_tensor::testgen_mask!(); - burn_tensor::testgen_matmul!(); - burn_tensor::testgen_maxmin!(); - burn_tensor::testgen_mul!(); - burn_tensor::testgen_neg!(); - burn_tensor::testgen_powf!(); - burn_tensor::testgen_random!(); - // burn_tensor::testgen_repeat!(); - burn_tensor::testgen_reshape!(); - burn_tensor::testgen_select!(); - burn_tensor::testgen_sin!(); - // burn_tensor::testgen_slice!(); - burn_tensor::testgen_sqrt!(); - burn_tensor::testgen_abs!(); - burn_tensor::testgen_squeeze!(); - burn_tensor::testgen_sub!(); - burn_tensor::testgen_tanh!(); - burn_tensor::testgen_transpose!(); + // test ops + burn_tensor::testgen_add!(); + // burn_tensor::testgen_aggregation!(); + burn_tensor::testgen_arange!(); + burn_tensor::testgen_arange_step!(); + burn_tensor::testgen_arg!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_recip!(); + burn_tensor::testgen_clamp!(); + burn_tensor::testgen_cos!(); + // burn_tensor::testgen_div!(); + burn_tensor::testgen_erf!(); + burn_tensor::testgen_exp!(); + burn_tensor::testgen_flatten!(); + burn_tensor::testgen_full!(); + burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_init!(); + burn_tensor::testgen_log!(); + burn_tensor::testgen_log1p!(); + burn_tensor::testgen_map_comparison!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_matmul!(); + burn_tensor::testgen_maxmin!(); + burn_tensor::testgen_mul!(); + burn_tensor::testgen_neg!(); + burn_tensor::testgen_powf!(); + burn_tensor::testgen_random!(); + // burn_tensor::testgen_repeat!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_select!(); + burn_tensor::testgen_sin!(); + // burn_tensor::testgen_slice!(); + burn_tensor::testgen_sqrt!(); + burn_tensor::testgen_abs!(); + burn_tensor::testgen_squeeze!(); + burn_tensor::testgen_sub!(); + burn_tensor::testgen_tanh!(); + burn_tensor::testgen_transpose!(); - // test stats - burn_tensor::testgen_var!(); - burn_tensor::testgen_display!(); + // test stats + burn_tensor::testgen_var!(); + burn_tensor::testgen_display!(); - // Behavior - // burn_autodiff::testgen_ad_broadcast!(); + // Behavior + // burn_autodiff::testgen_ad_broadcast!(); - // Activation - burn_autodiff::testgen_ad_relu!(); - burn_autodiff::testgen_ad_gelu!(); + // Activation + burn_autodiff::testgen_ad_relu!(); + burn_autodiff::testgen_ad_gelu!(); - // Modules - // burn_autodiff::testgen_ad_conv1d!(); - // burn_autodiff::testgen_ad_conv2d!(); - // burn_autodiff::testgen_ad_conv_transpose1d!(); - // burn_autodiff::testgen_ad_conv_transpose2d!(); - // burn_autodiff::testgen_ad_max_pool1d!(); - // burn_autodiff::testgen_ad_max_pool2d!(); - // burn_autodiff::testgen_ad_avg_pool1d!(); - // burn_autodiff::testgen_ad_avg_pool2d!(); - // burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); - // burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); - burn_autodiff::testgen_module_backward!(); + // Modules + // burn_autodiff::testgen_ad_conv1d!(); + // burn_autodiff::testgen_ad_conv2d!(); + // burn_autodiff::testgen_ad_conv_transpose1d!(); + // burn_autodiff::testgen_ad_conv_transpose2d!(); + // burn_autodiff::testgen_ad_max_pool1d!(); + // burn_autodiff::testgen_ad_max_pool2d!(); + // burn_autodiff::testgen_ad_avg_pool1d!(); + // burn_autodiff::testgen_ad_avg_pool2d!(); + // burn_autodiff::testgen_ad_adaptive_avg_pool1d!(); + // burn_autodiff::testgen_ad_adaptive_avg_pool2d!(); + burn_autodiff::testgen_module_backward!(); - // Tensor - burn_autodiff::testgen_ad_complex!(); - burn_autodiff::testgen_ad_multithread!(); - burn_autodiff::testgen_ad_add!(); - burn_autodiff::testgen_ad_aggregation!(); - burn_autodiff::testgen_ad_maxmin!(); - // burn_autodiff::testgen_ad_cat!(); - burn_autodiff::testgen_ad_cos!(); - burn_autodiff::testgen_ad_cross_entropy_loss!(); - burn_autodiff::testgen_ad_div!(); - burn_autodiff::testgen_ad_erf!(); - burn_autodiff::testgen_ad_exp!(); - // burn_autodiff::testgen_ad_slice!(); - burn_autodiff::testgen_ad_gather_scatter!(); - burn_autodiff::testgen_ad_select!(); - burn_autodiff::testgen_ad_log!(); - burn_autodiff::testgen_ad_log1p!(); - burn_autodiff::testgen_ad_mask!(); - burn_autodiff::testgen_ad_matmul!(); - burn_autodiff::testgen_ad_mul!(); - burn_autodiff::testgen_ad_neg!(); - burn_autodiff::testgen_ad_powf!(); - burn_autodiff::testgen_ad_recip!(); - burn_autodiff::testgen_ad_reshape!(); - burn_autodiff::testgen_ad_sin!(); - burn_autodiff::testgen_ad_softmax!(); - burn_autodiff::testgen_ad_sqrt!(); - burn_autodiff::testgen_ad_abs!(); - burn_autodiff::testgen_ad_sub!(); - burn_autodiff::testgen_ad_tanh!(); - burn_autodiff::testgen_ad_transpose!(); + // Tensor + burn_autodiff::testgen_ad_complex!(); + burn_autodiff::testgen_ad_multithread!(); + burn_autodiff::testgen_ad_add!(); + burn_autodiff::testgen_ad_aggregation!(); + burn_autodiff::testgen_ad_maxmin!(); + // burn_autodiff::testgen_ad_cat!(); + burn_autodiff::testgen_ad_cos!(); + burn_autodiff::testgen_ad_cross_entropy_loss!(); + burn_autodiff::testgen_ad_div!(); + burn_autodiff::testgen_ad_erf!(); + burn_autodiff::testgen_ad_exp!(); + // burn_autodiff::testgen_ad_slice!(); + burn_autodiff::testgen_ad_gather_scatter!(); + burn_autodiff::testgen_ad_select!(); + burn_autodiff::testgen_ad_log!(); + burn_autodiff::testgen_ad_log1p!(); + burn_autodiff::testgen_ad_mask!(); + burn_autodiff::testgen_ad_matmul!(); + burn_autodiff::testgen_ad_mul!(); + burn_autodiff::testgen_ad_neg!(); + burn_autodiff::testgen_ad_powf!(); + burn_autodiff::testgen_ad_recip!(); + burn_autodiff::testgen_ad_reshape!(); + burn_autodiff::testgen_ad_sin!(); + burn_autodiff::testgen_ad_softmax!(); + burn_autodiff::testgen_ad_sqrt!(); + burn_autodiff::testgen_ad_abs!(); + burn_autodiff::testgen_ad_sub!(); + burn_autodiff::testgen_ad_tanh!(); + burn_autodiff::testgen_ad_transpose!(); } diff --git a/burn-candle/src/ops/activation.rs b/burn-candle/src/ops/activation.rs index fadeb94e7f..0cedb23aa5 100644 --- a/burn-candle/src/ops/activation.rs +++ b/burn-candle/src/ops/activation.rs @@ -1,16 +1,16 @@ use burn_tensor::ops::{ActivationOps, FloatTensor}; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - tensor, Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + tensor, Candle, CandleTensor, }; impl ActivationOps for Candle { - fn gelu(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.gelu().unwrap()) - } + fn gelu(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.gelu().unwrap()) + } - fn relu(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.relu().unwrap()) - } + fn relu(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.relu().unwrap()) + } } diff --git a/burn-candle/src/ops/base.rs b/burn-candle/src/ops/base.rs index a241622fb6..643c27a73d 100644 --- a/burn-candle/src/ops/base.rs +++ b/burn-candle/src/ops/base.rs @@ -3,88 +3,88 @@ use std::marker::PhantomData; use burn_tensor::{backend::Backend, Data, Reader, Shape}; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleDevice, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleDevice, CandleTensor, }; use super::tensor; pub fn cat( - tensors: Vec>, - dim: usize, + tensors: Vec>, + dim: usize, ) -> CandleTensor { - let tensors: Vec = tensors.into_iter().map(|t| t.tensor).collect(); - CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap()) + let tensors: Vec = tensors.into_iter().map(|t| t.tensor).collect(); + CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap()) } pub fn from_data( - data: Data, - device: &CandleDevice, + data: Data, + device: &CandleDevice, ) -> CandleTensor { - CandleTensor::from_data(data, *device) + CandleTensor::from_data(data, *device) } pub fn into_data(tensor: CandleTensor) -> Data { - Data::new( - tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(), - tensor.shape(), - ) + Data::new( + tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(), + tensor.shape(), + ) } pub fn to_device( - tensor: CandleTensor, - device: &CandleDevice, + tensor: CandleTensor, + device: &CandleDevice, ) -> CandleTensor { - CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap()) + CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap()) } pub fn empty( - shape: Shape, - device: &CandleDevice, + shape: Shape, + device: &CandleDevice, ) -> CandleTensor { - CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(*device).into()).unwrap()) + CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(*device).into()).unwrap()) } pub fn swap_dims( - mut tensor: CandleTensor, - dim1: usize, - dim2: usize, + mut tensor: CandleTensor, + dim1: usize, + dim2: usize, ) -> CandleTensor { - CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap()) + CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap()) } pub fn reshape( - tensor: CandleTensor, - shape: Shape, + tensor: CandleTensor, + shape: Shape, ) -> CandleTensor { - CandleTensor::new(tensor.tensor.reshape(&shape.dims).unwrap()) + CandleTensor::new(tensor.tensor.reshape(&shape.dims).unwrap()) } pub fn device(tensor: &CandleTensor) -> CandleDevice { - tensor.tensor.device().clone().into() + tensor.tensor.device().clone().into() } pub fn shape(tensor: &CandleTensor) -> Shape { - tensor.shape() + tensor.shape() } pub fn slice( - tensor: CandleTensor, - ranges: [std::ops::Range; D2], + tensor: CandleTensor, + ranges: [std::ops::Range; D2], ) -> CandleTensor { - let mut narrow_tensor = tensor.tensor; - for (i, range) in ranges.iter().enumerate().take(D2) { - narrow_tensor = narrow_tensor - .narrow(i, range.start, range.end - range.start) - .unwrap() - } - CandleTensor::new(narrow_tensor) + let mut narrow_tensor = tensor.tensor; + for (i, range) in ranges.iter().enumerate().take(D2) { + narrow_tensor = narrow_tensor + .narrow(i, range.start, range.end - range.start) + .unwrap() + } + CandleTensor::new(narrow_tensor) } pub fn slice_assign( - tensor: CandleTensor, - ranges: [std::ops::Range; D2], - value: CandleTensor, + tensor: CandleTensor, + ranges: [std::ops::Range; D2], + value: CandleTensor, ) -> CandleTensor { - panic!("slice_assign not supported by Candle") + panic!("slice_assign not supported by Candle") } diff --git a/burn-candle/src/ops/bool_tensor.rs b/burn-candle/src/ops/bool_tensor.rs index 4cd6fa0da9..e5fd153f85 100644 --- a/burn-candle/src/ops/bool_tensor.rs +++ b/burn-candle/src/ops/bool_tensor.rs @@ -1,113 +1,112 @@ use burn_tensor::{ - ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor}, - Data, Device, Reader, Shape, + ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor}, + Data, Device, Reader, Shape, }; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleTensor, }; impl BoolTensorOps for Candle { - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - super::base::empty(shape, device) - } - - fn bool_shape(tensor: &BoolTensor) -> Shape { - super::base::shape(tensor) - } - - fn bool_into_data(tensor: BoolTensor) -> Reader> { - let x: Vec = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(); - let y = x.iter().map(|b| !matches!(b, 0)).collect(); - let data = Data::new(y, tensor.shape()); - - Reader::Concrete(data) - } - - fn bool_from_data( - data: Data, - device: &Device, - ) -> BoolTensor { - let data: Data = Data::new( - data - .value - .into_iter() - .map(|c| match c { - true => 1, - false => 0, - }) - .collect(), - data.shape, - ); - super::base::from_data(data, device) - } - - fn bool_into_int(tensor: BoolTensor) -> IntTensor { - CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) - } - - fn bool_into_float(tensor: BoolTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) - } - - fn bool_device(tensor: &BoolTensor) -> Device { - super::base::device(tensor) - } - - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - super::base::to_device(tensor, device) - } - - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - super::base::reshape(tensor, shape) - } - - fn bool_slice( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - ) -> BoolTensor { - super::base::slice(tensor, ranges) - } - - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - value: BoolTensor, - ) -> BoolTensor { - super::base::slice_assign(tensor, ranges, value) - } - - fn bool_cat( - tensors: Vec>, - dim: usize, - ) -> BoolTensor { - super::base::cat(tensors, dim) - } - - fn bool_equal( - lhs: BoolTensor, - rhs: BoolTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) - } - - fn bool_not(tensor: BoolTensor) -> BoolTensor { - let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap()); - CandleTensor::new(tensor.tensor.eq(&x).unwrap()) - } - - fn bool_swap_dims( - tensor: as burn_tensor::backend::Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { - super::base::swap_dims(tensor, dim1, dim2) - } + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + super::base::empty(shape, device) + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + super::base::shape(tensor) + } + + fn bool_into_data(tensor: BoolTensor) -> Reader> { + let x: Vec = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(); + let y = x.iter().map(|b| !matches!(b, 0)).collect(); + let data = Data::new(y, tensor.shape()); + + Reader::Concrete(data) + } + + fn bool_from_data( + data: Data, + device: &Device, + ) -> BoolTensor { + let data: Data = Data::new( + data.value + .into_iter() + .map(|c| match c { + true => 1, + false => 0, + }) + .collect(), + data.shape, + ); + super::base::from_data(data, device) + } + + fn bool_into_int(tensor: BoolTensor) -> IntTensor { + CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) + } + + fn bool_into_float(tensor: BoolTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) + } + + fn bool_device(tensor: &BoolTensor) -> Device { + super::base::device(tensor) + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + super::base::to_device(tensor, device) + } + + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + super::base::reshape(tensor, shape) + } + + fn bool_slice( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + ) -> BoolTensor { + super::base::slice(tensor, ranges) + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + value: BoolTensor, + ) -> BoolTensor { + super::base::slice_assign(tensor, ranges, value) + } + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> BoolTensor { + super::base::cat(tensors, dim) + } + + fn bool_equal( + lhs: BoolTensor, + rhs: BoolTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap()); + CandleTensor::new(tensor.tensor.eq(&x).unwrap()) + } + + fn bool_swap_dims( + tensor: as burn_tensor::backend::Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { + super::base::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-candle/src/ops/candle_utils.rs b/burn-candle/src/ops/candle_utils.rs index 499b46cff3..2a9b92cb37 100644 --- a/burn-candle/src/ops/candle_utils.rs +++ b/burn-candle/src/ops/candle_utils.rs @@ -3,23 +3,23 @@ use candle_core::{DType, Device, Shape, Tensor}; use crate::element::CandleElement; pub(crate) fn fill>( - value: E, - shape: S, - dtype: DType, - device: &Device, + value: E, + shape: S, + dtype: DType, + device: &Device, ) -> Tensor { - let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::()).unwrap(); - values.expand(shape).unwrap() + let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::()).unwrap(); + values.expand(shape).unwrap() } pub(crate) fn fill_like( - value: E, - reference_tensor: &Tensor, + value: E, + reference_tensor: &Tensor, ) -> Tensor { - fill( - value, - reference_tensor.shape(), - reference_tensor.dtype(), - reference_tensor.device(), - ) + fill( + value, + reference_tensor.shape(), + reference_tensor.dtype(), + reference_tensor.device(), + ) } diff --git a/burn-candle/src/ops/int_tensor.rs b/burn-candle/src/ops/int_tensor.rs index 429d27029a..c1c502bf33 100644 --- a/burn-candle/src/ops/int_tensor.rs +++ b/burn-candle/src/ops/int_tensor.rs @@ -1,365 +1,362 @@ use burn_tensor::{ - ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, - Bool, Data, Device, Reader, Shape, + ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, + Bool, Data, Device, Reader, Shape, }; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleTensor, }; impl IntTensorOps for Candle { - fn int_empty(shape: Shape, device: &Device) -> IntTensor { - super::base::empty(shape, device) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - super::base::shape(tensor) - } - - fn int_into_data(tensor: IntTensor) -> Reader, D>> { - Reader::Concrete(super::base::into_data(tensor)) - } - - fn int_from_data( - data: Data, D>, - device: &Device, - ) -> IntTensor { - super::base::from_data(data, device) - } - - fn int_device(tensor: &IntTensor) -> Device { - super::base::device(tensor) - } - - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - super::base::to_device(tensor, device) - } - - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - super::base::reshape(tensor, shape) - } - - fn int_slice( - tensor: IntTensor, - indices: [std::ops::Range; D2], - ) -> IntTensor { - super::base::slice(tensor, indices) - } - - fn int_slice_assign( - tensor: IntTensor, - indices: [std::ops::Range; D2], - value: IntTensor, - ) -> IntTensor { - super::base::slice_assign(tensor, indices, value) - } - - fn int_into_float(tensor: IntTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) - } - - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - source: IntTensor, - ) -> IntTensor { - CandleTensor::new( - mask - .tensor - .where_cond(&source.tensor, &tensor.tensor) - .unwrap(), - ) - } - - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor { - CandleTensor::new( - mask - .tensor - .where_cond( - &super::candle_utils::fill_like::(value, &tensor.tensor), - &tensor.tensor, + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + super::base::empty(shape, device) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + super::base::shape(tensor) + } + + fn int_into_data(tensor: IntTensor) -> Reader, D>> { + Reader::Concrete(super::base::into_data(tensor)) + } + + fn int_from_data( + data: Data, D>, + device: &Device, + ) -> IntTensor { + super::base::from_data(data, device) + } + + fn int_device(tensor: &IntTensor) -> Device { + super::base::device(tensor) + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + super::base::to_device(tensor, device) + } + + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + super::base::reshape(tensor, shape) + } + + fn int_slice( + tensor: IntTensor, + indices: [std::ops::Range; D2], + ) -> IntTensor { + super::base::slice(tensor, indices) + } + + fn int_slice_assign( + tensor: IntTensor, + indices: [std::ops::Range; D2], + value: IntTensor, + ) -> IntTensor { + super::base::slice_assign(tensor, indices, value) + } + + fn int_into_float(tensor: IntTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + source: IntTensor, + ) -> IntTensor { + CandleTensor::new( + mask.tensor + .where_cond(&source.tensor, &tensor.tensor) + .unwrap(), ) - .unwrap(), - ) - } - - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .scatter_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) - } - - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .index_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - super::base::cat(tensors, dim) - } - - fn int_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) - } - - fn int_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_greater( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) - } - - fn int_greater_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_lower( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) - } - - fn int_lower_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) - } - - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn int_add( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) - } - - fn int_add_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) - } - - fn int_sub( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) - } - - fn int_sub_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) - } - - fn int_mul( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) - } - - fn int_mul_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) - } - - fn int_div( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) - } - - fn int_div_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. - panic!("Not supported by Candle") - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(*device).into()).unwrap()) - } - - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - CandleTensor::new(candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(*device).into()).unwrap()) - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); - CandleTensor::from_data( - Data::new([sum].into(), [1].into()), - Self::int_device(&tensor), - ) - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. - panic!("Not supported by Candle") - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmax_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmin_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn int_abs(tensor: IntTensor) -> IntTensor { - // Ugly type conversion here as Candle does not support unary ops on ints - CandleTensor::new( - tensor - .tensor - .to_dtype(F::DTYPE) - .unwrap() - .abs() - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn int_swap_dims( - tensor: as burn_tensor::backend::Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as burn_tensor::backend::Backend>::IntTensorPrimitive { - super::base::swap_dims(tensor, dim1, dim2) - } + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor { + CandleTensor::new( + mask.tensor + .where_cond( + &super::candle_utils::fill_like::(value, &tensor.tensor), + &tensor.tensor, + ) + .unwrap(), + ) + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .scatter_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .index_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + super::base::cat(tensors, dim) + } + + fn int_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) + } + + fn int_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_greater( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) + } + + fn int_greater_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_lower( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) + } + + fn int_lower_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) + } + + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn int_add( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) + } + + fn int_add_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) + } + + fn int_sub( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) + } + + fn int_sub_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) + } + + fn int_mul( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) + } + + fn int_mul_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) + } + + fn int_div( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) + } + + fn int_div_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. + panic!("Not supported by Candle") + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + CandleTensor::new( + candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(*device).into()).unwrap(), + ) + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + CandleTensor::new( + candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(*device).into()).unwrap(), + ) + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); + CandleTensor::from_data( + Data::new([sum].into(), [1].into()), + Self::int_device(&tensor), + ) + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. + panic!("Not supported by Candle") + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmax_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmin_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn int_abs(tensor: IntTensor) -> IntTensor { + // Ugly type conversion here as Candle does not support unary ops on ints + CandleTensor::new( + tensor + .tensor + .to_dtype(F::DTYPE) + .unwrap() + .abs() + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn int_swap_dims( + tensor: as burn_tensor::backend::Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as burn_tensor::backend::Backend>::IntTensorPrimitive { + super::base::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-candle/src/ops/module.rs b/burn-candle/src/ops/module.rs index f8c1ae719f..0a169277fa 100644 --- a/burn-candle/src/ops/module.rs +++ b/burn-candle/src/ops/module.rs @@ -1,220 +1,223 @@ use burn_tensor::{ - ops::{ - ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool2dBackward, - MaxPool2dWithIndices, ModuleOps, UnfoldOptions, - }, - Shape, + ops::{ + ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool2dBackward, + MaxPool2dWithIndices, ModuleOps, UnfoldOptions, + }, + Shape, }; use candle_core::ToUsize2; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - ops::base::reshape, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + ops::base::reshape, + Candle, CandleTensor, }; impl ModuleOps for Candle { - fn conv1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> FloatTensor { - let conv = x - .tensor - .conv1d( - &weight.tensor, - options.padding[0], - options.stride[0], - options.dilation[0], - options.groups, - ) - .unwrap(); - CandleTensor::new(match bias { - Some(bias) => conv - .broadcast_add(&bias.tensor.unsqueeze(1).unwrap()) - .unwrap(), - None => conv, - }) - } + fn conv1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> FloatTensor { + let conv = x + .tensor + .conv1d( + &weight.tensor, + options.padding[0], + options.stride[0], + options.dilation[0], + options.groups, + ) + .unwrap(); + CandleTensor::new(match bias { + Some(bias) => conv + .broadcast_add(&bias.tensor.unsqueeze(1).unwrap()) + .unwrap(), + None => conv, + }) + } - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor { - assert!( - options.dilation[0] == options.dilation[1] - && options.padding[0] == options.padding[1] - && options.stride[0] == options.stride[1], - "Candle does not support per dimension options in convolutions" - ); - let conv = x - .tensor - .conv2d( - &weight.tensor, - options.padding[0], - options.stride[0], - options.dilation[0], - options.groups, - ) - .unwrap(); - CandleTensor::new(match bias { - Some(bias) => conv - .broadcast_add( - &bias + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor { + assert!( + options.dilation[0] == options.dilation[1] + && options.padding[0] == options.padding[1] + && options.stride[0] == options.stride[1], + "Candle does not support per dimension options in convolutions" + ); + let conv = x .tensor - .unsqueeze(0) - .unwrap() - .unsqueeze(2) - .unwrap() - .unsqueeze(3) - .unwrap(), - ) - .unwrap(), - None => conv, - }) - } + .conv2d( + &weight.tensor, + options.padding[0], + options.stride[0], + options.dilation[0], + options.groups, + ) + .unwrap(); + CandleTensor::new(match bias { + Some(bias) => conv + .broadcast_add( + &bias + .tensor + .unsqueeze(0) + .unwrap() + .unsqueeze(2) + .unwrap() + .unsqueeze(3) + .unwrap(), + ) + .unwrap(), + None => conv, + }) + } - fn conv_transpose1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> FloatTensor { - panic!("Candle does not support conv_transpose1d") - } + fn conv_transpose1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + panic!("Candle does not support conv_transpose1d") + } - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor { - assert!( - options.dilation[0] == options.dilation[1] - && options.padding[0] == options.padding[1] - && options.padding_out[0] == options.padding_out[1] - && options.stride[0] == options.stride[1], - "Candle does not support per dimension options in transposed convolutions" - ); - assert!( - options.groups == 1, - "Candle does not support groups in transposed convolutions" - ); - let conv_transpose = x - .tensor - .conv_transpose2d( - &weight.tensor, - options.padding[0], - options.padding_out[0], - options.stride[0], - options.dilation[0], - ) - .unwrap(); - CandleTensor::new(match bias { - Some(bias) => conv_transpose - .broadcast_add( - &bias + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + assert!( + options.dilation[0] == options.dilation[1] + && options.padding[0] == options.padding[1] + && options.padding_out[0] == options.padding_out[1] + && options.stride[0] == options.stride[1], + "Candle does not support per dimension options in transposed convolutions" + ); + assert!( + options.groups == 1, + "Candle does not support groups in transposed convolutions" + ); + let conv_transpose = x .tensor - .unsqueeze(0) - .unwrap() - .unsqueeze(2) - .unwrap() - .unsqueeze(3) - .unwrap(), - ) - .unwrap(), - None => conv_transpose, - }) - } + .conv_transpose2d( + &weight.tensor, + options.padding[0], + options.padding_out[0], + options.stride[0], + options.dilation[0], + ) + .unwrap(); + CandleTensor::new(match bias { + Some(bias) => conv_transpose + .broadcast_add( + &bias + .tensor + .unsqueeze(0) + .unwrap() + .unsqueeze(2) + .unwrap() + .unsqueeze(3) + .unwrap(), + ) + .unwrap(), + None => conv_transpose, + }) + } - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - assert!( - padding[0] == 0 && padding[1] == 0, - "Candle does not support padding in pooling" - ); - assert!( - count_include_pad, - "Candle does not support excluding pad count in pooling" - ); - CandleTensor::new( - x.tensor - .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) - .unwrap(), - ) - } + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + assert!( + padding[0] == 0 && padding[1] == 0, + "Candle does not support padding in pooling" + ); + assert!( + count_include_pad, + "Candle does not support excluding pad count in pooling" + ); + CandleTensor::new( + x.tensor + .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) + .unwrap(), + ) + } - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - panic!("avg_pool2d_backward is not supported by Candle") - } + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + panic!("avg_pool2d_backward is not supported by Candle") + } - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor { - assert!( - padding[0] == 0 && padding[1] == 0, - "Candle does not support padding in pooling" - ); - assert!( - dilation[0] == 1 && dilation[1] == 1, - "Candle does not support dilation in pooling" - ); - CandleTensor::new( - x.tensor - .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) - .unwrap(), - ) - } + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor { + assert!( + padding[0] == 0 && padding[1] == 0, + "Candle does not support padding in pooling" + ); + assert!( + dilation[0] == 1 && dilation[1] == 1, + "Candle does not support dilation in pooling" + ); + CandleTensor::new( + x.tensor + .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) + .unwrap(), + ) + } - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - panic!("max_pool2d_with_indices is not supported by Candle") - } + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + panic!("max_pool2d_with_indices is not supported by Candle") + } - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward> { - panic!("max_pool2d_with_indices_backward is not supported by Candle") - } + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward> { + panic!("max_pool2d_with_indices_backward is not supported by Candle") + } - fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { - panic!("adaptive_avg_pool2 is not supported by Candle") - } + fn adaptive_avg_pool2d( + x: FloatTensor, + output_size: [usize; 2], + ) -> FloatTensor { + panic!("adaptive_avg_pool2 is not supported by Candle") + } - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - panic!("adaptive_avg_pool2d_backward is not supported by Candle") - } + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + panic!("adaptive_avg_pool2d_backward is not supported by Candle") + } } diff --git a/burn-candle/src/ops/tensor.rs b/burn-candle/src/ops/tensor.rs index 5da1bf86cc..3f096e92f4 100644 --- a/burn-candle/src/ops/tensor.rs +++ b/burn-candle/src/ops/tensor.rs @@ -1,456 +1,449 @@ use std::borrow::Borrow; use burn_tensor::{ - ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, - Data, Device, Distribution, ElementConversion, Reader, Shape, + ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, + Data, Device, Distribution, ElementConversion, Reader, Shape, }; use candle_core::{backend::BackendStorage, shape, Tensor}; use crate::{ - element::{CandleElement, FloatCandleElement, IntCandleElement}, - Candle, CandleTensor, + element::{CandleElement, FloatCandleElement, IntCandleElement}, + Candle, CandleTensor, }; impl TensorOps for Candle { - fn from_data(data: Data, device: &Device) -> CandleTensor { - CandleTensor::from_data(data, *device) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: &Device, - ) -> FloatTensor { - let shape = &shape.dims; - let device = &(*device).into(); - match distribution { - Distribution::Default => CandleTensor::new( - candle_core::Tensor::rand(0., 1., shape, device) - .unwrap() - .to_dtype(F::DTYPE) - .unwrap(), - ), - Distribution::Bernoulli(prob) => CandleTensor::new( - candle_core::Tensor::rand(0., 1., shape, device) - .unwrap() - .to_dtype(F::DTYPE) - .unwrap() - .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device)) - .unwrap() - .to_dtype(F::DTYPE) - .unwrap(), - ), - Distribution::Uniform(from, to) => { - CandleTensor::new(candle_core::Tensor::rand(from, to, shape, device).unwrap()) - } - Distribution::Normal(mean, std) => { - CandleTensor::new(candle_core::Tensor::randn(mean, std, shape, device).unwrap()) - } - } - } - - fn shape(tensor: &CandleTensor) -> Shape { - super::base::shape(tensor) - } - - fn into_data(tensor: CandleTensor) -> Reader> { - Reader::Concrete(super::base::into_data(tensor)) - } - - fn device(tensor: &CandleTensor) -> Device { - super::base::device(tensor) - } - - fn to_device( - tensor: CandleTensor, - device: &Device, - ) -> CandleTensor { - super::base::to_device(tensor, device) - } - - fn into_int(tensor: CandleTensor) -> IntTensor { - CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) - } - - fn empty(shape: Shape, device: &Device) -> FloatTensor { - super::base::empty(shape, device) - } - - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) - } - - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - CandleTensor::new((lhs.tensor / rhs.elem::()).unwrap()) - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs.tensor).unwrap()) - } - - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - super::base::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - super::base::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - CandleTensor::new( - tensor - .tensor - .scatter_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - CandleTensor::new( - tensor - .tensor - .index_add(&indices.tensor, &value.tensor, dim) - .unwrap(), - ) - } - - fn slice( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - ) -> FloatTensor { - super::base::slice(tensor, ranges) - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [std::ops::Range; D2], - value: FloatTensor, - ) -> FloatTensor { - super::base::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor { - CandleTensor::new( - mask - .tensor - .where_cond(&value.tensor, &tensor.tensor) - .unwrap(), - ) - } - - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - CandleTensor::new( - mask - .tensor - .where_cond( - &super::candle_utils::fill_like::(value, &tensor.tensor), - &tensor.tensor, + fn from_data(data: Data, device: &Device) -> CandleTensor { + CandleTensor::from_data(data, *device) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: &Device, + ) -> FloatTensor { + let shape = &shape.dims; + let device = &(*device).into(); + match distribution { + Distribution::Default => CandleTensor::new( + candle_core::Tensor::rand(0., 1., shape, device) + .unwrap() + .to_dtype(F::DTYPE) + .unwrap(), + ), + Distribution::Bernoulli(prob) => CandleTensor::new( + candle_core::Tensor::rand(0., 1., shape, device) + .unwrap() + .to_dtype(F::DTYPE) + .unwrap() + .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device)) + .unwrap() + .to_dtype(F::DTYPE) + .unwrap(), + ), + Distribution::Uniform(from, to) => { + CandleTensor::new(candle_core::Tensor::rand(from, to, shape, device).unwrap()) + } + Distribution::Normal(mean, std) => { + CandleTensor::new(candle_core::Tensor::randn(mean, std, shape, device).unwrap()) + } + } + } + + fn shape(tensor: &CandleTensor) -> Shape { + super::base::shape(tensor) + } + + fn into_data(tensor: CandleTensor) -> Reader> { + Reader::Concrete(super::base::into_data(tensor)) + } + + fn device(tensor: &CandleTensor) -> Device { + super::base::device(tensor) + } + + fn to_device( + tensor: CandleTensor, + device: &Device, + ) -> CandleTensor { + super::base::to_device(tensor, device) + } + + fn into_int(tensor: CandleTensor) -> IntTensor { + CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + super::base::empty(shape, device) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + CandleTensor::new((lhs.tensor / rhs.elem::()).unwrap()) + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs.tensor).unwrap()) + } + + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + super::base::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + super::base::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.gather(&indices.tensor, dim).unwrap()) + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + CandleTensor::new( + tensor + .tensor + .scatter_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + CandleTensor::new( + tensor + .tensor + .index_add(&indices.tensor, &value.tensor, dim) + .unwrap(), + ) + } + + fn slice( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + ) -> FloatTensor { + super::base::slice(tensor, ranges) + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [std::ops::Range; D2], + value: FloatTensor, + ) -> FloatTensor { + super::base::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor { + CandleTensor::new( + mask.tensor + .where_cond(&value.tensor, &tensor.tensor) + .unwrap(), + ) + } + + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + CandleTensor::new( + mask.tensor + .where_cond( + &super::candle_utils::fill_like::(value, &tensor.tensor), + &tensor.tensor, + ) + .unwrap(), + ) + } + + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) + } + + fn equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), ) - .unwrap(), - ) - } - - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap()) - } - - fn equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .eq(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) - } - - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) - } - - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) - } - - fn lower_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) - } - - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - CandleTensor::new( - lhs - .tensor - .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) - .unwrap(), - ) - } - - fn sum(tensor: FloatTensor) -> FloatTensor { - let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); - CandleTensor::from_data(Data::new([sum].into(), [1].into()), Self::device(&tensor)) - } - - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) - } - - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap()) - } - - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - CandleTensor::new(tensor.tensor.to_dtype(candle_core::DType::F32).unwrap()) - } - - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) - } - - fn exp(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.exp().unwrap()) - } - - fn log(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.log().unwrap()) - } - - fn log1p(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap()) - } - - fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { - CandleTensor::new(tensor.tensor.powf(value.elem::()).unwrap()) - } - - fn sqrt(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.sqrt().unwrap()) - } - - fn abs(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.abs().unwrap()) - } - - fn cos(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.cos().unwrap()) - } - - fn sin(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.sin().unwrap()) - } - - fn tanh(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.tanh().unwrap()) - } - - fn erf(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.erf().unwrap()) - } - - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - super::base::cat(tensors, dim) - } - - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmax_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - CandleTensor::new( - tensor - .tensor - .argmin_keepdim(dim) - .unwrap() - .to_dtype(I::DTYPE) - .unwrap(), - ) - } - - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.minimum(max).unwrap()) - } - - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.maximum(min).unwrap()) - } - - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.clamp(min, max).unwrap()) - } - - fn recip(tensor: FloatTensor) -> FloatTensor { - CandleTensor::new(tensor.tensor.recip().unwrap()) - } + } + + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.gt(&rhs.tensor).unwrap()) + } + + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .gt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.ge(&rhs.tensor).unwrap()) + } + + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .ge(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.lt(&rhs.tensor).unwrap()) + } + + fn lower_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .lt(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + CandleTensor::new(lhs.tensor.le(&rhs.tensor).unwrap()) + } + + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .le(&super::candle_utils::fill_like::(rhs, &lhs.tensor)) + .unwrap(), + ) + } + + fn sum(tensor: FloatTensor) -> FloatTensor { + let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); + CandleTensor::from_data(Data::new([sum].into(), [1].into()), Self::device(&tensor)) + } + + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) + } + + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap()) + } + + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + CandleTensor::new(tensor.tensor.to_dtype(candle_core::DType::F32).unwrap()) + } + + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) + } + + fn exp(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.exp().unwrap()) + } + + fn log(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.log().unwrap()) + } + + fn log1p(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap()) + } + + fn powf(tensor: FloatTensor, value: f32) -> FloatTensor { + CandleTensor::new(tensor.tensor.powf(value.elem::()).unwrap()) + } + + fn sqrt(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.sqrt().unwrap()) + } + + fn abs(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.abs().unwrap()) + } + + fn cos(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.cos().unwrap()) + } + + fn sin(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.sin().unwrap()) + } + + fn tanh(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.tanh().unwrap()) + } + + fn erf(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.erf().unwrap()) + } + + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + super::base::cat(tensors, dim) + } + + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmax_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + CandleTensor::new( + tensor + .tensor + .argmin_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) + } + + fn clamp_max( + tensor: FloatTensor, + max: FloatElem, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.minimum(max).unwrap()) + } + + fn clamp_min( + tensor: FloatTensor, + min: FloatElem, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.maximum(min).unwrap()) + } + + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.clamp(min, max).unwrap()) + } + + fn recip(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.recip().unwrap()) + } } diff --git a/burn-candle/src/tensor.rs b/burn-candle/src/tensor.rs index 7f98b8ee3b..5ade8cc366 100644 --- a/burn-candle/src/tensor.rs +++ b/burn-candle/src/tensor.rs @@ -7,38 +7,38 @@ use crate::{element::CandleElement, CandleDevice}; /// A tensor that uses the candle backend. #[derive(Debug, Clone)] pub struct CandleTensor { - pub(crate) tensor: candle_core::Tensor, - phantom: PhantomData, + pub(crate) tensor: candle_core::Tensor, + phantom: PhantomData, } impl CandleTensor { - /// Create a new tensor. - pub fn new(tensor: candle_core::Tensor) -> Self { - Self { - tensor, - phantom: PhantomData, + /// Create a new tensor. + pub fn new(tensor: candle_core::Tensor) -> Self { + Self { + tensor, + phantom: PhantomData, + } } - } - /// Creates a new tensor from data and a device. - /// - /// # Arguments - /// - /// * `data` - The tensor's data. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// A new tensor. - pub fn from_data(data: Data, device: CandleDevice) -> Self { - let candle_shape: candle_core::Shape = (&data.shape.dims).into(); - let tensor = - candle_core::Tensor::from_slice(data.value.as_slice(), candle_shape, &device.into()); - Self::new(tensor.unwrap()) - } + /// Creates a new tensor from data and a device. + /// + /// # Arguments + /// + /// * `data` - The tensor's data. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// A new tensor. + pub fn from_data(data: Data, device: CandleDevice) -> Self { + let candle_shape: candle_core::Shape = (&data.shape.dims).into(); + let tensor = + candle_core::Tensor::from_slice(data.value.as_slice(), candle_shape, &device.into()); + Self::new(tensor.unwrap()) + } - pub(crate) fn shape(&self) -> Shape { - let x: [usize; D] = self.tensor.dims().try_into().unwrap(); - Shape::from(x) - } + pub(crate) fn shape(&self) -> Shape { + let x: [usize; D] = self.tensor.dims().try_into().unwrap(); + Shape::from(x) + } } diff --git a/burn-common/src/benchmark.rs b/burn-common/src/benchmark.rs index 4de028cbf9..a4abba2f3f 100644 --- a/burn-common/src/benchmark.rs +++ b/burn-common/src/benchmark.rs @@ -10,45 +10,45 @@ use std::time::Instant; /// Results of a benchmark run. #[derive(new, Debug)] pub struct BenchmarkResult { - durations: Vec, + durations: Vec, } impl BenchmarkResult { - /// Returns the median duration among all durations - pub fn median_duration(&self) -> Duration { - let mut sorted = self.durations.clone(); - sorted.sort(); - *sorted.get(sorted.len() / 2).unwrap() - } - pub(crate) fn mean_duration(&self) -> Duration { - self.durations.iter().sum::() / self.durations.len() as u32 - } + /// Returns the median duration among all durations + pub fn median_duration(&self) -> Duration { + let mut sorted = self.durations.clone(); + sorted.sort(); + *sorted.get(sorted.len() / 2).unwrap() + } + pub(crate) fn mean_duration(&self) -> Duration { + self.durations.iter().sum::() / self.durations.len() as u32 + } } impl Display for BenchmarkResult { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mean = self.mean_duration(); - let var = self - .durations - .iter() - .map(|duration| { - let tmp = duration.as_secs_f64() - mean.as_secs_f64(); - Duration::from_secs_f64(tmp * tmp) - }) - .sum::() - / self.durations.len() as u32; - - let mut sorted = self.durations.clone(); - sorted.sort(); - - let min = sorted.first().unwrap(); - let max = sorted.last().unwrap(); - let median = sorted.get(sorted.len() / 2).unwrap(); - let num_sample = self.durations.len(); - - f.write_str( - format!( - " + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mean = self.mean_duration(); + let var = self + .durations + .iter() + .map(|duration| { + let tmp = duration.as_secs_f64() - mean.as_secs_f64(); + Duration::from_secs_f64(tmp * tmp) + }) + .sum::() + / self.durations.len() as u32; + + let mut sorted = self.durations.clone(); + sorted.sort(); + + let min = sorted.first().unwrap(); + let max = sorted.last().unwrap(); + let median = sorted.get(sorted.len() / 2).unwrap(); + let num_sample = self.durations.len(); + + f.write_str( + format!( + " ―――――――― Result ――――――――― Samples {num_sample} Mean {mean:.3?} @@ -57,85 +57,85 @@ impl Display for BenchmarkResult { Min {min:.3?} Max {max:.3?} ―――――――――――――――――――――――――" - ) - .as_str(), - ) - } + ) + .as_str(), + ) + } } /// Benchmark trait. pub trait Benchmark { - /// Benchmark arguments. - type Args; - - /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't - /// count as included in the duration. - /// - /// # Notes - /// - /// This should not include warmup, the benchmark will be run at least one time without - /// measuring the execution time. - fn prepare(&self) -> Self::Args; - /// Execute the benchmark and returns the time it took to complete. - fn execute(&self, args: Self::Args); - /// Number of samples required to have a statistical significance. - fn num_samples(&self) -> usize { - 10 - } - /// Name of the benchmark. - fn name(&self) -> String; - /// Wait for computations to be over - fn sync(&self); - /// Run the benchmark a number of times. - fn run(&self) -> BenchmarkResult { - #[cfg(not(feature = "std"))] - panic!("Attempting to run benchmark in a no-std environment"); - - #[cfg(feature = "std")] - { - // Warmup - self.execute(self.prepare()); - self.sync(); - - let mut durations = Vec::with_capacity(self.num_samples()); - - for _ in 0..self.num_samples() { - // Prepare - let args = self.prepare(); - self.sync(); - - // Execute the benchmark - let start = Instant::now(); - self.execute(args); - self.sync(); - let end = Instant::now(); - - // Register the duration - durations.push(end - start); - } - - BenchmarkResult { durations } + /// Benchmark arguments. + type Args; + + /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't + /// count as included in the duration. + /// + /// # Notes + /// + /// This should not include warmup, the benchmark will be run at least one time without + /// measuring the execution time. + fn prepare(&self) -> Self::Args; + /// Execute the benchmark and returns the time it took to complete. + fn execute(&self, args: Self::Args); + /// Number of samples required to have a statistical significance. + fn num_samples(&self) -> usize { + 10 + } + /// Name of the benchmark. + fn name(&self) -> String; + /// Wait for computations to be over + fn sync(&self); + /// Run the benchmark a number of times. + fn run(&self) -> BenchmarkResult { + #[cfg(not(feature = "std"))] + panic!("Attempting to run benchmark in a no-std environment"); + + #[cfg(feature = "std")] + { + // Warmup + self.execute(self.prepare()); + self.sync(); + + let mut durations = Vec::with_capacity(self.num_samples()); + + for _ in 0..self.num_samples() { + // Prepare + let args = self.prepare(); + self.sync(); + + // Execute the benchmark + let start = Instant::now(); + self.execute(args); + self.sync(); + let end = Instant::now(); + + // Register the duration + durations.push(end - start); + } + + BenchmarkResult { durations } + } } - } } #[cfg(feature = "std")] /// Runs the given benchmark on the device and prints result and information. pub fn run_benchmark(benchmark: BM) where - BM: Benchmark, + BM: Benchmark, { - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(); - let output = std::process::Command::new("git") - .args(["rev-porse", "HEAD"]) - .output() - .unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - - println!("Timestamp: {}", timestamp); - println!("Git Hash: {}", str::trim(&git_hash)); - println!("Benchmarking - {}{}", benchmark.name(), benchmark.run()); + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis(); + let output = std::process::Command::new("git") + .args(["rev-porse", "HEAD"]) + .output() + .unwrap(); + let git_hash = String::from_utf8(output.stdout).unwrap(); + + println!("Timestamp: {}", timestamp); + println!("Git Hash: {}", str::trim(&git_hash)); + println!("Benchmarking - {}{}", benchmark.name(), benchmark.run()); } diff --git a/burn-common/src/id.rs b/burn-common/src/id.rs index 45dd4b0c9f..25c2161817 100644 --- a/burn-common/src/id.rs +++ b/burn-common/src/id.rs @@ -6,70 +6,70 @@ use uuid::{Builder, Bytes}; pub struct IdGenerator {} impl IdGenerator { - /// Generates a new ID in the form of a UUID. - pub fn generate() -> String { - let random_bytes: Bytes = gen_random(); + /// Generates a new ID in the form of a UUID. + pub fn generate() -> String { + let random_bytes: Bytes = gen_random(); - let uuid = Builder::from_random_bytes(random_bytes).into_uuid(); + let uuid = Builder::from_random_bytes(random_bytes).into_uuid(); - uuid.as_hyphenated().to_string() - } + uuid.as_hyphenated().to_string() + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - use alloc::{collections::BTreeSet, string::String}; + use alloc::{collections::BTreeSet, string::String}; - #[cfg(feature = "std")] - use dashmap::DashSet; //Concurrent HashMap - #[cfg(feature = "std")] - use std::{sync::Arc, thread}; + #[cfg(feature = "std")] + use dashmap::DashSet; //Concurrent HashMap + #[cfg(feature = "std")] + use std::{sync::Arc, thread}; - #[test] - fn not_empty_test() { - assert!(!IdGenerator::generate().is_empty()); - } + #[test] + fn not_empty_test() { + assert!(!IdGenerator::generate().is_empty()); + } - #[test] - fn uniqueness_test() { - const IDS_CNT: usize = 10_000; + #[test] + fn uniqueness_test() { + const IDS_CNT: usize = 10_000; - let mut set: BTreeSet = BTreeSet::new(); + let mut set: BTreeSet = BTreeSet::new(); - for _i in 0..IDS_CNT { - assert!(set.insert(IdGenerator::generate())); - } + for _i in 0..IDS_CNT { + assert!(set.insert(IdGenerator::generate())); + } - assert_eq!(set.len(), IDS_CNT); - } + assert_eq!(set.len(), IDS_CNT); + } - #[cfg(feature = "std")] - #[test] - fn thread_safety_test() { - const NUM_THREADS: usize = 10; - const NUM_REPEATS: usize = 1_000; - const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS; + #[cfg(feature = "std")] + #[test] + fn thread_safety_test() { + const NUM_THREADS: usize = 10; + const NUM_REPEATS: usize = 1_000; + const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS; - let set: Arc> = Arc::new(DashSet::new()); + let set: Arc> = Arc::new(DashSet::new()); - let mut handles = vec![]; + let mut handles = vec![]; - for _ in 0..NUM_THREADS { - let set = set.clone(); + for _ in 0..NUM_THREADS { + let set = set.clone(); - let handle = thread::spawn(move || { - for _i in 0..NUM_REPEATS { - assert!(set.insert(IdGenerator::generate())); + let handle = thread::spawn(move || { + for _i in 0..NUM_REPEATS { + assert!(set.insert(IdGenerator::generate())); + } + }); + handles.push(handle); } - }); - handles.push(handle); - } - for handle in handles { - handle.join().unwrap(); + for handle in handles { + handle.join().unwrap(); + } + assert_eq!(set.len(), EXPECTED_TOTAL_IDS); } - assert_eq!(set.len(), EXPECTED_TOTAL_IDS); - } } diff --git a/burn-common/src/rand.rs b/burn-common/src/rand.rs index 8e50e48c46..c9198930c3 100644 --- a/burn-common/src/rand.rs +++ b/burn-common/src/rand.rs @@ -7,15 +7,15 @@ use rand::prelude::Distribution; #[cfg(feature = "std")] #[inline(always)] pub fn get_seeded_rng() -> StdRng { - StdRng::from_entropy() + StdRng::from_entropy() } /// Returns a seeded random number generator using a pre-generated seed. #[cfg(not(feature = "std"))] #[inline(always)] pub fn get_seeded_rng() -> StdRng { - const CONST_SEED: u64 = 42; - StdRng::seed_from_u64(CONST_SEED) + const CONST_SEED: u64 = 42; + StdRng::seed_from_u64(CONST_SEED) } /// Generates random data from a thread-local RNG. @@ -23,9 +23,9 @@ pub fn get_seeded_rng() -> StdRng { #[inline] pub fn gen_random() -> T where - Standard: Distribution, + Standard: Distribution, { - rand::thread_rng().gen() + rand::thread_rng().gen() } /// Generates random data from a mutex-protected RNG. @@ -33,13 +33,13 @@ where #[inline] pub fn gen_random() -> T where - Standard: Distribution, + Standard: Distribution, { - use crate::stub::Mutex; - static RNG: Mutex> = Mutex::new(None); - let mut rng = RNG.lock().unwrap(); - if rng.is_none() { - *rng = Some(get_seeded_rng()); - } - rng.as_mut().unwrap().gen() + use crate::stub::Mutex; + static RNG: Mutex> = Mutex::new(None); + let mut rng = RNG.lock().unwrap(); + if rng.is_none() { + *rng = Some(get_seeded_rng()); + } + rng.as_mut().unwrap().gen() } diff --git a/burn-common/src/reader.rs b/burn-common/src/reader.rs index 7a9d0b7af7..91f4492c1a 100644 --- a/burn-common/src/reader.rs +++ b/burn-common/src/reader.rs @@ -5,111 +5,111 @@ use core::marker::PhantomData; #[async_trait::async_trait] /// Allows to create async reader. pub trait AsyncReader: Send { - /// Read asynchronously. - async fn read(self: Box) -> T; + /// Read asynchronously. + async fn read(self: Box) -> T; } /// Define how data is read, sync or async. pub enum Reader { - /// Concrete variant. - Concrete(T), - /// Sync data variant. - Sync(Box>), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Async data variant. - Async(Box>), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Future data variant. - Future(core::pin::Pin + Send>>), + /// Concrete variant. + Concrete(T), + /// Sync data variant. + Sync(Box>), + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Async data variant. + Async(Box>), + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Future data variant. + Future(core::pin::Pin + Send>>), } /// Allows to create sync reader. pub trait SyncReader: Send { - /// Read synchronously. - fn read(self: Box) -> T; + /// Read synchronously. + fn read(self: Box) -> T; } #[derive(new)] struct MappedReader { - reader: Reader, - mapper: F, - _output: PhantomData, + reader: Reader, + mapper: F, + _output: PhantomData, } impl SyncReader for MappedReader where - I: Send, - O: Send, - F: Send + FnOnce(I) -> O, + I: Send, + O: Send, + F: Send + FnOnce(I) -> O, { - fn read(self: Box) -> O { - let input = self - .reader - .read_sync() - .expect("Only sync data supported in a sync reader."); + fn read(self: Box) -> O { + let input = self + .reader + .read_sync() + .expect("Only sync data supported in a sync reader."); - (self.mapper)(input) - } + (self.mapper)(input) + } } #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] #[async_trait::async_trait] impl AsyncReader for MappedReader where - I: Send, - O: Send, - F: Send + FnOnce(I) -> O, + I: Send, + O: Send, + F: Send + FnOnce(I) -> O, { - async fn read(self: Box) -> O { - let input = self.reader.read().await; - (self.mapper)(input) - } + async fn read(self: Box) -> O { + let input = self.reader.read().await; + (self.mapper)(input) + } } impl Reader { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Read the data. - pub async fn read(self) -> T { - match self { - Self::Concrete(data) => data, - Self::Sync(reader) => reader.read(), - Self::Async(func) => func.read().await, - Self::Future(future) => future.await, + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Read the data. + pub async fn read(self) -> T { + match self { + Self::Concrete(data) => data, + Self::Sync(reader) => reader.read(), + Self::Async(func) => func.read().await, + Self::Future(future) => future.await, + } } - } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Read the data. - pub fn read(self) -> T { - match self { - Self::Concrete(data) => data, - Self::Sync(reader) => reader.read(), + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Read the data. + pub fn read(self) -> T { + match self { + Self::Concrete(data) => data, + Self::Sync(reader) => reader.read(), + } } - } - /// Read the data only if sync, returns None if an async reader. - pub fn read_sync(self) -> Option { - match self { - Self::Concrete(data) => Some(data), - Self::Sync(reader) => Some(reader.read()), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - Self::Async(_func) => return None, - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - Self::Future(_future) => return None, + /// Read the data only if sync, returns None if an async reader. + pub fn read_sync(self) -> Option { + match self { + Self::Concrete(data) => Some(data), + Self::Sync(reader) => Some(reader.read()), + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + Self::Async(_func) => return None, + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + Self::Future(_future) => return None, + } } - } - /// Map the current reader to another type. - pub fn map O>(self, mapper: F) -> Reader - where - T: 'static + Send, - O: 'static + Send, - F: 'static + Send, - { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - return Reader::Async(Box::new(MappedReader::new(self, mapper))); + /// Map the current reader to another type. + pub fn map O>(self, mapper: F) -> Reader + where + T: 'static + Send, + O: 'static + Send, + F: 'static + Send, + { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + return Reader::Async(Box::new(MappedReader::new(self, mapper))); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - Reader::Sync(Box::new(MappedReader::new(self, mapper))) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + Reader::Sync(Box::new(MappedReader::new(self, mapper))) + } } diff --git a/burn-common/src/stub.rs b/burn-common/src/stub.rs index ecec818581..93d0715f5f 100644 --- a/burn-common/src/stub.rs +++ b/burn-common/src/stub.rs @@ -1,5 +1,5 @@ use spin::{ - Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard, + Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard, }; /// A mutual exclusion primitive useful for protecting shared data @@ -10,23 +10,23 @@ use spin::{ /// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap #[derive(Debug)] pub struct Mutex { - inner: MutexImported, + inner: MutexImported, } impl Mutex { - /// Creates a new mutex in an unlocked state ready for use. - #[inline(always)] - pub const fn new(value: T) -> Self { - Self { - inner: MutexImported::new(value), + /// Creates a new mutex in an unlocked state ready for use. + #[inline(always)] + pub const fn new(value: T) -> Self { + Self { + inner: MutexImported::new(value), + } } - } - /// Locks the mutex blocking the current thread until it is able to do so. - #[inline(always)] - pub fn lock(&self) -> Result, alloc::string::String> { - Ok(self.inner.lock()) - } + /// Locks the mutex blocking the current thread until it is able to do so. + #[inline(always)] + pub fn lock(&self) -> Result, alloc::string::String> { + Ok(self.inner.lock()) + } } /// A reader-writer lock which is exclusively locked for writing or shared for reading. @@ -35,31 +35,31 @@ impl Mutex { /// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap #[derive(Debug)] pub struct RwLock { - inner: RwLockImported, + inner: RwLockImported, } impl RwLock { - /// Creates a new reader-writer lock in an unlocked state ready for use. - #[inline(always)] - pub const fn new(value: T) -> Self { - Self { - inner: RwLockImported::new(value), + /// Creates a new reader-writer lock in an unlocked state ready for use. + #[inline(always)] + pub const fn new(value: T) -> Self { + Self { + inner: RwLockImported::new(value), + } } - } - /// Locks this rwlock with shared read access, blocking the current thread - /// until it can be acquired. - #[inline(always)] - pub fn read(&self) -> Result, alloc::string::String> { - Ok(self.inner.read()) - } + /// Locks this rwlock with shared read access, blocking the current thread + /// until it can be acquired. + #[inline(always)] + pub fn read(&self) -> Result, alloc::string::String> { + Ok(self.inner.read()) + } - /// Locks this rwlock with exclusive write access, blocking the current thread - /// until it can be acquired. - #[inline(always)] - pub fn write(&self) -> Result, alloc::string::String> { - Ok(self.inner.write()) - } + /// Locks this rwlock with exclusive write access, blocking the current thread + /// until it can be acquired. + #[inline(always)] + pub fn write(&self) -> Result, alloc::string::String> { + Ok(self.inner.write()) + } } /// A unique identifier for a running thread. diff --git a/burn-compute/src/channel/base.rs b/burn-compute/src/channel/base.rs index 78259fdadb..9b2c7e3db5 100644 --- a/burn-compute/src/channel/base.rs +++ b/burn-compute/src/channel/base.rs @@ -5,18 +5,18 @@ use burn_common::reader::Reader; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety pub trait ComputeChannel: Clone + core::fmt::Debug { - /// Given a handle, returns owned resource as bytes - fn read(&self, handle: &Handle) -> Reader>; + /// Given a handle, returns owned resource as bytes + fn read(&self, handle: &Handle) -> Reader>; - /// Given a resource as bytes, stores it and returns the resource handle - fn create(&self, data: &[u8]) -> Handle; + /// Given a resource as bytes, stores it and returns the resource handle + fn create(&self, data: &[u8]) -> Handle; - /// Reserves `size` bytes in the storage, and returns a handle over them - fn empty(&self, size: usize) -> Handle; + /// Reserves `size` bytes in the storage, and returns a handle over them + fn empty(&self, size: usize) -> Handle; - /// Executes the `kernel` over the given `handles`. - fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]); + /// Executes the `kernel` over the given `handles`. + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]); - /// Wait for the completion of every task in the server. - fn sync(&self); + /// Wait for the completion of every task in the server. + fn sync(&self); } diff --git a/burn-compute/src/channel/cell.rs b/burn-compute/src/channel/cell.rs index cb110e4d77..002b237271 100644 --- a/burn-compute/src/channel/cell.rs +++ b/burn-compute/src/channel/cell.rs @@ -15,53 +15,52 @@ use burn_common::reader::Reader; /// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels. #[derive(Debug)] pub struct RefCellComputeChannel { - server: Arc>, + server: Arc>, } impl Clone for RefCellComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), + fn clone(&self) -> Self { + Self { + server: self.server.clone(), + } } - } } impl RefCellComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - /// Create a new cell compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(core::cell::RefCell::new(server)), + /// Create a new cell compute channel. + pub fn new(server: Server) -> Self { + Self { + server: Arc::new(core::cell::RefCell::new(server)), + } } - } } impl ComputeChannel for RefCellComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - fn read(&self, handle: &Handle) -> Reader> { - self.server.borrow_mut().read(handle) - } + fn read(&self, handle: &Handle) -> Reader> { + self.server.borrow_mut().read(handle) + } - fn create(&self, resource: &[u8]) -> Handle { - self.server.borrow_mut().create(resource) - } + fn create(&self, resource: &[u8]) -> Handle { + self.server.borrow_mut().create(resource) + } - fn empty(&self, size: usize) -> Handle { - self.server.borrow_mut().empty(size) - } + fn empty(&self, size: usize) -> Handle { + self.server.borrow_mut().empty(size) + } - fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle]) { - self - .server - .borrow_mut() - .execute(kernel_description, handles) - } + fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle]) { + self.server + .borrow_mut() + .execute(kernel_description, handles) + } - fn sync(&self) { - self.server.borrow_mut().sync() - } + fn sync(&self) { + self.server.borrow_mut().sync() + } } diff --git a/burn-compute/src/channel/mpsc.rs b/burn-compute/src/channel/mpsc.rs index 8bd6bc576c..e0f07ccca8 100644 --- a/burn-compute/src/channel/mpsc.rs +++ b/burn-compute/src/channel/mpsc.rs @@ -1,6 +1,6 @@ use std::{ - sync::{mpsc, Arc}, - thread, + sync::{mpsc, Arc}, + thread, }; use burn_common::reader::Reader; @@ -13,150 +13,146 @@ use crate::server::{ComputeServer, Handle}; #[derive(Debug)] pub struct MpscComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - state: Arc>, + state: Arc>, } #[derive(Debug)] struct MpscComputeChannelState where - Server: ComputeServer, + Server: ComputeServer, { - _handle: thread::JoinHandle<()>, - sender: mpsc::SyncSender>, + _handle: thread::JoinHandle<()>, + sender: mpsc::SyncSender>, } type Callback = mpsc::SyncSender; enum Message where - Server: ComputeServer, + Server: ComputeServer, { - Read(Handle, Callback>>), - Create(Vec, Callback>), - Empty(usize, Callback>), - ExecuteKernel(Server::Kernel, Vec>), - Sync(Callback<()>), + Read(Handle, Callback>>), + Create(Vec, Callback>), + Empty(usize, Callback>), + ExecuteKernel(Server::Kernel, Vec>), + Sync(Callback<()>), } impl MpscComputeChannel where - Server: ComputeServer + 'static, + Server: ComputeServer + 'static, { - /// Create a new mpsc compute channel. - pub fn new(mut server: Server, bound: usize) -> Self { - let (sender, receiver) = mpsc::sync_channel(bound); - - let _handle = thread::spawn(move || { - while let Ok(message) = receiver.recv() { - match message { - Message::Read(handle, callback) => { - let data = server.read(&handle); - core::mem::drop(handle); - callback.send(data).unwrap(); - } - Message::Create(data, callback) => { - let handle = server.create(&data); - callback.send(handle).unwrap(); - } - Message::Empty(size, callback) => { - let handle = server.empty(size); - callback.send(handle).unwrap(); - } - Message::ExecuteKernel(kernel, handles) => { - server.execute(kernel, &handles.iter().collect::>()); - } - Message::Sync(callback) => { - server.sync(); - callback.send(()).unwrap(); - } - }; - } - }); - - let state = Arc::new(MpscComputeChannelState { sender, _handle }); - - Self { state } - } + /// Create a new mpsc compute channel. + pub fn new(mut server: Server, bound: usize) -> Self { + let (sender, receiver) = mpsc::sync_channel(bound); + + let _handle = thread::spawn(move || { + while let Ok(message) = receiver.recv() { + match message { + Message::Read(handle, callback) => { + let data = server.read(&handle); + core::mem::drop(handle); + callback.send(data).unwrap(); + } + Message::Create(data, callback) => { + let handle = server.create(&data); + callback.send(handle).unwrap(); + } + Message::Empty(size, callback) => { + let handle = server.empty(size); + callback.send(handle).unwrap(); + } + Message::ExecuteKernel(kernel, handles) => { + server.execute(kernel, &handles.iter().collect::>()); + } + Message::Sync(callback) => { + server.sync(); + callback.send(()).unwrap(); + } + }; + } + }); + + let state = Arc::new(MpscComputeChannelState { sender, _handle }); + + Self { state } + } } impl Clone for MpscComputeChannel { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + } } - } } impl ComputeChannel for MpscComputeChannel where - Server: ComputeServer + 'static, + Server: ComputeServer + 'static, { - fn read(&self, handle: &Handle) -> Reader> { - let (callback, response) = mpsc::sync_channel(1); - - self - .state - .sender - .send(Message::Read(handle.clone(), callback)) - .unwrap(); - - self.response(response) - } - - fn create(&self, data: &[u8]) -> Handle { - let (callback, response) = mpsc::sync_channel(1); - - self - .state - .sender - .send(Message::Create(data.to_vec(), callback)) - .unwrap(); - - self.response(response) - } - - fn empty(&self, size: usize) -> Handle { - let (callback, response) = mpsc::sync_channel(1); - - self - .state - .sender - .send(Message::Empty(size, callback)) - .unwrap(); - - self.response(response) - } - - fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { - self - .state - .sender - .send(Message::ExecuteKernel( - kernel, - handles - .iter() - .map(|h| (*h).clone()) - .collect::>>(), - )) - .unwrap() - } - - fn sync(&self) { - let (callback, response) = mpsc::sync_channel(1); - - self.state.sender.send(Message::Sync(callback)).unwrap(); - - self.response(response) - } + fn read(&self, handle: &Handle) -> Reader> { + let (callback, response) = mpsc::sync_channel(1); + + self.state + .sender + .send(Message::Read(handle.clone(), callback)) + .unwrap(); + + self.response(response) + } + + fn create(&self, data: &[u8]) -> Handle { + let (callback, response) = mpsc::sync_channel(1); + + self.state + .sender + .send(Message::Create(data.to_vec(), callback)) + .unwrap(); + + self.response(response) + } + + fn empty(&self, size: usize) -> Handle { + let (callback, response) = mpsc::sync_channel(1); + + self.state + .sender + .send(Message::Empty(size, callback)) + .unwrap(); + + self.response(response) + } + + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.state + .sender + .send(Message::ExecuteKernel( + kernel, + handles + .iter() + .map(|h| (*h).clone()) + .collect::>>(), + )) + .unwrap() + } + + fn sync(&self) { + let (callback, response) = mpsc::sync_channel(1); + + self.state.sender.send(Message::Sync(callback)).unwrap(); + + self.response(response) + } } impl MpscComputeChannel { - fn response(&self, response: mpsc::Receiver) -> Response { - match response.recv() { - Ok(val) => val, - Err(err) => panic!("Can't connect to the server correctly {err:?}"), + fn response(&self, response: mpsc::Receiver) -> Response { + match response.recv() { + Ok(val) => val, + Err(err) => panic!("Can't connect to the server correctly {err:?}"), + } } - } } diff --git a/burn-compute/src/channel/mutex.rs b/burn-compute/src/channel/mutex.rs index 731d7e1fb0..140b850eb0 100644 --- a/burn-compute/src/channel/mutex.rs +++ b/burn-compute/src/channel/mutex.rs @@ -9,49 +9,49 @@ use spin::Mutex; /// on every operation #[derive(Debug)] pub struct MutexComputeChannel { - server: Arc>, + server: Arc>, } impl Clone for MutexComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), + fn clone(&self) -> Self { + Self { + server: self.server.clone(), + } } - } } impl MutexComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - /// Create a new mutex compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(Mutex::new(server)), + /// Create a new mutex compute channel. + pub fn new(server: Server) -> Self { + Self { + server: Arc::new(Mutex::new(server)), + } } - } } impl ComputeChannel for MutexComputeChannel where - Server: ComputeServer, + Server: ComputeServer, { - fn read(&self, handle: &Handle) -> Reader> { - self.server.lock().read(handle) - } + fn read(&self, handle: &Handle) -> Reader> { + self.server.lock().read(handle) + } - fn create(&self, data: &[u8]) -> Handle { - self.server.lock().create(data) - } + fn create(&self, data: &[u8]) -> Handle { + self.server.lock().create(data) + } - fn empty(&self, size: usize) -> Handle { - self.server.lock().empty(size) - } + fn empty(&self, size: usize) -> Handle { + self.server.lock().empty(size) + } - fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { - self.server.lock().execute(kernel, handles) - } + fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.server.lock().execute(kernel, handles) + } - fn sync(&self) { - self.server.lock().sync() - } + fn sync(&self) { + self.server.lock().sync() + } } diff --git a/burn-compute/src/client.rs b/burn-compute/src/client.rs index 8c989e4ce4..3832652aeb 100644 --- a/burn-compute/src/client.rs +++ b/burn-compute/src/client.rs @@ -1,7 +1,7 @@ use crate::{ - channel::ComputeChannel, - server::{ComputeServer, Handle}, - tune::{AutotuneOperationSet, Tuner}, + channel::ComputeChannel, + server::{ComputeServer, Handle}, + tune::{AutotuneOperationSet, Tuner}, }; use alloc::vec::Vec; use alloc::{boxed::Box, sync::Arc}; @@ -13,72 +13,71 @@ use spin::Mutex; /// It should be obtained for a specific device via the Compute struct. #[derive(Debug)] pub struct ComputeClient { - channel: Channel, - tuner: Arc>>, - _server: PhantomData, + channel: Channel, + tuner: Arc>>, + _server: PhantomData, } impl Clone for ComputeClient where - S: ComputeServer, - C: ComputeChannel, + S: ComputeServer, + C: ComputeChannel, { - fn clone(&self) -> Self { - Self { - channel: self.channel.clone(), - tuner: self.tuner.clone(), - _server: PhantomData, + fn clone(&self) -> Self { + Self { + channel: self.channel.clone(), + tuner: self.tuner.clone(), + _server: PhantomData, + } } - } } impl ComputeClient where - Server: ComputeServer, - Channel: ComputeChannel, + Server: ComputeServer, + Channel: ComputeChannel, { - /// Create a new client. - pub fn new(channel: Channel, tuner: Arc>>) -> Self { - Self { - channel, - tuner, - _server: PhantomData, + /// Create a new client. + pub fn new(channel: Channel, tuner: Arc>>) -> Self { + Self { + channel, + tuner, + _server: PhantomData, + } } - } - /// Given a handle, returns owned resource as bytes. - pub fn read(&self, handle: &Handle) -> Reader> { - self.channel.read(handle) - } + /// Given a handle, returns owned resource as bytes. + pub fn read(&self, handle: &Handle) -> Reader> { + self.channel.read(handle) + } - /// Given a resource, stores it and returns the resource handle. - pub fn create(&self, data: &[u8]) -> Handle { - self.channel.create(data) - } + /// Given a resource, stores it and returns the resource handle. + pub fn create(&self, data: &[u8]) -> Handle { + self.channel.create(data) + } - /// Reserves `size` bytes in the storage, and returns a handle over them. - pub fn empty(&self, size: usize) -> Handle { - self.channel.empty(size) - } + /// Reserves `size` bytes in the storage, and returns a handle over them. + pub fn empty(&self, size: usize) -> Handle { + self.channel.empty(size) + } - /// Executes the `kernel` over the given `handles`. - pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { - self.channel.execute(kernel, handles) - } + /// Executes the `kernel` over the given `handles`. + pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle]) { + self.channel.execute(kernel, handles) + } - /// Wait for the completion of every task in the server. - pub fn sync(&self) { - self.channel.sync() - } + /// Wait for the completion of every task in the server. + pub fn sync(&self) { + self.channel.sync() + } - /// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks - pub fn execute_autotune( - &self, - autotune_operation_set: Box>, - ) { - self - .tuner - .lock() - .execute_autotune(autotune_operation_set, self); - } + /// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks + pub fn execute_autotune( + &self, + autotune_operation_set: Box>, + ) { + self.tuner + .lock() + .execute_autotune(autotune_operation_set, self); + } } diff --git a/burn-compute/src/compute.rs b/burn-compute/src/compute.rs index a9dd96cbff..33f5d34337 100644 --- a/burn-compute/src/compute.rs +++ b/burn-compute/src/compute.rs @@ -5,79 +5,79 @@ use hashbrown::HashMap; /// The compute type has the responsibility to retrieve the correct compute client based on the /// given device. pub struct Compute { - clients: spin::Mutex>>>, + clients: spin::Mutex>>>, } impl Compute where - Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, - Server: ComputeServer, - Channel: ComputeChannel, + Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, + Server: ComputeServer, + Channel: ComputeChannel, { - /// Create a new compute. - pub const fn new() -> Self { - Self { - clients: spin::Mutex::new(None), + /// Create a new compute. + pub const fn new() -> Self { + Self { + clients: spin::Mutex::new(None), + } } - } - /// Get the compute client for the given device. - /// - /// Provide the init function to create a new client if it isn't already initialized. - pub fn client(&self, device: &Device, init: Init) -> ComputeClient - where - Init: Fn() -> ComputeClient, - { - let mut clients = self.clients.lock(); + /// Get the compute client for the given device. + /// + /// Provide the init function to create a new client if it isn't already initialized. + pub fn client(&self, device: &Device, init: Init) -> ComputeClient + where + Init: Fn() -> ComputeClient, + { + let mut clients = self.clients.lock(); - if clients.is_none() { - Self::register_inner(device, init(), &mut clients); - } + if clients.is_none() { + Self::register_inner(device, init(), &mut clients); + } - match clients.deref_mut() { - Some(clients) => match clients.get(device) { - Some(client) => client.clone(), - None => { - let client = init(); - clients.insert(device.clone(), client.clone()); - client + match clients.deref_mut() { + Some(clients) => match clients.get(device) { + Some(client) => client.clone(), + None => { + let client = init(); + clients.insert(device.clone(), client.clone()); + client + } + }, + _ => unreachable!(), } - }, - _ => unreachable!(), } - } - - /// Register the compute client for the given device. - /// - /// # Note - /// - /// This function is mostly useful when the creation of the compute client can't be done - /// synchronously and require special context. - /// - /// # Panics - /// - /// If a client is already registered for the given device. - pub fn register(&self, device: &Device, client: ComputeClient) { - let mut clients = self.clients.lock(); - Self::register_inner(device, client, &mut clients); - } + /// Register the compute client for the given device. + /// + /// # Note + /// + /// This function is mostly useful when the creation of the compute client can't be done + /// synchronously and require special context. + /// + /// # Panics + /// + /// If a client is already registered for the given device. + pub fn register(&self, device: &Device, client: ComputeClient) { + let mut clients = self.clients.lock(); - fn register_inner( - device: &Device, - client: ComputeClient, - clients: &mut Option>>, - ) { - if clients.is_none() { - *clients = Some(HashMap::new()); + Self::register_inner(device, client, &mut clients); } - if let Some(clients) = clients { - if clients.contains_key(device) { - panic!("Client already created for device {:?}", device); - } + fn register_inner( + device: &Device, + client: ComputeClient, + clients: &mut Option>>, + ) { + if clients.is_none() { + *clients = Some(HashMap::new()); + } + + if let Some(clients) = clients { + if clients.contains_key(device) { + panic!("Client already created for device {:?}", device); + } - clients.insert(device.clone(), client); + clients.insert(device.clone(), client); + } } - } } diff --git a/burn-compute/src/id.rs b/burn-compute/src/id.rs index 90fb4e28c8..33ba53c044 100644 --- a/burn-compute/src/id.rs +++ b/burn-compute/src/id.rs @@ -1,53 +1,53 @@ #[macro_export(local_inner_macros)] /// Create a new storage ID type. macro_rules! storage_id_type { - ($name:ident) => { - #[derive(Clone, Hash, PartialEq, Eq)] - /// Storage ID. - pub struct $name { - id: alloc::sync::Arc, - } + ($name:ident) => { + #[derive(Clone, Hash, PartialEq, Eq)] + /// Storage ID. + pub struct $name { + id: alloc::sync::Arc, + } - impl $name { - /// Create a new ID. - pub fn new() -> Self { - Self { - id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), + impl $name { + /// Create a new ID. + pub fn new() -> Self { + Self { + id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), + } + } } - } - } - impl Default for $name { - fn default() -> Self { - Self::new() - } - } - }; + impl Default for $name { + fn default() -> Self { + Self::new() + } + } + }; } #[macro_export(local_inner_macros)] /// Create a new memory ID type. macro_rules! memory_id_type { - ($name:ident) => { - #[derive(Clone, Hash, PartialEq, Eq, Debug)] - /// Memory ID. - pub struct $name { - id: alloc::sync::Arc, - } + ($name:ident) => { + #[derive(Clone, Hash, PartialEq, Eq, Debug)] + /// Memory ID. + pub struct $name { + id: alloc::sync::Arc, + } - impl $name { - /// Create a new ID. - pub(crate) fn new() -> Self { - Self { - id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), + impl $name { + /// Create a new ID. + pub(crate) fn new() -> Self { + Self { + id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()), + } + } } - } - } - impl Default for $name { - fn default() -> Self { - Self::new() - } - } - }; + impl Default for $name { + fn default() -> Self { + Self::new() + } + } + }; } diff --git a/burn-compute/src/memory_management/base.rs b/burn-compute/src/memory_management/base.rs index 1be75be6cc..4a6310cf9e 100644 --- a/burn-compute/src/memory_management/base.rs +++ b/burn-compute/src/memory_management/base.rs @@ -6,8 +6,8 @@ use crate::storage::ComputeStorage; /// It is responsible for determining if the memory segment can be mutated, /// for instance by keeping track of a reference count pub trait MemoryHandle: Clone + Send + core::fmt::Debug { - /// Checks if the underlying memory can be safely mutated. - fn can_mut(&self) -> bool; + /// Checks if the underlying memory can be safely mutated. + fn can_mut(&self) -> bool; } /// The MemoryManagement trait encapsulates strategies for (de)allocating memory. @@ -16,38 +16,38 @@ pub trait MemoryHandle: Clone + Send + core::fmt::Debug { /// The MemoryManagement can only reserve memory space or get the resource located at a space. /// Modification of the resource data should be done directly on the resource. pub trait MemoryManagement: Send + core::fmt::Debug { - /// The associated type Handle must implement MemoryHandle - type Handle: MemoryHandle; + /// The associated type Handle must implement MemoryHandle + type Handle: MemoryHandle; - /// Returns the resource from the storage at the specified handle - fn get(&mut self, handle: &Self::Handle) -> Storage::Resource; + /// Returns the resource from the storage at the specified handle + fn get(&mut self, handle: &Self::Handle) -> Storage::Resource; - /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it - fn reserve(&mut self, size: usize) -> Self::Handle; + /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it + fn reserve(&mut self, size: usize) -> Self::Handle; - /// Bypass the memory allocation algorithm to allocate data directly. - /// - /// # Notes - /// - /// Can be useful for servers that want specific control over memory. - fn alloc(&mut self, size: usize) -> Self::Handle; + /// Bypass the memory allocation algorithm to allocate data directly. + /// + /// # Notes + /// + /// Can be useful for servers that want specific control over memory. + fn alloc(&mut self, size: usize) -> Self::Handle; - /// Bypass the memory allocation algorithm to deallocate data directly. - /// - /// # Notes - /// - /// Can be useful for servers that want specific control over memory. - fn dealloc(&mut self, handle: &Self::Handle); + /// Bypass the memory allocation algorithm to deallocate data directly. + /// + /// # Notes + /// + /// Can be useful for servers that want specific control over memory. + fn dealloc(&mut self, handle: &Self::Handle); - /// Fetch the storage used by the memory manager. - /// - /// # Notes - /// - /// The storage should probably not be used for allocations since the handles won't be - /// compatible with the ones provided by the current trait. Prefer using the - /// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions. - /// - /// This is useful if you need to time the deallocations based on async computation, or to - /// change the mode of storage for different reasons. - fn storage(&mut self) -> &mut Storage; + /// Fetch the storage used by the memory manager. + /// + /// # Notes + /// + /// The storage should probably not be used for allocations since the handles won't be + /// compatible with the ones provided by the current trait. Prefer using the + /// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions. + /// + /// This is useful if you need to time the deallocations based on async computation, or to + /// change the mode of storage for different reasons. + fn storage(&mut self) -> &mut Storage; } diff --git a/burn-compute/src/memory_management/simple.rs b/burn-compute/src/memory_management/simple.rs index 1605ee0839..e6bb4fb37d 100644 --- a/burn-compute/src/memory_management/simple.rs +++ b/burn-compute/src/memory_management/simple.rs @@ -1,7 +1,7 @@ use super::{MemoryHandle, MemoryManagement}; use crate::{ - memory_id_type, - storage::{ComputeStorage, StorageHandle, StorageUtilization}, + memory_id_type, + storage::{ComputeStorage, StorageHandle, StorageUtilization}, }; use alloc::{sync::Arc, vec::Vec}; use hashbrown::HashMap; @@ -12,451 +12,451 @@ memory_id_type!(ChunkId); memory_id_type!(SliceId); impl ChunkId { - /// A chunk is free if it is only referred by the chunk hashmap. - fn is_free(&self) -> bool { - Arc::strong_count(&self.id) <= 1 - } + /// A chunk is free if it is only referred by the chunk hashmap. + fn is_free(&self) -> bool { + Arc::strong_count(&self.id) <= 1 + } } impl SliceId { - /// A slice is free if it is only referred by the slice hashmap and the chunk it is in. - fn is_free(&self) -> bool { - Arc::strong_count(&self.id) <= 2 - } + /// A slice is free if it is only referred by the slice hashmap and the chunk it is in. + fn is_free(&self) -> bool { + Arc::strong_count(&self.id) <= 2 + } } /// The SimpleHandle is a memory handle, referring to either a chunk or a slice. #[derive(Debug, Clone)] pub enum SimpleHandle { - /// A whole chunk of memory. - Chunk(ChunkId), - /// A slice of a chunk of memory. - Slice(SliceId), + /// A whole chunk of memory. + Chunk(ChunkId), + /// A slice of a chunk of memory. + Slice(SliceId), } /// The strategy defines the frequency at which deallocation of unused memory chunks should occur. #[derive(Debug)] pub enum DeallocStrategy { - /// Once every n calls to reserve. - PeriodTick { - /// Number of calls to be executed before triggering the deallocation. - period: usize, - /// Current state. Should start at zero. - state: usize, - }, - #[cfg(feature = "std")] - /// Once every period of time - PeriodTime { - /// Number of time before triggering the deallocation. - period: std::time::Duration, - /// Current state. Should start at now. - state: std::time::Instant, - }, - /// Never deallocate. - Never, + /// Once every n calls to reserve. + PeriodTick { + /// Number of calls to be executed before triggering the deallocation. + period: usize, + /// Current state. Should start at zero. + state: usize, + }, + #[cfg(feature = "std")] + /// Once every period of time + PeriodTime { + /// Number of time before triggering the deallocation. + period: std::time::Duration, + /// Current state. Should start at now. + state: std::time::Instant, + }, + /// Never deallocate. + Never, } /// The strategy defines when to reuse chunk with slices. #[derive(Debug)] pub enum SliceStrategy { - /// Never use slices. - Never, - /// Ratio needed before the chunk can be used as a slice. Between 0 and 1. - Ratio(f32), - /// When the reserved memory is at least {} bytes. - MinimumSize(usize), - /// When the reserved memory less than {} bytes. - MaximumSize(usize), + /// Never use slices. + Never, + /// Ratio needed before the chunk can be used as a slice. Between 0 and 1. + Ratio(f32), + /// When the reserved memory is at least {} bytes. + MinimumSize(usize), + /// When the reserved memory less than {} bytes. + MaximumSize(usize), } impl SliceStrategy { - /// If the chunk can be used with a slice. - pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool { - if chunk_size < reserved_size { - return false; - } + /// If the chunk can be used with a slice. + pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool { + if chunk_size < reserved_size { + return false; + } - match self { - SliceStrategy::Never => false, - SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio, - SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes, - SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes, + match self { + SliceStrategy::Never => false, + SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio, + SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes, + SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes, + } } - } } impl DeallocStrategy { - /// Create a new strategy with the given period. - pub fn new_period_tick(period: usize) -> Self { - DeallocStrategy::PeriodTick { period, state: 0 } - } - - fn should_dealloc(&mut self) -> bool { - match self { - DeallocStrategy::PeriodTick { period, state } => { - *state = (*state + 1) % *period; - *state == 0 - } - #[cfg(feature = "std")] - DeallocStrategy::PeriodTime { period, state } => { - if &state.elapsed() > period { - *state = std::time::Instant::now(); - true - } else { - false + /// Create a new strategy with the given period. + pub fn new_period_tick(period: usize) -> Self { + DeallocStrategy::PeriodTick { period, state: 0 } + } + + fn should_dealloc(&mut self) -> bool { + match self { + DeallocStrategy::PeriodTick { period, state } => { + *state = (*state + 1) % *period; + *state == 0 + } + #[cfg(feature = "std")] + DeallocStrategy::PeriodTime { period, state } => { + if &state.elapsed() > period { + *state = std::time::Instant::now(); + true + } else { + false + } + } + DeallocStrategy::Never => false, } - } - DeallocStrategy::Never => false, } - } } /// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. pub struct SimpleMemoryManagement { - chunks: HashMap)>, - slices: HashMap, - dealloc_strategy: DeallocStrategy, - slice_strategy: SliceStrategy, - storage: Storage, + chunks: HashMap)>, + slices: HashMap, + dealloc_strategy: DeallocStrategy, + slice_strategy: SliceStrategy, + storage: Storage, } impl core::fmt::Debug for SimpleMemoryManagement { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - alloc::format!( - "SimpleMemoryManagement {:?} - {:?}", - self.dealloc_strategy, - core::any::type_name::(), - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + alloc::format!( + "SimpleMemoryManagement {:?} - {:?}", + self.dealloc_strategy, + core::any::type_name::(), + ) + .as_str(), + ) + } } impl MemoryHandle for SimpleHandle { - /// Returns true if referenced by only one tensor, and only once by the - /// memory management hashmaps - fn can_mut(&self) -> bool { - // One reference in the chunk hashmap, another owned by one tensor. - const REFERENCE_LIMIT_CHUNK: usize = 2; - // One reference in the chunk hashmap (for the chunk on which this slice is built), - // another in the slice hashmap for this slice, and another owned by one tensor. - const REFERENCE_LIMIT_SLICE: usize = 3; - - match &self { - SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK, - SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE, + /// Returns true if referenced by only one tensor, and only once by the + /// memory management hashmaps + fn can_mut(&self) -> bool { + // One reference in the chunk hashmap, another owned by one tensor. + const REFERENCE_LIMIT_CHUNK: usize = 2; + // One reference in the chunk hashmap (for the chunk on which this slice is built), + // another in the slice hashmap for this slice, and another owned by one tensor. + const REFERENCE_LIMIT_SLICE: usize = 3; + + match &self { + SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK, + SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE, + } } - } } impl MemoryManagement for SimpleMemoryManagement { - type Handle = SimpleHandle; + type Handle = SimpleHandle; - /// Returns the resource from the storage, for the specified handle. - fn get(&mut self, handle: &Self::Handle) -> Storage::Resource { - let resource = match &handle { - SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0, - SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0, - }; + /// Returns the resource from the storage, for the specified handle. + fn get(&mut self, handle: &Self::Handle) -> Storage::Resource { + let resource = match &handle { + SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0, + SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0, + }; - self.storage.get(resource) - } + self.storage.get(resource) + } - /// Reserves memory of specified size using the reserve algorithm, and return - /// a handle to the reserved memory. - /// - /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy. - fn reserve(&mut self, size: usize) -> Self::Handle { - self.cleanup_slices(); + /// Reserves memory of specified size using the reserve algorithm, and return + /// a handle to the reserved memory. + /// + /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy. + fn reserve(&mut self, size: usize) -> Self::Handle { + self.cleanup_slices(); - let handle = self.reserve_algorithm(size); + let handle = self.reserve_algorithm(size); - if self.dealloc_strategy.should_dealloc() { - self.cleanup_chunks(); - } + if self.dealloc_strategy.should_dealloc() { + self.cleanup_chunks(); + } - handle - } + handle + } - fn alloc(&mut self, size: usize) -> Self::Handle { - self.create_chunk(size) - } + fn alloc(&mut self, size: usize) -> Self::Handle { + self.create_chunk(size) + } - fn dealloc(&mut self, handle: &Self::Handle) { - match handle { - SimpleHandle::Chunk(id) => { - if let Some((handle, _slices)) = self.chunks.remove(id) { - self.storage.dealloc(handle.id); + fn dealloc(&mut self, handle: &Self::Handle) { + match handle { + SimpleHandle::Chunk(id) => { + if let Some((handle, _slices)) = self.chunks.remove(id) { + self.storage.dealloc(handle.id); + } + } + SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"), } - } - SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"), } - } - fn storage(&mut self) -> &mut Storage { - &mut self.storage - } + fn storage(&mut self) -> &mut Storage { + &mut self.storage + } } impl SimpleMemoryManagement { - /// Creates a new instance using the given storage, deallocation strategy and slice strategy. - pub fn new( - storage: Storage, - dealloc_strategy: DeallocStrategy, - slice_strategy: SliceStrategy, - ) -> Self { - Self { - chunks: HashMap::new(), - slices: HashMap::new(), - dealloc_strategy, - slice_strategy, - storage, + /// Creates a new instance using the given storage, deallocation strategy and slice strategy. + pub fn new( + storage: Storage, + dealloc_strategy: DeallocStrategy, + slice_strategy: SliceStrategy, + ) -> Self { + Self { + chunks: HashMap::new(), + slices: HashMap::new(), + dealloc_strategy, + slice_strategy, + storage, + } } - } - fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { - // Looks for a large enough, existing but unused chunk of memory. - let chunk = self.find_free_chunk(size); + fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { + // Looks for a large enough, existing but unused chunk of memory. + let chunk = self.find_free_chunk(size); + + match chunk { + Some((chunk_id, chunk_size)) => { + if size == chunk_size { + // If there is one of exactly the same size, it reuses it. + SimpleHandle::Chunk(chunk_id.clone()) + } else { + // Otherwise creates a slice of the right size upon it, always starting at zero. + self.create_slice(size, chunk_id) + } + } + // If no chunk available, creates one of exactly the right size. + None => self.create_chunk(size), + } + } - match chunk { - Some((chunk_id, chunk_size)) => { - if size == chunk_size { - // If there is one of exactly the same size, it reuses it. - SimpleHandle::Chunk(chunk_id.clone()) - } else { - // Otherwise creates a slice of the right size upon it, always starting at zero. - self.create_slice(size, chunk_id) + /// Finds the smallest of the free and large enough chunks to fit `size` + /// Returns the chunk's id and size. + fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> { + let mut size_diff_current = usize::MAX; + let mut current = None; + + for (chunk_id, (resource, slices)) in self.chunks.iter() { + // If chunk is already used, we do not choose it + if !slices.is_empty() || !chunk_id.is_free() { + continue; + } + + let resource_size = resource.size(); + + // If we find a chunk of exactly the right size, we stop searching altogether + if size == resource_size { + current = Some((chunk_id, resource)); + break; + } + + // Finds the smallest of the large enough chunks that can accept a slice + // of the given size + if self.slice_strategy.can_use_chunk(resource_size, size) { + let size_diff = resource_size - size; + + if size_diff < size_diff_current { + current = Some((chunk_id, resource)); + size_diff_current = size_diff; + } + } } - } - // If no chunk available, creates one of exactly the right size. - None => self.create_chunk(size), + + current.map(|(id, handle)| (id.clone(), handle.size())) } - } - - /// Finds the smallest of the free and large enough chunks to fit `size` - /// Returns the chunk's id and size. - fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> { - let mut size_diff_current = usize::MAX; - let mut current = None; - - for (chunk_id, (resource, slices)) in self.chunks.iter() { - // If chunk is already used, we do not choose it - if !slices.is_empty() || !chunk_id.is_free() { - continue; - } - - let resource_size = resource.size(); - - // If we find a chunk of exactly the right size, we stop searching altogether - if size == resource_size { - current = Some((chunk_id, resource)); - break; - } - - // Finds the smallest of the large enough chunks that can accept a slice - // of the given size - if self.slice_strategy.can_use_chunk(resource_size, size) { - let size_diff = resource_size - size; - - if size_diff < size_diff_current { - current = Some((chunk_id, resource)); - size_diff_current = size_diff; + + /// Creates a slice of size `size` upon the given chunk. + /// + /// For now slices must start at zero, therefore there can be only one per chunk + fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle { + let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap(); + let slice_id = SliceId::new(); + + let storage = StorageHandle { + id: handle.id.clone(), + utilization: StorageUtilization::Slice(0, size), + }; + + if slices.is_empty() { + self.slices.insert(slice_id.clone(), (storage, chunk_id)); + } else { + panic!("Can't have more than 1 slice yet."); } - } + + slices.push(slice_id.clone()); + + SimpleHandle::Slice(slice_id) } - current.map(|(id, handle)| (id.clone(), handle.size())) - } + /// Creates a chunk of given size by allocating on the storage. + fn create_chunk(&mut self, size: usize) -> SimpleHandle { + let resource = self.storage.alloc(size); + let chunk_id = ChunkId::new(); - /// Creates a slice of size `size` upon the given chunk. - /// - /// For now slices must start at zero, therefore there can be only one per chunk - fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle { - let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap(); - let slice_id = SliceId::new(); + self.chunks.insert(chunk_id.clone(), (resource, Vec::new())); - let storage = StorageHandle { - id: handle.id.clone(), - utilization: StorageUtilization::Slice(0, size), - }; + SimpleHandle::Chunk(chunk_id) + } - if slices.is_empty() { - self.slices.insert(slice_id.clone(), (storage, chunk_id)); - } else { - panic!("Can't have more than 1 slice yet."); + /// Deallocates free chunks and remove them from chunks map. + fn cleanup_chunks(&mut self) { + let mut ids_to_remove = Vec::new(); + + self.chunks.iter().for_each(|(chunk_id, _resource)| { + if chunk_id.is_free() { + ids_to_remove.push(chunk_id.clone()); + } + }); + + ids_to_remove + .iter() + .map(|chunk_id| self.chunks.remove(chunk_id).unwrap()) + .for_each(|(resource, _slices)| { + self.storage.dealloc(resource.id); + }); } - slices.push(slice_id.clone()); - - SimpleHandle::Slice(slice_id) - } - - /// Creates a chunk of given size by allocating on the storage. - fn create_chunk(&mut self, size: usize) -> SimpleHandle { - let resource = self.storage.alloc(size); - let chunk_id = ChunkId::new(); - - self.chunks.insert(chunk_id.clone(), (resource, Vec::new())); - - SimpleHandle::Chunk(chunk_id) - } - - /// Deallocates free chunks and remove them from chunks map. - fn cleanup_chunks(&mut self) { - let mut ids_to_remove = Vec::new(); - - self.chunks.iter().for_each(|(chunk_id, _resource)| { - if chunk_id.is_free() { - ids_to_remove.push(chunk_id.clone()); - } - }); - - ids_to_remove - .iter() - .map(|chunk_id| self.chunks.remove(chunk_id).unwrap()) - .for_each(|(resource, _slices)| { - self.storage.dealloc(resource.id); - }); - } - - /// Removes free slices from slice map and corresponding chunks. - fn cleanup_slices(&mut self) { - let mut ids_to_remove = Vec::new(); - - self.slices.iter().for_each(|(slice_id, _resource)| { - if slice_id.is_free() { - ids_to_remove.push(slice_id.clone()); - } - }); - - ids_to_remove - .iter() - .map(|slice_id| { - let value = self.slices.remove(slice_id).unwrap(); - (slice_id, value.1) - }) - .for_each(|(slice_id, chunk_id)| { - let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap(); - slices.retain(|id| id != slice_id); - }); - } + /// Removes free slices from slice map and corresponding chunks. + fn cleanup_slices(&mut self) { + let mut ids_to_remove = Vec::new(); + + self.slices.iter().for_each(|(slice_id, _resource)| { + if slice_id.is_free() { + ids_to_remove.push(slice_id.clone()); + } + }); + + ids_to_remove + .iter() + .map(|slice_id| { + let value = self.slices.remove(slice_id).unwrap(); + (slice_id, value.1) + }) + .for_each(|(slice_id, chunk_id)| { + let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap(); + slices.retain(|id| id != slice_id); + }); + } } #[cfg(test)] mod tests { - use crate::{ - memory_management::{MemoryHandle, MemoryManagement, SliceStrategy}, - storage::BytesStorage, - }; - - use super::{DeallocStrategy, SimpleMemoryManagement}; - - #[test] - fn can_mut_with_single_tensor_reference() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - - let chunk_size = 4; - let simple_handle = memory_management.create_chunk(chunk_size); - - let x = simple_handle.clone(); - core::mem::drop(simple_handle); - - assert!(x.can_mut()); - } - - #[test] - fn two_tensor_references_remove_mutability() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - - let chunk_size = 4; - let simple_handle = memory_management.create_chunk(chunk_size); - - let x = simple_handle.clone(); - - assert!(!simple_handle.can_mut()); - assert!(!x.can_mut()) - } - - #[test] - fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - let chunk_size = 4; - let _chunk_handle = memory_management.reserve(chunk_size); - let _new_handle = memory_management.reserve(chunk_size); - - assert_eq!(memory_management.chunks.len(), 2); - } - - #[test] - fn when_empty_chunk_is_cleaned_upexists_it_disappears() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - let chunk_size = 4; - let chunk_handle = memory_management.reserve(chunk_size); - drop(chunk_handle); - memory_management.cleanup_chunks(); - - assert_eq!(memory_management.chunks.len(), 0); - } - - #[test] - fn never_dealloc_strategy_never_deallocs() { - let mut never_dealloc = DeallocStrategy::Never; - for _ in 0..20 { - assert!(!never_dealloc.should_dealloc()) + use crate::{ + memory_management::{MemoryHandle, MemoryManagement, SliceStrategy}, + storage::BytesStorage, + }; + + use super::{DeallocStrategy, SimpleMemoryManagement}; + + #[test] + fn can_mut_with_single_tensor_reference() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + + let chunk_size = 4; + let simple_handle = memory_management.create_chunk(chunk_size); + + let x = simple_handle.clone(); + core::mem::drop(simple_handle); + + assert!(x.can_mut()); + } + + #[test] + fn two_tensor_references_remove_mutability() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + + let chunk_size = 4; + let simple_handle = memory_management.create_chunk(chunk_size); + + let x = simple_handle.clone(); + + assert!(!simple_handle.can_mut()); + assert!(!x.can_mut()) } - } - - #[test] - fn period_tick_dealloc_strategy_should_dealloc_after_period() { - let period = 3; - let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period); - - for _ in 0..3 { - for _ in 0..period - 1 { - assert!(!period_tick_dealloc.should_dealloc()); - } - assert!(period_tick_dealloc.should_dealloc()); + + #[test] + fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + let chunk_size = 4; + let _chunk_handle = memory_management.reserve(chunk_size); + let _new_handle = memory_management.reserve(chunk_size); + + assert_eq!(memory_management.chunks.len(), 2); } - } - #[test] - fn slice_strategy_minimum_bytes() { - let strategy = SliceStrategy::MinimumSize(100); + #[test] + fn when_empty_chunk_is_cleaned_upexists_it_disappears() { + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); + let chunk_size = 4; + let chunk_handle = memory_management.reserve(chunk_size); + drop(chunk_handle); + memory_management.cleanup_chunks(); + + assert_eq!(memory_management.chunks.len(), 0); + } - assert!(strategy.can_use_chunk(200, 101)); - assert!(!strategy.can_use_chunk(200, 99)); - } + #[test] + fn never_dealloc_strategy_never_deallocs() { + let mut never_dealloc = DeallocStrategy::Never; + for _ in 0..20 { + assert!(!never_dealloc.should_dealloc()) + } + } - #[test] - fn slice_strategy_maximum_bytes() { - let strategy = SliceStrategy::MaximumSize(100); + #[test] + fn period_tick_dealloc_strategy_should_dealloc_after_period() { + let period = 3; + let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period); - assert!(strategy.can_use_chunk(200, 99)); - assert!(!strategy.can_use_chunk(200, 101)); - } + for _ in 0..3 { + for _ in 0..period - 1 { + assert!(!period_tick_dealloc.should_dealloc()); + } + assert!(period_tick_dealloc.should_dealloc()); + } + } - #[test] - fn slice_strategy_ratio() { - let strategy = SliceStrategy::Ratio(0.9); + #[test] + fn slice_strategy_minimum_bytes() { + let strategy = SliceStrategy::MinimumSize(100); - assert!(strategy.can_use_chunk(200, 180)); - assert!(!strategy.can_use_chunk(200, 179)); - } + assert!(strategy.can_use_chunk(200, 101)); + assert!(!strategy.can_use_chunk(200, 99)); + } + + #[test] + fn slice_strategy_maximum_bytes() { + let strategy = SliceStrategy::MaximumSize(100); + + assert!(strategy.can_use_chunk(200, 99)); + assert!(!strategy.can_use_chunk(200, 101)); + } + + #[test] + fn slice_strategy_ratio() { + let strategy = SliceStrategy::Ratio(0.9); + + assert!(strategy.can_use_chunk(200, 180)); + assert!(!strategy.can_use_chunk(200, 179)); + } } diff --git a/burn-compute/src/server.rs b/burn-compute/src/server.rs index e8b75413e6..3682ed487a 100644 --- a/burn-compute/src/server.rs +++ b/burn-compute/src/server.rs @@ -1,9 +1,9 @@ use core::fmt::Debug; use crate::{ - memory_management::{MemoryHandle, MemoryManagement}, - storage::ComputeStorage, - tune::AutotuneKey, + memory_management::{MemoryHandle, MemoryManagement}, + storage::ComputeStorage, + tune::AutotuneKey, }; use alloc::vec::Vec; use burn_common::reader::Reader; @@ -14,54 +14,54 @@ use burn_common::reader::Reader; /// [compute channel](crate::channel::ComputeChannel) for thread safety. pub trait ComputeServer: Send + core::fmt::Debug where - Self: Sized, + Self: Sized, { - /// The kernel type defines the computation algorithms. - type Kernel: Send; - /// The [storage](ComputeStorage) type defines how data is stored and accessed. - type Storage: ComputeStorage; - /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. - type MemoryManagement: MemoryManagement; - /// The key used to cache operations used on specific inputs in autotune - type AutotuneKey: AutotuneKey; + /// The kernel type defines the computation algorithms. + type Kernel: Send; + /// The [storage](ComputeStorage) type defines how data is stored and accessed. + type Storage: ComputeStorage; + /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. + type MemoryManagement: MemoryManagement; + /// The key used to cache operations used on specific inputs in autotune + type AutotuneKey: AutotuneKey; - /// Given a handle, returns the owned resource as bytes. - fn read(&mut self, handle: &Handle) -> Reader>; + /// Given a handle, returns the owned resource as bytes. + fn read(&mut self, handle: &Handle) -> Reader>; - /// Given a resource as bytes, stores it and returns the memory handle. - fn create(&mut self, data: &[u8]) -> Handle; + /// Given a resource as bytes, stores it and returns the memory handle. + fn create(&mut self, data: &[u8]) -> Handle; - /// Reserves `size` bytes in the storage, and returns a handle over them. - fn empty(&mut self, size: usize) -> Handle; + /// Reserves `size` bytes in the storage, and returns a handle over them. + fn empty(&mut self, size: usize) -> Handle; - /// Executes the `kernel` over the given memory `handles`. - /// - /// Kernels have mutable access to every resource they are given - /// and are responsible of determining which should be read or written. - fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]); + /// Executes the `kernel` over the given memory `handles`. + /// + /// Kernels have mutable access to every resource they are given + /// and are responsible of determining which should be read or written. + fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]); - /// Wait for the completion of every task in the server. - fn sync(&mut self); + /// Wait for the completion of every task in the server. + fn sync(&mut self); } /// Server handle containing the [memory handle](MemoryManagement::Handle). #[derive(new, Debug)] pub struct Handle { - /// Handle for the memory in use. - pub memory: >::Handle, + /// Handle for the memory in use. + pub memory: >::Handle, } impl Handle { - /// If the tensor handle can be mut with an inplace operation. - pub fn can_mut(&self) -> bool { - self.memory.can_mut() - } + /// If the tensor handle can be mut with an inplace operation. + pub fn can_mut(&self) -> bool { + self.memory.can_mut() + } } impl Clone for Handle { - fn clone(&self) -> Self { - Self { - memory: self.memory.clone(), + fn clone(&self) -> Self { + Self { + memory: self.memory.clone(), + } } - } } diff --git a/burn-compute/src/storage/base.rs b/burn-compute/src/storage/base.rs index 37a293f408..ce6be5bceb 100644 --- a/burn-compute/src/storage/base.rs +++ b/burn-compute/src/storage/base.rs @@ -6,43 +6,43 @@ storage_id_type!(StorageId); /// Defines if data uses a full memory chunk or a slice of it. #[derive(Clone)] pub enum StorageUtilization { - /// Full memory chunk of specified size - Full(usize), - /// Slice of memory chunk with start index and size. - Slice(usize, usize), + /// Full memory chunk of specified size + Full(usize), + /// Slice of memory chunk with start index and size. + Slice(usize, usize), } /// Contains the [storage id](StorageId) of a resource and the way it is used. #[derive(new)] pub struct StorageHandle { - /// Storage id. - pub id: StorageId, - /// How the storage is used. - pub utilization: StorageUtilization, + /// Storage id. + pub id: StorageId, + /// How the storage is used. + pub utilization: StorageUtilization, } impl StorageHandle { - /// Returns the size the handle is pointing to in memory. - pub fn size(&self) -> usize { - match self.utilization { - StorageUtilization::Full(size) => size, - StorageUtilization::Slice(_, size) => size, + /// Returns the size the handle is pointing to in memory. + pub fn size(&self) -> usize { + match self.utilization { + StorageUtilization::Full(size) => size, + StorageUtilization::Slice(_, size) => size, + } } - } } /// Storage types are responsible for allocating and deallocating memory. pub trait ComputeStorage: Send { - /// The resource associated type determines the way data is implemented and how - /// it can be accessed by kernels. - type Resource: Send; + /// The resource associated type determines the way data is implemented and how + /// it can be accessed by kernels. + type Resource: Send; - /// Returns the underlying resource for a specified storage handle - fn get(&mut self, handle: &StorageHandle) -> Self::Resource; + /// Returns the underlying resource for a specified storage handle + fn get(&mut self, handle: &StorageHandle) -> Self::Resource; - /// Allocates `size` units of memory and returns a handle to it - fn alloc(&mut self, size: usize) -> StorageHandle; + /// Allocates `size` units of memory and returns a handle to it + fn alloc(&mut self, size: usize) -> StorageHandle; - /// Deallocates the memory pointed by the given storage id. - fn dealloc(&mut self, id: StorageId); + /// Deallocates the memory pointed by the given storage id. + fn dealloc(&mut self, id: StorageId); } diff --git a/burn-compute/src/storage/bytes_cpu.rs b/burn-compute/src/storage/bytes_cpu.rs index 7f54d180fe..bfaf07950e 100644 --- a/burn-compute/src/storage/bytes_cpu.rs +++ b/burn-compute/src/storage/bytes_cpu.rs @@ -5,13 +5,13 @@ use hashbrown::HashMap; /// The bytes storage maps ids to pointers of bytes in a contiguous layout. #[derive(Default)] pub struct BytesStorage { - memory: HashMap, + memory: HashMap, } impl core::fmt::Debug for BytesStorage { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str("BytesStorage") - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str("BytesStorage") + } } /// Can send to other threads. @@ -20,108 +20,108 @@ unsafe impl Send for BytesResource {} /// This struct is a pointer to a memory chunk or slice. pub struct BytesResource { - ptr: *mut u8, - utilization: StorageUtilization, + ptr: *mut u8, + utilization: StorageUtilization, } /// This struct refers to a specific (contiguous) layout of bytes. struct AllocatedBytes { - ptr: *mut u8, - layout: Layout, + ptr: *mut u8, + layout: Layout, } impl BytesResource { - fn get_exact_location_and_length(&self) -> (*mut u8, usize) { - match self.utilization { - StorageUtilization::Full(len) => (self.ptr, len), - StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) }, + fn get_exact_location_and_length(&self) -> (*mut u8, usize) { + match self.utilization { + StorageUtilization::Full(len) => (self.ptr, len), + StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) }, + } } - } - /// Returns the resource as a mutable slice of bytes. - pub fn write<'a>(&self) -> &'a mut [u8] { - let (ptr, len) = self.get_exact_location_and_length(); + /// Returns the resource as a mutable slice of bytes. + pub fn write<'a>(&self) -> &'a mut [u8] { + let (ptr, len) = self.get_exact_location_and_length(); - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - } + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + } - /// Returns the resource as an immutable slice of bytes. - pub fn read<'a>(&self) -> &'a [u8] { - let (ptr, len) = self.get_exact_location_and_length(); + /// Returns the resource as an immutable slice of bytes. + pub fn read<'a>(&self) -> &'a [u8] { + let (ptr, len) = self.get_exact_location_and_length(); - unsafe { core::slice::from_raw_parts(ptr, len) } - } + unsafe { core::slice::from_raw_parts(ptr, len) } + } } impl ComputeStorage for BytesStorage { - type Resource = BytesResource; + type Resource = BytesResource; - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let allocated_bytes = self.memory.get_mut(&handle.id).unwrap(); + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let allocated_bytes = self.memory.get_mut(&handle.id).unwrap(); - BytesResource { - ptr: allocated_bytes.ptr, - utilization: handle.utilization.clone(), + BytesResource { + ptr: allocated_bytes.ptr, + utilization: handle.utilization.clone(), + } } - } - fn alloc(&mut self, size: usize) -> StorageHandle { - let id = StorageId::new(); - let handle = StorageHandle { - id: id.clone(), - utilization: StorageUtilization::Full(size), - }; + fn alloc(&mut self, size: usize) -> StorageHandle { + let id = StorageId::new(); + let handle = StorageHandle { + id: id.clone(), + utilization: StorageUtilization::Full(size), + }; - unsafe { - let layout = Layout::array::(size).unwrap(); - let ptr = alloc(layout); - let memory = AllocatedBytes { ptr, layout }; + unsafe { + let layout = Layout::array::(size).unwrap(); + let ptr = alloc(layout); + let memory = AllocatedBytes { ptr, layout }; - self.memory.insert(id, memory); - } + self.memory.insert(id, memory); + } - handle - } + handle + } - fn dealloc(&mut self, id: StorageId) { - if let Some(memory) = self.memory.remove(&id) { - unsafe { - dealloc(memory.ptr, memory.layout); - } + fn dealloc(&mut self, id: StorageId) { + if let Some(memory) = self.memory.remove(&id) { + unsafe { + dealloc(memory.ptr, memory.layout); + } + } } - } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_can_alloc_and_dealloc() { - let mut storage = BytesStorage::default(); - let handle_1 = storage.alloc(64); - - assert_eq!(handle_1.size(), 64); - storage.dealloc(handle_1.id); - } - - #[test] - fn test_slices() { - let mut storage = BytesStorage::default(); - let handle_1 = storage.alloc(64); - let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8)); - - storage - .get(&handle_1) - .write() - .iter_mut() - .enumerate() - .for_each(|(i, b)| { - *b = i as u8; - }); - - let bytes = storage.get(&handle_2).read().to_vec(); - storage.dealloc(handle_1.id); - assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]); - } + use super::*; + + #[test] + fn test_can_alloc_and_dealloc() { + let mut storage = BytesStorage::default(); + let handle_1 = storage.alloc(64); + + assert_eq!(handle_1.size(), 64); + storage.dealloc(handle_1.id); + } + + #[test] + fn test_slices() { + let mut storage = BytesStorage::default(); + let handle_1 = storage.alloc(64); + let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8)); + + storage + .get(&handle_1) + .write() + .iter_mut() + .enumerate() + .for_each(|(i, b)| { + *b = i as u8; + }); + + let bytes = storage.get(&handle_2).read().to_vec(); + storage.dealloc(handle_1.id); + assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]); + } } diff --git a/burn-compute/src/tune/operation.rs b/burn-compute/src/tune/operation.rs index 1dd1d19e60..548b5f59f2 100644 --- a/burn-compute/src/tune/operation.rs +++ b/burn-compute/src/tune/operation.rs @@ -6,30 +6,30 @@ use core::hash::Hash; /// Groups operations of the same type for autotune pub trait AutotuneOperationSet: Send { - /// The key used in the tune cache - fn key(&self) -> K; + /// The key used in the tune cache + fn key(&self) -> K; - /// All candidate operations for autotuning this operation type - /// Operations can run on toy tensors of relevant size - fn autotunables(&self) -> Vec>; + /// All candidate operations for autotuning this operation type + /// Operations can run on toy tensors of relevant size + fn autotunables(&self) -> Vec>; - /// Returns the operation for the given index, matching the order - /// returned by autotunables. Operation obtained here runs on original tensors - fn fastest(self: Box, fastest_index: usize) -> Box; + /// Returns the operation for the given index, matching the order + /// returned by autotunables. Operation obtained here runs on original tensors + fn fastest(self: Box, fastest_index: usize) -> Box; } /// Contains operation to run and inputs on which to run it pub trait AutotuneOperation { - /// Runs the operation - fn execute(self: Box); + /// Runs the operation + fn execute(self: Box); - /// The name of the operation. - fn name(&self) -> &str { - core::any::type_name::() - } + /// The name of the operation. + fn name(&self) -> &str { + core::any::type_name::() + } - /// Clones the operation and inputs - fn clone(&self) -> Box; + /// Clones the operation and inputs + fn clone(&self) -> Box; } /// Trait alias diff --git a/burn-compute/src/tune/tune_benchmark.rs b/burn-compute/src/tune/tune_benchmark.rs index bf4051e77c..dd0231340d 100644 --- a/burn-compute/src/tune/tune_benchmark.rs +++ b/burn-compute/src/tune/tune_benchmark.rs @@ -11,30 +11,30 @@ use alloc::string::{String, ToString}; /// A benchmark that runs on server handles #[derive(new)] pub struct TuneBenchmark { - operation: Box, - client: ComputeClient, + operation: Box, + client: ComputeClient, } impl> Benchmark for TuneBenchmark { - type Args = Box; + type Args = Box; - fn prepare(&self) -> Self::Args { - self.operation.clone() - } + fn prepare(&self) -> Self::Args { + self.operation.clone() + } - fn num_samples(&self) -> usize { - 10 - } + fn num_samples(&self) -> usize { + 10 + } - fn execute(&self, operation: Self::Args) { - AutotuneOperation::execute(operation); - } + fn execute(&self, operation: Self::Args) { + AutotuneOperation::execute(operation); + } - fn name(&self) -> String { - "Autotune".to_string() - } + fn name(&self) -> String { + "Autotune".to_string() + } - fn sync(&self) { - self.client.sync(); - } + fn sync(&self) { + self.client.sync(); + } } diff --git a/burn-compute/src/tune/tune_cache.rs b/burn-compute/src/tune/tune_cache.rs index 30d63619da..d91b70ec12 100644 --- a/burn-compute/src/tune/tune_cache.rs +++ b/burn-compute/src/tune/tune_cache.rs @@ -7,37 +7,37 @@ use hashbrown::HashMap; /// Use to find and reuse the best kernel for some input #[derive(Debug, Default)] pub(crate) struct TuneCache { - cache: HashMap, + cache: HashMap, } /// Result of the cache try pub enum TuneCacheResult { - /// An operation is found and given - Hit(Box), - /// No operation is found and the set is given back for ownership - Miss(Box>), + /// An operation is found and given + Hit(Box), + /// No operation is found and the set is given back for ownership + Miss(Box>), } impl TuneCache { - pub(crate) fn new() -> Self { - TuneCache { - cache: HashMap::new(), + pub(crate) fn new() -> Self { + TuneCache { + cache: HashMap::new(), + } } - } - #[allow(clippy::borrowed_box)] - pub(crate) fn try_cache( - &self, - autotune_operation_set: Box>, - ) -> TuneCacheResult { - let index = self.cache.get(&autotune_operation_set.key()); - if let Some(&i) = index { - return TuneCacheResult::Hit(autotune_operation_set.fastest(i)); + #[allow(clippy::borrowed_box)] + pub(crate) fn try_cache( + &self, + autotune_operation_set: Box>, + ) -> TuneCacheResult { + let index = self.cache.get(&autotune_operation_set.key()); + if let Some(&i) = index { + return TuneCacheResult::Hit(autotune_operation_set.fastest(i)); + } + TuneCacheResult::Miss(autotune_operation_set) } - TuneCacheResult::Miss(autotune_operation_set) - } - pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) { - self.cache.insert(key, fastest_index); - } + pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) { + self.cache.insert(key, fastest_index); + } } diff --git a/burn-compute/src/tune/tuner.rs b/burn-compute/src/tune/tuner.rs index a2d8a3e7fe..c9a9afeb8f 100644 --- a/burn-compute/src/tune/tuner.rs +++ b/burn-compute/src/tune/tuner.rs @@ -14,87 +14,87 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa #[derive(Debug, Default)] /// Executes autotune benchmarking and caching pub struct Tuner { - tune_cache: TuneCache, - _channel: PhantomData, + tune_cache: TuneCache, + _channel: PhantomData, } impl> Tuner { - /// Returns a tuner with empty cache - pub fn new() -> Self { - Self { - tune_cache: TuneCache::new(), - _channel: PhantomData, - } - } - - pub(crate) fn execute_autotune( - &mut self, - autotune_operation_set: Box>, - client: &ComputeClient, - ) { - let operation = match self.tune_cache.try_cache(autotune_operation_set) { - super::TuneCacheResult::Hit(ops) => ops, - super::TuneCacheResult::Miss(set) => self.autotuning(set, client), - }; - - AutotuneOperation::execute(operation); - } - - fn autotuning( - &mut self, - autotune_operation_set: Box>, - client: &ComputeClient, - ) -> Box { - let key = autotune_operation_set.key(); - let autotunables = autotune_operation_set.autotunables(); - let mut names = Vec::with_capacity(autotunables.len()); - - // Run all autotune benchmarks - let results: Vec = autotunables - .into_iter() - .map(|op| { - names.push(op.name().to_string()); - self.run_benchmark(op, client) - }) - .collect(); - - for (name, result) in names.iter().zip(results.iter()) { - log::info!("Benchmark result {name}-{key} => {result}"); + /// Returns a tuner with empty cache + pub fn new() -> Self { + Self { + tune_cache: TuneCache::new(), + _channel: PhantomData, + } } - // Finds the fastest operation, stores it and returns it - let fastest_index = self.find_fastest(results); - let fastest_name = names.get(fastest_index).unwrap(); - log::info!("Fastest result {fastest_name}-{key}"); + pub(crate) fn execute_autotune( + &mut self, + autotune_operation_set: Box>, + client: &ComputeClient, + ) { + let operation = match self.tune_cache.try_cache(autotune_operation_set) { + super::TuneCacheResult::Hit(ops) => ops, + super::TuneCacheResult::Miss(set) => self.autotuning(set, client), + }; + + AutotuneOperation::execute(operation); + } - self.tune_cache.cache_insert(key, fastest_index); - match self.tune_cache.try_cache(autotune_operation_set) { - super::TuneCacheResult::Hit(ops) => ops, - super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"), + fn autotuning( + &mut self, + autotune_operation_set: Box>, + client: &ComputeClient, + ) -> Box { + let key = autotune_operation_set.key(); + let autotunables = autotune_operation_set.autotunables(); + let mut names = Vec::with_capacity(autotunables.len()); + + // Run all autotune benchmarks + let results: Vec = autotunables + .into_iter() + .map(|op| { + names.push(op.name().to_string()); + self.run_benchmark(op, client) + }) + .collect(); + + for (name, result) in names.iter().zip(results.iter()) { + log::info!("Benchmark result {name}-{key} => {result}"); + } + + // Finds the fastest operation, stores it and returns it + let fastest_index = self.find_fastest(results); + let fastest_name = names.get(fastest_index).unwrap(); + log::info!("Fastest result {fastest_name}-{key}"); + + self.tune_cache.cache_insert(key, fastest_index); + match self.tune_cache.try_cache(autotune_operation_set) { + super::TuneCacheResult::Hit(ops) => ops, + super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"), + } } - } - - fn run_benchmark( - &mut self, - operation: Box, - client: &ComputeClient, - ) -> BenchmarkResult { - TuneBenchmark::new(operation, client.clone()).run() - } - - fn find_fastest(&self, results: Vec) -> usize { - let mut smallest_duration = Duration::MAX; - let mut fastest_tunable = None; - - for (i, result) in results.into_iter().enumerate() { - let duration = result.median_duration(); - - if duration < smallest_duration { - smallest_duration = duration; - fastest_tunable = Some(i); - } + + fn run_benchmark( + &mut self, + operation: Box, + client: &ComputeClient, + ) -> BenchmarkResult { + TuneBenchmark::new(operation, client.clone()).run() } - fastest_tunable.expect("At least one kernel needed. ") - } + fn find_fastest(&self, results: Vec) -> usize { + let mut smallest_duration = Duration::MAX; + let mut fastest_tunable = None; + + for (i, result) in results.into_iter().enumerate() { + let duration = result.median_duration(); + + if duration < smallest_duration { + smallest_duration = duration; + fastest_tunable = Some(i); + } + } + + fastest_tunable.expect("At least one kernel needed. ") + } } diff --git a/burn-compute/tests/dummy/compute.rs b/burn-compute/tests/dummy/compute.rs index e651082565..54081a1510 100644 --- a/burn-compute/tests/dummy/compute.rs +++ b/burn-compute/tests/dummy/compute.rs @@ -19,14 +19,14 @@ pub type DummyClient = ComputeClient; static COMPUTE: Compute = Compute::new(); pub fn client(device: &DummyDevice) -> DummyClient { - COMPUTE.client(device, || { - let storage = BytesStorage::default(); - let memory_management = - SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never); - let server = DummyServer::new(memory_management); - let channel = MutexComputeChannel::new(server); - let tuner = Arc::new(Mutex::new(Tuner::new())); + COMPUTE.client(device, || { + let storage = BytesStorage::default(); + let memory_management = + SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never); + let server = DummyServer::new(memory_management); + let channel = MutexComputeChannel::new(server); + let tuner = Arc::new(Mutex::new(Tuner::new())); - ComputeClient::new(channel, tuner) - }) + ComputeClient::new(channel, tuner) + }) } diff --git a/burn-compute/tests/dummy/kernel.rs b/burn-compute/tests/dummy/kernel.rs index b2f8cf2668..30a67d5538 100644 --- a/burn-compute/tests/dummy/kernel.rs +++ b/burn-compute/tests/dummy/kernel.rs @@ -2,24 +2,24 @@ use burn_compute::storage::BytesResource; /// The DummyKernel trait should be implemented for every supported operation pub trait DummyKernel: Sync + Send { - fn compute(&self, resources: &mut [BytesResource]); + fn compute(&self, resources: &mut [BytesResource]); } /// Contains the algorithm for element-wise addition pub struct DummyElementwiseAddition; impl DummyKernel for DummyElementwiseAddition { - fn compute(&self, inputs: &mut [BytesResource]) { - // Notice how the kernel is responsible for determining which inputs - // are read-only and which are writable. - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + // Notice how the kernel is responsible for determining which inputs + // are read-only and which are writable. + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - out[i] = lhs[i] + rhs[i]; + for i in 0..size { + out[i] = lhs[i] + rhs[i]; + } } - } } diff --git a/burn-compute/tests/dummy/server.rs b/burn-compute/tests/dummy/server.rs index 5749a407ab..55d8f49c2e 100644 --- a/burn-compute/tests/dummy/server.rs +++ b/burn-compute/tests/dummy/server.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use burn_common::reader::Reader; use burn_compute::{ - memory_management::{MemoryManagement, SimpleMemoryManagement}, - server::{ComputeServer, Handle}, - storage::BytesStorage, + memory_management::{MemoryManagement, SimpleMemoryManagement}, + server::{ComputeServer, Handle}, + storage::BytesStorage, }; use derive_new::new; @@ -14,51 +14,51 @@ use super::DummyKernel; /// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks. #[derive(new, Debug)] pub struct DummyServer> { - memory_management: MM, + memory_management: MM, } impl ComputeServer for DummyServer where - MM: MemoryManagement, + MM: MemoryManagement, { - type Kernel = Arc; - type Storage = BytesStorage; - type MemoryManagement = MM; - type AutotuneKey = String; + type Kernel = Arc; + type Storage = BytesStorage; + type MemoryManagement = MM; + type AutotuneKey = String; - fn read(&mut self, handle: &Handle) -> Reader> { - let bytes = self.memory_management.get(&handle.memory); + fn read(&mut self, handle: &Handle) -> Reader> { + let bytes = self.memory_management.get(&handle.memory); - Reader::Concrete(bytes.read().to_vec()) - } + Reader::Concrete(bytes.read().to_vec()) + } - fn create(&mut self, data: &[u8]) -> Handle { - let handle = self.memory_management.reserve(data.len()); - let resource = self.memory_management.get(&handle); + fn create(&mut self, data: &[u8]) -> Handle { + let handle = self.memory_management.reserve(data.len()); + let resource = self.memory_management.get(&handle); - let bytes = resource.write(); + let bytes = resource.write(); - for (i, val) in data.iter().enumerate() { - bytes[i] = *val; - } + for (i, val) in data.iter().enumerate() { + bytes[i] = *val; + } - Handle::new(handle) - } + Handle::new(handle) + } - fn empty(&mut self, size: usize) -> Handle { - Handle::new(self.memory_management.reserve(size)) - } + fn empty(&mut self, size: usize) -> Handle { + Handle::new(self.memory_management.reserve(size)) + } - fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]) { - let mut resources = handles - .iter() - .map(|handle| self.memory_management.get(&handle.memory)) - .collect::>(); + fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]) { + let mut resources = handles + .iter() + .map(|handle| self.memory_management.get(&handle.memory)) + .collect::>(); - kernel.compute(&mut resources); - } + kernel.compute(&mut resources); + } - fn sync(&mut self) { - // Nothing to do with dummy backend. - } + fn sync(&mut self) { + // Nothing to do with dummy backend. + } } diff --git a/burn-compute/tests/dummy/tune/autotune_operations.rs b/burn-compute/tests/dummy/tune/autotune_operations.rs index fec4ab029c..5af0eaa472 100644 --- a/burn-compute/tests/dummy/tune/autotune_operations.rs +++ b/burn-compute/tests/dummy/tune/autotune_operations.rs @@ -9,25 +9,25 @@ use crate::dummy::{DummyChannel, DummyKernel, DummyServer}; /// Extended kernel that accounts for additional parameters, i.e. needed /// information that does not count as an input/output. pub struct OneKernelAutotuneOperation { - kernel: Arc, - client: ComputeClient, - shapes: Vec>, - handles: Vec>, + kernel: Arc, + client: ComputeClient, + shapes: Vec>, + handles: Vec>, } impl AutotuneOperation for OneKernelAutotuneOperation { - /// Executes the operation on given handles and server, with the additional parameters - fn execute(self: Box) { - let handle_refs: &Vec<&Handle> = &self.handles.iter().collect(); - self.client.execute(self.kernel.clone(), handle_refs); - } + /// Executes the operation on given handles and server, with the additional parameters + fn execute(self: Box) { + let handle_refs: &Vec<&Handle> = &self.handles.iter().collect(); + self.client.execute(self.kernel.clone(), handle_refs); + } - fn clone(&self) -> Box { - Box::new(Self { - kernel: self.kernel.clone(), - client: self.client.clone(), - shapes: self.shapes.clone(), - handles: self.handles.clone(), - }) - } + fn clone(&self) -> Box { + Box::new(Self { + kernel: self.kernel.clone(), + client: self.client.clone(), + shapes: self.shapes.clone(), + handles: self.handles.clone(), + }) + } } diff --git a/burn-compute/tests/dummy/tune/kernels.rs b/burn-compute/tests/dummy/tune/kernels.rs index ffce69802f..bd0058310b 100644 --- a/burn-compute/tests/dummy/tune/kernels.rs +++ b/burn-compute/tests/dummy/tune/kernels.rs @@ -14,93 +14,93 @@ pub struct CacheTestSlowOn3; pub struct ParameteredKernel; impl DummyKernel for DummyElementwiseAdditionSlowWrong { - fn compute(&self, inputs: &mut [BytesResource]) { - // Slow and wrong on purpose, for tests - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + // Slow and wrong on purpose, for tests + let lhs = &inputs[0].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i] + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = lhs[i] + } } - } } impl DummyKernel for DummyElementwiseMultiplication { - fn compute(&self, inputs: &mut [BytesResource]) { - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - out[i] = lhs[i] * rhs[i]; + for i in 0..size { + out[i] = lhs[i] * rhs[i]; + } } - } } impl DummyKernel for DummyElementwiseMultiplicationSlowWrong { - fn compute(&self, inputs: &mut [BytesResource]) { - // Slow and wrong on purpose, for tests - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); + fn compute(&self, inputs: &mut [BytesResource]) { + // Slow and wrong on purpose, for tests + let lhs = &inputs[0].read(); + let out = &mut inputs[2].write(); - let size = lhs.len(); + let size = lhs.len(); - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i]; + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = lhs[i]; + } } - } } impl DummyKernel for CacheTestFastOn3 { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for testing cache only - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - if size == 3 { - out[..size].copy_from_slice(&lhs[..size]); - } else { - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i]; - } + fn compute(&self, inputs: &mut [BytesResource]) { + // This is an artificial kernel designed for testing cache only + let lhs = &inputs[0].read(); + let out = &mut inputs[2].write(); + + let size = lhs.len(); + if size == 3 { + out[..size].copy_from_slice(&lhs[..size]); + } else { + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = lhs[i]; + } + } } - } } impl DummyKernel for CacheTestSlowOn3 { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for testing cache only - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - if size == 3 { - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = rhs[i]; - } - } else { - out[..size].copy_from_slice(&rhs[..size]); + fn compute(&self, inputs: &mut [BytesResource]) { + // This is an artificial kernel designed for testing cache only + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); + + let size = lhs.len(); + if size == 3 { + for i in 0..size { + sleep(Duration::from_millis(SLEEP_MS)); + out[i] = rhs[i]; + } + } else { + out[..size].copy_from_slice(&rhs[..size]); + } } - } } impl DummyKernel for ParameteredKernel { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for info buffer - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - let info = &inputs[3].read(); - - for i in 0..lhs.len() { - out[i] = lhs[i] + rhs[i] + info[0]; + fn compute(&self, inputs: &mut [BytesResource]) { + // This is an artificial kernel designed for info buffer + let lhs = &inputs[0].read(); + let rhs = &inputs[1].read(); + let out = &mut inputs[2].write(); + let info = &inputs[3].read(); + + for i in 0..lhs.len() { + out[i] = lhs[i] + rhs[i] + info[0]; + } } - } } diff --git a/burn-compute/tests/dummy/tune/operation_sets.rs b/burn-compute/tests/dummy/tune/operation_sets.rs index 89a68f1aa6..f5b30e0727 100644 --- a/burn-compute/tests/dummy/tune/operation_sets.rs +++ b/burn-compute/tests/dummy/tune/operation_sets.rs @@ -1,170 +1,170 @@ use std::sync::Arc; use burn_compute::{ - server::Handle, - tune::{AutotuneOperation, AutotuneOperationSet}, + server::Handle, + tune::{AutotuneOperation, AutotuneOperationSet}, }; use crate::dummy::{ - CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition, - DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer, - OneKernelAutotuneOperation, + CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition, + DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer, + OneKernelAutotuneOperation, }; use super::DummyElementwiseAdditionSlowWrong; pub struct AdditionAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - handles: Vec>, -} - -impl AdditionAutotuneOperationSet { - pub fn new( client: DummyClient, + key: String, shapes: Vec>, handles: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "add", log_shape_input_key(&shapes)), - shapes, - handles, +} + +impl AdditionAutotuneOperationSet { + pub fn new( + client: DummyClient, + shapes: Vec>, + handles: Vec>, + ) -> Self { + Self { + client, + key: format!("{}-{}", "add", log_shape_input_key(&shapes)), + shapes, + handles, + } } - } } impl AutotuneOperationSet for AdditionAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseAddition), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseAdditionSlowWrong), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } -} + fn key(&self) -> String { + self.key.clone() + } -pub struct MultiplicationAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - handles: Vec>, + fn autotunables(&self) -> Vec> { + vec![ + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseAddition), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseAdditionSlowWrong), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + self.autotunables()[fastest_index].clone() + } } -impl MultiplicationAutotuneOperationSet { - pub fn new( +pub struct MultiplicationAutotuneOperationSet { client: DummyClient, + key: String, shapes: Vec>, handles: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "mul", log_shape_input_key(&shapes)), - shapes, - handles, +} + +impl MultiplicationAutotuneOperationSet { + pub fn new( + client: DummyClient, + shapes: Vec>, + handles: Vec>, + ) -> Self { + Self { + client, + key: format!("{}-{}", "mul", log_shape_input_key(&shapes)), + shapes, + handles, + } } - } } impl AutotuneOperationSet for MultiplicationAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseMultiplicationSlowWrong), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseMultiplication), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } -} + fn key(&self) -> String { + self.key.clone() + } -pub struct CacheTestAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - handles: Vec>, + fn autotunables(&self) -> Vec> { + vec![ + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseMultiplicationSlowWrong), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + Box::new(OneKernelAutotuneOperation::new( + Arc::new(DummyElementwiseMultiplication), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + self.autotunables()[fastest_index].clone() + } } -impl CacheTestAutotuneOperationSet { - pub fn new( +pub struct CacheTestAutotuneOperationSet { client: DummyClient, + key: String, shapes: Vec>, handles: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)), - shapes, - handles, +} + +impl CacheTestAutotuneOperationSet { + pub fn new( + client: DummyClient, + shapes: Vec>, + handles: Vec>, + ) -> Self { + Self { + client, + key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)), + shapes, + handles, + } } - } } impl AutotuneOperationSet for CacheTestAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(CacheTestFastOn3), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(CacheTestSlowOn3), - self.client.clone(), - self.shapes.clone(), - self.handles.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } + fn key(&self) -> String { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + vec![ + Box::new(OneKernelAutotuneOperation::new( + Arc::new(CacheTestFastOn3), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + Box::new(OneKernelAutotuneOperation::new( + Arc::new(CacheTestSlowOn3), + self.client.clone(), + self.shapes.clone(), + self.handles.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + self.autotunables()[fastest_index].clone() + } } pub fn log_shape_input_key(shapes: &[Vec]) -> String { - let mut hash = String::new(); - let lhs = &shapes[0]; - for size in lhs { - let exp = f32::ceil(f32::log2(*size as f32)) as u32; - hash.push_str(2_u32.pow(exp).to_string().as_str()); - hash.push(','); - } - hash + let mut hash = String::new(); + let lhs = &shapes[0]; + for size in lhs { + let exp = f32::ceil(f32::log2(*size as f32)) as u32; + hash.push_str(2_u32.pow(exp).to_string().as_str()); + hash.push(','); + } + hash } diff --git a/burn-compute/tests/integration_test.rs b/burn-compute/tests/integration_test.rs index c9db430423..79532f7f9a 100644 --- a/burn-compute/tests/integration_test.rs +++ b/burn-compute/tests/integration_test.rs @@ -8,141 +8,141 @@ use serial_test::serial; #[test] fn created_resource_is_the_same_when_read() { - let client = client(&DummyDevice); - let resource = Vec::from([0, 1, 2]); - let resource_description = client.create(&resource); + let client = client(&DummyDevice); + let resource = Vec::from([0, 1, 2]); + let resource_description = client.create(&resource); - let obtained_resource = client.read(&resource_description); + let obtained_resource = client.read(&resource_description); - assert_eq!(resource, obtained_resource.read()) + assert_eq!(resource, obtained_resource.read()) } #[test] fn empty_allocates_memory() { - let client = client(&DummyDevice); - let size = 4; - let resource_description = client.empty(size); - let empty_resource = client.read(&resource_description); + let client = client(&DummyDevice); + let size = 4; + let resource_description = client.empty(size); + let empty_resource = client.read(&resource_description); - assert_eq!(empty_resource.read().len(), 4); + assert_eq!(empty_resource.read().len(), 4); } #[test] fn execute_elementwise_addition() { - let client = client(&DummyDevice); - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); + let client = client(&DummyDevice); + let lhs = client.create(&[0, 1, 2]); + let rhs = client.create(&[4, 4, 4]); + let out = client.empty(3); - client.execute(Arc::new(DummyElementwiseAddition), &[&lhs, &rhs, &out]); + client.execute(Arc::new(DummyElementwiseAddition), &[&lhs, &rhs, &out]); - let obtained_resource = client.read(&out); + let obtained_resource = client.read(&out); - assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])) + assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])) } #[test] #[serial] #[cfg(feature = "std")] fn autotune_basic_addition_execution() { - let client = client(&DummyDevice); + let client = client(&DummyDevice); - let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - let handles = vec![lhs, rhs, out.clone()]; + let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs = client.create(&[0, 1, 2]); + let rhs = client.create(&[4, 4, 4]); + let out = client.empty(3); + let handles = vec![lhs, rhs, out.clone()]; - let addition_autotune_kernel = - dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles); - client.execute_autotune(Box::new(addition_autotune_kernel)); + let addition_autotune_kernel = + dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles); + client.execute_autotune(Box::new(addition_autotune_kernel)); - let obtained_resource = client.read(&out); + let obtained_resource = client.read(&out); - // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])); + // If slow kernel was selected it would output [0, 1, 2] + assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])); } #[test] #[serial] #[cfg(feature = "std")] fn autotune_basic_multiplication_execution() { - let client = client(&DummyDevice); + let client = client(&DummyDevice); - let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - let handles = vec![lhs, rhs, out.clone()]; + let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs = client.create(&[0, 1, 2]); + let rhs = client.create(&[4, 4, 4]); + let out = client.empty(3); + let handles = vec![lhs, rhs, out.clone()]; - let multiplication_autotune_kernel = - dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles); - client.execute_autotune(Box::new(multiplication_autotune_kernel)); + let multiplication_autotune_kernel = + dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles); + client.execute_autotune(Box::new(multiplication_autotune_kernel)); - let obtained_resource = client.read(&out); + let obtained_resource = client.read(&out); - // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource.read(), Vec::from([0, 4, 8])); + // If slow kernel was selected it would output [0, 1, 2] + assert_eq!(obtained_resource.read(), Vec::from([0, 4, 8])); } #[test] #[serial] #[cfg(feature = "std")] fn autotune_cache_hit_test() { - let client = client(&DummyDevice); - - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1, rhs_1, out_1]; - - let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]]; - let lhs_2 = client.create(&[0, 1, 2, 3]); - let rhs_2 = client.create(&[5, 6, 7, 8]); - let out_2 = client.empty(4); - let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; - - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - let cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); - client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); - - let obtained_resource = client.read(&out_2); - - // Cache should be hit, so CacheTestFastOn3 should be used, returning lhs - assert_eq!(obtained_resource.read(), Vec::from([0, 1, 2, 3])); + let client = client(&DummyDevice); + + let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs_1 = client.create(&[0, 1, 2]); + let rhs_1 = client.create(&[4, 4, 4]); + let out_1 = client.empty(3); + let handles_1 = vec![lhs_1, rhs_1, out_1]; + + let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]]; + let lhs_2 = client.create(&[0, 1, 2, 3]); + let rhs_2 = client.create(&[5, 6, 7, 8]); + let out_2 = client.empty(4); + let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; + + let cache_test_autotune_kernel_1 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); + let cache_test_autotune_kernel_2 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); + client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); + client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); + + let obtained_resource = client.read(&out_2); + + // Cache should be hit, so CacheTestFastOn3 should be used, returning lhs + assert_eq!(obtained_resource.read(), Vec::from([0, 1, 2, 3])); } #[test] #[serial] #[cfg(feature = "std")] fn autotune_cache_miss_test() { - let client = client(&DummyDevice); - - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1, rhs_1, out_1]; - - let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]]; - let lhs_2 = client.create(&[0, 1, 2, 3, 4]); - let rhs_2 = client.create(&[5, 6, 7, 8, 9]); - let out_2 = client.empty(5); - let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; - - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - let cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); - client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); - - let obtained_resource = client.read(&out_2); - - // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs - assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9])); + let client = client(&DummyDevice); + + let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; + let lhs_1 = client.create(&[0, 1, 2]); + let rhs_1 = client.create(&[4, 4, 4]); + let out_1 = client.empty(3); + let handles_1 = vec![lhs_1, rhs_1, out_1]; + + let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]]; + let lhs_2 = client.create(&[0, 1, 2, 3, 4]); + let rhs_2 = client.create(&[5, 6, 7, 8, 9]); + let out_2 = client.empty(5); + let handles_2 = vec![lhs_2, rhs_2, out_2.clone()]; + + let cache_test_autotune_kernel_1 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); + let cache_test_autotune_kernel_2 = + dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); + client.execute_autotune(Box::new(cache_test_autotune_kernel_1)); + client.execute_autotune(Box::new(cache_test_autotune_kernel_2)); + + let obtained_resource = client.read(&out_2); + + // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs + assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9])); } diff --git a/burn-core/src/config.rs b/burn-core/src/config.rs index ba5340a30b..94166228da 100644 --- a/burn-core/src/config.rs +++ b/burn-core/src/config.rs @@ -4,28 +4,28 @@ pub use burn_derive::Config; /// Configuration IO error. #[derive(Debug)] pub enum ConfigError { - /// Invalid format. - InvalidFormat(String), + /// Invalid format. + InvalidFormat(String), - /// File not found. - FileNotFound(String), + /// File not found. + FileNotFound(String), } impl core::fmt::Display for ConfigError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mut message = "Config error => ".to_string(); + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut message = "Config error => ".to_string(); - match self { - Self::InvalidFormat(err) => { - message += format!("Invalid format: {err}").as_str(); - } - Self::FileNotFound(err) => { - message += format!("File not found: {err}").as_str(); - } - }; + match self { + Self::InvalidFormat(err) => { + message += format!("Invalid format: {err}").as_str(); + } + Self::FileNotFound(err) => { + message += format!("File not found: {err}").as_str(); + } + }; - f.write_str(message.as_str()) - } + f.write_str(message.as_str()) + } } // TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765) @@ -34,50 +34,51 @@ impl std::error::Error for ConfigError {} /// Configuration trait. pub trait Config: serde::Serialize + serde::de::DeserializeOwned { - /// Saves the configuration to a file. - /// - /// # Arguments - /// - /// * `file` - File to save the configuration to. - /// - /// # Returns - /// - /// The output of the save operation. - #[cfg(feature = "std")] - fn save>(&self, file: P) -> std::io::Result<()> { - std::fs::write(file, config_to_json(self)) - } + /// Saves the configuration to a file. + /// + /// # Arguments + /// + /// * `file` - File to save the configuration to. + /// + /// # Returns + /// + /// The output of the save operation. + #[cfg(feature = "std")] + fn save>(&self, file: P) -> std::io::Result<()> { + std::fs::write(file, config_to_json(self)) + } - /// Loads the configuration from a file. - /// - /// # Arguments - /// - /// * `file` - File to load the configuration from. - /// - /// # Returns - /// - /// The loaded configuration. - #[cfg(feature = "std")] - fn load>(file: P) -> Result { - let content = std::fs::read_to_string(file.as_ref()) - .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?; - config_from_str(&content) - } + /// Loads the configuration from a file. + /// + /// # Arguments + /// + /// * `file` - File to load the configuration from. + /// + /// # Returns + /// + /// The loaded configuration. + #[cfg(feature = "std")] + fn load>(file: P) -> Result { + let content = std::fs::read_to_string(file.as_ref()) + .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?; + config_from_str(&content) + } - /// Loads the configuration from a binary buffer. - /// - /// # Arguments - /// - /// * `data` - Binary buffer to load the configuration from. - /// - /// # Returns - /// - /// The loaded configuration. - fn load_binary(data: &[u8]) -> Result { - let content = core::str::from_utf8(data) - .map_err(|_| ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string()))?; - config_from_str(content) - } + /// Loads the configuration from a binary buffer. + /// + /// # Arguments + /// + /// * `data` - Binary buffer to load the configuration from. + /// + /// # Returns + /// + /// The loaded configuration. + fn load_binary(data: &[u8]) -> Result { + let content = core::str::from_utf8(data).map_err(|_| { + ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string()) + })?; + config_from_str(content) + } } /// Converts a configuration to a JSON string. @@ -90,9 +91,9 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned { /// /// The JSON string. pub fn config_to_json(config: &C) -> String { - serde_json::to_string_pretty(config).unwrap() + serde_json::to_string_pretty(config).unwrap() } fn config_from_str(content: &str) -> Result { - serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}"))) + serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}"))) } diff --git a/burn-core/src/data/dataloader/base.rs b/burn-core/src/data/dataloader/base.rs index 7619f95b47..0222248ad3 100644 --- a/burn-core/src/data/dataloader/base.rs +++ b/burn-core/src/data/dataloader/base.rs @@ -4,21 +4,21 @@ use core::iter::Iterator; /// A progress struct that can be used to track the progress of a data loader. #[derive(Clone, Debug)] pub struct Progress { - /// The number of items that have been processed. - pub items_processed: usize, + /// The number of items that have been processed. + pub items_processed: usize, - /// The total number of items that need to be processed. - pub items_total: usize, + /// The total number of items that need to be processed. + pub items_total: usize, } /// A data loader iterator that can be used to iterate over a data loader. pub trait DataLoaderIterator: Iterator { - /// Returns the progress of the data loader. - fn progress(&self) -> Progress; + /// Returns the progress of the data loader. + fn progress(&self) -> Progress; } /// A data loader that can be used to iterate over a dataset. pub trait DataLoader { - /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader. - fn iter<'a>(&'a self) -> Box + 'a>; + /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader. + fn iter<'a>(&'a self) -> Box + 'a>; } diff --git a/burn-core/src/data/dataloader/batch.rs b/burn-core/src/data/dataloader/batch.rs index 34b82c953d..978127bad8 100644 --- a/burn-core/src/data/dataloader/batch.rs +++ b/burn-core/src/data/dataloader/batch.rs @@ -1,253 +1,254 @@ use super::{ - batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, MultiThreadDataLoader, Progress, + batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, MultiThreadDataLoader, + Progress, }; use burn_dataset::{ - transform::{PartialDataset, ShuffledDataset}, - Dataset, + transform::{PartialDataset, ShuffledDataset}, + Dataset, }; use rand::{distributions::Standard, prelude::Distribution, rngs::StdRng, Rng, SeedableRng}; use std::sync::Arc; /// A data loader that can be used to iterate over a dataset in batches. pub struct BatchDataLoader { - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - rng: Option>, -} - -impl BatchDataLoader { - /// Creates a new batch data loader. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader - /// iterator is created. - /// - /// # Returns - /// - /// The batch data loader. - pub fn new( strategy: Box>, dataset: Arc>, batcher: Arc>, - rng: Option, - ) -> Self { - Self { - strategy, - dataset, - batcher, - rng: rng.map(spin::Mutex::new), + rng: Option>, +} + +impl BatchDataLoader { + /// Creates a new batch data loader. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader + /// iterator is created. + /// + /// # Returns + /// + /// The batch data loader. + pub fn new( + strategy: Box>, + dataset: Arc>, + batcher: Arc>, + rng: Option, + ) -> Self { + Self { + strategy, + dataset, + batcher, + rng: rng.map(spin::Mutex::new), + } } - } } /// A data loader iterator that can be used to iterate over a data loader. struct BatchDataloaderIterator { - current_index: usize, - strategy: Box>, - dataset: Arc>, - batcher: Arc>, + current_index: usize, + strategy: Box>, + dataset: Arc>, + batcher: Arc>, } impl BatchDataLoader where - I: Send + Sync + Clone + 'static, - O: Send + Sync + Clone + 'static, + I: Send + Sync + Clone + 'static, + O: Send + Sync + Clone + 'static, { - /// Creates a new multi-threaded batch data loader. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// * `num_threads` - The number of threads. - /// - /// # Returns - /// - /// The multi-threaded batch data loader. - pub fn multi_thread( - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - num_threads: usize, - mut rng: Option, - ) -> MultiThreadDataLoader { - let datasets = PartialDataset::split(dataset, num_threads); - - let mut dataloaders: Vec + Send + Sync>> = - Vec::with_capacity(num_threads); - - // Create more rngs from the first one, one for each new dataloader. - let rngs = (0..num_threads).map(|_| { - rng - .as_mut() - .map(|rng| StdRng::seed_from_u64(Distribution::sample(&Standard, rng))) - }); - - for (dataset, rng) in datasets.into_iter().zip(rngs) { - let strategy = strategy.new_like(); - let dataloader = BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone(), rng); - let dataloader = Arc::new(dataloader); - dataloaders.push(dataloader); + /// Creates a new multi-threaded batch data loader. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// * `num_threads` - The number of threads. + /// + /// # Returns + /// + /// The multi-threaded batch data loader. + pub fn multi_thread( + strategy: Box>, + dataset: Arc>, + batcher: Arc>, + num_threads: usize, + mut rng: Option, + ) -> MultiThreadDataLoader { + let datasets = PartialDataset::split(dataset, num_threads); + + let mut dataloaders: Vec + Send + Sync>> = + Vec::with_capacity(num_threads); + + // Create more rngs from the first one, one for each new dataloader. + let rngs = (0..num_threads).map(|_| { + rng.as_mut() + .map(|rng| StdRng::seed_from_u64(Distribution::sample(&Standard, rng))) + }); + + for (dataset, rng) in datasets.into_iter().zip(rngs) { + let strategy = strategy.new_like(); + let dataloader = + BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone(), rng); + let dataloader = Arc::new(dataloader); + dataloaders.push(dataloader); + } + MultiThreadDataLoader::new(dataloaders) } - MultiThreadDataLoader::new(dataloaders) - } } impl DataLoader for BatchDataLoader { - fn iter<'a>(&'a self) -> Box + 'a> { - // When starting a new iteration, we first check if the dataloader was created with an rng, - // implying that we should shuffle the dataset beforehand, while advancing the current - // rng to ensure that each new iteration shuffles the dataset differently. - let dataset = match &self.rng { - Some(rng) => { - let mut rng = rng.lock(); - - Arc::new(ShuffledDataset::with_seed( - self.dataset.clone(), - rng.sample(Standard), + fn iter<'a>(&'a self) -> Box + 'a> { + // When starting a new iteration, we first check if the dataloader was created with an rng, + // implying that we should shuffle the dataset beforehand, while advancing the current + // rng to ensure that each new iteration shuffles the dataset differently. + let dataset = match &self.rng { + Some(rng) => { + let mut rng = rng.lock(); + + Arc::new(ShuffledDataset::with_seed( + self.dataset.clone(), + rng.sample(Standard), + )) + } + None => self.dataset.clone(), + }; + Box::new(BatchDataloaderIterator::new( + self.strategy.new_like(), + dataset, + self.batcher.clone(), )) - } - None => self.dataset.clone(), - }; - Box::new(BatchDataloaderIterator::new( - self.strategy.new_like(), - dataset, - self.batcher.clone(), - )) - } + } } impl BatchDataloaderIterator { - /// Creates a new batch data loader iterator. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The batch data loader iterator. - pub fn new( - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - ) -> Self { - BatchDataloaderIterator { - current_index: 0, - strategy, - dataset, - batcher, + /// Creates a new batch data loader iterator. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// + /// # Returns + /// + /// The batch data loader iterator. + pub fn new( + strategy: Box>, + dataset: Arc>, + batcher: Arc>, + ) -> Self { + BatchDataloaderIterator { + current_index: 0, + strategy, + dataset, + batcher, + } } - } } impl Iterator for BatchDataloaderIterator { - type Item = O; + type Item = O; - fn next(&mut self) -> Option { - while let Some(item) = self.dataset.get(self.current_index) { - self.current_index += 1; - self.strategy.add(item); + fn next(&mut self) -> Option { + while let Some(item) = self.dataset.get(self.current_index) { + self.current_index += 1; + self.strategy.add(item); - if let Some(items) = self.strategy.batch(false) { - return Some(self.batcher.batch(items)); - } - } + if let Some(items) = self.strategy.batch(false) { + return Some(self.batcher.batch(items)); + } + } - if let Some(items) = self.strategy.batch(true) { - return Some(self.batcher.batch(items)); - } + if let Some(items) = self.strategy.batch(true) { + return Some(self.batcher.batch(items)); + } - None - } + None + } } impl DataLoaderIterator for BatchDataloaderIterator { - fn progress(&self) -> Progress { - Progress { - items_processed: self.current_index, - items_total: self.dataset.len(), + fn progress(&self) -> Progress { + Progress { + items_processed: self.current_index, + items_total: self.dataset.len(), + } } - } } #[cfg(test)] mod tests { - use std::collections::HashSet; - - use super::*; - use crate::data::dataloader::batcher::TestBatcher; - use crate::data::dataloader::FixBatchStrategy; - use crate::data::dataset::FakeDataset; - - #[test] - fn test_batch_dataloader() { - let batcher = Arc::new(TestBatcher::new()); - let dataset = Arc::new(FakeDataset::::new(27)); - let dataloader = BatchDataLoader::new( - Box::new(FixBatchStrategy::new(5)), - dataset.clone(), - batcher, - None, - ); - - let mut items_dataset = HashSet::new(); - let mut items_dataloader = HashSet::new(); - - for item in dataset.iter() { - items_dataset.insert(item); - } - - for items in dataloader.iter() { - for item in items { - items_dataloader.insert(item); - } + use std::collections::HashSet; + + use super::*; + use crate::data::dataloader::batcher::TestBatcher; + use crate::data::dataloader::FixBatchStrategy; + use crate::data::dataset::FakeDataset; + + #[test] + fn test_batch_dataloader() { + let batcher = Arc::new(TestBatcher::new()); + let dataset = Arc::new(FakeDataset::::new(27)); + let dataloader = BatchDataLoader::new( + Box::new(FixBatchStrategy::new(5)), + dataset.clone(), + batcher, + None, + ); + + let mut items_dataset = HashSet::new(); + let mut items_dataloader = HashSet::new(); + + for item in dataset.iter() { + items_dataset.insert(item); + } + + for items in dataloader.iter() { + for item in items { + items_dataloader.insert(item); + } + } + + assert_eq!(items_dataset, items_dataloader); } - assert_eq!(items_dataset, items_dataloader); - } - - #[test] - fn test_multi_thread_batch_dataloader() { - let batcher = Arc::new(TestBatcher::new()); - let dataset = Arc::new(FakeDataset::::new(27)); - let dataloader_single_thread = BatchDataLoader::new( - Box::new(FixBatchStrategy::new(5)), - dataset.clone(), - batcher.clone(), - None, - ); - let dataloader_multi_thread = BatchDataLoader::multi_thread( - Box::new(FixBatchStrategy::new(5)), - dataset, - batcher, - 4, - None, - ); - - let mut items_single_thread = HashSet::new(); - let mut items_multi_thread = HashSet::new(); - - for items in dataloader_single_thread.iter() { - for item in items { - items_single_thread.insert(item); - } + #[test] + fn test_multi_thread_batch_dataloader() { + let batcher = Arc::new(TestBatcher::new()); + let dataset = Arc::new(FakeDataset::::new(27)); + let dataloader_single_thread = BatchDataLoader::new( + Box::new(FixBatchStrategy::new(5)), + dataset.clone(), + batcher.clone(), + None, + ); + let dataloader_multi_thread = BatchDataLoader::multi_thread( + Box::new(FixBatchStrategy::new(5)), + dataset, + batcher, + 4, + None, + ); + + let mut items_single_thread = HashSet::new(); + let mut items_multi_thread = HashSet::new(); + + for items in dataloader_single_thread.iter() { + for item in items { + items_single_thread.insert(item); + } + } + + for items in dataloader_multi_thread.iter() { + for item in items { + items_multi_thread.insert(item); + } + } + + assert_eq!(items_single_thread, items_multi_thread); } - - for items in dataloader_multi_thread.iter() { - for item in items { - items_multi_thread.insert(item); - } - } - - assert_eq!(items_single_thread, items_multi_thread); - } } diff --git a/burn-core/src/data/dataloader/batcher.rs b/burn-core/src/data/dataloader/batcher.rs index 0e52444da1..724a2e3a54 100644 --- a/burn-core/src/data/dataloader/batcher.rs +++ b/burn-core/src/data/dataloader/batcher.rs @@ -1,15 +1,15 @@ /// A trait for batching items of type `I` into items of type `O`. pub trait Batcher: Send + Sync { - /// Batches the given items. - /// - /// # Arguments - /// - /// * `items` - The items to batch. - /// - /// # Returns - /// - /// The batched items. - fn batch(&self, items: Vec) -> O; + /// Batches the given items. + /// + /// # Arguments + /// + /// * `items` - The items to batch. + /// + /// # Returns + /// + /// The batched items. + fn batch(&self, items: Vec) -> O; } #[cfg(test)] @@ -17,7 +17,7 @@ pub trait Batcher: Send + Sync { pub struct TestBatcher; #[cfg(test)] impl Batcher> for TestBatcher { - fn batch(&self, items: Vec) -> Vec { - items - } + fn batch(&self, items: Vec) -> Vec { + items + } } diff --git a/burn-core/src/data/dataloader/builder.rs b/burn-core/src/data/dataloader/builder.rs index 8c6d29154b..d6227ebc49 100644 --- a/burn-core/src/data/dataloader/builder.rs +++ b/burn-core/src/data/dataloader/builder.rs @@ -5,113 +5,113 @@ use std::sync::Arc; /// A builder for data loaders. pub struct DataLoaderBuilder { - strategy: Option>>, - batcher: Arc>, - num_threads: Option, - shuffle: Option, + strategy: Option>>, + batcher: Arc>, + num_threads: Option, + shuffle: Option, } impl DataLoaderBuilder where - I: Send + Sync + Clone + std::fmt::Debug + 'static, - O: Send + Sync + Clone + std::fmt::Debug + 'static, + I: Send + Sync + Clone + std::fmt::Debug + 'static, + O: Send + Sync + Clone + std::fmt::Debug + 'static, { - /// Creates a new data loader builder. - /// - /// # Arguments - /// - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The data loader builder. - pub fn new(batcher: B) -> Self - where - B: Batcher + 'static, - { - Self { - batcher: Arc::new(batcher), - strategy: None, - num_threads: None, - shuffle: None, + /// Creates a new data loader builder. + /// + /// # Arguments + /// + /// * `batcher` - The batcher. + /// + /// # Returns + /// + /// The data loader builder. + pub fn new(batcher: B) -> Self + where + B: Batcher + 'static, + { + Self { + batcher: Arc::new(batcher), + strategy: None, + num_threads: None, + shuffle: None, + } } - } - /// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy) - /// will be used. - /// - /// # Arguments - /// - /// * `batch_size` - The batch size. - /// - /// # Returns - /// - /// The data loader builder. - pub fn batch_size(mut self, batch_size: usize) -> Self { - self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size))); - self - } + /// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy) + /// will be used. + /// + /// # Arguments + /// + /// * `batch_size` - The batch size. + /// + /// # Returns + /// + /// The data loader builder. + pub fn batch_size(mut self, batch_size: usize) -> Self { + self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size))); + self + } - /// Sets the seed for shuffling. - /// - /// Each time the dataloader starts a new iteration, the dataset will be shuffled. - /// - /// # Arguments - /// - /// * `seed` - The seed. - /// - /// # Returns - /// - /// The data loader builder. - pub fn shuffle(mut self, seed: u64) -> Self { - self.shuffle = Some(seed); - self - } + /// Sets the seed for shuffling. + /// + /// Each time the dataloader starts a new iteration, the dataset will be shuffled. + /// + /// # Arguments + /// + /// * `seed` - The seed. + /// + /// # Returns + /// + /// The data loader builder. + pub fn shuffle(mut self, seed: u64) -> Self { + self.shuffle = Some(seed); + self + } - /// Sets the number of workers. - /// - /// # Arguments - /// - /// * `num_workers` - The number of workers. - /// - /// # Returns - /// - /// The data loader builder. - pub fn num_workers(mut self, num_workers: usize) -> Self { - self.num_threads = Some(num_workers); - self - } + /// Sets the number of workers. + /// + /// # Arguments + /// + /// * `num_workers` - The number of workers. + /// + /// # Returns + /// + /// The data loader builder. + pub fn num_workers(mut self, num_workers: usize) -> Self { + self.num_threads = Some(num_workers); + self + } - /// Builds the data loader. - /// - /// # Arguments - /// - /// * `dataset` - The dataset. - /// - /// # Returns - /// - /// The data loader. - pub fn build(self, dataset: D) -> Arc> - where - D: Dataset + 'static, - { - let dataset = Arc::new(dataset); + /// Builds the data loader. + /// + /// # Arguments + /// + /// * `dataset` - The dataset. + /// + /// # Returns + /// + /// The data loader. + pub fn build(self, dataset: D) -> Arc> + where + D: Dataset + 'static, + { + let dataset = Arc::new(dataset); - let rng = self.shuffle.map(StdRng::seed_from_u64); - let strategy = match self.strategy { - Some(strategy) => strategy, - None => Box::new(FixBatchStrategy::new(1)), - }; - if let Some(num_threads) = self.num_threads { - return Arc::new(BatchDataLoader::multi_thread( - strategy, - dataset, - self.batcher, - num_threads, - rng, - )); - } + let rng = self.shuffle.map(StdRng::seed_from_u64); + let strategy = match self.strategy { + Some(strategy) => strategy, + None => Box::new(FixBatchStrategy::new(1)), + }; + if let Some(num_threads) = self.num_threads { + return Arc::new(BatchDataLoader::multi_thread( + strategy, + dataset, + self.batcher, + num_threads, + rng, + )); + } - Arc::new(BatchDataLoader::new(strategy, dataset, self.batcher, rng)) - } + Arc::new(BatchDataLoader::new(strategy, dataset, self.batcher, rng)) + } } diff --git a/burn-core/src/data/dataloader/multithread.rs b/burn-core/src/data/dataloader/multithread.rs index 28bbb5e0b9..00fabf2957 100644 --- a/burn-core/src/data/dataloader/multithread.rs +++ b/burn-core/src/data/dataloader/multithread.rs @@ -7,134 +7,134 @@ const MAX_QUEUED_ITEMS: usize = 100; /// A multi-threaded data loader that can be used to iterate over a dataset. pub struct MultiThreadDataLoader { - dataloaders: Vec + Send + Sync>>, + dataloaders: Vec + Send + Sync>>, } /// A message that can be sent between threads. #[derive(Debug)] pub enum Message { - /// A batch of items. - Batch(usize, O, Progress), + /// A batch of items. + Batch(usize, O, Progress), - /// The thread is done. - Done, + /// The thread is done. + Done, } struct MultiThreadsDataloaderIterator { - num_done: usize, - workers: Vec>, - receiver: mpsc::Receiver>, - progresses: HashMap, + num_done: usize, + workers: Vec>, + receiver: mpsc::Receiver>, + progresses: HashMap, } impl MultiThreadDataLoader { - /// Creates a new multi-threaded data loader. - /// - /// # Arguments - /// - /// * `dataloaders` - The data loaders. - /// - /// # Returns - /// - /// The multi-threaded data loader. - pub fn new(dataloaders: Vec + Send + Sync>>) -> Self { - Self { dataloaders } - } + /// Creates a new multi-threaded data loader. + /// + /// # Arguments + /// + /// * `dataloaders` - The data loaders. + /// + /// # Returns + /// + /// The multi-threaded data loader. + pub fn new(dataloaders: Vec + Send + Sync>>) -> Self { + Self { dataloaders } + } } impl DataLoader for MultiThreadDataLoader where - O: Send + 'static + std::fmt::Debug, + O: Send + 'static + std::fmt::Debug, { - fn iter<'a>(&'a self) -> Box + 'a> { - let (sender, receiver) = mpsc::sync_channel::>(MAX_QUEUED_ITEMS); - - let handlers: Vec<_> = self - .dataloaders - .clone() - .into_iter() - .enumerate() - .map(|(index, dataloader)| { - let dataloader_cloned = dataloader; - let sender_cloned = sender.clone(); - - thread::spawn(move || { - let mut iterator = dataloader_cloned.iter(); - while let Some(item) = iterator.next() { - let progress = iterator.progress(); - - match sender_cloned.send(Message::Batch(index, item, progress)) { - Ok(_) => {} - // The receiver is probably gone, no need to panic, just need to stop - // iterating. - Err(_) => return, - }; - } - // Same thing. - sender_cloned.send(Message::Done).ok(); - }) - }) - .collect(); - - Box::new(MultiThreadsDataloaderIterator::new(receiver, handlers)) - } + fn iter<'a>(&'a self) -> Box + 'a> { + let (sender, receiver) = mpsc::sync_channel::>(MAX_QUEUED_ITEMS); + + let handlers: Vec<_> = self + .dataloaders + .clone() + .into_iter() + .enumerate() + .map(|(index, dataloader)| { + let dataloader_cloned = dataloader; + let sender_cloned = sender.clone(); + + thread::spawn(move || { + let mut iterator = dataloader_cloned.iter(); + while let Some(item) = iterator.next() { + let progress = iterator.progress(); + + match sender_cloned.send(Message::Batch(index, item, progress)) { + Ok(_) => {} + // The receiver is probably gone, no need to panic, just need to stop + // iterating. + Err(_) => return, + }; + } + // Same thing. + sender_cloned.send(Message::Done).ok(); + }) + }) + .collect(); + + Box::new(MultiThreadsDataloaderIterator::new(receiver, handlers)) + } } impl MultiThreadsDataloaderIterator { - pub fn new(receiver: mpsc::Receiver>, workers: Vec>) -> Self { - MultiThreadsDataloaderIterator { - num_done: 0, - workers, - receiver, - progresses: HashMap::new(), + pub fn new(receiver: mpsc::Receiver>, workers: Vec>) -> Self { + MultiThreadsDataloaderIterator { + num_done: 0, + workers, + receiver, + progresses: HashMap::new(), + } } - } } impl DataLoaderIterator for MultiThreadsDataloaderIterator { - fn progress(&self) -> Progress { - let mut items_total = 0; - let mut items_processed = 0; + fn progress(&self) -> Progress { + let mut items_total = 0; + let mut items_processed = 0; - for progress in self.progresses.values() { - items_total += progress.items_total; - items_processed += progress.items_processed; - } + for progress in self.progresses.values() { + items_total += progress.items_total; + items_processed += progress.items_processed; + } - Progress { - items_processed, - items_total, + Progress { + items_processed, + items_total, + } } - } } impl Iterator for MultiThreadsDataloaderIterator { - type Item = O; + type Item = O; - fn next(&mut self) -> Option { - if self.workers.is_empty() { - return None; - } - - loop { - let item = self.receiver.recv(); - let item = item.unwrap(); - - match item { - Message::Batch(index, item, progress) => { - self.progresses.insert(index, progress); - return Some(item); + fn next(&mut self) -> Option { + if self.workers.is_empty() { + return None; } - Message::Done => { - self.num_done += 1; - } - }; - if self.num_done == self.workers.len() { - while let Some(worker) = self.workers.pop() { - worker.join().unwrap(); + loop { + let item = self.receiver.recv(); + let item = item.unwrap(); + + match item { + Message::Batch(index, item, progress) => { + self.progresses.insert(index, progress); + return Some(item); + } + Message::Done => { + self.num_done += 1; + } + }; + + if self.num_done == self.workers.len() { + while let Some(worker) = self.workers.pop() { + worker.join().unwrap(); + } + return None; + } } - return None; - } } - } } diff --git a/burn-core/src/data/dataloader/strategy.rs b/burn-core/src/data/dataloader/strategy.rs index f2302947ba..9e09207edf 100644 --- a/burn-core/src/data/dataloader/strategy.rs +++ b/burn-core/src/data/dataloader/strategy.rs @@ -1,76 +1,76 @@ /// A strategy to batch items. pub trait BatchStrategy: Send + Sync { - /// Adds an item to the strategy. - /// - /// # Arguments - /// - /// * `item` - The item to add. - fn add(&mut self, item: I); + /// Adds an item to the strategy. + /// + /// # Arguments + /// + /// * `item` - The item to add. + fn add(&mut self, item: I); - /// Batches the items. - /// - /// # Arguments - /// - /// * `force` - Whether to force batching. - /// - /// # Returns - /// - /// The batched items. - fn batch(&mut self, force: bool) -> Option>; + /// Batches the items. + /// + /// # Arguments + /// + /// * `force` - Whether to force batching. + /// + /// # Returns + /// + /// The batched items. + fn batch(&mut self, force: bool) -> Option>; - /// Creates a new strategy of the same type. - /// - /// # Returns - /// - /// The new strategy. - fn new_like(&self) -> Box>; + /// Creates a new strategy of the same type. + /// + /// # Returns + /// + /// The new strategy. + fn new_like(&self) -> Box>; } /// A strategy to batch items with a fixed batch size. pub struct FixBatchStrategy { - items: Vec, - batch_size: usize, + items: Vec, + batch_size: usize, } impl FixBatchStrategy { - /// Creates a new strategy to batch items with a fixed batch size. - /// - /// # Arguments - /// - /// * `batch_size` - The batch size. - /// - /// # Returns - /// - /// The strategy. - pub fn new(batch_size: usize) -> Self { - FixBatchStrategy { - items: Vec::with_capacity(batch_size), - batch_size, + /// Creates a new strategy to batch items with a fixed batch size. + /// + /// # Arguments + /// + /// * `batch_size` - The batch size. + /// + /// # Returns + /// + /// The strategy. + pub fn new(batch_size: usize) -> Self { + FixBatchStrategy { + items: Vec::with_capacity(batch_size), + batch_size, + } } - } } impl BatchStrategy for FixBatchStrategy { - fn add(&mut self, item: I) { - self.items.push(item); - } - - fn batch(&mut self, force: bool) -> Option> { - if self.items.len() < self.batch_size && !force { - return None; + fn add(&mut self, item: I) { + self.items.push(item); } - let mut items = Vec::with_capacity(self.batch_size); - std::mem::swap(&mut items, &mut self.items); + fn batch(&mut self, force: bool) -> Option> { + if self.items.len() < self.batch_size && !force { + return None; + } - if items.is_empty() { - return None; - } + let mut items = Vec::with_capacity(self.batch_size); + std::mem::swap(&mut items, &mut self.items); - Some(items) - } + if items.is_empty() { + return None; + } - fn new_like(&self) -> Box> { - Box::new(Self::new(self.batch_size)) - } + Some(items) + } + + fn new_like(&self) -> Box> { + Box::new(Self::new(self.batch_size)) + } } diff --git a/burn-core/src/data/mod.rs b/burn-core/src/data/mod.rs index 6e81a9a0a7..5489cae640 100644 --- a/burn-core/src/data/mod.rs +++ b/burn-core/src/data/mod.rs @@ -3,5 +3,5 @@ pub mod dataloader; /// Dataset module. pub mod dataset { - pub use burn_dataset::*; + pub use burn_dataset::*; } diff --git a/burn-core/src/grad_clipping/base.rs b/burn-core/src/grad_clipping/base.rs index 5f38487841..91a6be069d 100644 --- a/burn-core/src/grad_clipping/base.rs +++ b/burn-core/src/grad_clipping/base.rs @@ -6,138 +6,138 @@ use burn_tensor::backend::Backend; /// Gradient Clipping provides a way to mitigate exploding gradients #[derive(Config)] pub enum GradientClippingConfig { - /// Clip the gradient by value. - Value(f32), + /// Clip the gradient by value. + Value(f32), - /// Clip the gradient by norm. - Norm(f32), + /// Clip the gradient by norm. + Norm(f32), } impl GradientClippingConfig { - /// Initialize the gradient clipping. - /// - /// # Returns - /// - /// The gradient clipping. - pub fn init(&self) -> GradientClipping { - match self { - GradientClippingConfig::Value(val) => GradientClipping::Value(*val), - GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val), + /// Initialize the gradient clipping. + /// + /// # Returns + /// + /// The gradient clipping. + pub fn init(&self) -> GradientClipping { + match self { + GradientClippingConfig::Value(val) => GradientClipping::Value(*val), + GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val), + } } - } } /// Gradient Clipping provides a way to mitigate exploding gradients /// by clipping every component of the gradient by value or by norm during /// backpropagation. pub enum GradientClipping { - /// Clip the gradient by value. - Value(f32), + /// Clip the gradient by value. + Value(f32), - /// Clip the gradient by norm. - Norm(f32), + /// Clip the gradient by norm. + Norm(f32), } impl GradientClipping { - /// Clip the gradient. - /// - /// # Arguments - /// - /// * `grad` - The gradient to clip. - /// - /// # Returns - /// - /// The clipped gradient. - pub fn clip_gradient(&self, grad: Tensor) -> Tensor { - match self { - GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold), - GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm), + /// Clip the gradient. + /// + /// # Arguments + /// + /// * `grad` - The gradient to clip. + /// + /// # Returns + /// + /// The clipped gradient. + pub fn clip_gradient(&self, grad: Tensor) -> Tensor { + match self { + GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold), + GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm), + } } - } - - fn clip_by_value( - &self, - grad: Tensor, - threshold: f32, - ) -> Tensor { - let greater_mask = grad.clone().greater_elem(threshold); - let lower_mask = grad.clone().lower_elem(-threshold); - - let clipped_grad = grad.mask_fill(greater_mask, threshold); - - clipped_grad.mask_fill(lower_mask, -threshold) - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - fn clip_by_norm( - &self, - _grad: Tensor, - _threshold: f32, - ) -> Tensor { - todo!("Not yet supported on wasm"); - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn clip_by_norm( - &self, - grad: Tensor, - threshold: f32, - ) -> Tensor { - use burn_tensor::ElementConversion; - - let norm = Self::l2_norm(grad.clone()); - let norm_float = norm.into_scalar().elem::(); - - if norm_float > threshold { - let scale = threshold / norm_float; - grad.mul_scalar(scale) - } else { - grad + + fn clip_by_value( + &self, + grad: Tensor, + threshold: f32, + ) -> Tensor { + let greater_mask = grad.clone().greater_elem(threshold); + let lower_mask = grad.clone().lower_elem(-threshold); + + let clipped_grad = grad.mask_fill(greater_mask, threshold); + + clipped_grad.mask_fill(lower_mask, -threshold) + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + fn clip_by_norm( + &self, + _grad: Tensor, + _threshold: f32, + ) -> Tensor { + todo!("Not yet supported on wasm"); + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn clip_by_norm( + &self, + grad: Tensor, + threshold: f32, + ) -> Tensor { + use burn_tensor::ElementConversion; + + let norm = Self::l2_norm(grad.clone()); + let norm_float = norm.into_scalar().elem::(); + + if norm_float > threshold { + let scale = threshold / norm_float; + grad.mul_scalar(scale) + } else { + grad + } } - } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn l2_norm(tensor: Tensor) -> Tensor { - let squared = tensor.powf(2.0); - let sum = squared.sum(); + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn l2_norm(tensor: Tensor) -> Tensor { + let squared = tensor.powf(2.0); + let sum = squared.sum(); - sum.sqrt() - } + sum.sqrt() + } } #[cfg(test)] mod tests { - use super::*; - use crate::tensor::Tensor; - use crate::TestBackend; - - #[test] - fn test_clip_by_value() { - let gradient: Tensor = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]); - - let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient); - let clipped_gradient_data = clipped_gradient.into_data(); - - for value in clipped_gradient_data.value { - assert!(value <= 0.5); + use super::*; + use crate::tensor::Tensor; + use crate::TestBackend; + + #[test] + fn test_clip_by_value() { + let gradient: Tensor = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]); + + let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient); + let clipped_gradient_data = clipped_gradient.into_data(); + + for value in clipped_gradient_data.value { + assert!(value <= 0.5); + } } - } - #[test] - fn test_clip_by_norm() { - let gradient: Tensor = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]); + #[test] + fn test_clip_by_norm() { + let gradient: Tensor = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]); - let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient); - let clipped_gradient_data = clipped_gradient.into_data(); + let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient); + let clipped_gradient_data = clipped_gradient.into_data(); - for value in clipped_gradient_data.value { - assert!(value <= 0.88); + for value in clipped_gradient_data.value { + assert!(value <= 0.88); + } } - } } diff --git a/burn-core/src/lr_scheduler/base.rs b/burn-core/src/lr_scheduler/base.rs index 662a590d27..23c56f8a58 100644 --- a/burn-core/src/lr_scheduler/base.rs +++ b/burn-core/src/lr_scheduler/base.rs @@ -2,16 +2,16 @@ use crate::{record::Record, LearningRate}; /// Learning rate scheduler defines how the learning rate will evolve during training. pub trait LrScheduler: Send + Sync { - /// Scheduler associative type to be used when saving and loading the state. - type Record: Record; + /// Scheduler associative type to be used when saving and loading the state. + type Record: Record; - /// Perform the scheduler step, potentially updating its state, and returning the effective - /// learning rate. - fn step(&mut self) -> LearningRate; + /// Perform the scheduler step, potentially updating its state, and returning the effective + /// learning rate. + fn step(&mut self) -> LearningRate; - /// Get the current state of the scheduler as a [record](Record). - fn to_record(&self) -> Self::Record; + /// Get the current state of the scheduler as a [record](Record). + fn to_record(&self) -> Self::Record; - /// Load the state of the scheduler as a [record](Record). - fn load_record(self, record: Self::Record) -> Self; + /// Load the state of the scheduler as a [record](Record). + fn load_record(self, record: Self::Record) -> Self; } diff --git a/burn-core/src/lr_scheduler/constant.rs b/burn-core/src/lr_scheduler/constant.rs index 36820bf33b..eb41f1c108 100644 --- a/burn-core/src/lr_scheduler/constant.rs +++ b/burn-core/src/lr_scheduler/constant.rs @@ -8,39 +8,39 @@ use crate::LearningRate; /// You can also use [learning rate](LearningRate) which the same effect. #[derive(new, Clone, Debug)] pub struct ConstantLr { - lr: LearningRate, + lr: LearningRate, } impl From for ConstantLr { - fn from(lr: LearningRate) -> Self { - Self { lr } - } + fn from(lr: LearningRate) -> Self { + Self { lr } + } } impl LrScheduler for ConstantLr { - type Record = (); + type Record = (); - fn step(&mut self) -> LearningRate { - self.lr - } + fn step(&mut self) -> LearningRate { + self.lr + } - fn to_record(&self) -> Self::Record {} + fn to_record(&self) -> Self::Record {} - fn load_record(self, _record: Self::Record) -> Self { - self - } + fn load_record(self, _record: Self::Record) -> Self { + self + } } impl LrScheduler for LearningRate { - type Record = (); + type Record = (); - fn step(&mut self) -> LearningRate { - *self - } + fn step(&mut self) -> LearningRate { + *self + } - fn to_record(&self) -> Self::Record {} + fn to_record(&self) -> Self::Record {} - fn load_record(self, _record: Self::Record) -> Self { - self - } + fn load_record(self, _record: Self::Record) -> Self { + self + } } diff --git a/burn-core/src/lr_scheduler/noam.rs b/burn-core/src/lr_scheduler/noam.rs index 2cb415c535..622ee5c5b9 100644 --- a/burn-core/src/lr_scheduler/noam.rs +++ b/burn-core/src/lr_scheduler/noam.rs @@ -6,87 +6,87 @@ use crate::{config::Config, LearningRate}; /// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler. #[derive(Config)] pub struct NoamLrSchedulerConfig { - /// The initial learning rate. - init_lr: LearningRate, - /// The number of steps before the exponential decay stats. - #[config(default = 4000)] - warmup_steps: usize, - /// The size of the model. - #[config(default = 512)] - model_size: usize, + /// The initial learning rate. + init_lr: LearningRate, + /// The number of steps before the exponential decay stats. + #[config(default = 4000)] + warmup_steps: usize, + /// The size of the model. + #[config(default = 512)] + model_size: usize, } /// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762). #[derive(Clone, Debug)] pub struct NoamLrScheduler { - warmup_steps: f64, - embedding_size: f64, - init_lr: LearningRate, - step: f64, + warmup_steps: f64, + embedding_size: f64, + init_lr: LearningRate, + step: f64, } impl NoamLrSchedulerConfig { - /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler. - pub fn init(&self) -> NoamLrScheduler { - NoamLrScheduler { - warmup_steps: self.warmup_steps as f64, - embedding_size: self.model_size as f64, - init_lr: self.init_lr, - step: 0.0, + /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler. + pub fn init(&self) -> NoamLrScheduler { + NoamLrScheduler { + warmup_steps: self.warmup_steps as f64, + embedding_size: self.model_size as f64, + init_lr: self.init_lr, + step: 0.0, + } } - } } impl LrScheduler for NoamLrScheduler { - type Record = usize; + type Record = usize; - fn step(&mut self) -> LearningRate { - self.step += 1.0; + fn step(&mut self) -> LearningRate { + self.step += 1.0; - let arg1 = self.step.powf(-0.5); - let arg2 = self.step * self.warmup_steps.powf(-1.5); + let arg1 = self.step.powf(-0.5); + let arg2 = self.step * self.warmup_steps.powf(-1.5); - self.init_lr * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2) - } + self.init_lr * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2) + } - fn to_record(&self) -> Self::Record { - self.step as usize - } + fn to_record(&self) -> Self::Record { + self.step as usize + } - fn load_record(mut self, record: Self::Record) -> Self { - self.step = record as f64; - self - } + fn load_record(mut self, record: Self::Record) -> Self { + self.step = record as f64; + self + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_function_increase_and_decrease() { - let warmup_steps = 100; - let mut scheduler = NoamLrSchedulerConfig::new(10.0) - .with_warmup_steps(warmup_steps) - .init(); - let mut lr_current = 0.0; + #[test] + fn test_function_increase_and_decrease() { + let warmup_steps = 100; + let mut scheduler = NoamLrSchedulerConfig::new(10.0) + .with_warmup_steps(warmup_steps) + .init(); + let mut lr_current = 0.0; - for _ in 0..warmup_steps { - let lr = scheduler.step(); - assert!( - lr > lr_current, - "Learning rate should increase before the warmup_steps is reached." - ); - lr_current = lr; - } + for _ in 0..warmup_steps { + let lr = scheduler.step(); + assert!( + lr > lr_current, + "Learning rate should increase before the warmup_steps is reached." + ); + lr_current = lr; + } - for _ in 0..warmup_steps { - let lr = scheduler.step(); - assert!( - lr < lr_current, - "Learning rate should decrease after the warmup_steps is reached." - ); - lr_current = lr; + for _ in 0..warmup_steps { + let lr = scheduler.step(); + assert!( + lr < lr_current, + "Learning rate should decrease after the warmup_steps is reached." + ); + lr_current = lr; + } } - } } diff --git a/burn-core/src/module/base.rs b/burn-core/src/module/base.rs index f97f31692d..a54f00624d 100644 --- a/burn-core/src/module/base.rs +++ b/burn-core/src/module/base.rs @@ -2,8 +2,8 @@ use alloc::vec::Vec; use super::ParamId; use crate::{ - record::Record, - tensor::backend::{AutodiffBackend, Backend}, + record::Record, + tensor::backend::{AutodiffBackend, Backend}, }; pub use burn_derive::Module; use burn_tensor::Tensor; @@ -11,54 +11,54 @@ use burn_tensor::Tensor; // At the moment, our plan is to continue experimenting with the macro internally and monitor its development. // We may consider making it public in the future. macro_rules! module { - (map=$module:ident, ops=$item:expr) => {{ - struct Mapper; - impl ModuleMapper for Mapper { - fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { - let func = $item; - func(tensor) - } - } - let mut mapper = Mapper; - $module.map(&mut mapper) - }}; - (map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{ - struct Mapper<'a, B: Backend> { - capture: &'a $ty, - backend: core::marker::PhantomData, - } - impl<'a, B: Backend> ModuleMapper for Mapper<'a, B> { - fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { - let func = $item; - func(tensor, self.capture) - } - } - let mut mapper = Mapper { - capture: $capture, - backend: core::marker::PhantomData, - }; - $module.map(&mut mapper) - }}; - (visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ - struct Visitor<'a, B: Backend> { - state: &'a mut $state_ty, - backend: core::marker::PhantomData, - } - impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { - fn visit(&mut self, _id: &ParamId, tensor: &Tensor) { - let func = $item; - func(tensor, &mut self.state) - } - } - #[allow(clippy::redundant_closure_call)] - let mut state = $init(); - let mut visitor = Visitor { - state: &mut state, - backend: core::marker::PhantomData, - }; - $module.visit(&mut visitor); - state - }}; + (map=$module:ident, ops=$item:expr) => {{ + struct Mapper; + impl ModuleMapper for Mapper { + fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + let func = $item; + func(tensor) + } + } + let mut mapper = Mapper; + $module.map(&mut mapper) + }}; + (map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{ + struct Mapper<'a, B: Backend> { + capture: &'a $ty, + backend: core::marker::PhantomData, + } + impl<'a, B: Backend> ModuleMapper for Mapper<'a, B> { + fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + let func = $item; + func(tensor, self.capture) + } + } + let mut mapper = Mapper { + capture: $capture, + backend: core::marker::PhantomData, + }; + $module.map(&mut mapper) + }}; + (visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ + struct Visitor<'a, B: Backend> { + state: &'a mut $state_ty, + backend: core::marker::PhantomData, + } + impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { + fn visit(&mut self, _id: &ParamId, tensor: &Tensor) { + let func = $item; + func(tensor, &mut self.state) + } + } + #[allow(clippy::redundant_closure_call)] + let mut state = $init(); + let mut visitor = Visitor { + state: &mut state, + backend: core::marker::PhantomData, + }; + $module.visit(&mut visitor); + state + }}; } /// Trait for all neural network modules. @@ -91,163 +91,163 @@ macro_rules! module { /// } /// ``` pub trait Module: Clone + Send + Sync + core::fmt::Debug { - /// Type to save and load the module. - type Record: Record; - - /// Get the device list of the module and all of its sub-modules. - fn devices(&self) -> Vec { - module!( - visit = self, - ops = |tensor: &Tensor, state: &mut Vec| { - let device = tensor.device(); - if !state.contains(&device) { - state.push(device); - } - }, - state = Vec, - init = Vec::new - ) - } - - /// Fork the module and all of its sub-modules to the given device. - /// - /// # Notes - /// - /// This is similar to [to_device](Module::to_device), but it ensures the module will - /// have its own autodiff graph. - fn fork(self, device: &B::Device) -> Self { - module!( - map = self, - ops = |tensor: Tensor, device: &B::Device| { - let is_require_grad = tensor.is_require_grad(); - let mut tensor = tensor.to_device(device).detach(); - - if is_require_grad { - tensor = tensor.require_grad(); - } + /// Type to save and load the module. + type Record: Record; + + /// Get the device list of the module and all of its sub-modules. + fn devices(&self) -> Vec { + module!( + visit = self, + ops = |tensor: &Tensor, state: &mut Vec| { + let device = tensor.device(); + if !state.contains(&device) { + state.push(device); + } + }, + state = Vec, + init = Vec::new + ) + } - tensor - }, - capture = { device: B::Device } - ) - } - - /// Move the module and all of its sub-modules to the given device. - /// - /// # Warnings - /// - /// The device operations will be registered in the autodiff graph. Therefore, be sure to call - /// backward only one time even if you have the same module on multiple devices. If you want to - /// call backward multiple times, look into using [fork](Module::fork) instead. - fn to_device(self, device: &B::Device) -> Self { - module!( - map = self, - ops = |tensor: Tensor, device: &B::Device| tensor.to_device(device), - capture = { device: B::Device } - ) - } - - /// Each tensor in the module tree will not require grad. - /// - /// # Warnings - /// - /// This should not be used for inference, use [valid](AutodiffModule::valid) when using - /// AD modules. This is mostly useful when performing partial finetuning, which is updating only - /// a small fraction of the parameters instead of finetuning all of them. - fn no_grad(self) -> Self { - module!( - map = self, - ops = |tensor: Tensor| tensor.set_require_grad(false) - ) - } - - /// Get the number of parameters the module has, including all of its sub-modules. - fn num_params(&self) -> usize { - module!( - visit = self, - ops = |tensor: &Tensor, state: &mut usize| { - *state += tensor.shape().num_elements(); - }, - state = usize, - init = || 0 - ) - } - /// Visit each tensor in the module with a [visitor](ModuleVisitor). - fn visit>(&self, visitor: &mut V); - - /// Map each tensor in the module with a [mapper](ModuleMapper). - fn map>(self, mapper: &mut M) -> Self; - - /// Load the module state from a record. - fn load_record(self, record: Self::Record) -> Self; - - /// Convert the module into a record containing the state. - fn into_record(self) -> Self::Record; - - #[cfg(feature = "std")] - /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder). - /// - /// List of supported file recorders: - /// - /// * [default](crate::record::DefaultFileRecorder) - /// * [bincode](crate::record::BinFileRecorder) - /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder) - /// * [json pretty](crate::record::PrettyJsonFileRecorder) - /// * [json compressed with gzip](crate::record::JsonGzFileRecorder) - /// * [named mpk](crate::record::NamedMpkFileRecorder) - /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder) - /// - /// ## Notes - /// - /// The file extension is automatically added depending on the file recorder provided, you - /// don't have to specify it. - fn save_file>( - self, - file_path: PB, - recorder: &FR, - ) -> Result<(), crate::record::RecorderError> { - let record = Self::into_record(self); - recorder.record(record, file_path.into()) - } - - #[cfg(feature = "std")] - /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder). - /// - /// The recorder should be the same as the one used to save the module, see - /// [save_file](Self::save_file). - /// - /// ## Notes - /// - /// The file extension is automatically added depending on the file recorder provided, you - /// don't have to specify it. - fn load_file>( - self, - file_path: PB, - recorder: &FR, - ) -> Result { - let record = recorder.load(file_path.into())?; - - Ok(self.load_record(record)) - } + /// Fork the module and all of its sub-modules to the given device. + /// + /// # Notes + /// + /// This is similar to [to_device](Module::to_device), but it ensures the module will + /// have its own autodiff graph. + fn fork(self, device: &B::Device) -> Self { + module!( + map = self, + ops = |tensor: Tensor, device: &B::Device| { + let is_require_grad = tensor.is_require_grad(); + let mut tensor = tensor.to_device(device).detach(); + + if is_require_grad { + tensor = tensor.require_grad(); + } + + tensor + }, + capture = { device: B::Device } + ) + } + + /// Move the module and all of its sub-modules to the given device. + /// + /// # Warnings + /// + /// The device operations will be registered in the autodiff graph. Therefore, be sure to call + /// backward only one time even if you have the same module on multiple devices. If you want to + /// call backward multiple times, look into using [fork](Module::fork) instead. + fn to_device(self, device: &B::Device) -> Self { + module!( + map = self, + ops = |tensor: Tensor, device: &B::Device| tensor.to_device(device), + capture = { device: B::Device } + ) + } + + /// Each tensor in the module tree will not require grad. + /// + /// # Warnings + /// + /// This should not be used for inference, use [valid](AutodiffModule::valid) when using + /// AD modules. This is mostly useful when performing partial finetuning, which is updating only + /// a small fraction of the parameters instead of finetuning all of them. + fn no_grad(self) -> Self { + module!( + map = self, + ops = |tensor: Tensor| tensor.set_require_grad(false) + ) + } + + /// Get the number of parameters the module has, including all of its sub-modules. + fn num_params(&self) -> usize { + module!( + visit = self, + ops = |tensor: &Tensor, state: &mut usize| { + *state += tensor.shape().num_elements(); + }, + state = usize, + init = || 0 + ) + } + /// Visit each tensor in the module with a [visitor](ModuleVisitor). + fn visit>(&self, visitor: &mut V); + + /// Map each tensor in the module with a [mapper](ModuleMapper). + fn map>(self, mapper: &mut M) -> Self; + + /// Load the module state from a record. + fn load_record(self, record: Self::Record) -> Self; + + /// Convert the module into a record containing the state. + fn into_record(self) -> Self::Record; + + #[cfg(feature = "std")] + /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder). + /// + /// List of supported file recorders: + /// + /// * [default](crate::record::DefaultFileRecorder) + /// * [bincode](crate::record::BinFileRecorder) + /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder) + /// * [json pretty](crate::record::PrettyJsonFileRecorder) + /// * [json compressed with gzip](crate::record::JsonGzFileRecorder) + /// * [named mpk](crate::record::NamedMpkFileRecorder) + /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder) + /// + /// ## Notes + /// + /// The file extension is automatically added depending on the file recorder provided, you + /// don't have to specify it. + fn save_file>( + self, + file_path: PB, + recorder: &FR, + ) -> Result<(), crate::record::RecorderError> { + let record = Self::into_record(self); + recorder.record(record, file_path.into()) + } + + #[cfg(feature = "std")] + /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder). + /// + /// The recorder should be the same as the one used to save the module, see + /// [save_file](Self::save_file). + /// + /// ## Notes + /// + /// The file extension is automatically added depending on the file recorder provided, you + /// don't have to specify it. + fn load_file>( + self, + file_path: PB, + recorder: &FR, + ) -> Result { + let record = recorder.load(file_path.into())?; + + Ok(self.load_record(record)) + } } /// Module visitor trait. pub trait ModuleVisitor { - /// Visit a tensor in the module. - fn visit(&mut self, id: &ParamId, tensor: &Tensor); + /// Visit a tensor in the module. + fn visit(&mut self, id: &ParamId, tensor: &Tensor); } /// Module mapper trait. pub trait ModuleMapper { - /// Map a tensor in the module. - fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor; + /// Map a tensor in the module. + fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor; } /// Module with auto-differentiation backend. pub trait AutodiffModule: Module + Send + Sync + core::fmt::Debug { - /// Inner module without auto-differentiation. - type InnerModule: Module; + /// Inner module without auto-differentiation. + type InnerModule: Module; - /// Get the same module, but on the inner backend without auto-differentiation. - fn valid(&self) -> Self::InnerModule; + /// Get the same module, but on the inner backend without auto-differentiation. + fn valid(&self) -> Self::InnerModule; } diff --git a/burn-core/src/module/param/base.rs b/burn-core/src/module/param/base.rs index 0e76a6ed50..72174cba70 100644 --- a/burn-core/src/module/param/base.rs +++ b/burn-core/src/module/param/base.rs @@ -4,37 +4,37 @@ use alloc::format; /// Define a parameter. #[derive(new, Debug, Clone)] pub struct Param { - pub(crate) id: ParamId, - pub(crate) value: T, + pub(crate) id: ParamId, + pub(crate) value: T, } impl core::fmt::Display for Param { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("Param: {}", self.id).as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(format!("Param: {}", self.id).as_str()) + } } impl Param { - /// Gets the parameter value. - /// - /// # Returns - /// - /// The parameter value. - pub fn val(&self) -> T { - self.value.clone() - } + /// Gets the parameter value. + /// + /// # Returns + /// + /// The parameter value. + pub fn val(&self) -> T { + self.value.clone() + } - /// Execute the given function on the inner value. - pub fn map T>(mut self, func: F) -> Self { - self.value = func(self.value); - self - } + /// Execute the given function on the inner value. + pub fn map T>(mut self, func: F) -> Self { + self.value = func(self.value); + self + } } impl core::ops::Deref for Param { - type Target = T; + type Target = T; - fn deref(&self) -> &Self::Target { - &self.value - } + fn deref(&self) -> &Self::Target { + &self.value + } } diff --git a/burn-core/src/module/param/constant.rs b/burn-core/src/module/param/constant.rs index 67f6e2bb52..9c33d1409a 100644 --- a/burn-core/src/module/param/constant.rs +++ b/burn-core/src/module/param/constant.rs @@ -1,14 +1,14 @@ use core::marker::PhantomData; use crate::{ - self as burn, - module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}, - record::Record, + self as burn, + module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}, + record::Record, }; use burn::record::PrecisionSettings; use burn_tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, + backend::{AutodiffBackend, Backend}, + Tensor, }; use super::ParamId; @@ -18,76 +18,76 @@ use super::ParamId; pub struct ConstantRecord; impl serde::Serialize for ConstantRecord { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - // nothing to serialize - S::serialize_none(serializer) - } + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + // nothing to serialize + S::serialize_none(serializer) + } } impl<'de> serde::Deserialize<'de> for ConstantRecord { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_option(serde::de::IgnoredAny).ok(); - Ok(ConstantRecord::new()) - } + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_option(serde::de::IgnoredAny).ok(); + Ok(ConstantRecord::new()) + } } impl Record for ConstantRecord { - type Item = ConstantRecord; + type Item = ConstantRecord; - fn into_item(self) -> Self::Item { - self - } + fn into_item(self) -> Self::Item { + self + } - fn from_item(item: Self::Item) -> Self { - item - } + fn from_item(item: Self::Item) -> Self { + item + } } /// Constant macro. #[macro_export] macro_rules! constant { - (module) => { - type Record = burn::module::ConstantRecord; - - fn visit>(&self, _visitor: &mut V) { - // Nothing to do - } - - fn map>(self, _mapper: &mut M) -> Self { - self - } - - fn load_record(self, _record: Self::Record) -> Self { - self - } - - fn into_record(self) -> Self::Record { - burn::module::ConstantRecord::new() - } - }; - - (ad_module, $type:ty) => { - type InnerModule = $type; - - fn valid(&self) -> Self::InnerModule { - self.clone() - } - }; - - ($type:ty) => { - impl burn::module::Module for $type { - constant!(module); - } - - impl burn::module::AutodiffModule for $type { - constant!(ad_module, $type); - } - }; + (module) => { + type Record = burn::module::ConstantRecord; + + fn visit>(&self, _visitor: &mut V) { + // Nothing to do + } + + fn map>(self, _mapper: &mut M) -> Self { + self + } + + fn load_record(self, _record: Self::Record) -> Self { + self + } + + fn into_record(self) -> Self::Record { + burn::module::ConstantRecord::new() + } + }; + + (ad_module, $type:ty) => { + type InnerModule = $type; + + fn valid(&self) -> Self::InnerModule { + self.clone() + } + }; + + ($type:ty) => { + impl burn::module::Module for $type { + constant!(module); + } + + impl burn::module::AutodiffModule for $type { + constant!(ad_module, $type); + } + }; } // General Types @@ -114,121 +114,121 @@ constant!(i16); constant!(i8); impl Module for Tensor { - type Record = ConstantRecord; - - fn visit>(&self, visitor: &mut V) { - // Important: - // We need to implement visit method for Tensor Module because - // to_device will be called during the visit method of the ModuleVisitor - - // We are using a dummy param id because the visit method requires a param id - let dummy_param_id = ParamId::new(); - visitor.visit(&dummy_param_id, self) - } - - fn map>(self, mapper: &mut M) -> Self { - // Important: - // We need to implement visit method for Tensor Module because - // to_device will be called during the visit method of the ModuleVisitor - - // We are using a dummy param id because the visit method requires a param id - let dummy_param_id = ParamId::new(); - mapper.map(&dummy_param_id, self) - } - - fn into_record(self) -> Self::Record { - ConstantRecord - } - - fn load_record(self, _record: Self::Record) -> Self { - self - } + type Record = ConstantRecord; + + fn visit>(&self, visitor: &mut V) { + // Important: + // We need to implement visit method for Tensor Module because + // to_device will be called during the visit method of the ModuleVisitor + + // We are using a dummy param id because the visit method requires a param id + let dummy_param_id = ParamId::new(); + visitor.visit(&dummy_param_id, self) + } + + fn map>(self, mapper: &mut M) -> Self { + // Important: + // We need to implement visit method for Tensor Module because + // to_device will be called during the visit method of the ModuleVisitor + + // We are using a dummy param id because the visit method requires a param id + let dummy_param_id = ParamId::new(); + mapper.map(&dummy_param_id, self) + } + + fn into_record(self) -> Self::Record { + ConstantRecord + } + + fn load_record(self, _record: Self::Record) -> Self { + self + } } impl AutodiffModule for Tensor { - type InnerModule = Tensor; + type InnerModule = Tensor; - fn valid(&self) -> Self::InnerModule { - self.clone().inner() - } + fn valid(&self) -> Self::InnerModule { + self.clone().inner() + } } impl Module for PhantomData { - type Record = ConstantRecord; + type Record = ConstantRecord; - fn visit>(&self, _visitor: &mut V) { - // Nothing to do - } + fn visit>(&self, _visitor: &mut V) { + // Nothing to do + } - fn map>(self, _mapper: &mut M) -> Self { - self - } + fn map>(self, _mapper: &mut M) -> Self { + self + } - fn load_record(self, _record: Self::Record) -> Self { - self - } + fn load_record(self, _record: Self::Record) -> Self { + self + } - fn into_record(self) -> Self::Record { - ConstantRecord::new() - } + fn into_record(self) -> Self::Record { + ConstantRecord::new() + } } impl AutodiffModule for PhantomData { - type InnerModule = PhantomData; + type InnerModule = PhantomData; - fn valid(&self) -> Self::InnerModule { - PhantomData - } + fn valid(&self) -> Self::InnerModule { + PhantomData + } } #[cfg(all(test, feature = "std"))] mod tests { - use core::marker::PhantomData; - - use burn_tensor::backend::Backend; - use burn_tensor::Tensor; + use core::marker::PhantomData; - use crate::TestBackend; - use crate::{ - record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, - TestAutodiffBackend, - }; - use burn::module::Module; + use burn_tensor::backend::Backend; + use burn_tensor::Tensor; - use crate as burn; + use crate::TestBackend; + use crate::{ + record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, + TestAutodiffBackend, + }; + use burn::module::Module; - #[test] - fn tensor_load_record_setting() { - let tensor = Tensor::::ones([3, 3]); + use crate as burn; - let byte_recorder = BinBytesRecorder::::default(); - let bytes = byte_recorder - .record(tensor.clone().into_record(), ()) - .unwrap(); + #[test] + fn tensor_load_record_setting() { + let tensor = Tensor::::ones([3, 3]); - let no_grad_is_require_grad = tensor - .clone() - .no_grad() - .load_record(byte_recorder.load(bytes.clone()).unwrap()) - .is_require_grad(); + let byte_recorder = BinBytesRecorder::::default(); + let bytes = byte_recorder + .record(tensor.clone().into_record(), ()) + .unwrap(); - let with_default_is_require_grad = tensor - .load_record(byte_recorder.load(bytes).unwrap()) - .is_require_grad(); + let no_grad_is_require_grad = tensor + .clone() + .no_grad() + .load_record(byte_recorder.load(bytes.clone()).unwrap()) + .is_require_grad(); - assert!(!no_grad_is_require_grad); - assert!(!with_default_is_require_grad); - } + let with_default_is_require_grad = tensor + .load_record(byte_recorder.load(bytes).unwrap()) + .is_require_grad(); - #[test] - fn empty_module_with_phantom() { - #[derive(Module, Debug, new)] - struct EmptyModule { - _phantom: PhantomData, + assert!(!no_grad_is_require_grad); + assert!(!with_default_is_require_grad); } - let _module = EmptyModule::::new(); + #[test] + fn empty_module_with_phantom() { + #[derive(Module, Debug, new)] + struct EmptyModule { + _phantom: PhantomData, + } - assert_eq!(core::mem::size_of::>(), 0); - } + let _module = EmptyModule::::new(); + + assert_eq!(core::mem::size_of::>(), 0); + } } diff --git a/burn-core/src/module/param/id.rs b/burn-core/src/module/param/id.rs index 8ef7607013..2828cf38c7 100644 --- a/burn-core/src/module/param/id.rs +++ b/burn-core/src/module/param/id.rs @@ -4,45 +4,45 @@ use burn_common::id::IdGenerator; /// Parameter ID. #[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct ParamId { - value: String, + value: String, } impl From<&str> for ParamId { - fn from(val: &str) -> Self { - Self { - value: val.to_string(), + fn from(val: &str) -> Self { + Self { + value: val.to_string(), + } } - } } impl From for ParamId { - fn from(value: String) -> Self { - Self { value } - } + fn from(value: String) -> Self { + Self { value } + } } impl Default for ParamId { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl ParamId { - /// Create a new parameter ID. - pub fn new() -> Self { - Self { - value: IdGenerator::generate(), + /// Create a new parameter ID. + pub fn new() -> Self { + Self { + value: IdGenerator::generate(), + } } - } - /// Convert the parameter ID into a string. - pub fn into_string(self) -> String { - self.value - } + /// Convert the parameter ID into a string. + pub fn into_string(self) -> String { + self.value + } } impl core::fmt::Display for ParamId { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(self.value.as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.value.as_str()) + } } diff --git a/burn-core/src/module/param/primitive.rs b/burn-core/src/module/param/primitive.rs index e306d17441..dfa23e2d85 100644 --- a/burn-core/src/module/param/primitive.rs +++ b/burn-core/src/module/param/primitive.rs @@ -5,156 +5,153 @@ use core::fmt::Debug; impl Module for Option where - T: Module + Debug + Send + Sync + Clone, - B: Backend, + T: Module + Debug + Send + Sync + Clone, + B: Backend, { - type Record = Option; + type Record = Option; - fn visit>(&self, visitor: &mut V) { - if let Some(module) = self { - module.visit(visitor) + fn visit>(&self, visitor: &mut V) { + if let Some(module) = self { + module.visit(visitor) + } } - } - fn map>(self, mapper: &mut M) -> Self { - self.map(|module| module.map(mapper)) - } + fn map>(self, mapper: &mut M) -> Self { + self.map(|module| module.map(mapper)) + } - fn load_record(self, record: Self::Record) -> Self { - self - .zip(record) - .map(|(module, record)| module.load_record(record)) - } + fn load_record(self, record: Self::Record) -> Self { + self.zip(record) + .map(|(module, record)| module.load_record(record)) + } - fn into_record(self) -> Self::Record { - self.map(Module::into_record) - } + fn into_record(self) -> Self::Record { + self.map(Module::into_record) + } } impl AutodiffModule for Option where - T: AutodiffModule + Debug + Send + Sync + Clone, - B: AutodiffBackend, + T: AutodiffModule + Debug + Send + Sync + Clone, + B: AutodiffBackend, { - type InnerModule = Option; + type InnerModule = Option; - fn valid(&self) -> Self::InnerModule { - self.as_ref().map(|module| module.valid()) - } + fn valid(&self) -> Self::InnerModule { + self.as_ref().map(|module| module.valid()) + } } impl Module for Vec where - T: Module + Debug + Send + Sync + Clone, - B: Backend, + T: Module + Debug + Send + Sync + Clone, + B: Backend, { - type Record = Vec; - - fn num_params(&self) -> usize { - let mut num_params = 0; - for module in self.iter() { - num_params += module.num_params(); - } - - num_params - } - - fn visit>(&self, visitor: &mut V) { - self.iter().for_each(|module| { - module.visit(visitor); - }); - } - - fn map>(self, mapper: &mut M) -> Self { - self.into_iter().map(|module| module.map(mapper)).collect() - } - - fn into_record(self) -> Self::Record { - self.into_iter().map(Module::into_record).collect() - } - - fn load_record(self, record: Self::Record) -> Self { - self - .into_iter() - .zip(record) - .map(|(module, record)| module.load_record(record)) - .collect() - } + type Record = Vec; + + fn num_params(&self) -> usize { + let mut num_params = 0; + for module in self.iter() { + num_params += module.num_params(); + } + + num_params + } + + fn visit>(&self, visitor: &mut V) { + self.iter().for_each(|module| { + module.visit(visitor); + }); + } + + fn map>(self, mapper: &mut M) -> Self { + self.into_iter().map(|module| module.map(mapper)).collect() + } + + fn into_record(self) -> Self::Record { + self.into_iter().map(Module::into_record).collect() + } + + fn load_record(self, record: Self::Record) -> Self { + self.into_iter() + .zip(record) + .map(|(module, record)| module.load_record(record)) + .collect() + } } impl AutodiffModule for Vec where - T: AutodiffModule + Debug + Send + Sync + Clone, - B: AutodiffBackend, + T: AutodiffModule + Debug + Send + Sync + Clone, + B: AutodiffBackend, { - type InnerModule = Vec; + type InnerModule = Vec; - fn valid(&self) -> Self::InnerModule { - self.iter().map(|module| module.valid()).collect() - } + fn valid(&self) -> Self::InnerModule { + self.iter().map(|module| module.valid()).collect() + } } impl Module for [T; N] where - T: Module + Debug + Send + Sync + Clone + Copy, - T::Record: Debug, - B: Backend, + T: Module + Debug + Send + Sync + Clone + Copy, + T::Record: Debug, + B: Backend, { - type Record = [T::Record; N]; - - fn devices(&self) -> Vec<::Device> { - let mut devices = Vec::new(); - for module in self.iter() { - devices.append(&mut module.devices()); - } - devices - } - - fn num_params(&self) -> usize { - let mut num_params = 0; - for module in self.iter() { - num_params += module.num_params(); - } - - num_params - } - - fn visit>(&self, visitor: &mut V) { - self.iter().for_each(|module| { - module.visit(visitor); - }); - } - - fn map>(self, mapper: &mut M) -> Self { - self.map(|module| module.map(mapper)) - } - - fn load_record(self, record: Self::Record) -> Self { - self - .into_iter() - .zip(record) - .map(|(module, record)| module.load_record(record)) - .collect::>() - .try_into() - .unwrap() - } - - fn into_record(self) -> Self::Record { - self.map(Module::into_record) - } + type Record = [T::Record; N]; + + fn devices(&self) -> Vec<::Device> { + let mut devices = Vec::new(); + for module in self.iter() { + devices.append(&mut module.devices()); + } + devices + } + + fn num_params(&self) -> usize { + let mut num_params = 0; + for module in self.iter() { + num_params += module.num_params(); + } + + num_params + } + + fn visit>(&self, visitor: &mut V) { + self.iter().for_each(|module| { + module.visit(visitor); + }); + } + + fn map>(self, mapper: &mut M) -> Self { + self.map(|module| module.map(mapper)) + } + + fn load_record(self, record: Self::Record) -> Self { + self.into_iter() + .zip(record) + .map(|(module, record)| module.load_record(record)) + .collect::>() + .try_into() + .unwrap() + } + + fn into_record(self) -> Self::Record { + self.map(Module::into_record) + } } impl AutodiffModule for [T; N] where - T: AutodiffModule + Debug + Send + Sync + Clone + Copy, - T::InnerModule: Copy + Debug, - >::Record: Debug, - >::Record: Debug, - B: AutodiffBackend, + T: AutodiffModule + Debug + Send + Sync + Clone + Copy, + T::InnerModule: Copy + Debug, + >::Record: Debug, + >::Record: Debug, + B: AutodiffBackend, { - type InnerModule = [T::InnerModule; N]; + type InnerModule = [T::InnerModule; N]; - fn valid(&self) -> Self::InnerModule { - self.map(|module| module.valid()) - } + fn valid(&self) -> Self::InnerModule { + self.map(|module| module.valid()) + } } diff --git a/burn-core/src/module/param/running.rs b/burn-core/src/module/param/running.rs index 952adc2180..ab17f181f0 100644 --- a/burn-core/src/module/param/running.rs +++ b/burn-core/src/module/param/running.rs @@ -3,31 +3,31 @@ use alloc::sync::Arc; use super::ParamId; use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param}; use burn_tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, + backend::{AutodiffBackend, Backend}, + Tensor, }; #[cfg(feature = "std")] mod threading { - pub(super) use std::collections::HashMap; - pub(super) use std::sync::{Mutex, RwLock}; - pub(super) use std::thread::ThreadId; - - #[inline(always)] - pub(super) fn get_thread_current_id() -> ThreadId { - std::thread::current().id() - } + pub(super) use std::collections::HashMap; + pub(super) use std::sync::{Mutex, RwLock}; + pub(super) use std::thread::ThreadId; + + #[inline(always)] + pub(super) fn get_thread_current_id() -> ThreadId { + std::thread::current().id() + } } #[cfg(not(feature = "std"))] mod threading { - pub(super) use burn_common::stub::{Mutex, RwLock, ThreadId}; - pub(super) use hashbrown::HashMap; + pub(super) use burn_common::stub::{Mutex, RwLock, ThreadId}; + pub(super) use hashbrown::HashMap; - #[inline(always)] - pub(super) fn get_thread_current_id() -> ThreadId { - panic!("Current thread id is not available") - } + #[inline(always)] + pub(super) fn get_thread_current_id() -> ThreadId { + panic!("Current thread id is not available") + } } // Re-export items from the disabled/enabled blocks @@ -40,152 +40,152 @@ use threading::*; /// The state value is the average of all updates on all threads. #[derive(Clone, Debug)] pub struct RunningState { - id: ParamId, - values: Arc>>, - value: Arc>, + id: ParamId, + values: Arc>>, + value: Arc>, } impl Module for RunningState> { - type Record = Param>; + type Record = Param>; - fn visit>(&self, visitor: &mut V) { - let tensor = self.value.read().unwrap(); + fn visit>(&self, visitor: &mut V) { + let tensor = self.value.read().unwrap(); - visitor.visit(&self.id, &tensor) - } + visitor.visit(&self.id, &tensor) + } - fn map>(self, mapper: &mut M) -> Self { - let mut tensor = self.value.write().unwrap(); - let tensor_out = mapper.map(&self.id, tensor.clone()); + fn map>(self, mapper: &mut M) -> Self { + let mut tensor = self.value.write().unwrap(); + let tensor_out = mapper.map(&self.id, tensor.clone()); - *tensor = tensor_out; - core::mem::drop(tensor); + *tensor = tensor_out; + core::mem::drop(tensor); - self - } + self + } - fn into_record(self) -> Self::Record { - self.sync(); - let tensor = self.value.read().unwrap(); + fn into_record(self) -> Self::Record { + self.sync(); + let tensor = self.value.read().unwrap(); - Param::new(self.id, tensor.clone()) - } + Param::new(self.id, tensor.clone()) + } - fn load_record(mut self, record: Self::Record) -> Self { - let mut tensor = self.value.write().unwrap(); - *tensor = record.value.to_device(&tensor.device()); - self.id = record.id; + fn load_record(mut self, record: Self::Record) -> Self { + let mut tensor = self.value.write().unwrap(); + *tensor = record.value.to_device(&tensor.device()); + self.id = record.id; - core::mem::drop(tensor); + core::mem::drop(tensor); - self - } + self + } } impl RunningState> { - /// Create a new running state. - pub fn new(value: Tensor) -> Self { - Self { - id: ParamId::new(), - values: Arc::new(Mutex::new(HashMap::new())), - value: Arc::new(RwLock::new(value)), + /// Create a new running state. + pub fn new(value: Tensor) -> Self { + Self { + id: ParamId::new(), + values: Arc::new(Mutex::new(HashMap::new())), + value: Arc::new(RwLock::new(value)), + } } - } - - /// Create a new running state. - pub fn with_id(id: ParamId, value: Tensor) -> Self { - Self { - id, - values: Arc::new(Mutex::new(HashMap::new())), - value: Arc::new(RwLock::new(value)), + + /// Create a new running state. + pub fn with_id(id: ParamId, value: Tensor) -> Self { + Self { + id, + values: Arc::new(Mutex::new(HashMap::new())), + value: Arc::new(RwLock::new(value)), + } } - } - - /// Create a new running state from a record. - pub fn from_record(record: Param>) -> Self { - Self { - id: record.id, - values: Arc::new(Mutex::new(HashMap::new())), - value: Arc::new(RwLock::new(record.value)), + + /// Create a new running state from a record. + pub fn from_record(record: Param>) -> Self { + Self { + id: record.id, + values: Arc::new(Mutex::new(HashMap::new())), + value: Arc::new(RwLock::new(record.value)), + } } - } - /// Update the value on the current thread. - pub fn update(&self, value: Tensor) { - let thread_id = get_thread_current_id(); - let mut map = self.values.lock().unwrap(); + /// Update the value on the current thread. + pub fn update(&self, value: Tensor) { + let thread_id = get_thread_current_id(); + let mut map = self.values.lock().unwrap(); + + if map.contains_key(&thread_id) { + self.update_value(&mut map); + } - if map.contains_key(&thread_id) { - self.update_value(&mut map); + map.insert(thread_id, value); } - map.insert(thread_id, value); - } - - /// Get the current value, - /// - /// # Note - /// - /// The current value might be outdated by one update. - pub fn value(&self) -> Tensor { - let value = self.value.read().unwrap(); - value.clone() - } - - /// Get the current value and make sure it is sync. - /// - /// # Note - /// - /// Don't use this function after an update on the same thread where other threads might have to - /// register their update before the actual synchronization needs to happen. - pub fn value_sync(&self) -> Tensor { - let thread_id = get_thread_current_id(); - let mut map = self.values.lock().unwrap(); - - if map.contains_key(&thread_id) { - self.update_value(&mut map); + /// Get the current value, + /// + /// # Note + /// + /// The current value might be outdated by one update. + pub fn value(&self) -> Tensor { + let value = self.value.read().unwrap(); + value.clone() } - let value = self.value.read().unwrap(); - value.clone() - } + /// Get the current value and make sure it is sync. + /// + /// # Note + /// + /// Don't use this function after an update on the same thread where other threads might have to + /// register their update before the actual synchronization needs to happen. + pub fn value_sync(&self) -> Tensor { + let thread_id = get_thread_current_id(); + let mut map = self.values.lock().unwrap(); + + if map.contains_key(&thread_id) { + self.update_value(&mut map); + } + + let value = self.value.read().unwrap(); + value.clone() + } - fn sync(&self) { - let mut map = self.values.lock().unwrap(); + fn sync(&self) { + let mut map = self.values.lock().unwrap(); - if !map.is_empty() { - self.update_value(&mut map); + if !map.is_empty() { + self.update_value(&mut map); + } } - } - fn update_value(&self, map: &mut HashMap>) { - let mut value_updated = None; - let mut counter = 0; + fn update_value(&self, map: &mut HashMap>) { + let mut value_updated = None; + let mut counter = 0; - for (_key, tensor) in map.drain() { - counter += 1; + for (_key, tensor) in map.drain() { + counter += 1; - value_updated = match value_updated { - Some(current) => Some(tensor.add(current)), - None => Some(tensor), - }; - } + value_updated = match value_updated { + Some(current) => Some(tensor.add(current)), + None => Some(tensor), + }; + } - if let Some(value) = value_updated { - let value = value.div_scalar(counter); - let mut value_old = self.value.write().unwrap(); - *value_old = value; + if let Some(value) = value_updated { + let value = value.div_scalar(counter); + let mut value_old = self.value.write().unwrap(); + *value_old = value; + } } - } } impl AutodiffModule for RunningState> { - type InnerModule = RunningState>; + type InnerModule = RunningState>; - fn valid(&self) -> Self::InnerModule { - self.sync(); - let value = self.value(); + fn valid(&self) -> Self::InnerModule { + self.sync(); + let value = self.value(); - RunningState::with_id(self.id.clone(), value.inner()) - } + RunningState::with_id(self.id.clone(), value.inner()) + } } diff --git a/burn-core/src/module/param/tensor.rs b/burn-core/src/module/param/tensor.rs index c9936e0190..3bbe519dc4 100644 --- a/burn-core/src/module/param/tensor.rs +++ b/burn-core/src/module/param/tensor.rs @@ -1,104 +1,104 @@ use super::{Param, ParamId}; use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}; use crate::tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, + backend::{AutodiffBackend, Backend}, + Tensor, }; impl From> for Param> { - fn from(value: Tensor) -> Self { - Param::new(ParamId::new(), value.require_grad()) - } + fn from(value: Tensor) -> Self { + Param::new(ParamId::new(), value.require_grad()) + } } impl Module for Param> { - type Record = Param>; + type Record = Param>; - fn visit>(&self, visitor: &mut V) { - visitor.visit(&self.id, &self.value) - } + fn visit>(&self, visitor: &mut V) { + visitor.visit(&self.id, &self.value) + } - fn map>(self, mapper: &mut M) -> Self { - let value = mapper.map(&self.id, self.value); - Self::new(self.id, value) - } + fn map>(self, mapper: &mut M) -> Self { + let value = mapper.map(&self.id, self.value); + Self::new(self.id, value) + } - fn into_record(self) -> Self::Record { - self - } + fn into_record(self) -> Self::Record { + self + } - fn load_record(self, record: Self::Record) -> Self { - let mut tensor = record.value.detach(); - let device = self.device(); + fn load_record(self, record: Self::Record) -> Self { + let mut tensor = record.value.detach(); + let device = self.device(); - // Make sure we load the record into the same module device. - if tensor.device() != device { - tensor = tensor.to_device(&device).detach(); - } + // Make sure we load the record into the same module device. + if tensor.device() != device { + tensor = tensor.to_device(&device).detach(); + } - // Make sure we load the record with the same autodiff setting. - tensor = tensor.set_require_grad(self.is_require_grad()); + // Make sure we load the record with the same autodiff setting. + tensor = tensor.set_require_grad(self.is_require_grad()); - Self::new(record.id, tensor) - } + Self::new(record.id, tensor) + } } impl AutodiffModule for Param> { - type InnerModule = Param>; - - fn valid(&self) -> Self::InnerModule { - Param::new( - self.id.clone(), - self.value.clone().inner().set_require_grad(false), - ) - } + type InnerModule = Param>; + + fn valid(&self) -> Self::InnerModule { + Param::new( + self.id.clone(), + self.value.clone().inner().set_require_grad(false), + ) + } } #[cfg(all(test, feature = "std"))] mod tests { - use super::*; - use crate::{ - module::Module, - nn::LinearConfig, - record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, - TestAutodiffBackend, - }; - - #[test] - fn test_load_record_setting() { - let tensor = Tensor::::ones([3, 3]); - - let byte_recorder = BinBytesRecorder::::default(); - let bytes = byte_recorder - .record(Param::from(tensor.clone()).into_record(), ()) - .unwrap(); - - let no_grad_is_require_grad = Param::from(tensor.clone()) - .no_grad() - .load_record(byte_recorder.load(bytes.clone()).unwrap()) - .value - .is_require_grad(); - - let with_default_is_require_grad = Param::from(tensor) - .load_record(byte_recorder.load(bytes).unwrap()) - .value - .is_require_grad(); - - assert!(!no_grad_is_require_grad); - assert!(with_default_is_require_grad); - } - - #[test] - fn test_init_with_record_setting() { - let config = LinearConfig::new(32, 32); - let module_init = config.init::(); - - let record = module_init.clone().into_record(); - let module_init_with = config.init_with::(record); - - assert_eq!( - module_init.weight.is_require_grad(), - module_init_with.weight.is_require_grad() - ); - } + use super::*; + use crate::{ + module::Module, + nn::LinearConfig, + record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, + TestAutodiffBackend, + }; + + #[test] + fn test_load_record_setting() { + let tensor = Tensor::::ones([3, 3]); + + let byte_recorder = BinBytesRecorder::::default(); + let bytes = byte_recorder + .record(Param::from(tensor.clone()).into_record(), ()) + .unwrap(); + + let no_grad_is_require_grad = Param::from(tensor.clone()) + .no_grad() + .load_record(byte_recorder.load(bytes.clone()).unwrap()) + .value + .is_require_grad(); + + let with_default_is_require_grad = Param::from(tensor) + .load_record(byte_recorder.load(bytes).unwrap()) + .value + .is_require_grad(); + + assert!(!no_grad_is_require_grad); + assert!(with_default_is_require_grad); + } + + #[test] + fn test_init_with_record_setting() { + let config = LinearConfig::new(32, 32); + let module_init = config.init::(); + + let record = module_init.clone().into_record(); + let module_init_with = config.init_with::(record); + + assert_eq!( + module_init.weight.is_require_grad(), + module_init_with.weight.is_require_grad() + ); + } } diff --git a/burn-core/src/module/param/visitor.rs b/burn-core/src/module/param/visitor.rs index 95bffe173c..9e27e3b6d3 100644 --- a/burn-core/src/module/param/visitor.rs +++ b/burn-core/src/module/param/visitor.rs @@ -5,28 +5,28 @@ use burn_tensor::{backend::Backend, Tensor}; use core::marker::PhantomData; struct ParamIdCollector<'a, M> { - param_ids: &'a mut Vec, - phantom: PhantomData, + param_ids: &'a mut Vec, + phantom: PhantomData, } impl<'a, B, M> ModuleVisitor for ParamIdCollector<'a, M> where - B: Backend, - M: Module, + B: Backend, + M: Module, { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { - self.param_ids.push(id.clone()); - } + fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + self.param_ids.push(id.clone()); + } } /// List all the parameter ids in a module. pub fn list_param_ids, B: Backend>(module: &M) -> Vec { - let mut params_ids = Vec::new(); - let mut visitor = ParamIdCollector { - param_ids: &mut params_ids, - phantom: PhantomData::, - }; - module.visit(&mut visitor); + let mut params_ids = Vec::new(); + let mut visitor = ParamIdCollector { + param_ids: &mut params_ids, + phantom: PhantomData::, + }; + module.visit(&mut visitor); - params_ids + params_ids } diff --git a/burn-core/src/nn/attention/mask.rs b/burn-core/src/nn/attention/mask.rs index fb7c6f3f91..b59f3a4f0a 100644 --- a/burn-core/src/nn/attention/mask.rs +++ b/burn-core/src/nn/attention/mask.rs @@ -6,136 +6,136 @@ use burn_tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Shape, T /// /// The mask can be used in Transformer modules to train models to generate tensors sequentially. pub fn generate_autoregressive_mask( - batch_size: usize, - seq_length: usize, - device: &B::Device, + batch_size: usize, + seq_length: usize, + device: &B::Device, ) -> Tensor { - let mut mask = Tensor::::zeros([1, seq_length, seq_length]); + let mut mask = Tensor::::zeros([1, seq_length, seq_length]); - for i in 0..(seq_length - 1) { - let values = Tensor::::ones([1, 1, seq_length - (i + 1)]); - mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values); - } + for i in 0..(seq_length - 1) { + let values = Tensor::::ones([1, 1, seq_length - (i + 1)]); + mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values); + } - mask = mask.to_device(device).repeat(0, batch_size); + mask = mask.to_device(device).repeat(0, batch_size); - mask.equal_elem(1_i64.elem::()) + mask.equal_elem(1_i64.elem::()) } /// Generate a padding attention mask. pub struct GeneratePaddingMask { - /// The generated tensor. - pub tensor: Tensor, + /// The generated tensor. + pub tensor: Tensor, - /// The generated mask. - pub mask: Tensor, + /// The generated mask. + pub mask: Tensor, } /// Generation padding attention mask. pub fn generate_padding_mask( - pad_token: usize, - tokens_list: Vec>, - max_seq_length: Option, - device: &B::Device, + pad_token: usize, + tokens_list: Vec>, + max_seq_length: Option, + device: &B::Device, ) -> GeneratePaddingMask { - let mut max_size = 0; - let batch_size = tokens_list.len(); - - for tokens in tokens_list.iter() { - if tokens.len() > max_size { - max_size = tokens.len(); + let mut max_size = 0; + let batch_size = tokens_list.len(); + + for tokens in tokens_list.iter() { + if tokens.len() > max_size { + max_size = tokens.len(); + } + + if let Some(max_seq_length) = max_seq_length { + if tokens.len() >= max_seq_length { + max_size = max_seq_length; + break; + } + } } - if let Some(max_seq_length) = max_seq_length { - if tokens.len() >= max_seq_length { - max_size = max_seq_length; - break; - } + let mut tensor = Tensor::zeros([batch_size, max_size]); + tensor = tensor.add_scalar(pad_token as i64); + + for (index, tokens) in tokens_list.into_iter().enumerate() { + let mut seq_length = tokens.len(); + let mut tokens = tokens; + + if let Some(max_seq_length) = max_seq_length { + if seq_length > max_seq_length { + seq_length = max_seq_length; + let _ = tokens.split_off(seq_length); + } + } + + tensor = tensor.slice_assign( + [index..index + 1, 0..tokens.len()], + Tensor::from_data(Data::new( + tokens.into_iter().map(|e| (e as i64).elem()).collect(), + Shape::new([1, seq_length]), + )), + ); } - } - - let mut tensor = Tensor::zeros([batch_size, max_size]); - tensor = tensor.add_scalar(pad_token as i64); - - for (index, tokens) in tokens_list.into_iter().enumerate() { - let mut seq_length = tokens.len(); - let mut tokens = tokens; - if let Some(max_seq_length) = max_seq_length { - if seq_length > max_seq_length { - seq_length = max_seq_length; - let _ = tokens.split_off(seq_length); - } - } + let mask = tensor + .clone() + .equal_elem(pad_token as i64) + .to_device(device); + let tensor = tensor.to_device(device); - tensor = tensor.slice_assign( - [index..index + 1, 0..tokens.len()], - Tensor::from_data(Data::new( - tokens.into_iter().map(|e| (e as i64).elem()).collect(), - Shape::new([1, seq_length]), - )), - ); - } - - let mask = tensor - .clone() - .equal_elem(pad_token as i64) - .to_device(device); - let tensor = tensor.to_device(device); - - GeneratePaddingMask { tensor, mask } + GeneratePaddingMask { tensor, mask } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use alloc::vec; - use burn_tensor::Data; - - #[test] - fn test_generate_autoregressive_mask() { - let device = ::Device::default(); - - let mask = generate_autoregressive_mask::(2, 3, &device); - - assert_eq!( - mask.into_data(), - Data::from([ - [ - [false, true, true], - [false, false, true], - [false, false, false], - ], - [ - [false, true, true], - [false, false, true], - [false, false, false], - ] - ]) - ); - } - - #[test] - fn test_generate_padding_mask() { - let device = ::Device::default(); - let tokens = vec![ - vec![3, 3, 3], - vec![3, 3, 3], - vec![3, 3, 3, 4], - vec![3, 3, 3, 4, 10, 15], - ]; - - let mask = generate_padding_mask::(0, tokens, None, &device); - - assert_eq!( - mask.mask.into_data(), - Data::from([ - [false, false, false, true, true, true], - [false, false, false, true, true, true], - [false, false, false, false, true, true], - [false, false, false, false, false, false], - ]) - ); - } + use super::*; + use crate::TestBackend; + use alloc::vec; + use burn_tensor::Data; + + #[test] + fn test_generate_autoregressive_mask() { + let device = ::Device::default(); + + let mask = generate_autoregressive_mask::(2, 3, &device); + + assert_eq!( + mask.into_data(), + Data::from([ + [ + [false, true, true], + [false, false, true], + [false, false, false], + ], + [ + [false, true, true], + [false, false, true], + [false, false, false], + ] + ]) + ); + } + + #[test] + fn test_generate_padding_mask() { + let device = ::Device::default(); + let tokens = vec![ + vec![3, 3, 3], + vec![3, 3, 3], + vec![3, 3, 3, 4], + vec![3, 3, 3, 4, 10, 15], + ]; + + let mask = generate_padding_mask::(0, tokens, None, &device); + + assert_eq!( + mask.mask.into_data(), + Data::from([ + [false, false, false, true, true, true], + [false, false, false, true, true, true], + [false, false, false, false, true, true], + [false, false, false, false, false, false], + ]) + ); + } } diff --git a/burn-core/src/nn/attention/mha.rs b/burn-core/src/nn/attention/mha.rs index b46e793e89..16ed02508e 100644 --- a/burn-core/src/nn/attention/mha.rs +++ b/burn-core/src/nn/attention/mha.rs @@ -3,39 +3,41 @@ use crate as burn; use crate::nn::cache::TensorCache; use crate::nn::Initializer; use crate::{ - config::Config, - module::Module, - nn, - tensor::{activation, backend::Backend, Bool, Tensor}, + config::Config, + module::Module, + nn, + tensor::{activation, backend::Backend, Bool, Tensor}, }; use libm::sqrtf; /// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer. #[derive(Config)] pub struct MultiHeadAttentionConfig { - /// The size of the each linear layer. - d_model: usize, - /// The number of heads. - n_heads: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - dropout: f64, - /// The minimum value a float can take. Default: -1.0e4 - /// This is used to mask attention scores before calculating attention weights. - /// A value too low might result in NaN. - #[config(default = -1.0e4)] - min_float: f64, - /// Use "quiet softmax" instead of regular softmax. - /// - /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). - /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. - /// - /// Reference: - #[config(default = false)] - quiet_softmax: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] - pub initializer: Initializer, + /// The size of the each linear layer. + d_model: usize, + /// The number of heads. + n_heads: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + dropout: f64, + /// The minimum value a float can take. Default: -1.0e4 + /// This is used to mask attention scores before calculating attention weights. + /// A value too low might result in NaN. + #[config(default = -1.0e4)] + min_float: f64, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + quiet_softmax: bool, + /// The type of function used to initialize neural network parameters + #[config( + default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" + )] + pub initializer: Initializer, } /// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -48,415 +50,419 @@ pub struct MultiHeadAttentionConfig { /// - output: [Linear](nn::Linear) layer with `d_model` input and output features. #[derive(Module, Debug)] pub struct MultiHeadAttention { - query: nn::Linear, - key: nn::Linear, - value: nn::Linear, - output: nn::Linear, - dropout: nn::Dropout, - activation: nn::GELU, - n_heads: usize, - d_k: usize, - min_float: f64, - quiet_softmax: bool, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + output: nn::Linear, + dropout: nn::Dropout, + activation: nn::GELU, + n_heads: usize, + d_k: usize, + min_float: f64, + quiet_softmax: bool, } /// [Multihead attention](MultiHeadAttention) forward pass input argument. #[derive(Debug, Clone)] pub struct MhaInput { - query: Tensor, - key: Tensor, - value: Tensor, - mask_pad: Option>, - mask_attn: Option>, + query: Tensor, + key: Tensor, + value: Tensor, + mask_pad: Option>, + mask_attn: Option>, } impl MultiHeadAttentionConfig { - /// Initialize a new [multihead attention](MultiHeadAttention) module. - pub fn init(&self) -> MultiHeadAttention { - let linear = |config: &Self| { - nn::LinearConfig::new(config.d_model, config.d_model) - .with_initializer(self.initializer.clone()) - .init() - }; - - MultiHeadAttention { - query: linear(self), - key: linear(self), - value: linear(self), - output: linear(self), - dropout: nn::DropoutConfig::new(self.dropout).init(), - activation: nn::GELU::new(), - n_heads: self.n_heads, - d_k: self.d_model / self.n_heads, - min_float: self.min_float, - quiet_softmax: self.quiet_softmax, + /// Initialize a new [multihead attention](MultiHeadAttention) module. + pub fn init(&self) -> MultiHeadAttention { + let linear = |config: &Self| { + nn::LinearConfig::new(config.d_model, config.d_model) + .with_initializer(self.initializer.clone()) + .init() + }; + + MultiHeadAttention { + query: linear(self), + key: linear(self), + value: linear(self), + output: linear(self), + dropout: nn::DropoutConfig::new(self.dropout).init(), + activation: nn::GELU::new(), + n_heads: self.n_heads, + d_k: self.d_model / self.n_heads, + min_float: self.min_float, + quiet_softmax: self.quiet_softmax, + } } - } - - /// Initialize a new [multihead attention](MultiHeadAttention) module with a - /// [record](MultiHeadAttentionRecord). - pub fn init_with( - &self, - record: MultiHeadAttentionRecord, - ) -> MultiHeadAttention { - let linear = |config: &Self, record| { - nn::LinearConfig::new(config.d_model, config.d_model).init_with(record) - }; - - MultiHeadAttention { - query: linear(self, record.query), - key: linear(self, record.key), - value: linear(self, record.value), - output: linear(self, record.output), - dropout: nn::DropoutConfig::new(self.dropout).init(), - activation: nn::GELU::new(), - n_heads: self.n_heads, - d_k: self.d_model / self.n_heads, - min_float: self.min_float, - quiet_softmax: self.quiet_softmax, + + /// Initialize a new [multihead attention](MultiHeadAttention) module with a + /// [record](MultiHeadAttentionRecord). + pub fn init_with( + &self, + record: MultiHeadAttentionRecord, + ) -> MultiHeadAttention { + let linear = |config: &Self, record| { + nn::LinearConfig::new(config.d_model, config.d_model).init_with(record) + }; + + MultiHeadAttention { + query: linear(self, record.query), + key: linear(self, record.key), + value: linear(self, record.value), + output: linear(self, record.output), + dropout: nn::DropoutConfig::new(self.dropout).init(), + activation: nn::GELU::new(), + n_heads: self.n_heads, + d_k: self.d_model / self.n_heads, + min_float: self.min_float, + quiet_softmax: self.quiet_softmax, + } } - } } impl MhaInput { - /// Create a [multihead attention](MultiHeadAttention) input argument - /// by setting the query, key and value to the given tensor. - pub fn self_attn(tensor: Tensor) -> Self { - Self { - query: tensor.clone(), - key: tensor.clone(), - value: tensor, - mask_pad: None, - mask_attn: None, + /// Create a [multihead attention](MultiHeadAttention) input argument + /// by setting the query, key and value to the given tensor. + pub fn self_attn(tensor: Tensor) -> Self { + Self { + query: tensor.clone(), + key: tensor.clone(), + value: tensor, + mask_pad: None, + mask_attn: None, + } + } + + /// Create a [multihead attention](MultiHeadAttention) input argument. + pub fn new(query: Tensor, key: Tensor, value: Tensor) -> Self { + Self { + query, + key, + value, + mask_pad: None, + mask_attn: None, + } } - } - - /// Create a [multihead attention](MultiHeadAttention) input argument. - pub fn new(query: Tensor, key: Tensor, value: Tensor) -> Self { - Self { - query, - key, - value, - mask_pad: None, - mask_attn: None, + + /// Register the padding mask. + pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { + self.mask_pad = Some(mask_pad); + self + } + + /// Register the attention mask. + pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { + self.mask_attn = Some(mask_attn); + self } - } - - /// Register the padding mask. - pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { - self.mask_pad = Some(mask_pad); - self - } - - /// Register the attention mask. - pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { - self.mask_attn = Some(mask_attn); - self - } } /// [Multihead attention](MultiHeadAttention) outputs. #[derive(Debug, Clone)] pub struct MhaOutput { - /// The attention weights [batch_size, seq_length_1, seq_length_2]. - pub weights: Tensor, - /// The context tensor [batch_size, seq_length_1, d_model]. - pub context: Tensor, + /// The attention weights [batch_size, seq_length_1, seq_length_2]. + pub weights: Tensor, + /// The context tensor [batch_size, seq_length_1, d_model]. + pub context: Tensor, } impl MultiHeadAttention { - /// Applies the forward pass on the input tensors. - /// - /// # Shapes - /// - /// - query: `[batch_size, seq_length_1, d_model]` - /// - key: `[batch_size, seq_length_2, d_model]` - /// - value: `[batch_size, seq_length_2, d_model]` - /// - output: `[batch_size, seq_length_1, d_model]` - pub fn forward(&self, input: MhaInput) -> MhaOutput { - let [batch_size, seq_length_1, d_model] = input.query.dims(); - - let query = self.attention_linear(input.query, &self.query); - let key = self.attention_linear(input.key, &self.key); - let value = self.attention_linear(input.value, &self.value); - - let attn_scores = self.attn_scores(query, key); - let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); - - let context = weights.clone().matmul(value); - let context = context - .swap_dims(1, 2) - .reshape([batch_size, seq_length_1, d_model]); - let context = self.output.forward(context); - - MhaOutput { weights, context } - } - - /// Applies the forward pass using a cache. - /// - /// # Shapes - /// - /// - query: `[batch_size, seq_length_1, d_model]` - /// - key: `[batch_size, seq_length_2, d_model]` - /// - value: `[batch_size, seq_length_2, d_model]` - /// - output: `[batch_size, seq_length_1, d_model]` - pub fn forward_cache(&self, input: MhaInput, cache: &mut MhaCache) -> MhaOutput { - let [batch_size, seq_length_1, d_model] = input.query.dims(); - - let query = cache - .query - .forward(input.query, |t| self.attention_linear(t, &self.query)); - let key = cache - .key - .forward(input.key, |t| self.attention_linear(t, &self.key)); - let value = cache - .value - .forward(input.value, |t| self.attention_linear(t, &self.value)); - - let attn_scores = self.attn_scores(query, key); - let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); - - let context = weights.clone().matmul(value); - let context = context - .swap_dims(1, 2) - .reshape([batch_size, seq_length_1, d_model]); - - let context = cache.output.forward(context, |t| self.output.forward(t)); - - MhaOutput { weights, context } - } - - fn attn_scores(&self, query: Tensor, key: Tensor) -> Tensor { - let attn_scores = query - .matmul(key.transpose()) - .div_scalar(sqrtf(self.d_k as f32)); - - self.dropout.forward(attn_scores) - } - - fn attn_weights( - &self, - mut attn_scores: Tensor, - mask_pad: Option>, - mask_attn: Option>, - ) -> Tensor { - if let Some(mask_pad) = mask_pad { - let [batch_size, seq_length] = mask_pad.dims(); - - attn_scores = attn_scores.mask_fill( - mask_pad.reshape([batch_size, 1, 1, seq_length]), - self.min_float, - ); + /// Applies the forward pass on the input tensors. + /// + /// # Shapes + /// + /// - query: `[batch_size, seq_length_1, d_model]` + /// - key: `[batch_size, seq_length_2, d_model]` + /// - value: `[batch_size, seq_length_2, d_model]` + /// - output: `[batch_size, seq_length_1, d_model]` + pub fn forward(&self, input: MhaInput) -> MhaOutput { + let [batch_size, seq_length_1, d_model] = input.query.dims(); + + let query = self.attention_linear(input.query, &self.query); + let key = self.attention_linear(input.key, &self.key); + let value = self.attention_linear(input.value, &self.value); + + let attn_scores = self.attn_scores(query, key); + let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); + + let context = weights.clone().matmul(value); + let context = context + .swap_dims(1, 2) + .reshape([batch_size, seq_length_1, d_model]); + let context = self.output.forward(context); + + MhaOutput { weights, context } } - if let Some(mask_attn) = mask_attn { - let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims(); + /// Applies the forward pass using a cache. + /// + /// # Shapes + /// + /// - query: `[batch_size, seq_length_1, d_model]` + /// - key: `[batch_size, seq_length_2, d_model]` + /// - value: `[batch_size, seq_length_2, d_model]` + /// - output: `[batch_size, seq_length_1, d_model]` + pub fn forward_cache(&self, input: MhaInput, cache: &mut MhaCache) -> MhaOutput { + let [batch_size, seq_length_1, d_model] = input.query.dims(); + + let query = cache + .query + .forward(input.query, |t| self.attention_linear(t, &self.query)); + let key = cache + .key + .forward(input.key, |t| self.attention_linear(t, &self.key)); + let value = cache + .value + .forward(input.value, |t| self.attention_linear(t, &self.value)); + + let attn_scores = self.attn_scores(query, key); + let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); + + let context = weights.clone().matmul(value); + let context = context + .swap_dims(1, 2) + .reshape([batch_size, seq_length_1, d_model]); + + let context = cache.output.forward(context, |t| self.output.forward(t)); + + MhaOutput { weights, context } + } + + fn attn_scores(&self, query: Tensor, key: Tensor) -> Tensor { + let attn_scores = query + .matmul(key.transpose()) + .div_scalar(sqrtf(self.d_k as f32)); - attn_scores = attn_scores.mask_fill( - mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]), - self.min_float, - ); + self.dropout.forward(attn_scores) } - if self.quiet_softmax { - activation::quiet_softmax(attn_scores, 3) - } else { - activation::softmax(attn_scores, 3) + fn attn_weights( + &self, + mut attn_scores: Tensor, + mask_pad: Option>, + mask_attn: Option>, + ) -> Tensor { + if let Some(mask_pad) = mask_pad { + let [batch_size, seq_length] = mask_pad.dims(); + + attn_scores = attn_scores.mask_fill( + mask_pad.reshape([batch_size, 1, 1, seq_length]), + self.min_float, + ); + } + + if let Some(mask_attn) = mask_attn { + let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims(); + + attn_scores = attn_scores.mask_fill( + mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]), + self.min_float, + ); + } + + if self.quiet_softmax { + activation::quiet_softmax(attn_scores, 3) + } else { + activation::softmax(attn_scores, 3) + } + } + + fn attention_linear(&self, x: Tensor, linear: &nn::Linear) -> Tensor { + let [batch_size, seq_length, _d_model] = x.dims(); + linear + .forward(x) + .reshape([batch_size, seq_length, self.n_heads, self.d_k]) + .swap_dims(1, 2) } - } - - fn attention_linear(&self, x: Tensor, linear: &nn::Linear) -> Tensor { - let [batch_size, seq_length, _d_model] = x.dims(); - linear - .forward(x) - .reshape([batch_size, seq_length, self.n_heads, self.d_k]) - .swap_dims(1, 2) - } } /// Cache for the [Multi Head Attention](MultiHeadAttention) layer. /// /// To be used during inference when decoding tokens. pub struct MhaCache { - query: MhaLinearCache, - key: MhaLinearCache, - value: MhaLinearCache, - output: MhaLinearCache, + query: MhaLinearCache, + key: MhaLinearCache, + value: MhaLinearCache, + output: MhaLinearCache, } enum MhaLinearCache { - Autoregressive(TensorCache, usize), - Full(TensorCache), + Autoregressive(TensorCache, usize), + Full(TensorCache), } impl MhaCache { - /// Initialize a cache for autoregressive inference. - pub fn autoregressive() -> Self { - Self { - query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), + /// Initialize a cache for autoregressive inference. + pub fn autoregressive() -> Self { + Self { + query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), + } } - } - - /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and - /// values (cross-attention). - pub fn autoregressive_cross_attention() -> Self { - Self { - query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), - key: MhaLinearCache::Full(TensorCache::empty()), - value: MhaLinearCache::Full(TensorCache::empty()), - output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), + + /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and + /// values (cross-attention). + pub fn autoregressive_cross_attention() -> Self { + Self { + query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), + key: MhaLinearCache::Full(TensorCache::empty()), + value: MhaLinearCache::Full(TensorCache::empty()), + output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), + } } - } } impl MhaLinearCache { - pub fn forward) -> Tensor>( - &mut self, - tensor: Tensor, - func: F, - ) -> Tensor { - match self { - MhaLinearCache::Autoregressive(cache, dim) => { - cache.forward_autoregressive(tensor, *dim, func) - } - MhaLinearCache::Full(cache) => cache.forward_full(tensor, func), + pub fn forward) -> Tensor>( + &mut self, + tensor: Tensor, + func: F, + ) -> Tensor { + match self { + MhaLinearCache::Autoregressive(cache, dim) => { + cache.forward_autoregressive(tensor, *dim, func) + } + MhaLinearCache::Full(cache) => cache.forward_full(tensor, func), + } } - } } #[cfg(test)] mod tests { - use super::*; - use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; - use alloc::vec::Vec; - use burn::tensor::{Distribution, Shape}; - use burn_tensor::Int; - - #[test] - fn test_self_attention_shapes() { - let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - let input = MhaInput::self_attn(Tensor::random( - [batch_size, seq_length, d_model], - Distribution::Default, - )); - - let output = mha.forward(input); - - assert_eq!( - output.context.shape(), - Shape::new([batch_size, seq_length, d_model]), - "Context should have the correct shape", - ); - assert_eq!( - output.weights.shape(), - Shape::new([batch_size, n_heads, seq_length, seq_length]), - "Weights should have the correct shape", - ); - } - - #[test] - fn test_generic_mha_shapes() { - let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - let input = MhaInput::new( - Tensor::random([batch_size, seq_length_1, d_model], Distribution::Default), - Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), - Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), - ); - - let output = mha.forward(input); - - assert_eq!( - output.context.shape(), - Shape::new([batch_size, seq_length_1, d_model]), - "Context should have the correct shape", - ); - assert_eq!( - output.weights.shape(), - Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]), - "Weights should have the correct shape", - ); - } - - #[test] - fn test_self_attention_mask_pad() { - let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - - // Create a padding mask - let mask_pad: Tensor = Tensor::zeros([batch_size, seq_length]); - let mask_pad = mask_pad.slice_assign( - [0..batch_size, seq_length - num_padded..seq_length], - Tensor::ones([batch_size, num_padded]), - ); - let mask_pad = mask_pad.equal_elem(1); - - let tensor_1 = - Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); - // Change the end of the tensor - let tensor_2 = tensor_1.clone().slice_assign( - [ - 0..batch_size, - seq_length - num_padded..seq_length, - 0..d_model, - ], - Tensor::random([batch_size, num_padded, d_model], Distribution::Default), - ); - - let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone()); - let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad); - - let output_1 = mha.forward(input_1); - let output_2 = mha.forward(input_2); - - // Check that the beginning of each tensor is the same - output_1 - .context - .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) - .into_data() - .assert_approx_eq( - &output_2 - .context - .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) - .into_data(), - 3, - ); - } - - #[test] - fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() { - let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2]; - let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); - - let tensor = - Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); - let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); - let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn); - - let output_1 = mha.forward(input); - let mut output_2 = Vec::new(); - let mut cache = MhaCache::autoregressive(); - - for i in 1..seq_length + 1 { - let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); - let input = MhaInput::self_attn(tensor); - let next_tok = - mha - .forward_cache(input, &mut cache) - .context - .slice([0..batch_size, i - 1..i, 0..d_model]); - output_2.push(next_tok); + use super::*; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use alloc::vec::Vec; + use burn::tensor::{Distribution, Shape}; + use burn_tensor::Int; + + #[test] + fn test_self_attention_shapes() { + let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + let input = MhaInput::self_attn(Tensor::random( + [batch_size, seq_length, d_model], + Distribution::Default, + )); + + let output = mha.forward(input); + + assert_eq!( + output.context.shape(), + Shape::new([batch_size, seq_length, d_model]), + "Context should have the correct shape", + ); + assert_eq!( + output.weights.shape(), + Shape::new([batch_size, n_heads, seq_length, seq_length]), + "Weights should have the correct shape", + ); } - let output_2 = Tensor::cat(output_2, 1); + #[test] + fn test_generic_mha_shapes() { + let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + let input = MhaInput::new( + Tensor::random([batch_size, seq_length_1, d_model], Distribution::Default), + Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), + Tensor::random([batch_size, seq_length_2, d_model], Distribution::Default), + ); + + let output = mha.forward(input); + + assert_eq!( + output.context.shape(), + Shape::new([batch_size, seq_length_1, d_model]), + "Context should have the correct shape", + ); + assert_eq!( + output.weights.shape(), + Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]), + "Weights should have the correct shape", + ); + } + + #[test] + fn test_self_attention_mask_pad() { + let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + + // Create a padding mask + let mask_pad: Tensor = Tensor::zeros([batch_size, seq_length]); + let mask_pad = mask_pad.slice_assign( + [0..batch_size, seq_length - num_padded..seq_length], + Tensor::ones([batch_size, num_padded]), + ); + let mask_pad = mask_pad.equal_elem(1); + + let tensor_1 = Tensor::::random( + [batch_size, seq_length, d_model], + Distribution::Default, + ); + // Change the end of the tensor + let tensor_2 = tensor_1.clone().slice_assign( + [ + 0..batch_size, + seq_length - num_padded..seq_length, + 0..d_model, + ], + Tensor::random([batch_size, num_padded, d_model], Distribution::Default), + ); + + let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone()); + let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad); + + let output_1 = mha.forward(input_1); + let output_2 = mha.forward(input_2); + + // Check that the beginning of each tensor is the same + output_1 + .context + .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) + .into_data() + .assert_approx_eq( + &output_2 + .context + .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) + .into_data(), + 3, + ); + } - output_1 - .context - .into_data() - .assert_approx_eq(&output_2.into_data(), 3); - } + #[test] + fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() { + let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2]; + let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(); + + let tensor = Tensor::::random( + [batch_size, seq_length, d_model], + Distribution::Default, + ); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); + let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn); + + let output_1 = mha.forward(input); + let mut output_2 = Vec::new(); + let mut cache = MhaCache::autoregressive(); + + for i in 1..seq_length + 1 { + let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); + let input = MhaInput::self_attn(tensor); + let next_tok = mha.forward_cache(input, &mut cache).context.slice([ + 0..batch_size, + i - 1..i, + 0..d_model, + ]); + output_2.push(next_tok); + } + + let output_2 = Tensor::cat(output_2, 1); + + output_1 + .context + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/cache/autoregressive.rs b/burn-core/src/nn/cache/autoregressive.rs index 1cabecd6cb..8d1f1b5bb6 100644 --- a/burn-core/src/nn/cache/autoregressive.rs +++ b/burn-core/src/nn/cache/autoregressive.rs @@ -5,47 +5,47 @@ use crate::tensor::backend::Backend; use crate::tensor::Tensor; impl TensorCache { - pub(crate) fn forward_autoregressive( - &mut self, - tensor: Tensor, - dim_cat: usize, - func: F, - ) -> Tensor - where - F: Fn(Tensor) -> Tensor, - { - let mut tensor_old = CacheState::Empty; - core::mem::swap(&mut self.state, &mut tensor_old); - - let tensor_new = match tensor_old { - CacheState::Value(tensor_old) => { - let [batch_size, seq_length, d_model] = tensor.dims(); - let next_seq_token = - tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]); - let next_seq_token = func(next_seq_token); - - Tensor::cat(vec![tensor_old, next_seq_token], dim_cat) - } - _ => func(tensor), - }; - - self.state = CacheState::Value(tensor_new.clone()); - tensor_new - } - - pub(crate) fn forward_full(&mut self, tensor: Tensor, func: F) -> Tensor - where - F: Fn(Tensor) -> Tensor, - { - let mut tensor_old = CacheState::Empty; - core::mem::swap(&mut self.state, &mut tensor_old); - - let tensor_new = match tensor_old { - CacheState::Value(tensor_old) => tensor_old, - _ => func(tensor), - }; - - self.state = CacheState::Value(tensor_new.clone()); - tensor_new - } + pub(crate) fn forward_autoregressive( + &mut self, + tensor: Tensor, + dim_cat: usize, + func: F, + ) -> Tensor + where + F: Fn(Tensor) -> Tensor, + { + let mut tensor_old = CacheState::Empty; + core::mem::swap(&mut self.state, &mut tensor_old); + + let tensor_new = match tensor_old { + CacheState::Value(tensor_old) => { + let [batch_size, seq_length, d_model] = tensor.dims(); + let next_seq_token = + tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]); + let next_seq_token = func(next_seq_token); + + Tensor::cat(vec![tensor_old, next_seq_token], dim_cat) + } + _ => func(tensor), + }; + + self.state = CacheState::Value(tensor_new.clone()); + tensor_new + } + + pub(crate) fn forward_full(&mut self, tensor: Tensor, func: F) -> Tensor + where + F: Fn(Tensor) -> Tensor, + { + let mut tensor_old = CacheState::Empty; + core::mem::swap(&mut self.state, &mut tensor_old); + + let tensor_new = match tensor_old { + CacheState::Value(tensor_old) => tensor_old, + _ => func(tensor), + }; + + self.state = CacheState::Value(tensor_new.clone()); + tensor_new + } } diff --git a/burn-core/src/nn/cache/base.rs b/burn-core/src/nn/cache/base.rs index baa85bd414..322c65c810 100644 --- a/burn-core/src/nn/cache/base.rs +++ b/burn-core/src/nn/cache/base.rs @@ -2,24 +2,24 @@ use crate::tensor::backend::Backend; use crate::tensor::Tensor; pub(crate) enum CacheState { - Value(T), - Empty, + Value(T), + Empty, } /// A cache for a tensor. pub struct TensorCache { - pub(crate) state: CacheState>, + pub(crate) state: CacheState>, } impl TensorCache { - /// Creates a new empty cache. - /// - /// # Returns - /// - /// The empty cache. - pub fn empty() -> Self { - Self { - state: CacheState::Empty, + /// Creates a new empty cache. + /// + /// # Returns + /// + /// The empty cache. + pub fn empty() -> Self { + Self { + state: CacheState::Empty, + } } - } } diff --git a/burn-core/src/nn/conv/checks.rs b/burn-core/src/nn/conv/checks.rs index bf2253c8d6..ca470f7b29 100644 --- a/burn-core/src/nn/conv/checks.rs +++ b/burn-core/src/nn/conv/checks.rs @@ -1,8 +1,8 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize, groups: usize) { - let channels_in_div_by_group = channels_in % groups == 0; - let channels_out_div_by_group = channels_out % groups == 0; + let channels_in_div_by_group = channels_in % groups == 0; + let channels_out_div_by_group = channels_out % groups == 0; - if !channels_in_div_by_group && !channels_out_div_by_group { - panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}"); - } + if !channels_in_div_by_group && !channels_out_div_by_group { + panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}"); + } } diff --git a/burn-core/src/nn/conv/conv1d.rs b/burn-core/src/nn/conv/conv1d.rs index 2ceacebad7..8a5b79bb7b 100644 --- a/burn-core/src/nn/conv/conv1d.rs +++ b/burn-core/src/nn/conv/conv1d.rs @@ -15,30 +15,30 @@ use super::checks; /// Configuration to create an [1D convolution](Conv1d) layer. #[derive(Config, Debug)] pub struct Conv1dConfig { - /// The number of input channels. - pub channels_in: usize, - /// The number of output channels. - pub channels_out: usize, - /// The size of the kernel. - pub kernel_size: usize, - /// The stride of the convolution. - #[config(default = "1")] - pub stride: usize, - /// Spacing between kernel elements. - #[config(default = "1")] - pub dilation: usize, - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "PaddingConfig1d::Valid")] - pub padding: PaddingConfig1d, - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of input channels. + pub channels_in: usize, + /// The number of output channels. + pub channels_out: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The stride of the convolution. + #[config(default = "1")] + pub stride: usize, + /// Spacing between kernel elements. + #[config(default = "1")] + pub dilation: usize, + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "PaddingConfig1d::Valid")] + pub padding: PaddingConfig1d, + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 1D convolution over input tensors. @@ -50,113 +50,111 @@ pub struct Conv1dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct Conv1d { - weight: Param>, - bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: PaddingConfig1d, + weight: Param>, + bias: Option>>, + stride: usize, + kernel_size: usize, + dilation: usize, + groups: usize, + padding: PaddingConfig1d, } impl Conv1dConfig { - /// Initialize a new [conv1d](Conv1d) module. - pub fn init(&self) -> Conv1d { - checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); - - let shape = [ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size, - ]; - - let fan_in: usize = self.channels_in / self.groups * self.kernel_size; - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self - .initializer - .init_with([self.channels_out], Some(fan_in), None), - ); + /// Initialize a new [conv1d](Conv1d) module. + pub fn init(&self) -> Conv1d { + checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); + + let shape = [ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size, + ]; + + let fan_in: usize = self.channels_in / self.groups * self.kernel_size; + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self.initializer + .init_with([self.channels_out], Some(fan_in), None), + ); + } + + Conv1d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, + groups: self.groups, + } } - - Conv1d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, - groups: self.groups, - } - } - /// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord). - pub fn init_with(&self, record: Conv1dRecord) -> Conv1d { - Conv1d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, - groups: self.groups, + /// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord). + pub fn init_with(&self, record: Conv1dRecord) -> Conv1d { + Conv1d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, + groups: self.groups, + } } - } } impl Conv1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, length_in], - /// - output: [batch_size, channels_out, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); - - conv1d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvOptions::new([self.stride], [padding], [self.dilation], self.groups), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, length_in], + /// - output: [batch_size, channels_out, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); + + conv1d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new([self.stride], [padding], [self.dilation], self.groups), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = Conv1dConfig::new(5, 5, 5); - let k = (config.channels_in * config.kernel_size) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv - .weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = Conv1dConfig::new(5, 5, 5); + let k = (config.channels_in * config.kernel_size) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv.weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/conv/conv2d.rs b/burn-core/src/nn/conv/conv2d.rs index c114b02d10..ed27f3a8a8 100644 --- a/burn-core/src/nn/conv/conv2d.rs +++ b/burn-core/src/nn/conv/conv2d.rs @@ -16,28 +16,28 @@ use super::checks; /// Configuration to create an [2D convolution](Conv2d) layer. #[derive(Config, Debug)] pub struct Conv2dConfig { - /// The number of channels. - pub channels: [usize; 2], - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The stride of the convolution. - #[config(default = "[1, 1]")] - pub stride: [usize; 2], - /// Spacing between kernel elements. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "PaddingConfig2d::Valid")] - pub padding: PaddingConfig2d, - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of channels. + pub channels: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + #[config(default = "[1, 1]")] + pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "PaddingConfig2d::Valid")] + pub padding: PaddingConfig2d, + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 2D convolution over input tensors. @@ -49,115 +49,112 @@ pub struct Conv2dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct Conv2d { - weight: Param>, - bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: PaddingConfig2d, + weight: Param>, + bias: Option>>, + stride: [usize; 2], + kernel_size: [usize; 2], + dilation: [usize; 2], + groups: usize, + padding: PaddingConfig2d, } impl Conv2dConfig { - /// Initialize a new [conv2d](Conv2d) module. - pub fn init(&self) -> Conv2d { - checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - - let shape = [ - self.channels[1], - self.channels[0] / self.groups, - self.kernel_size[0], - self.kernel_size[1], - ]; - - let fan_in = self.channels[0] / self.groups * self.kernel_size.iter().product::(); - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self - .initializer - .init_with([self.channels[1]], Some(fan_in), None), - ); + /// Initialize a new [conv2d](Conv2d) module. + pub fn init(&self) -> Conv2d { + checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + + let shape = [ + self.channels[1], + self.channels[0] / self.groups, + self.kernel_size[0], + self.kernel_size[1], + ]; + + let fan_in = self.channels[0] / self.groups * self.kernel_size.iter().product::(); + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self.initializer + .init_with([self.channels[1]], Some(fan_in), None), + ); + } + + Conv2d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + dilation: self.dilation, + padding: self.padding.clone(), + groups: self.groups, + } } - Conv2d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - dilation: self.dilation, - padding: self.padding.clone(), - groups: self.groups, + /// Initialize a new [conv2d](Conv2d) module with a [record](Conv2dRecord). + pub fn init_with(&self, record: Conv2dRecord) -> Conv2d { + Conv2d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + dilation: self.dilation, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + groups: self.groups, + } } - } - - /// Initialize a new [conv2d](Conv2d) module with a [record](Conv2dRecord). - pub fn init_with(&self, record: Conv2dRecord) -> Conv2d { - Conv2d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - dilation: self.dilation, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - groups: self.groups, - } - } } impl Conv2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, height_in, width_in], - /// - output: [batch_size, channels_out, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self - .padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - conv2d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvOptions::new(self.stride, padding, self.dilation, self.groups), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, height_in, width_in], + /// - output: [batch_size, channels_out, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = + self.padding + .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); + conv2d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new(self.stride, padding, self.dilation, self.groups), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = Conv2dConfig::new([5, 1], [5, 5]); - let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv - .weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = Conv2dConfig::new([5, 1], [5, 5]); + let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv.weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/conv/conv_transpose1d.rs b/burn-core/src/nn/conv/conv_transpose1d.rs index 4309738e68..fb25d6f3ee 100644 --- a/burn-core/src/nn/conv/conv_transpose1d.rs +++ b/burn-core/src/nn/conv/conv_transpose1d.rs @@ -15,31 +15,31 @@ use super::checks; /// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer. #[derive(Config, Debug)] pub struct ConvTranspose1dConfig { - /// The number of channels. - pub channels: [usize; 2], - /// The size of the kernel. - pub kernel_size: usize, - /// The stride of the convolution. - #[config(default = "1")] - pub stride: usize, - /// Spacing between kernel elements. - #[config(default = "1")] - pub dilation: usize, - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "0")] - pub padding: usize, - /// The padding output configuration. - #[config(default = "0")] - pub padding_out: usize, - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of channels. + pub channels: [usize; 2], + /// The size of the kernel. + pub kernel_size: usize, + /// The stride of the convolution. + #[config(default = "1")] + pub stride: usize, + /// Spacing between kernel elements. + #[config(default = "1")] + pub dilation: usize, + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "0")] + pub padding: usize, + /// The padding output configuration. + #[config(default = "0")] + pub padding_out: usize, + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 1D transposed convolution over input tensors. @@ -51,118 +51,116 @@ pub struct ConvTranspose1dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct ConvTranspose1d { - weight: Param>, - bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: usize, - padding_out: usize, + weight: Param>, + bias: Option>>, + stride: usize, + kernel_size: usize, + dilation: usize, + groups: usize, + padding: usize, + padding_out: usize, } impl ConvTranspose1dConfig { - /// Initialize a new [conv transpose 1d](ConvTranspose1d) module. - pub fn init(&self) -> ConvTranspose1d { - checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - - let shape = [ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size, - ]; - - let fan_in = self.channels[1] / self.groups * self.kernel_size; - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self - .initializer - .init_with([self.channels[1]], Some(fan_in), None), - ); + /// Initialize a new [conv transpose 1d](ConvTranspose1d) module. + pub fn init(&self) -> ConvTranspose1d { + checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + + let shape = [ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size, + ]; + + let fan_in = self.channels[1] / self.groups * self.kernel_size; + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self.initializer + .init_with([self.channels[1]], Some(fan_in), None), + ); + } + + ConvTranspose1d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + dilation: self.dilation, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, + } } - ConvTranspose1d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - dilation: self.dilation, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, + /// Initialize a new [conv transpose 1d](ConvTranspose1d) module with a [record](ConvTranspose1dRecord). + pub fn init_with(&self, record: ConvTranspose1dRecord) -> ConvTranspose1d { + ConvTranspose1d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + dilation: self.dilation, + kernel_size: self.kernel_size, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, + } } - } - - /// Initialize a new [conv transpose 1d](ConvTranspose1d) module with a [record](ConvTranspose1dRecord). - pub fn init_with(&self, record: ConvTranspose1dRecord) -> ConvTranspose1d { - ConvTranspose1d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - dilation: self.dilation, - kernel_size: self.kernel_size, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, - } - } } impl ConvTranspose1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, length_in], - /// - output: [batch_size, channels_out, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - conv_transpose1d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvTransposeOptions::new( - [self.stride], - [self.padding], - [self.padding_out], - [self.dilation], - self.groups, - ), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, length_in], + /// - output: [batch_size, channels_out, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + conv_transpose1d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvTransposeOptions::new( + [self.stride], + [self.padding], + [self.padding_out], + [self.dilation], + self.groups, + ), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = ConvTranspose1dConfig::new([5, 1], 5); - let k = (config.channels[1] * config.kernel_size) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv - .weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = ConvTranspose1dConfig::new([5, 1], 5); + let k = (config.channels[1] * config.kernel_size) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv.weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/conv/conv_transpose2d.rs b/burn-core/src/nn/conv/conv_transpose2d.rs index 7b6fa0ac67..af4a249e62 100644 --- a/burn-core/src/nn/conv/conv_transpose2d.rs +++ b/burn-core/src/nn/conv/conv_transpose2d.rs @@ -15,31 +15,31 @@ use super::checks; /// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer. #[derive(Config, Debug)] pub struct ConvTranspose2dConfig { - /// The number of channels. - pub channels: [usize; 2], - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The stride of the convolution. - #[config(default = "[1, 1]")] - pub stride: [usize; 2], - /// Spacing between kernel elements. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], - /// Controls the connections between input and output channels. - #[config(default = "1")] - pub groups: usize, - /// The padding configuration. - #[config(default = "[0, 0]")] - pub padding: [usize; 2], - /// The padding output configuration. - #[config(default = "[0, 0]")] - pub padding_out: [usize; 2], - /// If bias should be added to the output. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] - pub initializer: Initializer, + /// The number of channels. + pub channels: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + #[config(default = "[1, 1]")] + pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub groups: usize, + /// The padding configuration. + #[config(default = "[0, 0]")] + pub padding: [usize; 2], + /// The padding output configuration. + #[config(default = "[0, 0]")] + pub padding_out: [usize; 2], + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0),fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a 2D transposed convolution over input tensors. @@ -51,119 +51,118 @@ pub struct ConvTranspose2dConfig { /// - bias: Tensor of shape `[channels_out]` #[derive(Module, Debug)] pub struct ConvTranspose2d { - weight: Param>, - bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: [usize; 2], - padding_out: [usize; 2], + weight: Param>, + bias: Option>>, + stride: [usize; 2], + kernel_size: [usize; 2], + dilation: [usize; 2], + groups: usize, + padding: [usize; 2], + padding_out: [usize; 2], } impl ConvTranspose2dConfig { - /// Initialize a new [conv transpose 2d](ConvTranspose2d) module. - pub fn init(&self) -> ConvTranspose2d { - checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); - - let shape = [ - self.channels[0], - self.channels[1] / self.groups, - self.kernel_size[0], - self.kernel_size[1], - ]; - - let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::(); - let weight = self.initializer.init_with(shape, Some(fan_in), None); - let mut bias = None; - - if self.bias { - bias = Some( - self - .initializer - .init_with([self.channels[1]], Some(fan_in), None), - ); + /// Initialize a new [conv transpose 2d](ConvTranspose2d) module. + pub fn init(&self) -> ConvTranspose2d { + checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + + let shape = [ + self.channels[0], + self.channels[1] / self.groups, + self.kernel_size[0], + self.kernel_size[1], + ]; + + let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::(); + let weight = self.initializer.init_with(shape, Some(fan_in), None); + let mut bias = None; + + if self.bias { + bias = Some( + self.initializer + .init_with([self.channels[1]], Some(fan_in), None), + ); + } + + ConvTranspose2d { + weight: Param::from(weight), + bias: bias.map(Param::from), + stride: self.stride, + kernel_size: self.kernel_size, + dilation: self.dilation, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, + } } - ConvTranspose2d { - weight: Param::from(weight), - bias: bias.map(Param::from), - stride: self.stride, - kernel_size: self.kernel_size, - dilation: self.dilation, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, + /// Initialize a new [conv transpose 2d](ConvTranspose2d) module with a [record](ConvTranspose2dRecord). + pub fn init_with(&self, record: ConvTranspose2dRecord) -> ConvTranspose2d { + ConvTranspose2d { + weight: record.weight, + bias: record.bias, + stride: self.stride, + dilation: self.dilation, + kernel_size: self.kernel_size, + groups: self.groups, + padding: self.padding, + padding_out: self.padding_out, + } } - } - - /// Initialize a new [conv transpose 2d](ConvTranspose2d) module with a [record](ConvTranspose2dRecord). - pub fn init_with(&self, record: ConvTranspose2dRecord) -> ConvTranspose2d { - ConvTranspose2d { - weight: record.weight, - bias: record.bias, - stride: self.stride, - dilation: self.dilation, - kernel_size: self.kernel_size, - groups: self.groups, - padding: self.padding, - padding_out: self.padding_out, - } - } } impl ConvTranspose2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels_in, height_in, width_in], - /// - output: [batch_size, channels_out, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - conv_transpose2d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvTransposeOptions::new( - self.stride, - self.padding, - self.padding_out, - self.dilation, - self.groups, - ), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels_in, height_in, width_in], + /// - output: [batch_size, channels_out, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + conv_transpose2d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvTransposeOptions::new( + self.stride, + self.padding, + self.padding_out, + self.dilation, + self.groups, + ), + ) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = ConvTranspose2dConfig::new([5, 1], [5, 5]); - let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64; - let k = sqrt(config.groups as f64 / k) as f32; - let conv = config.init::(); - - conv.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); - let conv = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - conv - .weight - .to_data() - .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = ConvTranspose2dConfig::new([5, 1], [5, 5]); + let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64; + let k = sqrt(config.groups as f64 / k) as f32; + let conv = config.init::(); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = + ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); + let conv = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + conv.weight + .to_data() + .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/dropout.rs b/burn-core/src/nn/dropout.rs index 51e4bd14ef..109040e56d 100644 --- a/burn-core/src/nn/dropout.rs +++ b/burn-core/src/nn/dropout.rs @@ -8,8 +8,8 @@ use crate::tensor::{Distribution, Tensor}; /// Configuration to create a [Dropout](Dropout) layer. #[derive(Config, Debug)] pub struct DropoutConfig { - /// The probability of randomly zeroes some elements of the input tensor during training. - pub prob: f64, + /// The probability of randomly zeroes some elements of the input tensor during training. + pub prob: f64, } /// Set at random some elements of the input tensor to zero during training. @@ -20,65 +20,65 @@ pub struct DropoutConfig { /// The input is also scaled during training to `1 / (1 - prob_keep)`. #[derive(Module, Clone, Debug)] pub struct Dropout { - prob: f64, + prob: f64, } impl DropoutConfig { - /// Initialize a new [dropout](Dropout) module. - pub fn init(&self) -> Dropout { - Dropout { prob: self.prob } - } + /// Initialize a new [dropout](Dropout) module. + pub fn init(&self) -> Dropout { + Dropout { prob: self.prob } + } } impl Dropout { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any]` - /// - output: `[..., any]` - pub fn forward(&self, input: Tensor) -> Tensor { - if !B::ad_enabled() || self.prob == 0.0 { - return input; + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any]` + /// - output: `[..., any]` + pub fn forward(&self, input: Tensor) -> Tensor { + if !B::ad_enabled() || self.prob == 0.0 { + return input; + } + + let prob_keep = 1.0 - self.prob; + let random = input.random_like(Distribution::Bernoulli(prob_keep)); + let x = input * random; + + x * (1.0 / prob_keep) } - - let prob_keep = 1.0 - self.prob; - let random = input.random_like(Distribution::Bernoulli(prob_keep)); - let x = input * random; - - x * (1.0 / prob_keep) - } } #[cfg(test)] mod tests { - use super::*; - use crate::tensor::Shape; + use super::*; + use crate::tensor::Shape; - #[cfg(feature = "std")] - use crate::{TestAutodiffBackend, TestBackend}; + #[cfg(feature = "std")] + use crate::{TestAutodiffBackend, TestBackend}; - #[cfg(not(feature = "std"))] - use crate::TestBackend; + #[cfg(not(feature = "std"))] + use crate::TestBackend; - #[cfg(feature = "std")] - #[test] - fn with_ad_backend_should_mark_input() { - let tensor = Tensor::::ones(Shape::new([100, 100])); - let dropout = DropoutConfig::new(0.5).init(); + #[cfg(feature = "std")] + #[test] + fn with_ad_backend_should_mark_input() { + let tensor = Tensor::::ones(Shape::new([100, 100])); + let dropout = DropoutConfig::new(0.5).init(); - let output = dropout.forward(tensor.clone()); + let output = dropout.forward(tensor.clone()); - assert_ne!(tensor.to_data(), output.to_data()); - } + assert_ne!(tensor.to_data(), output.to_data()); + } - #[test] - fn without_ad_backend_should_not_change_input() { - let tensor = Tensor::::ones(Shape::new([100, 100])); - let dropout = DropoutConfig::new(0.5).init(); + #[test] + fn without_ad_backend_should_not_change_input() { + let tensor = Tensor::::ones(Shape::new([100, 100])); + let dropout = DropoutConfig::new(0.5).init(); - let output = dropout.forward(tensor.clone()); + let output = dropout.forward(tensor.clone()); - assert_eq!(tensor.to_data(), output.to_data()); - } + assert_eq!(tensor.to_data(), output.to_data()); + } } diff --git a/burn-core/src/nn/embedding.rs b/burn-core/src/nn/embedding.rs index f3deb03a82..49f26eb727 100644 --- a/burn-core/src/nn/embedding.rs +++ b/burn-core/src/nn/embedding.rs @@ -11,13 +11,13 @@ use burn_tensor::Int; /// Configuration to create an [Embedding](Embedding) layer. #[derive(Config)] pub struct EmbeddingConfig { - /// The number of embedding vectors. - n_embedding: usize, - /// The size of each vector. - d_model: usize, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")] - pub initializer: Initializer, + /// The number of embedding vectors. + n_embedding: usize, + /// The size of each vector. + d_model: usize, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")] + pub initializer: Initializer, } /// Lookup table to store a fix number of vectors. @@ -28,80 +28,80 @@ pub struct EmbeddingConfig { /// `N(0, 1)` #[derive(Module, Debug)] pub struct Embedding { - weight: Param>, + weight: Param>, } impl EmbeddingConfig { - /// Initialize a new [embedding](Embedding) module. - pub fn init(&self) -> Embedding { - let weight = self - .initializer - .init([self.n_embedding, self.d_model]) - .require_grad(); + /// Initialize a new [embedding](Embedding) module. + pub fn init(&self) -> Embedding { + let weight = self + .initializer + .init([self.n_embedding, self.d_model]) + .require_grad(); - Embedding { - weight: Param::from(weight), + Embedding { + weight: Param::from(weight), + } } - } - /// Initialize a new [embedding](Embedding) module with a [record](EmbeddingRecord). - pub fn init_with(&self, record: EmbeddingRecord) -> Embedding { - Embedding { - weight: record.weight, + /// Initialize a new [embedding](Embedding) module with a [record](EmbeddingRecord). + pub fn init_with(&self, record: EmbeddingRecord) -> Embedding { + Embedding { + weight: record.weight, + } } - } } impl Embedding { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, seq_length] - /// - output: [batch_size, d_model] - pub fn forward(&self, input: Tensor) -> Tensor { - burn_tensor::module::embedding(self.weight.val(), input) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, seq_length] + /// - output: [batch_size, d_model] + pub fn forward(&self, input: Tensor) -> Tensor { + burn_tensor::module::embedding(self.weight.val(), input) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; + use super::*; + use crate::TestBackend; + use burn_tensor::Data; - #[test] - fn initializer_default() { - TestBackend::seed(0); + #[test] + fn initializer_default() { + TestBackend::seed(0); - let config = EmbeddingConfig::new(100, 10); - let embed = config.init::(); - let weights = embed.weight.val().reshape([1000]); - let (var_act, mean_act) = weights.var_mean(0); + let config = EmbeddingConfig::new(100, 10); + let embed = config.init::(); + let weights = embed.weight.val().reshape([1000]); + let (var_act, mean_act) = weights.var_mean(0); - assert_eq!( - config.initializer, - Initializer::Normal { - mean: 0.0, - std: 1.0 - } - ); - var_act.to_data().assert_approx_eq(&Data::from([1.0f32]), 0); - mean_act - .to_data() - .assert_approx_eq(&Data::from([0.0f32]), 0); - } + assert_eq!( + config.initializer, + Initializer::Normal { + mean: 0.0, + std: 1.0 + } + ); + var_act.to_data().assert_approx_eq(&Data::from([1.0f32]), 0); + mean_act + .to_data() + .assert_approx_eq(&Data::from([0.0f32]), 0); + } - #[test] - fn initializer_zeros() { - TestBackend::seed(0); + #[test] + fn initializer_zeros() { + TestBackend::seed(0); - let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros); - let embed = config.init::(); + let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros); + let embed = config.init::(); - assert_eq!(config.initializer, Initializer::Zeros); - embed - .weight - .to_data() - .assert_approx_eq(&Data::zeros(embed.weight.shape()), 3); - } + assert_eq!(config.initializer, Initializer::Zeros); + embed + .weight + .to_data() + .assert_approx_eq(&Data::zeros(embed.weight.shape()), 3); + } } diff --git a/burn-core/src/nn/gelu.rs b/burn-core/src/nn/gelu.rs index f0c392c4c4..020b6e5ee0 100644 --- a/burn-core/src/nn/gelu.rs +++ b/burn-core/src/nn/gelu.rs @@ -9,18 +9,18 @@ use crate::tensor::Tensor; pub struct GELU {} impl GELU { - /// Create the module. - pub fn new() -> Self { - Self {} - } + /// Create the module. + pub fn new() -> Self { + Self {} + } - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any]` - /// - output: `[..., any]` - pub fn forward(&self, input: Tensor) -> Tensor { - crate::tensor::activation::gelu(input) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any]` + /// - output: `[..., any]` + pub fn forward(&self, input: Tensor) -> Tensor { + crate::tensor::activation::gelu(input) + } } diff --git a/burn-core/src/nn/initializer.rs b/burn-core/src/nn/initializer.rs index 7a97403cef..51575b682d 100644 --- a/burn-core/src/nn/initializer.rs +++ b/burn-core/src/nn/initializer.rs @@ -10,351 +10,363 @@ use crate as burn; /// Enum specifying with what values a tensor should be initialized #[derive(Config, Debug, PartialEq)] pub enum Initializer { - /// Fills tensor with specified value everywhere - Constant { - /// The value to fill the tensor with - value: f64, - }, - /// Fills tensor with 1s everywhere - Ones, - /// Fills tensor with 0s everywhere - Zeros, - /// Fills tensor with values drawn uniformly between specified values - Uniform { - /// The minimum value to draw from - min: f64, - - /// The maximum value to draw from - max: f64, - }, - /// Fills tensor with values drawn from normal distribution with specified mean and std - Normal { - /// The mean of the normal distribution - mean: f64, - - /// The standard deviation of the normal distribution - std: f64, - }, - /// Fills tensor with values according to the uniform version of Kaiming initialization - KaimingUniform { - /// The gain to use in initialization formula - gain: f64, - - /// Whether to use fan out only in initialization formula - fan_out_only: bool, - }, - /// Fills tensor with values according to the uniform version of Kaiming initialization - KaimingNormal { - /// The gain to use in initialization formula - gain: f64, - - /// Whether to use fan out only in initialization formula - fan_out_only: bool, - }, - /// Fills tensor with values according to the uniform version of Xavier Glorot initialization - /// described in [Understanding the difficulty of training deep feedforward neural networks - /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) - XavierUniform { - /// The gain to use in initialization formula - gain: f64, - }, - /// Fills tensor with values according to the normal version of Xavier Glorot initialization - /// described in [Understanding the difficulty of training deep feedforward neural networks - /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) - XavierNormal { - /// The gain to use in initialization formula - gain: f64, - }, + /// Fills tensor with specified value everywhere + Constant { + /// The value to fill the tensor with + value: f64, + }, + /// Fills tensor with 1s everywhere + Ones, + /// Fills tensor with 0s everywhere + Zeros, + /// Fills tensor with values drawn uniformly between specified values + Uniform { + /// The minimum value to draw from + min: f64, + + /// The maximum value to draw from + max: f64, + }, + /// Fills tensor with values drawn from normal distribution with specified mean and std + Normal { + /// The mean of the normal distribution + mean: f64, + + /// The standard deviation of the normal distribution + std: f64, + }, + /// Fills tensor with values according to the uniform version of Kaiming initialization + KaimingUniform { + /// The gain to use in initialization formula + gain: f64, + + /// Whether to use fan out only in initialization formula + fan_out_only: bool, + }, + /// Fills tensor with values according to the uniform version of Kaiming initialization + KaimingNormal { + /// The gain to use in initialization formula + gain: f64, + + /// Whether to use fan out only in initialization formula + fan_out_only: bool, + }, + /// Fills tensor with values according to the uniform version of Xavier Glorot initialization + /// described in [Understanding the difficulty of training deep feedforward neural networks + /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) + XavierUniform { + /// The gain to use in initialization formula + gain: f64, + }, + /// Fills tensor with values according to the normal version of Xavier Glorot initialization + /// described in [Understanding the difficulty of training deep feedforward neural networks + /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) + XavierNormal { + /// The gain to use in initialization formula + gain: f64, + }, } impl Initializer { - /// Inits a tensor of given shape with values depending on initializer kind. - /// - /// # Params - /// - /// - shape: Shape of the initiated tensor. - pub fn init>>(&self, shape: S) -> Tensor { - self.init_with(shape, None, None) - } - - /// Inits a tensor of given shape with values depending on initializer kind, with the possibility - /// of specifying fan in and fan out - /// - /// # Params - /// - /// - shape: Shape of the initiated tensor. - /// - fan_in: `Option`, the fan in to use in initialization formula, if needed - /// - fan_out: `Option`, the fan out to use in initialization formula, if needed - pub fn init_with>>( - &self, - shape: S, - fan_in: Option, - fan_out: Option, - ) -> Tensor { - let shape = shape.into(); - match self { - Initializer::Constant { value } => Tensor::::full(shape, *value), - Initializer::Ones => Tensor::::ones(shape), - Initializer::Zeros => Tensor::::zeros(shape), - Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max), - Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std), - Initializer::KaimingUniform { gain, fan_out_only } => { - let a = sqrt(3.0) * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); - uniform_draw(shape, -a, a) - } - Initializer::KaimingNormal { gain, fan_out_only } => { - let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); - normal_draw(shape, 0.0, std) - } - Initializer::XavierUniform { gain } => { - let a = sqrt(3.0) * *gain * self.xavier_std(fan_in, fan_out); - uniform_draw(shape, -a, a) - } - Initializer::XavierNormal { gain } => { - let std = *gain * self.xavier_std(fan_in, fan_out); - normal_draw(shape, 0.0, std) - } + /// Inits a tensor of given shape with values depending on initializer kind. + /// + /// # Params + /// + /// - shape: Shape of the initiated tensor. + pub fn init>>(&self, shape: S) -> Tensor { + self.init_with(shape, None, None) + } + + /// Inits a tensor of given shape with values depending on initializer kind, with the possibility + /// of specifying fan in and fan out + /// + /// # Params + /// + /// - shape: Shape of the initiated tensor. + /// - fan_in: `Option`, the fan in to use in initialization formula, if needed + /// - fan_out: `Option`, the fan out to use in initialization formula, if needed + pub fn init_with>>( + &self, + shape: S, + fan_in: Option, + fan_out: Option, + ) -> Tensor { + let shape = shape.into(); + match self { + Initializer::Constant { value } => Tensor::::full(shape, *value), + Initializer::Ones => Tensor::::ones(shape), + Initializer::Zeros => Tensor::::zeros(shape), + Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max), + Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std), + Initializer::KaimingUniform { gain, fan_out_only } => { + let a = sqrt(3.0) * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); + uniform_draw(shape, -a, a) + } + Initializer::KaimingNormal { gain, fan_out_only } => { + let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); + normal_draw(shape, 0.0, std) + } + Initializer::XavierUniform { gain } => { + let a = sqrt(3.0) * *gain * self.xavier_std(fan_in, fan_out); + uniform_draw(shape, -a, a) + } + Initializer::XavierNormal { gain } => { + let std = *gain * self.xavier_std(fan_in, fan_out); + normal_draw(shape, 0.0, std) + } + } } - } - fn kaiming_std(&self, fan_out_only: bool, fan_in: Option, fan_out: Option) -> f64 { - let fan = if fan_out_only { fan_out } else { fan_in }; - let fan = - fan.expect("Can't use Kaiming initialization without specifying fan. Use init_with method."); + fn kaiming_std( + &self, + fan_out_only: bool, + fan_in: Option, + fan_out: Option, + ) -> f64 { + let fan = if fan_out_only { fan_out } else { fan_in }; + let fan = fan.expect( + "Can't use Kaiming initialization without specifying fan. Use init_with method.", + ); - 1.0 / sqrt(fan as f64) - } + 1.0 / sqrt(fan as f64) + } - fn xavier_std(&self, fan_in: Option, fan_out: Option) -> f64 { - let fan_in = fan_in.expect( + fn xavier_std(&self, fan_in: Option, fan_out: Option) -> f64 { + let fan_in = fan_in.expect( "Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in.", ); - let fan_out = fan_out.expect( + let fan_out = fan_out.expect( "Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out.", ); - sqrt(2.0 / (fan_in + fan_out) as f64) - } + sqrt(2.0 / (fan_in + fan_out) as f64) + } } fn uniform_draw>>( - shape: S, - low: f64, - high: f64, + shape: S, + low: f64, + high: f64, ) -> Tensor { - let distribution = Distribution::Uniform(low.elem::(), high.elem::()); - Tensor::::random(shape, distribution) + let distribution = + Distribution::Uniform(low.elem::(), high.elem::()); + Tensor::::random(shape, distribution) } fn normal_draw>>( - shape: S, - mean: f64, - std: f64, + shape: S, + mean: f64, + std: f64, ) -> Tensor { - let distribution = Distribution::Normal(mean, std); - Tensor::::random(shape, distribution) + let distribution = Distribution::Normal(mean, std); + Tensor::::random(shape, distribution) } #[cfg(test)] mod tests { - use super::*; + use super::*; + + use burn_tensor::Data; + + pub type TB = burn_ndarray::NdArray; + + fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor) { + let (actual_vars, actual_means) = tensor.clone().var_mean(0); + + for i in 0..tensor.shape().dims[0] { + let actual_var = actual_vars.to_data().value[i] as f64; + let actual_mean = actual_means.to_data().value[i] as f64; + + assert!( + (expected_var - actual_var).abs() <= 0.1, + "Expected variance to be between {expected_var} += 0.1, but got {actual_var}" + ); + assert!( + (expected_mean - actual_mean).abs() <= 0.1, + "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}" + ); + } + } + + #[test] + fn initializer_uniform_init() { + TB::seed(0); + + let (min, max) = (0.0, 1.0); + let uniform = Initializer::Uniform { min, max }; + let tensor: Tensor = uniform.init([2, 2, 2, 2]); + + tensor.into_data().assert_within_range(min..max); + } - use burn_tensor::Data; + #[test] + fn initializer_normal_init() { + // seed random generator + TB::seed(0); + let (mean, std) = (0.0, 1.0); + let normal: Tensor = Initializer::Normal { mean, std }.init([1000]); + let (var_act, mean_act) = normal.var_mean(0); - pub type TB = burn_ndarray::NdArray; + let var_act: f32 = var_act.into_scalar().elem(); + let mean_act: f32 = mean_act.into_scalar().elem(); - fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor) { - let (actual_vars, actual_means) = tensor.clone().var_mean(0); + assert!( + var_act > 0.9 && var_act < 1.1, + "Expected variance to be between 1.0 += 0.1, but got {var_act}" + ); + assert!( + mean_act > -0.1 && mean_act < 0.1, + "Expected mean to be between 0.0 += 0.1, but got {mean_act}" + ); + } - for i in 0..tensor.shape().dims[0] { - let actual_var = actual_vars.to_data().value[i] as f64; - let actual_mean = actual_means.to_data().value[i] as f64; + #[test] + fn initializer_constant_init() { + let value = 5.0; + let constants: Tensor = Initializer::Constant { value }.init([2, 2, 2, 2]); + constants + .sum() + .to_data() + .assert_approx_eq(&Data::from([value as f32 * 16.0]), 3); + } + + #[test] + fn initializer_zeros_init() { + let zeros: Tensor = Initializer::Zeros.init([2, 2, 2, 2]); + zeros + .sum() + .to_data() + .assert_approx_eq(&Data::from([0.0]), 3); + } - assert!( - (expected_var - actual_var).abs() <= 0.1, - "Expected variance to be between {expected_var} += 0.1, but got {actual_var}" - ); - assert!( - (expected_mean - actual_mean).abs() <= 0.1, - "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}" - ); + #[test] + fn initializer_ones_init() { + let ones: Tensor = Initializer::Ones.init([2, 2, 2, 2]); + ones.sum() + .to_data() + .assert_approx_eq(&Data::from([16.0]), 3); } - } - - #[test] - fn initializer_uniform_init() { - TB::seed(0); - - let (min, max) = (0.0, 1.0); - let uniform = Initializer::Uniform { min, max }; - let tensor: Tensor = uniform.init([2, 2, 2, 2]); - - tensor.into_data().assert_within_range(min..max); - } - - #[test] - fn initializer_normal_init() { - // seed random generator - TB::seed(0); - let (mean, std) = (0.0, 1.0); - let normal: Tensor = Initializer::Normal { mean, std }.init([1000]); - let (var_act, mean_act) = normal.var_mean(0); - - let var_act: f32 = var_act.into_scalar().elem(); - let mean_act: f32 = mean_act.into_scalar().elem(); - - assert!( - var_act > 0.9 && var_act < 1.1, - "Expected variance to be between 1.0 += 0.1, but got {var_act}" - ); - assert!( - mean_act > -0.1 && mean_act < 0.1, - "Expected mean to be between 0.0 += 0.1, but got {mean_act}" - ); - } - - #[test] - fn initializer_constant_init() { - let value = 5.0; - let constants: Tensor = Initializer::Constant { value }.init([2, 2, 2, 2]); - constants - .sum() - .to_data() - .assert_approx_eq(&Data::from([value as f32 * 16.0]), 3); - } - - #[test] - fn initializer_zeros_init() { - let zeros: Tensor = Initializer::Zeros.init([2, 2, 2, 2]); - zeros - .sum() - .to_data() - .assert_approx_eq(&Data::from([0.0]), 3); - } - - #[test] - fn initializer_ones_init() { - let ones: Tensor = Initializer::Ones.init([2, 2, 2, 2]); - ones - .sum() - .to_data() - .assert_approx_eq(&Data::from([16.0]), 3); - } - - #[test] - fn initializer_kaiming_uniform_init() { - TB::seed(0); - - let gain = 2_f64; - let (fan_in, fan_out) = (5, 6); - let k = gain * sqrt(3.0 / fan_in as f64); - - let tensor: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: false, + + #[test] + fn initializer_kaiming_uniform_init() { + TB::seed(0); + + let gain = 2_f64; + let (fan_in, fan_out) = (5, 6); + let k = gain * sqrt(3.0 / fan_in as f64); + + let tensor: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: false, + } + .init_with([fan_out, fan_in], Some(fan_in), None); + tensor.into_data().assert_within_range(-k..k); } - .init_with([fan_out, fan_in], Some(fan_in), None); - tensor.into_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_kaiming_normal_init() { - TB::seed(0); - - let gain = 2.; - let (fan_in, fan_out) = (1000, 10); - let expected_mean = 0_f64; - - let expected_var = (gain * sqrt(1. / (fan_in as f64))).powf(2.); - let tensor: Tensor = Initializer::KaimingNormal { - gain, - fan_out_only: false, + + #[test] + fn initializer_kaiming_normal_init() { + TB::seed(0); + + let gain = 2.; + let (fan_in, fan_out) = (1000, 10); + let expected_mean = 0_f64; + + let expected_var = (gain * sqrt(1. / (fan_in as f64))).powf(2.); + let tensor: Tensor = Initializer::KaimingNormal { + gain, + fan_out_only: false, + } + .init_with([fan_out, fan_in], Some(fan_in), None); + assert_normal_init(expected_mean, expected_var, &tensor) } - .init_with([fan_out, fan_in], Some(fan_in), None); - assert_normal_init(expected_mean, expected_var, &tensor) - } - - #[test] - fn initializer_kaiming_uniform_init_bias() { - TB::seed(0); - - let gain = 2_f64; - let shape = [3]; - let fan_in = 5; - let k = gain * sqrt(3.0 / fan_in as f64); - - let tensor: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: false, + + #[test] + fn initializer_kaiming_uniform_init_bias() { + TB::seed(0); + + let gain = 2_f64; + let shape = [3]; + let fan_in = 5; + let k = gain * sqrt(3.0 / fan_in as f64); + + let tensor: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: false, + } + .init_with(shape, Some(fan_in), None); + tensor.into_data().assert_within_range(-k..k); } - .init_with(shape, Some(fan_in), None); - tensor.into_data().assert_within_range(-k..k); - } - #[test] - fn initializer_kaiming_uniform_init_fan_out() { - TB::seed(0); + #[test] + fn initializer_kaiming_uniform_init_fan_out() { + TB::seed(0); - let gain = 2_f64; - let (fan_in, fan_out) = (5, 6); - let k = gain * sqrt(3.0 / fan_out as f64); + let gain = 2_f64; + let (fan_in, fan_out) = (5, 6); + let k = gain * sqrt(3.0 / fan_out as f64); - let tensor: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: true, + let tensor: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: true, + } + .init_with([fan_out, fan_in], None, Some(fan_out)); + tensor.into_data().assert_within_range(-k..k); } - .init_with([fan_out, fan_in], None, Some(fan_out)); - tensor.into_data().assert_within_range(-k..k); - } - #[test] - #[should_panic] - fn initializer_kaiming_uniform_no_fan() { - TB::seed(0); + #[test] + #[should_panic] + fn initializer_kaiming_uniform_no_fan() { + TB::seed(0); + + let gain = 2_f64; + let (fan_in, fan_out) = (5, 6); + + let _: Tensor = Initializer::KaimingUniform { + gain, + fan_out_only: false, + } + .init([fan_out, fan_in]); + } + + #[test] + fn initializer_xavier_uniform_init() { + TB::seed(0); + + let gain = 2.; + let (fan_in, fan_out) = (5, 6); + let bound = gain * sqrt(6. / (fan_in + fan_out) as f64); + let tensor: Tensor = Initializer::XavierUniform { gain }.init_with( + [fan_out, fan_in], + Some(fan_in), + Some(fan_out), + ); + + tensor.into_data().assert_within_range(-bound..bound); + } + + #[test] + fn initializer_xavier_normal_init() { + TB::seed(0); + + let gain = 2.; + let (fan_in, fan_out) = (1000, 10); + let expected_mean = 0_f64; + + let expected_var = (gain * sqrt(2. / (fan_in as f64 + fan_out as f64))).powf(2.); + let tensor: Tensor = Initializer::XavierNormal { gain }.init_with( + [fan_out, fan_in], + Some(fan_in), + Some(fan_out), + ); + assert_normal_init(expected_mean, expected_var, &tensor) + } - let gain = 2_f64; - let (fan_in, fan_out) = (5, 6); + #[test] + #[should_panic] + fn initializer_xavier_uniform_no_fan() { + TB::seed(0); - let _: Tensor = Initializer::KaimingUniform { - gain, - fan_out_only: false, + let gain = 2.; + let (fan_in, fan_out) = (5, 6); + let _: Tensor = Initializer::XavierUniform { gain }.init([fan_out, fan_in]); } - .init([fan_out, fan_in]); - } - - #[test] - fn initializer_xavier_uniform_init() { - TB::seed(0); - - let gain = 2.; - let (fan_in, fan_out) = (5, 6); - let bound = gain * sqrt(6. / (fan_in + fan_out) as f64); - let tensor: Tensor = - Initializer::XavierUniform { gain }.init_with([fan_out, fan_in], Some(fan_in), Some(fan_out)); - - tensor.into_data().assert_within_range(-bound..bound); - } - - #[test] - fn initializer_xavier_normal_init() { - TB::seed(0); - - let gain = 2.; - let (fan_in, fan_out) = (1000, 10); - let expected_mean = 0_f64; - - let expected_var = (gain * sqrt(2. / (fan_in as f64 + fan_out as f64))).powf(2.); - let tensor: Tensor = - Initializer::XavierNormal { gain }.init_with([fan_out, fan_in], Some(fan_in), Some(fan_out)); - assert_normal_init(expected_mean, expected_var, &tensor) - } - - #[test] - #[should_panic] - fn initializer_xavier_uniform_no_fan() { - TB::seed(0); - - let gain = 2.; - let (fan_in, fan_out) = (5, 6); - let _: Tensor = Initializer::XavierUniform { gain }.init([fan_out, fan_in]); - } } diff --git a/burn-core/src/nn/linear.rs b/burn-core/src/nn/linear.rs index 266bc8eea7..0b3ef20db2 100644 --- a/burn-core/src/nn/linear.rs +++ b/burn-core/src/nn/linear.rs @@ -11,16 +11,16 @@ use super::Initializer; /// Configuration to create a [Linear](Linear) layer. #[derive(Config, Debug)] pub struct LinearConfig { - /// The size of the input features. - pub d_input: usize, - /// The size of the output features. - pub d_output: usize, - /// If a bias should be applied during the linear transformation. - #[config(default = true)] - pub bias: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0), fan_out_only:false}")] - pub initializer: Initializer, + /// The size of the input features. + pub d_input: usize, + /// The size of the output features. + pub d_output: usize, + /// If a bias should be applied during the linear transformation. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0), fan_out_only:false}")] + pub initializer: Initializer, } /// Applies a linear transformation to the input tensor: @@ -28,131 +28,131 @@ pub struct LinearConfig { /// `O = IW + b` #[derive(Module, Debug)] pub struct Linear { - /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution: - /// `U(-k, k)`, where `k = sqrt(1 / d_input)` - pub weight: Param>, - /// Vector of size `d_output` initialized from a uniform distribution: - /// `U(-k, k)`, where `k = sqrt(1 / d_input)` - pub bias: Option>>, + /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution: + /// `U(-k, k)`, where `k = sqrt(1 / d_input)` + pub weight: Param>, + /// Vector of size `d_output` initialized from a uniform distribution: + /// `U(-k, k)`, where `k = sqrt(1 / d_input)` + pub bias: Option>>, } impl LinearConfig { - /// Initialize a new [linear](Linear) module. - pub fn init(&self) -> Linear { - let shape = [self.d_input, self.d_output]; - let weight = self - .initializer - .init_with(shape, Some(self.d_input), Some(self.d_output)); - let bias = if self.bias { - Some( - self - .initializer - .init_with([self.d_output], Some(self.d_input), Some(self.d_output)), - ) - } else { - None - }; - - Linear { - weight: Param::from(weight), - bias: bias.map(Param::from), + /// Initialize a new [linear](Linear) module. + pub fn init(&self) -> Linear { + let shape = [self.d_input, self.d_output]; + let weight = self + .initializer + .init_with(shape, Some(self.d_input), Some(self.d_output)); + let bias = if self.bias { + Some(self.initializer.init_with( + [self.d_output], + Some(self.d_input), + Some(self.d_output), + )) + } else { + None + }; + + Linear { + weight: Param::from(weight), + bias: bias.map(Param::from), + } } - } - /// Initialize a new [linear](Linear) module with a [record](LinearRecord). - pub fn init_with(&self, record: LinearRecord) -> Linear { - Linear { - weight: record.weight, - bias: record.bias, + /// Initialize a new [linear](Linear) module with a [record](LinearRecord). + pub fn init_with(&self, record: LinearRecord) -> Linear { + Linear { + weight: record.weight, + bias: record.bias, + } } - } } impl Linear { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any, d_input]` - /// - output: `[..., any, d_output]` - pub fn forward(&self, input: Tensor) -> Tensor { - let output = input.matmul(self.weight.val().unsqueeze()); - - match &self.bias { - Some(bias) => output + bias.val().unsqueeze(), - None => output, + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any, d_input]` + /// - output: `[..., any, d_output]` + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.matmul(self.weight.val().unsqueeze()); + + match &self.bias { + Some(bias) => output + bias.val().unsqueeze(), + None => output, + } } - } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::{Data, Shape}; - use libm::sqrt; - - #[test] - fn initializer_default() { - TestBackend::seed(0); - - let config = LinearConfig::new(5, 5); - let k = sqrt(1.0 / config.d_input as f64) as f32; - let linear = config.init::(); - - assert_eq!( - config.initializer, - Initializer::KaimingUniform { - gain: 1.0 / sqrt(3.0), - fan_out_only: false - } - ); - linear.weight.to_data().assert_within_range(-k..k); - } - - #[test] - fn initializer_zeros() { - TestBackend::seed(0); - - let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros); - let linear = config.init::(); - - assert_eq!(config.initializer, Initializer::Zeros); - linear - .weight - .to_data() - .assert_approx_eq(&Data::zeros(linear.weight.shape()), 3); - } - - #[test] - fn test_linear_forward_no_bias() { - TestBackend::seed(0); - - let value = 2.; - let config = LinearConfig::new(2, 3) - .with_initializer(Initializer::Constant { value }) - .with_bias(false); - let linear = config.init(); - - let input = Tensor::::ones(Shape::new([1, 2])); - let result = linear.forward(input); - let expected_result = Tensor::::from_data([[4., 4., 4.]]); - - assert_eq!(result.into_data(), expected_result.into_data()); - } - - #[test] - fn test_linear_forward_with_bias() { - TestBackend::seed(0); - - let value = 2.; - let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); - let linear = config.init(); - - let input = Tensor::::ones(Shape::new([1, 2])); - let result = linear.forward(input); - let expected_result = Tensor::::from_data([[6., 6., 6.]]); - - assert_eq!(result.into_data(), expected_result.into_data()); - } + use super::*; + use crate::TestBackend; + use burn_tensor::{Data, Shape}; + use libm::sqrt; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = LinearConfig::new(5, 5); + let k = sqrt(1.0 / config.d_input as f64) as f32; + let linear = config.init::(); + + assert_eq!( + config.initializer, + Initializer::KaimingUniform { + gain: 1.0 / sqrt(3.0), + fan_out_only: false + } + ); + linear.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros); + let linear = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + linear + .weight + .to_data() + .assert_approx_eq(&Data::zeros(linear.weight.shape()), 3); + } + + #[test] + fn test_linear_forward_no_bias() { + TestBackend::seed(0); + + let value = 2.; + let config = LinearConfig::new(2, 3) + .with_initializer(Initializer::Constant { value }) + .with_bias(false); + let linear = config.init(); + + let input = Tensor::::ones(Shape::new([1, 2])); + let result = linear.forward(input); + let expected_result = Tensor::::from_data([[4., 4., 4.]]); + + assert_eq!(result.into_data(), expected_result.into_data()); + } + + #[test] + fn test_linear_forward_with_bias() { + TestBackend::seed(0); + + let value = 2.; + let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); + let linear = config.init(); + + let input = Tensor::::ones(Shape::new([1, 2])); + let result = linear.forward(input); + let expected_result = Tensor::::from_data([[6., 6., 6.]]); + + assert_eq!(result.into_data(), expected_result.into_data()); + } } diff --git a/burn-core/src/nn/loss/binary_cross_entropy.rs b/burn-core/src/nn/loss/binary_cross_entropy.rs index dcca0373b5..7cb85cec66 100644 --- a/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -7,170 +7,170 @@ use burn_tensor::{backend::Backend, Int, Tensor}; /// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss). #[derive(Config, Debug)] pub struct BinaryCrossEntropyLossConfig { - /// Create weighted binary cross-entropy. - /// - /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, - /// - /// # Pre-conditions - /// - The order of the weight vector should correspond to the label integer assignment. - /// - Targets assigned negative Int's will not be allowed. - pub weights: Option<[f32; 2]>, - - /// Create binary cross-entropy with label smoothing. - /// - /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. - /// Alpha = 0 would be the same as default. - smoothing: Option, - - /// Create binary cross-entropy with probabilities as input instead of logits. - /// - #[config(default = true)] - logits: bool, + /// Create weighted binary cross-entropy. + /// + /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, + /// + /// # Pre-conditions + /// - The order of the weight vector should correspond to the label integer assignment. + /// - Targets assigned negative Int's will not be allowed. + pub weights: Option<[f32; 2]>, + + /// Create binary cross-entropy with label smoothing. + /// + /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. + /// Alpha = 0 would be the same as default. + smoothing: Option, + + /// Create binary cross-entropy with probabilities as input instead of logits. + /// + #[config(default = true)] + logits: bool, } impl BinaryCrossEntropyLossConfig { - /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss). - pub fn init(&self) -> BinaryCrossEntropyLoss { - self.assertions(); - BinaryCrossEntropyLoss { - weights: self - .weights - .as_ref() - .map(|e| Tensor::::from_floats(e.as_slice())), - smoothing: self.smoothing, - logits: self.logits, + /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss). + pub fn init(&self) -> BinaryCrossEntropyLoss { + self.assertions(); + BinaryCrossEntropyLoss { + weights: self + .weights + .as_ref() + .map(|e| Tensor::::from_floats(e.as_slice())), + smoothing: self.smoothing, + logits: self.logits, + } } - } - fn assertions(&self) { - if let Some(alpha) = self.smoothing { - assert!( + fn assertions(&self) { + if let Some(alpha) = self.smoothing { + assert!( (0.0..=1.).contains(&alpha), "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", alpha ); - }; - if let Some(weights) = self.weights.as_ref() { - assert!( - weights.iter().all(|e| e > &0.), - "Weights of cross-entropy have to be positive." - ); + }; + if let Some(weights) = self.weights.as_ref() { + assert!( + weights.iter().all(|e| e > &0.), + "Weights of cross-entropy have to be positive." + ); + } } - } } /// Calculate the cross entropy loss from the input logits and the targets. #[derive(Module, Debug)] pub struct BinaryCrossEntropyLoss { - /// Weights for cross-entropy. - pub weights: Option>, - smoothing: Option, - logits: bool, + /// Weights for cross-entropy. + pub weights: Option>, + smoothing: Option, + logits: bool, } impl Default for BinaryCrossEntropyLoss { - fn default() -> Self { - BinaryCrossEntropyLossConfig::new().init() - } + fn default() -> Self { + BinaryCrossEntropyLossConfig::new().init() + } } impl BinaryCrossEntropyLoss { - /// Compute the criterion on the input tensor. - /// - /// # Shapes - /// - /// - logits: `[batch_size]` - /// - targets: `[batch_size]` - pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { - Self::assertions(logits.clone(), targets.clone()); - let mut targets_float = targets.clone().float(); - if let Some(alpha) = self.smoothing { - targets_float = targets_float * (1. - alpha) + alpha / 2.; + /// Compute the criterion on the input tensor. + /// + /// # Shapes + /// + /// - logits: `[batch_size]` + /// - targets: `[batch_size]` + pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { + Self::assertions(logits.clone(), targets.clone()); + let mut targets_float = targets.clone().float(); + if let Some(alpha) = self.smoothing { + targets_float = targets_float * (1. - alpha) + alpha / 2.; + } + let logits = if self.logits { sigmoid(logits) } else { logits }; + let loss = targets_float.clone() * logits.clone().log() + + (targets_float.clone().neg() + 1.) * (logits.neg() + 1.).log(); + + match &self.weights { + Some(weights) => { + let weights = weights.clone().gather(0, targets); + let loss = loss * weights.clone(); + loss.neg().sum() / weights.sum() + } + None => loss.mean().neg(), + } } - let logits = if self.logits { sigmoid(logits) } else { logits }; - let loss = targets_float.clone() * logits.clone().log() - + (targets_float.clone().neg() + 1.) * (logits.neg() + 1.).log(); - - match &self.weights { - Some(weights) => { - let weights = weights.clone().gather(0, targets); - let loss = loss * weights.clone(); - loss.neg().sum() / weights.sum() - } - None => loss.mean().neg(), + + fn assertions(logits: Tensor, targets: Tensor) { + let [logits_height] = logits.dims(); + let [targets_height] = targets.dims(); + assert!( + logits_height == targets_height, + "Shape of targets ({}) should correspond to outer shape of logits ({}).", + targets_height, + logits_height + ); } - } - - fn assertions(logits: Tensor, targets: Tensor) { - let [logits_height] = logits.dims(); - let [targets_height] = targets.dims(); - assert!( - logits_height == targets_height, - "Shape of targets ({}) should correspond to outer shape of logits ({}).", - targets_height, - logits_height - ); - } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::{activation::sigmoid, Data, Distribution}; - - #[test] - fn test_binary_cross_entropy() { - let [batch_size] = [4]; - let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); - - let loss_1 = BinaryCrossEntropyLossConfig::new() - .init() - .forward(logits.clone(), targets.clone()); - let logits = sigmoid(logits); - let loss_2 = - targets.clone().float() * logits.clone().log() + (-targets.float() + 1) * (-logits + 1).log(); - let loss_2 = loss_2.mean().neg(); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_binary_cross_entropy_with_weights() { - let [batch_size] = [4]; - let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); - let weights = [3., 7.]; - - let loss_1 = BinaryCrossEntropyLossConfig::new() - .with_weights(Some(weights)) - .init() - .forward(logits.clone(), targets.clone()); - let logits = sigmoid(logits); - let loss_2 = - targets.clone().float() * logits.clone().log() + (-targets.float() + 1) * (-logits + 1).log(); - - let loss_2 = loss_2 * Tensor::from_floats([3., 7., 3., 7.]); - let loss_2 = loss_2.neg().sum() / (3. + 3. + 7. + 7.); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_binary_cross_entropy_with_smoothing() { - let [batch_size] = [4]; - let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); - - let loss_1 = BinaryCrossEntropyLossConfig::new() - .with_smoothing(Some(0.1)) - .init() - .forward(logits.clone(), targets.clone()); - - let logits = sigmoid(logits); - let targets = targets.float() * (1. - 0.1) + 0.1 / 2.; - let loss_2 = targets.clone() * logits.clone().log() + (-targets + 1) * (-logits + 1).log(); - let loss_2 = loss_2.mean().neg(); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::{activation::sigmoid, Data, Distribution}; + + #[test] + fn test_binary_cross_entropy() { + let [batch_size] = [4]; + let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); + + let loss_1 = BinaryCrossEntropyLossConfig::new() + .init() + .forward(logits.clone(), targets.clone()); + let logits = sigmoid(logits); + let loss_2 = targets.clone().float() * logits.clone().log() + + (-targets.float() + 1) * (-logits + 1).log(); + let loss_2 = loss_2.mean().neg(); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_binary_cross_entropy_with_weights() { + let [batch_size] = [4]; + let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); + let weights = [3., 7.]; + + let loss_1 = BinaryCrossEntropyLossConfig::new() + .with_weights(Some(weights)) + .init() + .forward(logits.clone(), targets.clone()); + let logits = sigmoid(logits); + let loss_2 = targets.clone().float() * logits.clone().log() + + (-targets.float() + 1) * (-logits + 1).log(); + + let loss_2 = loss_2 * Tensor::from_floats([3., 7., 3., 7.]); + let loss_2 = loss_2.neg().sum() / (3. + 3. + 7. + 7.); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_binary_cross_entropy_with_smoothing() { + let [batch_size] = [4]; + let logits = Tensor::::random([batch_size], Distribution::Normal(0., 1.0)); + let targets = Tensor::::from_data(Data::from([0, 1, 0, 1])); + + let loss_1 = BinaryCrossEntropyLossConfig::new() + .with_smoothing(Some(0.1)) + .init() + .forward(logits.clone(), targets.clone()); + + let logits = sigmoid(logits); + let targets = targets.float() * (1. - 0.1) + 0.1 / 2.; + let loss_2 = targets.clone() * logits.clone().log() + (-targets + 1) * (-logits + 1).log(); + let loss_2 = loss_2.mean().neg(); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/loss/cross_entropy.rs b/burn-core/src/nn/loss/cross_entropy.rs index f493984965..41c5a0b338 100644 --- a/burn-core/src/nn/loss/cross_entropy.rs +++ b/burn-core/src/nn/loss/cross_entropy.rs @@ -9,377 +9,382 @@ use burn_tensor::{backend::Backend, Bool, Int, Tensor}; /// Configuration to create a [Cross-entropy loss](CrossEntropyLoss). #[derive(Config, Debug)] pub struct CrossEntropyLossConfig { - /// Create padded cross entropy. - /// - /// Prevents pad tokens from impacting loss calculation. - pad_tokens: Option>, - - /// Create weighted cross-entropy. - /// - /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, - /// - /// # Pre-conditions - /// - The order of the weight vector should correspond to the label integer assignment. - /// - Targets assigned negative Int's will not be allowed. - weights: Option>, - - /// Create cross-entropy with label smoothing. - /// - /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. - /// Alpha = 0 would be the same as default. - smoothing: Option, - - /// Create cross-entropy with probabilities as input instead of logits. - /// - #[config(default = true)] - logits: bool, + /// Create padded cross entropy. + /// + /// Prevents pad tokens from impacting loss calculation. + pad_tokens: Option>, + + /// Create weighted cross-entropy. + /// + /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, + /// + /// # Pre-conditions + /// - The order of the weight vector should correspond to the label integer assignment. + /// - Targets assigned negative Int's will not be allowed. + weights: Option>, + + /// Create cross-entropy with label smoothing. + /// + /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. + /// Alpha = 0 would be the same as default. + smoothing: Option, + + /// Create cross-entropy with probabilities as input instead of logits. + /// + #[config(default = true)] + logits: bool, } impl CrossEntropyLossConfig { - /// Initialize [Cross-entropy loss](CrossEntropyLoss). - pub fn init(&self) -> CrossEntropyLoss { - self.assertions(); - CrossEntropyLoss { - pad_tokens: self.pad_tokens.clone(), - weights: self - .weights - .as_ref() - .map(|e| Tensor::::from_floats(e.as_slice())), - smoothing: self.smoothing, - logits: self.logits, + /// Initialize [Cross-entropy loss](CrossEntropyLoss). + pub fn init(&self) -> CrossEntropyLoss { + self.assertions(); + CrossEntropyLoss { + pad_tokens: self.pad_tokens.clone(), + weights: self + .weights + .as_ref() + .map(|e| Tensor::::from_floats(e.as_slice())), + smoothing: self.smoothing, + logits: self.logits, + } } - } - fn assertions(&self) { - if let Some(alpha) = self.smoothing { - assert!( + fn assertions(&self) { + if let Some(alpha) = self.smoothing { + assert!( (0.0..=1.).contains(&alpha), "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", alpha ); - }; - if let Some(weights) = self.weights.as_ref() { - assert!( - weights.iter().all(|e| e > &0.), - "Weights of cross-entropy have to be positive." - ); + }; + if let Some(weights) = self.weights.as_ref() { + assert!( + weights.iter().all(|e| e > &0.), + "Weights of cross-entropy have to be positive." + ); + } } - } } /// Calculate the cross entropy loss from the input logits and the targets. #[derive(Module, Debug)] pub struct CrossEntropyLoss { - pad_tokens: Option>, - /// Weights for cross-entropy. - pub weights: Option>, - smoothing: Option, - logits: bool, + pad_tokens: Option>, + /// Weights for cross-entropy. + pub weights: Option>, + smoothing: Option, + logits: bool, } impl Default for CrossEntropyLoss { - fn default() -> Self { - CrossEntropyLossConfig::new().init() - } + fn default() -> Self { + CrossEntropyLossConfig::new().init() + } } impl CrossEntropyLoss { - /// For backward compatibility. - pub fn new(pad_index: Option) -> Self { - CrossEntropyLossConfig::new() - .with_pad_tokens(pad_index.map(|e| vec![e])) - .init() - } - - /// Compute the criterion on the input tensor. - /// - /// # Shapes - /// - /// - logits: `[batch_size, num_targets]` - /// - targets: `[batch_size]` - pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { - Self::assertions(logits.clone(), targets.clone()); - match self.smoothing { - Some(alpha) => self.forward_smoothed(logits, targets, alpha), - _ => self.forward_default(logits, targets), + /// For backward compatibility. + pub fn new(pad_index: Option) -> Self { + CrossEntropyLossConfig::new() + .with_pad_tokens(pad_index.map(|e| vec![e])) + .init() } - } - - fn forward_smoothed( - &self, - logits: Tensor, - targets: Tensor, - alpha: f32, - ) -> Tensor { - let mask = self.padding_mask(&targets); - let tensor = if self.logits { - log_softmax(logits, 1) - } else { - logits.log() - }; - let [batch_size, nr_classes] = tensor.dims(); - let tensor = - tensor * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha); - - match &self.weights { - Some(weights) => { + + /// Compute the criterion on the input tensor. + /// + /// # Shapes + /// + /// - logits: `[batch_size, num_targets]` + /// - targets: `[batch_size]` + pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { + Self::assertions(logits.clone(), targets.clone()); + match self.smoothing { + Some(alpha) => self.forward_smoothed(logits, targets, alpha), + _ => self.forward_default(logits, targets), + } + } + + fn forward_smoothed( + &self, + logits: Tensor, + targets: Tensor, + alpha: f32, + ) -> Tensor { + let mask = self.padding_mask(&targets); + let tensor = if self.logits { + log_softmax(logits, 1) + } else { + logits.log() + }; + let [batch_size, nr_classes] = tensor.dims(); let tensor = tensor - * weights - .clone() - .reshape([1, nr_classes]) - .repeat(0, batch_size); - let weights = weights.clone().gather(0, targets); - let tensor = Self::apply_mask_2d(tensor, mask); - tensor.sum().neg() / weights.sum() - } - None => { - let tensor = Self::apply_mask_2d(tensor, mask); - tensor.sum_dim(1).mean().neg() - } + * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha); + + match &self.weights { + Some(weights) => { + let tensor = tensor + * weights + .clone() + .reshape([1, nr_classes]) + .repeat(0, batch_size); + let weights = weights.clone().gather(0, targets); + let tensor = Self::apply_mask_2d(tensor, mask); + tensor.sum().neg() / weights.sum() + } + None => { + let tensor = Self::apply_mask_2d(tensor, mask); + tensor.sum_dim(1).mean().neg() + } + } } - } - - fn forward_default(&self, logits: Tensor, targets: Tensor) -> Tensor { - let [batch_size] = targets.dims(); - - let mask = self.padding_mask(&targets); - let tensor = log_softmax(logits, 1); - let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1])); - - match &self.weights { - Some(weights) => { - let weights = weights.clone().gather(0, targets); - let tensor = tensor.reshape([batch_size]) * weights.clone(); - let tensor = Self::apply_mask_1d(tensor, mask); - tensor.sum().neg() / weights.sum() - } - None => { - let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask); - tensor.mean().neg() - } + + fn forward_default(&self, logits: Tensor, targets: Tensor) -> Tensor { + let [batch_size] = targets.dims(); + + let mask = self.padding_mask(&targets); + let tensor = log_softmax(logits, 1); + let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1])); + + match &self.weights { + Some(weights) => { + let weights = weights.clone().gather(0, targets); + let tensor = tensor.reshape([batch_size]) * weights.clone(); + let tensor = Self::apply_mask_1d(tensor, mask); + tensor.sum().neg() / weights.sum() + } + None => { + let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask); + tensor.mean().neg() + } + } + } + + fn compute_smoothed_targets( + shape: [usize; 2], + targets: Tensor, + alpha: f32, + ) -> Tensor { + let [batch_size, nr_classes] = shape; + let device = &targets.device(); + let targets_matrix = Tensor::::zeros_device(shape, device).scatter( + 1, + targets.reshape([batch_size, 1]), + Tensor::ones_device([batch_size, 1], device), + ); + targets_matrix * (1. - alpha) + alpha / nr_classes as f32 } - } - - fn compute_smoothed_targets( - shape: [usize; 2], - targets: Tensor, - alpha: f32, - ) -> Tensor { - let [batch_size, nr_classes] = shape; - let device = &targets.device(); - let targets_matrix = Tensor::::zeros_device(shape, device).scatter( - 1, - targets.reshape([batch_size, 1]), - Tensor::ones_device([batch_size, 1], device), - ); - targets_matrix * (1. - alpha) + alpha / nr_classes as f32 - } - - fn padding_mask(&self, targets: &Tensor) -> Option> { - let mut mask = None; - if let Some(pad_tokens) = &self.pad_tokens { - let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int(); - for x in pad_tokens { - res = res + targets.clone().equal_elem(*x as i64).int(); - } - mask = Some(res.greater_elem(0)); + + fn padding_mask(&self, targets: &Tensor) -> Option> { + let mut mask = None; + if let Some(pad_tokens) = &self.pad_tokens { + let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int(); + for x in pad_tokens { + res = res + targets.clone().equal_elem(*x as i64).int(); + } + mask = Some(res.greater_elem(0)); + } + + mask } - mask - } + fn apply_mask_1d(mut tensor: Tensor, mask: Option>) -> Tensor { + if let Some(mask) = mask { + tensor = tensor.mask_fill(mask, 0); + } - fn apply_mask_1d(mut tensor: Tensor, mask: Option>) -> Tensor { - if let Some(mask) = mask { - tensor = tensor.mask_fill(mask, 0); + tensor } - tensor - } + fn apply_mask_2d(mut tensor: Tensor, mask: Option>) -> Tensor { + if let Some(mask) = mask { + let [batch_size, nr_classes] = tensor.dims(); + tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0); + } - fn apply_mask_2d(mut tensor: Tensor, mask: Option>) -> Tensor { - if let Some(mask) = mask { - let [batch_size, nr_classes] = tensor.dims(); - tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0); + tensor } - tensor - } - - fn assertions(logits: Tensor, targets: Tensor) { - let [logits_height, _] = logits.dims(); - let [targets_height] = targets.dims(); - assert!( - logits_height == targets_height, - "Shape of targets ({}) should correspond to outer shape of logits ({}).", - targets_height, - logits_height - ); - } + fn assertions(logits: Tensor, targets: Tensor) { + let [logits_height, _] = logits.dims(); + let [targets_height] = targets.dims(); + assert!( + logits_height == targets_height, + "Shape of targets ({}) should correspond to outer shape of logits ({}).", + targets_height, + logits_height + ); + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::{loss::cross_entropy_with_logits, Data, Distribution}; - - macro_rules! setup { - () => {{ - let [batch_size, num_targets] = [4, 5]; - let logits = - Tensor::::random([batch_size, num_targets], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data(Data::from([2, 0, 4, 1])); - let targets_logits = Tensor::::from_data(Data::from([ - [0.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - ])); - (logits, targets, targets_logits) - }}; - } - - macro_rules! setup_padded { - () => {{ - let [batch_size, num_targets, pad_index] = [4, 5, 1]; - let logits = - Tensor::::random([batch_size, num_targets], Distribution::Normal(0., 1.0)); - let targets = Tensor::::from_data( - Data::::from([2, 0, 4, pad_index as i64]).convert(), - ); - let targets_logits = Tensor::::from_data(Data::from([ - [0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ])); - (logits, targets, targets_logits) - }}; - } - - #[test] - fn test_cross_entropy_loss_with_weights() { - let (logits, targets, targets_logits) = setup!(); - let weights = vec![1.0, 2., 3., 4., 5.]; - let loss_1 = CrossEntropyLossConfig::new() - .with_weights(Some(weights.clone())) - .init() - .forward(logits.clone(), targets); - let tensor = log_softmax(logits, 1); - let loss_2 = tensor - * targets_logits - * Tensor::::from_floats(weights.as_slice()) - .unsqueeze() - .repeat(0, 4); - let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_with_weights_and_alpha_zero() { - let (logits, targets, _) = setup!(); - let weights = vec![1.0, 2., 3., 4., 5.]; - let loss_1 = CrossEntropyLossConfig::new() - .with_weights(Some(weights.clone())) - .init() - .forward(logits.clone(), targets.clone()); - let loss_2 = CrossEntropyLossConfig::new() - .with_weights(Some(weights.clone())) - .with_smoothing(Some(0.)) - .init() - .forward(logits.clone(), targets); - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_cross_entropy_loss() { - let (logits, targets, targets_logits) = setup!(); - let loss_1 = CrossEntropyLossConfig::new() - .init() - .forward(logits.clone(), targets); - let loss_2 = cross_entropy_with_logits(logits, targets_logits); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_alpha_equal_zero() { - let (logits, targets, _) = setup!(); - let loss_1 = CrossEntropyLossConfig::new() - .init() - .forward(logits.clone(), targets.clone()); - let loss_2 = CrossEntropyLossConfig::new() - .with_smoothing(Some(0.)) - .init() - .forward(logits, targets); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_cross_entropy_loss_with_pad_token() { - let (logits, targets, targets_logits) = setup_padded!(); - let pad_index = 1; - - let loss_1 = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![pad_index, 2])) - .init() - .forward(logits.clone(), targets); - let loss_2 = cross_entropy_with_logits(logits, targets_logits); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_with_zero_alpha_and_pad_token() { - let (logits, targets, _) = setup_padded!(); - let pad_index = 1; - - let loss_1 = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![pad_index, 2])) - .init() - .forward(logits.clone(), targets.clone()); - let loss_2 = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![pad_index, 2])) - .with_smoothing(Some(0.)) - .init() - .forward(logits.clone(), targets); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } - - #[test] - fn test_label_smoothing_target_conversion() { - let (logits, targets, _) = setup!(); - let smoothed_targets = CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05); - let targets_logits = Tensor::::from_data(Data::from([ - [0.01, 0.01, 0.96, 0.01, 0.01], - [0.96, 0.01, 0.01, 0.01, 0.01], - [0.01, 0.01, 0.01, 0.01, 0.96], - [0.01, 0.96, 0.01, 0.01, 0.01], - ])); - smoothed_targets - .into_data() - .assert_approx_eq(&targets_logits.into_data(), 3); - } - - #[test] - fn test_label_smoothing() { - let (logits, targets, _) = setup!(); - let loss_1 = CrossEntropyLossConfig::new() - .with_smoothing(Some(0.05)) - .init() - .forward(logits.clone(), targets); - let targets_logits = Tensor::::from_data(Data::from([ - [0.01, 0.01, 0.96, 0.01, 0.01], - [0.96, 0.01, 0.01, 0.01, 0.01], - [0.01, 0.01, 0.01, 0.01, 0.96], - [0.01, 0.96, 0.01, 0.01, 0.01], - ])); - - let x = log_softmax(logits, 1); - let loss_2 = (x * targets_logits).sum_dim(1).mean().neg(); - - loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); - } + use super::*; + use crate::TestBackend; + use burn_tensor::{loss::cross_entropy_with_logits, Data, Distribution}; + + macro_rules! setup { + () => {{ + let [batch_size, num_targets] = [4, 5]; + let logits = Tensor::::random( + [batch_size, num_targets], + Distribution::Normal(0., 1.0), + ); + let targets = Tensor::::from_data(Data::from([2, 0, 4, 1])); + let targets_logits = Tensor::::from_data(Data::from([ + [0.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ])); + (logits, targets, targets_logits) + }}; + } + + macro_rules! setup_padded { + () => {{ + let [batch_size, num_targets, pad_index] = [4, 5, 1]; + let logits = Tensor::::random( + [batch_size, num_targets], + Distribution::Normal(0., 1.0), + ); + let targets = Tensor::::from_data( + Data::::from([2, 0, 4, pad_index as i64]).convert(), + ); + let targets_logits = Tensor::::from_data(Data::from([ + [0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ])); + (logits, targets, targets_logits) + }}; + } + + #[test] + fn test_cross_entropy_loss_with_weights() { + let (logits, targets, targets_logits) = setup!(); + let weights = vec![1.0, 2., 3., 4., 5.]; + let loss_1 = CrossEntropyLossConfig::new() + .with_weights(Some(weights.clone())) + .init() + .forward(logits.clone(), targets); + let tensor = log_softmax(logits, 1); + let loss_2 = tensor + * targets_logits + * Tensor::::from_floats(weights.as_slice()) + .unsqueeze() + .repeat(0, 4); + let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_with_weights_and_alpha_zero() { + let (logits, targets, _) = setup!(); + let weights = vec![1.0, 2., 3., 4., 5.]; + let loss_1 = CrossEntropyLossConfig::new() + .with_weights(Some(weights.clone())) + .init() + .forward(logits.clone(), targets.clone()); + let loss_2 = CrossEntropyLossConfig::new() + .with_weights(Some(weights.clone())) + .with_smoothing(Some(0.)) + .init() + .forward(logits.clone(), targets); + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_cross_entropy_loss() { + let (logits, targets, targets_logits) = setup!(); + let loss_1 = CrossEntropyLossConfig::new() + .init() + .forward(logits.clone(), targets); + let loss_2 = cross_entropy_with_logits(logits, targets_logits); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_alpha_equal_zero() { + let (logits, targets, _) = setup!(); + let loss_1 = CrossEntropyLossConfig::new() + .init() + .forward(logits.clone(), targets.clone()); + let loss_2 = CrossEntropyLossConfig::new() + .with_smoothing(Some(0.)) + .init() + .forward(logits, targets); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_cross_entropy_loss_with_pad_token() { + let (logits, targets, targets_logits) = setup_padded!(); + let pad_index = 1; + + let loss_1 = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![pad_index, 2])) + .init() + .forward(logits.clone(), targets); + let loss_2 = cross_entropy_with_logits(logits, targets_logits); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_with_zero_alpha_and_pad_token() { + let (logits, targets, _) = setup_padded!(); + let pad_index = 1; + + let loss_1 = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![pad_index, 2])) + .init() + .forward(logits.clone(), targets.clone()); + let loss_2 = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![pad_index, 2])) + .with_smoothing(Some(0.)) + .init() + .forward(logits.clone(), targets); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } + + #[test] + fn test_label_smoothing_target_conversion() { + let (logits, targets, _) = setup!(); + let smoothed_targets = + CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05); + let targets_logits = Tensor::::from_data(Data::from([ + [0.01, 0.01, 0.96, 0.01, 0.01], + [0.96, 0.01, 0.01, 0.01, 0.01], + [0.01, 0.01, 0.01, 0.01, 0.96], + [0.01, 0.96, 0.01, 0.01, 0.01], + ])); + smoothed_targets + .into_data() + .assert_approx_eq(&targets_logits.into_data(), 3); + } + + #[test] + fn test_label_smoothing() { + let (logits, targets, _) = setup!(); + let loss_1 = CrossEntropyLossConfig::new() + .with_smoothing(Some(0.05)) + .init() + .forward(logits.clone(), targets); + let targets_logits = Tensor::::from_data(Data::from([ + [0.01, 0.01, 0.96, 0.01, 0.01], + [0.96, 0.01, 0.01, 0.01, 0.01], + [0.01, 0.01, 0.01, 0.01, 0.96], + [0.01, 0.96, 0.01, 0.01, 0.01], + ])); + + let x = log_softmax(logits, 1); + let loss_2 = (x * targets_logits).sum_dim(1).mean().neg(); + + loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/loss/mse.rs b/burn-core/src/nn/loss/mse.rs index 277bea3c36..4ab95cd7d3 100644 --- a/burn-core/src/nn/loss/mse.rs +++ b/burn-core/src/nn/loss/mse.rs @@ -6,74 +6,74 @@ use burn_tensor::{backend::Backend, Tensor}; /// Calculate the mean squared error loss from the input logits and the targets. #[derive(Clone, Debug)] pub struct MSELoss { - backend: PhantomData, + backend: PhantomData, } impl Default for MSELoss { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl MSELoss { - /// Create the criterion. - pub fn new() -> Self { - Self { - backend: PhantomData, + /// Create the criterion. + pub fn new() -> Self { + Self { + backend: PhantomData, + } } - } - /// Compute the criterion on the input tensor. - /// - /// # Shapes - /// - /// - logits: [batch_size, num_targets] - /// - targets: [batch_size, num_targets] - pub fn forward( - &self, - logits: Tensor, - targets: Tensor, - reduction: Reduction, - ) -> Tensor { - let tensor = self.forward_no_reduction(logits, targets); - match reduction { - Reduction::Mean | Reduction::Auto => tensor.mean(), - Reduction::Sum => tensor.sum(), + /// Compute the criterion on the input tensor. + /// + /// # Shapes + /// + /// - logits: [batch_size, num_targets] + /// - targets: [batch_size, num_targets] + pub fn forward( + &self, + logits: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let tensor = self.forward_no_reduction(logits, targets); + match reduction { + Reduction::Mean | Reduction::Auto => tensor.mean(), + Reduction::Sum => tensor.sum(), + } } - } - /// Compute the criterion on the input tensor without reducing. - pub fn forward_no_reduction( - &self, - logits: Tensor, - targets: Tensor, - ) -> Tensor { - logits.sub(targets).powf(2.0) - } + /// Compute the criterion on the input tensor without reducing. + pub fn forward_no_reduction( + &self, + logits: Tensor, + targets: Tensor, + ) -> Tensor { + logits.sub(targets).powf(2.0) + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - use burn_tensor::Data; + use super::*; + use crate::TestBackend; + use burn_tensor::Data; - #[test] - fn test_mse_loss() { - let logits = Tensor::::from_data(Data::from([[1.0, 2.0], [3.0, 4.0]])); + #[test] + fn test_mse_loss() { + let logits = Tensor::::from_data(Data::from([[1.0, 2.0], [3.0, 4.0]])); - let targets = Tensor::::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]])); + let targets = Tensor::::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]])); - let mse = MSELoss::new(); - let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone()); - let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto); - let loss_sum = mse.forward(logits, targets, Reduction::Sum); + let mse = MSELoss::new(); + let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone()); + let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto); + let loss_sum = mse.forward(logits, targets, Reduction::Sum); - assert_eq!( - loss_no_reduction.into_data(), - Data::from([[1.0, 1.0], [0.0, 4.0]]) - ); - assert_eq!(loss.into_data(), Data::from([1.5])); - assert_eq!(loss_sum.into_data(), Data::from([6.0])); - } + assert_eq!( + loss_no_reduction.into_data(), + Data::from([[1.0, 1.0], [0.0, 4.0]]) + ); + assert_eq!(loss.into_data(), Data::from([1.5])); + assert_eq!(loss_sum.into_data(), Data::from([6.0])); + } } diff --git a/burn-core/src/nn/loss/reduction.rs b/burn-core/src/nn/loss/reduction.rs index 32107afd26..499b171537 100644 --- a/burn-core/src/nn/loss/reduction.rs +++ b/burn-core/src/nn/loss/reduction.rs @@ -1,11 +1,11 @@ /// The reduction type for the loss. pub enum Reduction { - /// The mean of the losses will be returned. - Mean, + /// The mean of the losses will be returned. + Mean, - /// The sum of the losses will be returned. - Sum, + /// The sum of the losses will be returned. + Sum, - /// The mean of the losses will be returned. - Auto, + /// The mean of the losses will be returned. + Auto, } diff --git a/burn-core/src/nn/norm/batch.rs b/burn-core/src/nn/norm/batch.rs index bbc37b30f2..d48ff76bff 100644 --- a/burn-core/src/nn/norm/batch.rs +++ b/burn-core/src/nn/norm/batch.rs @@ -1,22 +1,22 @@ use crate as burn; use crate::{ - config::Config, - module::{Module, Param, RunningState}, - tensor::{backend::Backend, Tensor}, + config::Config, + module::{Module, Param, RunningState}, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [BatchNorm](BatchNorm) layer. #[derive(Config, Debug)] pub struct BatchNormConfig { - /// The number of features. - pub num_features: usize, - /// A value required for numerical stability. Default: 1e-5 - #[config(default = 1e-5)] - pub epsilon: f64, - /// Momentum used to update the metrics. Default: 0.1 - #[config(default = 0.1)] - pub momentum: f64, + /// The number of features. + pub num_features: usize, + /// A value required for numerical stability. Default: 1e-5 + #[config(default = 1e-5)] + pub epsilon: f64, + /// Momentum used to update the metrics. Default: 0.1 + #[config(default = 0.1)] + pub momentum: f64, } /// Applies Batch Normalization over a tensor as described in the paper [Batch Normalization](https://arxiv.org/abs/1502.03167) @@ -24,361 +24,359 @@ pub struct BatchNormConfig { /// `Y = norm(X) * γ + β` #[derive(Module, Debug)] pub struct BatchNorm { - gamma: Param>, - beta: Param>, - running_mean: RunningState>, - running_var: RunningState>, - momentum: f64, - epsilon: f64, + gamma: Param>, + beta: Param>, + running_mean: RunningState>, + running_var: RunningState>, + momentum: f64, + epsilon: f64, } impl BatchNormConfig { - /// Initialize a new [batch norm](BatchNorm) module. - pub fn init(&self) -> BatchNorm { - let gamma = Tensor::ones([self.num_features]); - let beta = Tensor::zeros([self.num_features]); - - let running_mean = Tensor::zeros([self.num_features]); - let running_var = Tensor::ones([self.num_features]); - - BatchNorm { - gamma: Param::from(gamma), - beta: Param::from(beta), - running_mean: RunningState::new(running_mean), - running_var: RunningState::new(running_var), - momentum: self.momentum, - epsilon: self.epsilon, + /// Initialize a new [batch norm](BatchNorm) module. + pub fn init(&self) -> BatchNorm { + let gamma = Tensor::ones([self.num_features]); + let beta = Tensor::zeros([self.num_features]); + + let running_mean = Tensor::zeros([self.num_features]); + let running_var = Tensor::ones([self.num_features]); + + BatchNorm { + gamma: Param::from(gamma), + beta: Param::from(beta), + running_mean: RunningState::new(running_mean), + running_var: RunningState::new(running_var), + momentum: self.momentum, + epsilon: self.epsilon, + } } - } - - /// Initialize a new [batch norm](BatchNorm) module with a [record](BatchNormRecord). - pub fn init_with( - &self, - record: BatchNormRecord, - ) -> BatchNorm { - BatchNorm { - gamma: record.gamma, - beta: record.beta, - running_mean: RunningState::from_record(record.running_mean), - running_var: RunningState::from_record(record.running_var), - momentum: self.momentum, - epsilon: self.epsilon, + + /// Initialize a new [batch norm](BatchNorm) module with a [record](BatchNormRecord). + pub fn init_with( + &self, + record: BatchNormRecord, + ) -> BatchNorm { + BatchNorm { + gamma: record.gamma, + beta: record.beta, + running_mean: RunningState::from_record(record.running_mean), + running_var: RunningState::from_record(record.running_var), + momentum: self.momentum, + epsilon: self.epsilon, + } } - } } impl BatchNorm { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[batch_size, channels, ...]` - /// - output: `[batch_size, channels, ...]` - pub fn forward(&self, input: Tensor) -> Tensor { - // Should be move to a compilation error when const generic support that kind of - // validation. https://github.com/rust-lang/rust/issues/76560 - if D + 2 != DI { - panic!("BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor", D, D+2, DI); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[batch_size, channels, ...]` + /// - output: `[batch_size, channels, ...]` + pub fn forward(&self, input: Tensor) -> Tensor { + // Should be move to a compilation error when const generic support that kind of + // validation. https://github.com/rust-lang/rust/issues/76560 + if D + 2 != DI { + panic!("BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor", D, D+2, DI); + } + + match B::ad_enabled() { + true => self.forward_train(input), + false => self.forward_inference(input), + } } - match B::ad_enabled() { - true => self.forward_train(input), - false => self.forward_inference(input), + fn forward_inference(&self, input: Tensor) -> Tensor { + let channels = input.dims()[1]; + let mean = self.running_mean.value(); + let var = self.running_var.value(); + + let mut shape = [1; DI]; + shape[1] = channels; + + self.forward_shared(input, mean.reshape(shape), var.reshape(shape)) } - } - fn forward_inference(&self, input: Tensor) -> Tensor { - let channels = input.dims()[1]; - let mean = self.running_mean.value(); - let var = self.running_var.value(); + fn forward_train(&self, input: Tensor) -> Tensor { + let dims = input.dims(); + let batch_size = dims[0]; + let channels = dims[1]; + + let mut shape_unsqueeze = [1; DI]; + let mut flatten_size = batch_size; + shape_unsqueeze[1] = channels; + + for dim in dims.iter().take(DI).skip(2) { + flatten_size *= dim; + } + + let mean = input + .clone() + .swap_dims(0, 1) + .reshape([channels, flatten_size]) + .mean_dim(1) + .reshape(shape_unsqueeze); + + let var = input + .clone() + .sub(mean.clone()) + .powf(2.0) + .swap_dims(0, 1) + .reshape([channels, flatten_size]) + .mean_dim(1) + .reshape(shape_unsqueeze); + + let running_mean = self.running_mean.value_sync(); + let running_var = self.running_var.value_sync(); + + let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add( + mean.clone() + .detach() + .mul_scalar(self.momentum) + .reshape([channels]), + ); + let running_var = running_var.mul_scalar(1.0 - self.momentum).add( + var.clone() + .detach() + .mul_scalar(self.momentum) + .reshape([channels]), + ); + + self.running_mean.update(running_mean.detach()); + self.running_var.update(running_var.detach()); + + self.forward_shared(input, mean, var) + } - let mut shape = [1; DI]; - shape[1] = channels; + fn forward_shared( + &self, + x: Tensor, + mean: Tensor, + var: Tensor, + ) -> Tensor { + let channels = x.dims()[1]; + let mut shape = [1; DI]; + shape[1] = channels; - self.forward_shared(input, mean.reshape(shape), var.reshape(shape)) - } + let std = var.add_scalar(self.epsilon).sqrt(); - fn forward_train(&self, input: Tensor) -> Tensor { - let dims = input.dims(); - let batch_size = dims[0]; - let channels = dims[1]; + let x = x.sub(mean); + let x = x.div(std); - let mut shape_unsqueeze = [1; DI]; - let mut flatten_size = batch_size; - shape_unsqueeze[1] = channels; + let x = x.mul(self.gamma.val().reshape(shape)); - for dim in dims.iter().take(DI).skip(2) { - flatten_size *= dim; + x.add(self.beta.val().reshape(shape)) } - - let mean = input - .clone() - .swap_dims(0, 1) - .reshape([channels, flatten_size]) - .mean_dim(1) - .reshape(shape_unsqueeze); - - let var = input - .clone() - .sub(mean.clone()) - .powf(2.0) - .swap_dims(0, 1) - .reshape([channels, flatten_size]) - .mean_dim(1) - .reshape(shape_unsqueeze); - - let running_mean = self.running_mean.value_sync(); - let running_var = self.running_var.value_sync(); - - let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add( - mean - .clone() - .detach() - .mul_scalar(self.momentum) - .reshape([channels]), - ); - let running_var = running_var.mul_scalar(1.0 - self.momentum).add( - var - .clone() - .detach() - .mul_scalar(self.momentum) - .reshape([channels]), - ); - - self.running_mean.update(running_mean.detach()); - self.running_var.update(running_var.detach()); - - self.forward_shared(input, mean, var) - } - - fn forward_shared( - &self, - x: Tensor, - mean: Tensor, - var: Tensor, - ) -> Tensor { - let channels = x.dims()[1]; - let mut shape = [1; DI]; - shape[1] = channels; - - let std = var.add_scalar(self.epsilon).sqrt(); - - let x = x.sub(mean); - let x = x.div(std); - - let x = x.mul(self.gamma.val().reshape(shape)); - - x.add(self.beta.val().reshape(shape)) - } } #[cfg(feature = "std")] #[cfg(test)] mod tests_1d { - use super::*; - use crate::{module::AutodiffModule, TestAutodiffBackend}; - use burn_tensor::Data; - - #[test] - fn batch_norm_forward_train() { - let module = BatchNormConfig::new(3).init::(); - - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [ - [1.1483e+00, 3.7521e-01], - [1.6272e-03, 7.5067e-01], - [1.6204e+00, -4.5168e-02], - ], - [ - [6.8856e-02, -1.5923e+00], - [-1.6318e+00, 8.7949e-01], - [-5.3368e-01, -1.0416e+00], - ], - ]), - 2, - ); - } - - #[test] - fn batch_norm_forward_inference() { - let module = BatchNormConfig::new(3).init::(); - - module.forward(input_tensor()); - let module = module.valid(); - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]], - [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]], - ]), - 2, - ); - } - - fn input_tensor() -> Tensor { - Tensor::::from_floats([ - [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]], - [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]], - ]) - } + use super::*; + use crate::{module::AutodiffModule, TestAutodiffBackend}; + use burn_tensor::Data; + + #[test] + fn batch_norm_forward_train() { + let module = BatchNormConfig::new(3).init::(); + + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [1.1483e+00, 3.7521e-01], + [1.6272e-03, 7.5067e-01], + [1.6204e+00, -4.5168e-02], + ], + [ + [6.8856e-02, -1.5923e+00], + [-1.6318e+00, 8.7949e-01], + [-5.3368e-01, -1.0416e+00], + ], + ]), + 2, + ); + } + + #[test] + fn batch_norm_forward_inference() { + let module = BatchNormConfig::new(3).init::(); + + module.forward(input_tensor()); + let module = module.valid(); + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]], + [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]], + ]), + 2, + ); + } + + fn input_tensor() -> Tensor { + Tensor::::from_floats([ + [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]], + [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]], + ]) + } } #[cfg(feature = "std")] #[cfg(test)] mod tests_2d { - use super::*; - use crate::{module::AutodiffModule, TestAutodiffBackend}; - use burn_tensor::Data; - - #[test] - fn batch_norm_forward_train() { - let module = BatchNormConfig::new(3).init::(); - - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [ - [[1.5136, 0.7506], [-1.2216, 0.1477]], - [[0.3135, 1.2252], [-0.4150, 0.6130]], - [[1.4186, 0.3372], [-1.5183, 1.5262]], - ], - [ - [[0.4483, -1.1914], [-1.2010, 0.7537]], - [[-1.6752, 1.3822], [-0.5058, -0.9381]], - [[0.0200, -0.3097], [-0.5715, -0.9026]], - ], - ]), - 2, - ); - } - - #[test] - fn batch_norm_forward_inference() { - let module = BatchNormConfig::new(3).init::(); - - module.forward(input_tensor()); - let module = module.valid(); - let output = module.forward(input_tensor()); - - output.to_data().assert_approx_eq( - &Data::from([ - [ - [[0.9538, 0.7103], [0.0808, 0.5179]], - [[0.6015, 0.8910], [0.3703, 0.6966]], - [[0.9171, 0.6912], [0.3037, 0.9395]], - ], - [ - [[0.6138, 0.0904], [0.0874, 0.7113]], - [[-0.0297, 0.9408], [0.3415, 0.2042]], - [[0.6250, 0.5561], [0.5013, 0.4323]], - ], - ]), - 2, - ); - } - - #[test] - fn batch_norm_running_mean() { - let module = BatchNormConfig::new(3).init::(); - - let _output = module.forward(input_tensor()); - - let running_mean = module.running_mean.value_sync(); - - running_mean - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([0.0499, 0.0532, 0.0656]), 2); - } - - #[test] - fn batch_norm_running_var() { - let module = BatchNormConfig::new(3).init::(); - - let _output = module.forward(input_tensor()); - - let running_var = module.running_var.value_sync(); - - running_var - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([0.9106, 0.9105, 0.9045]), 2); - } - - #[test] - fn batch_norm_running_mean_inner_module() { - let module = BatchNormConfig::new(3).init::(); - - let _output = module.forward(input_tensor()); - - let module_valid = module.valid(); - let running_mean = module_valid.running_mean.value(); - let running_mean_after = module.running_mean.value(); - - running_mean_after - .into_data() - .assert_approx_eq(&running_mean.into_data(), 3); - } - - #[test] - fn batch_norm_grads() { - let module = BatchNormConfig::new(3).init::(); - let input = input_tensor().require_grad(); - - let output = module.forward(input.clone()); - - let grads = output.backward(); - - module - .gamma - .grad(&grads) - .unwrap() - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([0.0000e+00, -5.9035e-07, -6.0011e-07]), 3); - - module - .beta - .grad(&grads) - .unwrap() - .reshape([3]) - .into_data() - .assert_approx_eq(&Data::from([8., 8., 8.]), 3); - - input.grad(&grads).unwrap().into_data().assert_approx_eq( - &Data::from([ - [ - [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], - [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]], - [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]], - ], - [ - [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], - [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]], - [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]], - ], - ]), - 4, - ); - } - - fn input_tensor() -> Tensor { - Tensor::::from_floats([ - [ - [[0.9601, 0.7277], [0.1270, 0.5441]], - [[0.6272, 0.9034], [0.4066, 0.7179]], - [[0.9378, 0.7230], [0.3544, 0.9591]], - ], - [ - [[0.6356, 0.1362], [0.1333, 0.7287]], - [[0.0249, 0.9509], [0.3791, 0.2481]], - [[0.6600, 0.5945], [0.5424, 0.4767]], - ], - ]) - } + use super::*; + use crate::{module::AutodiffModule, TestAutodiffBackend}; + use burn_tensor::Data; + + #[test] + fn batch_norm_forward_train() { + let module = BatchNormConfig::new(3).init::(); + + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [[1.5136, 0.7506], [-1.2216, 0.1477]], + [[0.3135, 1.2252], [-0.4150, 0.6130]], + [[1.4186, 0.3372], [-1.5183, 1.5262]], + ], + [ + [[0.4483, -1.1914], [-1.2010, 0.7537]], + [[-1.6752, 1.3822], [-0.5058, -0.9381]], + [[0.0200, -0.3097], [-0.5715, -0.9026]], + ], + ]), + 2, + ); + } + + #[test] + fn batch_norm_forward_inference() { + let module = BatchNormConfig::new(3).init::(); + + module.forward(input_tensor()); + let module = module.valid(); + let output = module.forward(input_tensor()); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [[0.9538, 0.7103], [0.0808, 0.5179]], + [[0.6015, 0.8910], [0.3703, 0.6966]], + [[0.9171, 0.6912], [0.3037, 0.9395]], + ], + [ + [[0.6138, 0.0904], [0.0874, 0.7113]], + [[-0.0297, 0.9408], [0.3415, 0.2042]], + [[0.6250, 0.5561], [0.5013, 0.4323]], + ], + ]), + 2, + ); + } + + #[test] + fn batch_norm_running_mean() { + let module = BatchNormConfig::new(3).init::(); + + let _output = module.forward(input_tensor()); + + let running_mean = module.running_mean.value_sync(); + + running_mean + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([0.0499, 0.0532, 0.0656]), 2); + } + + #[test] + fn batch_norm_running_var() { + let module = BatchNormConfig::new(3).init::(); + + let _output = module.forward(input_tensor()); + + let running_var = module.running_var.value_sync(); + + running_var + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([0.9106, 0.9105, 0.9045]), 2); + } + + #[test] + fn batch_norm_running_mean_inner_module() { + let module = BatchNormConfig::new(3).init::(); + + let _output = module.forward(input_tensor()); + + let module_valid = module.valid(); + let running_mean = module_valid.running_mean.value(); + let running_mean_after = module.running_mean.value(); + + running_mean_after + .into_data() + .assert_approx_eq(&running_mean.into_data(), 3); + } + + #[test] + fn batch_norm_grads() { + let module = BatchNormConfig::new(3).init::(); + let input = input_tensor().require_grad(); + + let output = module.forward(input.clone()); + + let grads = output.backward(); + + module + .gamma + .grad(&grads) + .unwrap() + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([0.0000e+00, -5.9035e-07, -6.0011e-07]), 3); + + module + .beta + .grad(&grads) + .unwrap() + .reshape([3]) + .into_data() + .assert_approx_eq(&Data::from([8., 8., 8.]), 3); + + input.grad(&grads).unwrap().into_data().assert_approx_eq( + &Data::from([ + [ + [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], + [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]], + [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]], + ], + [ + [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], + [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]], + [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]], + ], + ]), + 4, + ); + } + + fn input_tensor() -> Tensor { + Tensor::::from_floats([ + [ + [[0.9601, 0.7277], [0.1270, 0.5441]], + [[0.6272, 0.9034], [0.4066, 0.7179]], + [[0.9378, 0.7230], [0.3544, 0.9591]], + ], + [ + [[0.6356, 0.1362], [0.1333, 0.7287]], + [[0.0249, 0.9509], [0.3791, 0.2481]], + [[0.6600, 0.5945], [0.5424, 0.4767]], + ], + ]) + } } diff --git a/burn-core/src/nn/norm/layer.rs b/burn-core/src/nn/norm/layer.rs index f8a757abf8..54bed1a7ff 100644 --- a/burn-core/src/nn/norm/layer.rs +++ b/burn-core/src/nn/norm/layer.rs @@ -9,11 +9,11 @@ use crate::tensor::Tensor; /// Configuration to create a [LayerNorm](LayerNorm) layer. #[derive(Config)] pub struct LayerNormConfig { - /// The size of the input features. - pub d_model: usize, - /// A value required for numerical stability. Default: 1e-5 - #[config(default = 1e-5)] - pub epsilon: f64, + /// The size of the input features. + pub d_model: usize, + /// A value required for numerical stability. Default: 1e-5 + #[config(default = 1e-5)] + pub epsilon: f64, } /// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450). @@ -21,112 +21,112 @@ pub struct LayerNormConfig { /// `Y = norm(X) * γ + β` #[derive(Module, Debug)] pub struct LayerNorm { - gamma: Param>, - beta: Param>, - epsilon: f64, + gamma: Param>, + beta: Param>, + epsilon: f64, } impl LayerNormConfig { - /// Initialize a new [layer norm](LayerNorm) module. - pub fn init(&self) -> LayerNorm { - let gamma = Tensor::ones([self.d_model]); - let beta = Tensor::zeros([self.d_model]); - - LayerNorm { - gamma: Param::from(gamma), - beta: Param::from(beta), - epsilon: self.epsilon, + /// Initialize a new [layer norm](LayerNorm) module. + pub fn init(&self) -> LayerNorm { + let gamma = Tensor::ones([self.d_model]); + let beta = Tensor::zeros([self.d_model]); + + LayerNorm { + gamma: Param::from(gamma), + beta: Param::from(beta), + epsilon: self.epsilon, + } } - } - - /// Initialize a new [layer norm](LayerNorm) module with a [record](LayerNormRecord). - pub fn init_with(&self, record: LayerNormRecord) -> LayerNorm { - LayerNorm { - gamma: record.gamma, - beta: record.beta, - epsilon: self.epsilon, + + /// Initialize a new [layer norm](LayerNorm) module with a [record](LayerNormRecord). + pub fn init_with(&self, record: LayerNormRecord) -> LayerNorm { + LayerNorm { + gamma: record.gamma, + beta: record.beta, + epsilon: self.epsilon, + } } - } } impl LayerNorm { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any, d_model]` - /// - output: `[..., any, d_model]` - pub fn forward(&self, input: Tensor) -> Tensor { - let (var, mean) = input.clone().var_mean_bias(D - 1); - - let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon)); - - input_normalized - .mul(self.gamma.val().unsqueeze()) - .add(self.beta.val().unsqueeze()) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any, d_model]` + /// - output: `[..., any, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let (var, mean) = input.clone().var_mean_bias(D - 1); + + let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon)); + + input_normalized + .mul(self.gamma.val().unsqueeze()) + .add(self.beta.val().unsqueeze()) + } } #[cfg(test)] mod tests { - use super::*; - use burn_tensor::Data; - - #[cfg(feature = "std")] - use crate::{TestAutodiffBackend, TestBackend}; - - #[cfg(not(feature = "std"))] - use crate::TestBackend; - - #[test] - fn layer_norm_forward() { - let module = LayerNormConfig::new(10).init::(); - let input = Tensor::from_data(Data::from([[ - -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728, - ]])); - - let output = module.forward(input); - - output.to_data().assert_approx_eq( - &Data::from([[ - -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915, - ]]), - 3, - ); - } - - #[cfg(feature = "std")] - #[test] - fn layer_norm_backward() { - let module = LayerNormConfig::new(2).init::(); - let tensor_1 = - Tensor::::from_data(Data::from([[0.0, 1.0], [3.0, 4.0]])) - .require_grad(); - let tensor_2 = - Tensor::::from_data(Data::from([[6.0, 7.0], [9.0, 10.0]])) - .require_grad(); - - let x = tensor_1.clone().matmul(tensor_2.clone()); - - let output = module.forward(x); - let grads = output.backward(); - - let tensor_1_grad = tensor_1.grad(&grads).unwrap(); - let tensor_2_grad = tensor_2.grad(&grads).unwrap(); - let gamma_grad = module.gamma.grad(&grads).unwrap(); - let beta_grad = module.beta.grad(&grads).unwrap(); - - gamma_grad - .to_data() - .assert_approx_eq(&Data::from([-2.0, 2.0]), 3); - beta_grad - .to_data() - .assert_approx_eq(&Data::from([2.0, 2.0]), 3); - tensor_1_grad - .to_data() - .assert_approx_eq(&Data::zeros(tensor_1_grad.shape()), 3); - tensor_2_grad - .to_data() - .assert_approx_eq(&Data::zeros(tensor_2_grad.shape()), 3); - } + use super::*; + use burn_tensor::Data; + + #[cfg(feature = "std")] + use crate::{TestAutodiffBackend, TestBackend}; + + #[cfg(not(feature = "std"))] + use crate::TestBackend; + + #[test] + fn layer_norm_forward() { + let module = LayerNormConfig::new(10).init::(); + let input = Tensor::from_data(Data::from([[ + -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728, + ]])); + + let output = module.forward(input); + + output.to_data().assert_approx_eq( + &Data::from([[ + -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915, + ]]), + 3, + ); + } + + #[cfg(feature = "std")] + #[test] + fn layer_norm_backward() { + let module = LayerNormConfig::new(2).init::(); + let tensor_1 = + Tensor::::from_data(Data::from([[0.0, 1.0], [3.0, 4.0]])) + .require_grad(); + let tensor_2 = + Tensor::::from_data(Data::from([[6.0, 7.0], [9.0, 10.0]])) + .require_grad(); + + let x = tensor_1.clone().matmul(tensor_2.clone()); + + let output = module.forward(x); + let grads = output.backward(); + + let tensor_1_grad = tensor_1.grad(&grads).unwrap(); + let tensor_2_grad = tensor_2.grad(&grads).unwrap(); + let gamma_grad = module.gamma.grad(&grads).unwrap(); + let beta_grad = module.beta.grad(&grads).unwrap(); + + gamma_grad + .to_data() + .assert_approx_eq(&Data::from([-2.0, 2.0]), 3); + beta_grad + .to_data() + .assert_approx_eq(&Data::from([2.0, 2.0]), 3); + tensor_1_grad + .to_data() + .assert_approx_eq(&Data::zeros(tensor_1_grad.shape()), 3); + tensor_2_grad + .to_data() + .assert_approx_eq(&Data::zeros(tensor_2_grad.shape()), 3); + } } diff --git a/burn-core/src/nn/padding.rs b/burn-core/src/nn/padding.rs index 0c7decb9c8..db27bf4486 100644 --- a/burn-core/src/nn/padding.rs +++ b/burn-core/src/nn/padding.rs @@ -8,62 +8,62 @@ use crate::module::Module; /// Padding configuration for 1D operators. #[derive(Module, Config, Debug, PartialEq)] pub enum PaddingConfig1d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. - Same, - /// Same as no padding. - Valid, - /// Applies the specified amount of padding to all inputs. - Explicit(usize), + /// Dynamically calculate the amount of padding necessary to ensure that the output size will be + /// the same as the input. + Same, + /// Same as no padding. + Valid, + /// Applies the specified amount of padding to all inputs. + Explicit(usize), } impl PaddingConfig1d { - pub(crate) fn calculate_padding_1d( - &self, - length: usize, - kernel_size: usize, - stride: usize, - ) -> usize { - let same_padding = || calculate_conv_padding(kernel_size, stride, length, length); - match self { - Self::Valid => 0, - Self::Same => same_padding(), - Self::Explicit(value) => *value, + pub(crate) fn calculate_padding_1d( + &self, + length: usize, + kernel_size: usize, + stride: usize, + ) -> usize { + let same_padding = || calculate_conv_padding(kernel_size, stride, length, length); + match self { + Self::Valid => 0, + Self::Same => same_padding(), + Self::Explicit(value) => *value, + } } - } } /// Padding configuration for 2D operators. #[derive(Module, Config, Debug, PartialEq)] pub enum PaddingConfig2d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. - Same, - /// Same as no padding. - Valid, - /// Applies the specified amount of padding to all inputs. - Explicit(usize, usize), + /// Dynamically calculate the amount of padding necessary to ensure that the output size will be + /// the same as the input. + Same, + /// Same as no padding. + Valid, + /// Applies the specified amount of padding to all inputs. + Explicit(usize, usize), } impl PaddingConfig2d { - pub(crate) fn calculate_padding_2d( - &self, - height: usize, - width: usize, - kernel_size: &[usize; 2], - stride: &[usize; 2], - ) -> [usize; 2] { - let same_padding = || { - let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height); - let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width); + pub(crate) fn calculate_padding_2d( + &self, + height: usize, + width: usize, + kernel_size: &[usize; 2], + stride: &[usize; 2], + ) -> [usize; 2] { + let same_padding = || { + let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height); + let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width); - [p1, p2] - }; + [p1, p2] + }; - match self { - Self::Same => same_padding(), - Self::Valid => [0, 0], - Self::Explicit(v1, v2) => [*v1, *v2], + match self { + Self::Same => same_padding(), + Self::Valid => [0, 0], + Self::Explicit(v1, v2) => [*v1, *v2], + } } - } } diff --git a/burn-core/src/nn/pool/adaptive_avg_pool1d.rs b/burn-core/src/nn/pool/adaptive_avg_pool1d.rs index 547c2d40ed..2bd321f575 100644 --- a/burn-core/src/nn/pool/adaptive_avg_pool1d.rs +++ b/burn-core/src/nn/pool/adaptive_avg_pool1d.rs @@ -9,33 +9,33 @@ use burn_tensor::module::adaptive_avg_pool1d; /// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer. #[derive(Config)] pub struct AdaptiveAvgPool1dConfig { - /// The size of the output. - pub output_size: usize, + /// The size of the output. + pub output_size: usize, } /// Applies a 1D adaptive avg pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct AdaptiveAvgPool1d { - output_size: usize, + output_size: usize, } impl AdaptiveAvgPool1dConfig { - /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module. - pub fn init(&self) -> AdaptiveAvgPool1d { - AdaptiveAvgPool1d { - output_size: self.output_size, + /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module. + pub fn init(&self) -> AdaptiveAvgPool1d { + AdaptiveAvgPool1d { + output_size: self.output_size, + } } - } } impl AdaptiveAvgPool1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, length], - /// - output: [batch_size, channels, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - adaptive_avg_pool1d(input, self.output_size) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, length], + /// - output: [batch_size, channels, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + adaptive_avg_pool1d(input, self.output_size) + } } diff --git a/burn-core/src/nn/pool/adaptive_avg_pool2d.rs b/burn-core/src/nn/pool/adaptive_avg_pool2d.rs index b11cf6f362..1aba65648d 100644 --- a/burn-core/src/nn/pool/adaptive_avg_pool2d.rs +++ b/burn-core/src/nn/pool/adaptive_avg_pool2d.rs @@ -9,33 +9,33 @@ use burn_tensor::module::adaptive_avg_pool2d; /// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer. #[derive(Config)] pub struct AdaptiveAvgPool2dConfig { - /// The size of the output. - pub output_size: [usize; 2], + /// The size of the output. + pub output_size: [usize; 2], } /// Applies a 2D adaptive avg pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct AdaptiveAvgPool2d { - output_size: [usize; 2], + output_size: [usize; 2], } impl AdaptiveAvgPool2dConfig { - /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module. - pub fn init(&self) -> AdaptiveAvgPool2d { - AdaptiveAvgPool2d { - output_size: self.output_size, + /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module. + pub fn init(&self) -> AdaptiveAvgPool2d { + AdaptiveAvgPool2d { + output_size: self.output_size, + } } - } } impl AdaptiveAvgPool2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, height_in, width_in], - /// - output: [batch_size, channels, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - adaptive_avg_pool2d(input, self.output_size) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, height_in, width_in], + /// - output: [batch_size, channels, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + adaptive_avg_pool2d(input, self.output_size) + } } diff --git a/burn-core/src/nn/pool/avg_pool1d.rs b/burn-core/src/nn/pool/avg_pool1d.rs index d4e912d49d..50596b5ba3 100644 --- a/burn-core/src/nn/pool/avg_pool1d.rs +++ b/burn-core/src/nn/pool/avg_pool1d.rs @@ -10,17 +10,17 @@ use burn_tensor::module::avg_pool1d; /// Configuration to create a [1D avg pooling](AvgPool1d) layer. #[derive(Config)] pub struct AvgPool1dConfig { - /// The size of the kernel. - pub kernel_size: usize, - /// The stride. - #[config(default = "1")] - pub stride: usize, - /// The padding configuration. - #[config(default = "PaddingConfig1d::Valid")] - pub padding: PaddingConfig1d, - /// If the padding is counted in the denominator when computing the average. - #[config(default = "true")] - count_include_pad: bool, + /// The size of the kernel. + pub kernel_size: usize, + /// The stride. + #[config(default = "1")] + pub stride: usize, + /// The padding configuration. + #[config(default = "PaddingConfig1d::Valid")] + pub padding: PaddingConfig1d, + /// If the padding is counted in the denominator when computing the average. + #[config(default = "true")] + count_include_pad: bool, } /// Applies a 1D avg pooling over input tensors. @@ -40,43 +40,43 @@ pub struct AvgPool1dConfig { #[derive(Module, Debug, Clone)] pub struct AvgPool1d { - stride: usize, - kernel_size: usize, - padding: PaddingConfig1d, - count_include_pad: bool, + stride: usize, + kernel_size: usize, + padding: PaddingConfig1d, + count_include_pad: bool, } impl AvgPool1dConfig { - /// Initialize a new [avg pool 1d](AvgPool1d) module. - pub fn init(&self) -> AvgPool1d { - AvgPool1d { - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - count_include_pad: self.count_include_pad, + /// Initialize a new [avg pool 1d](AvgPool1d) module. + pub fn init(&self) -> AvgPool1d { + AvgPool1d { + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + count_include_pad: self.count_include_pad, + } } - } } impl AvgPool1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, length_in], - /// - output: [batch_size, channels, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, length_in], + /// - output: [batch_size, channels, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); - avg_pool1d( - input, - self.kernel_size, - self.stride, - padding, - self.count_include_pad, - ) - } + avg_pool1d( + input, + self.kernel_size, + self.stride, + padding, + self.count_include_pad, + ) + } } diff --git a/burn-core/src/nn/pool/avg_pool2d.rs b/burn-core/src/nn/pool/avg_pool2d.rs index a8be01e92b..f3bb2b60ec 100644 --- a/burn-core/src/nn/pool/avg_pool2d.rs +++ b/burn-core/src/nn/pool/avg_pool2d.rs @@ -10,17 +10,17 @@ use burn_tensor::module::avg_pool2d; /// Configuration to create a [2D avg pooling](AvgPool2d) layer. #[derive(Config, Debug)] pub struct AvgPool2dConfig { - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The strides. - #[config(default = "[1, 1]")] - pub strides: [usize; 2], - /// The padding configuration. - #[config(default = "PaddingConfig2d::Valid")] - pub padding: PaddingConfig2d, - /// If the padding is counted in the denominator when computing the average. - #[config(default = "true")] - count_include_pad: bool, + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The strides. + #[config(default = "[1, 1]")] + pub strides: [usize; 2], + /// The padding configuration. + #[config(default = "PaddingConfig2d::Valid")] + pub padding: PaddingConfig2d, + /// If the padding is counted in the denominator when computing the average. + #[config(default = "true")] + count_include_pad: bool, } /// Applies a 2D avg pooling over input tensors. @@ -39,44 +39,43 @@ pub struct AvgPool2dConfig { /// [Issue 636](https://github.com/burn-rs/burn/issues/636) #[derive(Module, Debug, Clone)] pub struct AvgPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: PaddingConfig2d, - count_include_pad: bool, + stride: [usize; 2], + kernel_size: [usize; 2], + padding: PaddingConfig2d, + count_include_pad: bool, } impl AvgPool2dConfig { - /// Initialize a new [avg pool 2d](AvgPool2d) module. - pub fn init(&self) -> AvgPool2d { - AvgPool2d { - stride: self.strides, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - count_include_pad: self.count_include_pad, + /// Initialize a new [avg pool 2d](AvgPool2d) module. + pub fn init(&self) -> AvgPool2d { + AvgPool2d { + stride: self.strides, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + count_include_pad: self.count_include_pad, + } } - } } impl AvgPool2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, height_in, width_in], - /// - output: [batch_size, channels, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self - .padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, height_in, width_in], + /// - output: [batch_size, channels, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = + self.padding + .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - avg_pool2d( - input, - self.kernel_size, - self.stride, - padding, - self.count_include_pad, - ) - } + avg_pool2d( + input, + self.kernel_size, + self.stride, + padding, + self.count_include_pad, + ) + } } diff --git a/burn-core/src/nn/pool/max_pool1d.rs b/burn-core/src/nn/pool/max_pool1d.rs index 7a7ef99918..ca7a2bf01c 100644 --- a/burn-core/src/nn/pool/max_pool1d.rs +++ b/burn-core/src/nn/pool/max_pool1d.rs @@ -10,53 +10,53 @@ use burn_tensor::module::max_pool1d; /// Configuration to create a [1D max pooling](MaxPool1d) layer. #[derive(Config)] pub struct MaxPool1dConfig { - /// The size of the kernel. - pub kernel_size: usize, - /// The stride. - #[config(default = "1")] - pub stride: usize, - /// The padding configuration. - #[config(default = "PaddingConfig1d::Valid")] - pub padding: PaddingConfig1d, - /// The dilation. - #[config(default = "1")] - pub dilation: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The stride. + #[config(default = "1")] + pub stride: usize, + /// The padding configuration. + #[config(default = "PaddingConfig1d::Valid")] + pub padding: PaddingConfig1d, + /// The dilation. + #[config(default = "1")] + pub dilation: usize, } /// Applies a 1D max pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct MaxPool1d { - stride: usize, - kernel_size: usize, - padding: PaddingConfig1d, - dilation: usize, + stride: usize, + kernel_size: usize, + padding: PaddingConfig1d, + dilation: usize, } impl MaxPool1dConfig { - /// Initialize a new [max pool 1d](MaxPool1d) module. - pub fn init(&self) -> MaxPool1d { - MaxPool1d { - stride: self.stride, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, + /// Initialize a new [max pool 1d](MaxPool1d) module. + pub fn init(&self) -> MaxPool1d { + MaxPool1d { + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, + } } - } } impl MaxPool1d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, length_in], - /// - output: [batch_size, channels, length_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, length_in], + /// - output: [batch_size, channels, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); - max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation) - } + max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation) + } } diff --git a/burn-core/src/nn/pool/max_pool2d.rs b/burn-core/src/nn/pool/max_pool2d.rs index 33ccd24ba7..f1aebc19ed 100644 --- a/burn-core/src/nn/pool/max_pool2d.rs +++ b/burn-core/src/nn/pool/max_pool2d.rs @@ -10,54 +10,53 @@ use burn_tensor::module::max_pool2d; /// Configuration to create an [2D max pooling](MaxPool2d) layer. #[derive(Debug, Config)] pub struct MaxPool2dConfig { - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The strides. - #[config(default = "[1, 1]")] - pub strides: [usize; 2], - /// The padding configuration. - #[config(default = "PaddingConfig2d::Valid")] - pub padding: PaddingConfig2d, - /// The dilation. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The strides. + #[config(default = "[1, 1]")] + pub strides: [usize; 2], + /// The padding configuration. + #[config(default = "PaddingConfig2d::Valid")] + pub padding: PaddingConfig2d, + /// The dilation. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], } /// Applies a 2D max pooling over input tensors. #[derive(Module, Debug, Clone)] pub struct MaxPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: PaddingConfig2d, - dilation: [usize; 2], + stride: [usize; 2], + kernel_size: [usize; 2], + padding: PaddingConfig2d, + dilation: [usize; 2], } impl MaxPool2dConfig { - /// Initialize a new [max pool 2d](MaxPool2d) module. - pub fn init(&self) -> MaxPool2d { - MaxPool2d { - stride: self.strides, - kernel_size: self.kernel_size, - padding: self.padding.clone(), - dilation: self.dilation, + /// Initialize a new [max pool 2d](MaxPool2d) module. + pub fn init(&self) -> MaxPool2d { + MaxPool2d { + stride: self.strides, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + dilation: self.dilation, + } } - } } impl MaxPool2d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: [batch_size, channels, height_in, width_in], - /// - output: [batch_size, channels, height_out, width_out], - pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self - .padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, height_in, width_in], + /// - output: [batch_size, channels, height_out, width_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = + self.padding + .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation) - } + max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation) + } } diff --git a/burn-core/src/nn/pos_encoding.rs b/burn-core/src/nn/pos_encoding.rs index 45d7b2900a..e4a935ac96 100644 --- a/burn-core/src/nn/pos_encoding.rs +++ b/burn-core/src/nn/pos_encoding.rs @@ -12,16 +12,16 @@ use libm::{cosf, expf, logf, sinf}; /// Configuration to create an [PositionalEncoding](PositionalEncoding) layer. #[derive(Config)] pub struct PositionalEncodingConfig { - /// Maximum sequence size to use. - #[config(default = "5_000")] - max_sequence_size: usize, + /// Maximum sequence size to use. + #[config(default = "5_000")] + max_sequence_size: usize, - /// The size of each vector. - d_model: usize, + /// The size of each vector. + d_model: usize, - /// Max time scale to use. - #[config(default = "10_000")] - max_timescale: usize, + /// Max time scale to use. + #[config(default = "10_000")] + max_timescale: usize, } /// Positional encoding layer for transformer models. @@ -38,55 +38,55 @@ pub struct PositionalEncodingConfig { /// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) #[derive(Module, Debug)] pub struct PositionalEncoding { - sinusoids: Tensor, + sinusoids: Tensor, } impl PositionalEncodingConfig { - /// Initialize a new [PositionalEncoding](PositionalEncoding) module. - pub fn init(&self) -> PositionalEncoding { - let sinusoids = - generate_sinusoids::(self.max_sequence_size, self.d_model, self.max_timescale) - .unsqueeze::<3>(); - - PositionalEncoding { sinusoids } - } + /// Initialize a new [PositionalEncoding](PositionalEncoding) module. + pub fn init(&self) -> PositionalEncoding { + let sinusoids = + generate_sinusoids::(self.max_sequence_size, self.d_model, self.max_timescale) + .unsqueeze::<3>(); + + PositionalEncoding { sinusoids } + } } impl PositionalEncoding { - /// Applies the forward pass on the input tensor by adding the sinusoids to the input. - /// - /// # Shapes - /// - /// * input: [batch_size, seq_length, d_model] - /// * output: [batch_size, seq_length, d_model] - /// - /// - /// # Panics - /// - /// * Panics if the input sequence length is greater than the maximum sequence size. - /// * Panics if the input d_model is not equal to the d_model of the sinusoids. - pub fn forward(&self, input: Tensor) -> Tensor { - let [_, seq_length, d_model_input] = input.dims(); - - let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims(); - - assert!( - max_sequence_size >= seq_length, - "max_sequence_size({}) must be greater or equal than length({seq_length})", - max_sequence_size, - ); - - assert!( - d_model_input == d_model, - "d_model({}) of the input must be equal to d_model of encoding({})", - d_model_input, - d_model, - ); - - let slices = [0..batch_size, 0..seq_length, 0..d_model]; - - input.add(self.sinusoids.clone().slice(slices)) - } + /// Applies the forward pass on the input tensor by adding the sinusoids to the input. + /// + /// # Shapes + /// + /// * input: [batch_size, seq_length, d_model] + /// * output: [batch_size, seq_length, d_model] + /// + /// + /// # Panics + /// + /// * Panics if the input sequence length is greater than the maximum sequence size. + /// * Panics if the input d_model is not equal to the d_model of the sinusoids. + pub fn forward(&self, input: Tensor) -> Tensor { + let [_, seq_length, d_model_input] = input.dims(); + + let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims(); + + assert!( + max_sequence_size >= seq_length, + "max_sequence_size({}) must be greater or equal than length({seq_length})", + max_sequence_size, + ); + + assert!( + d_model_input == d_model, + "d_model({}) of the input must be equal to d_model of encoding({})", + d_model_input, + d_model, + ); + + let slices = [0..batch_size, 0..seq_length, 0..d_model]; + + input.add(self.sinusoids.clone().slice(slices)) + } } /// Returns sinusoids for positional embedding introduced in @@ -106,124 +106,124 @@ impl PositionalEncoding { /// /// A tensor of shape [length, d_model] containing the sinusoids. pub fn generate_sinusoids( - length: usize, - d_model: usize, - max_timescale: usize, + length: usize, + d_model: usize, + max_timescale: usize, ) -> Tensor { - assert!(d_model % 2 == 0, "d_model must be even"); - assert!( - max_timescale >= length, - "max_timescale must be greater than length" - ); - - // Calculate the increment for the logarithmic timescale - let log_timescale_increment = -logf(max_timescale as f32) / d_model as f32; - - // Create a vector to hold the sinusoids - let mut scaled_time_sin_cos = Vec::with_capacity(length); - - // Loop over each position in the sequence - for i in 0..length { - // Create a vector to hold the sinusoids for this position - let mut row = Vec::with_capacity(d_model / 2); - // Loop over each dimension of the sinusoids - for k in (0..d_model).step_by(2) { - // Calculate the division term for this dimension - let div_term = expf(k as f32 * log_timescale_increment); - // Calculate the sine and cosine values for this dimension and position - row.push(sinf(div_term * i as f32)); - row.push(cosf(div_term * i as f32)); - } + assert!(d_model % 2 == 0, "d_model must be even"); + assert!( + max_timescale >= length, + "max_timescale must be greater than length" + ); - // Add the sinusoids for this position to the vector - scaled_time_sin_cos.push(row); - } + // Calculate the increment for the logarithmic timescale + let log_timescale_increment = -logf(max_timescale as f32) / d_model as f32; + + // Create a vector to hold the sinusoids + let mut scaled_time_sin_cos = Vec::with_capacity(length); + + // Loop over each position in the sequence + for i in 0..length { + // Create a vector to hold the sinusoids for this position + let mut row = Vec::with_capacity(d_model / 2); + // Loop over each dimension of the sinusoids + for k in (0..d_model).step_by(2) { + // Calculate the division term for this dimension + let div_term = expf(k as f32 * log_timescale_increment); + // Calculate the sine and cosine values for this dimension and position + row.push(sinf(div_term * i as f32)); + row.push(cosf(div_term * i as f32)); + } + + // Add the sinusoids for this position to the vector + scaled_time_sin_cos.push(row); + } - // Convert the sinusoids to a tensor and return it - let data = Data::new( - scaled_time_sin_cos.into_iter().flatten().collect(), - [length, d_model].into(), - ); + // Convert the sinusoids to a tensor and return it + let data = Data::new( + scaled_time_sin_cos.into_iter().flatten().collect(), + [length, d_model].into(), + ); - Tensor::::from_data(data.convert()) + Tensor::::from_data(data.convert()) } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - - #[test] - fn test_module() { - let d_model = 6; - let length = 3; - - // expected to broadcast - let batch_size = 2; - - let pe = PositionalEncodingConfig::new(d_model).init::(); - - // Use a tensor of zeros as input for easy verification of the output - // The output should be the sinusoids broadcasted to the input shape - let tensor = Tensor::zeros([batch_size, length, d_model]); - - let output = pe.forward(tensor); - - assert_eq!(output.shape().dims, [batch_size, length, d_model]); - - let expected = Tensor::::from_floats([ - [ - [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], - [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], - [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], - ], - [ - [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], - [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], - [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], - ], - ]); - - output.to_data().assert_approx_eq(&expected.to_data(), 5); - } - - #[test] - fn test_generate_sinusoids() { - let sinusoids = generate_sinusoids::(12, 6, 10_000); - - // The values are taken from the pytorch reference implementation - let expected = Tensor::::from_floats([ - [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], - [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], - [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], - [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998], - [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996], - [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994], - [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992], - [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989], - [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985], - [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981], - [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977], - [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972], - ]); - sinusoids.to_data().assert_approx_eq(&expected.to_data(), 5); - } - - #[test] - #[should_panic] - fn d_model_input_should_match() { - let d_model = 8; - let pe = PositionalEncodingConfig::new(d_model).init::(); - let input = Tensor::zeros([1, 5, 10]); - let _output = pe.forward(input); - } - - #[test] - #[should_panic] - fn input_length_should_be_less_than_max_len() { - let d_model = 8; - let pe = PositionalEncodingConfig::new(d_model).init::(); - let input = Tensor::zeros([1, 6_000, d_model]); - let _output = pe.forward(input); - } + use super::*; + use crate::TestBackend; + + #[test] + fn test_module() { + let d_model = 6; + let length = 3; + + // expected to broadcast + let batch_size = 2; + + let pe = PositionalEncodingConfig::new(d_model).init::(); + + // Use a tensor of zeros as input for easy verification of the output + // The output should be the sinusoids broadcasted to the input shape + let tensor = Tensor::zeros([batch_size, length, d_model]); + + let output = pe.forward(tensor); + + assert_eq!(output.shape().dims, [batch_size, length, d_model]); + + let expected = Tensor::::from_floats([ + [ + [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], + [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], + [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], + ], + [ + [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], + [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], + [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], + ], + ]); + + output.to_data().assert_approx_eq(&expected.to_data(), 5); + } + + #[test] + fn test_generate_sinusoids() { + let sinusoids = generate_sinusoids::(12, 6, 10_000); + + // The values are taken from the pytorch reference implementation + let expected = Tensor::::from_floats([ + [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], + [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], + [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], + [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998], + [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996], + [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994], + [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992], + [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989], + [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985], + [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981], + [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977], + [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972], + ]); + sinusoids.to_data().assert_approx_eq(&expected.to_data(), 5); + } + + #[test] + #[should_panic] + fn d_model_input_should_match() { + let d_model = 8; + let pe = PositionalEncodingConfig::new(d_model).init::(); + let input = Tensor::zeros([1, 5, 10]); + let _output = pe.forward(input); + } + + #[test] + #[should_panic] + fn input_length_should_be_less_than_max_len() { + let d_model = 8; + let pe = PositionalEncodingConfig::new(d_model).init::(); + let input = Tensor::zeros([1, 6_000, d_model]); + let _output = pe.forward(input); + } } diff --git a/burn-core/src/nn/relu.rs b/burn-core/src/nn/relu.rs index a84d7431e3..92e260c7ee 100644 --- a/burn-core/src/nn/relu.rs +++ b/burn-core/src/nn/relu.rs @@ -11,17 +11,17 @@ use crate::tensor::Tensor; pub struct ReLU {} impl ReLU { - /// Create the module. - pub fn new() -> Self { - Self {} - } - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[..., any]` - /// - output: `[..., any]` - pub fn forward(&self, input: Tensor) -> Tensor { - crate::tensor::activation::relu(input) - } + /// Create the module. + pub fn new() -> Self { + Self {} + } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any]` + /// - output: `[..., any]` + pub fn forward(&self, input: Tensor) -> Tensor { + crate::tensor::activation::relu(input) + } } diff --git a/burn-core/src/nn/rnn/gate_controller.rs b/burn-core/src/nn/rnn/gate_controller.rs index 6a07c2b80b..c063301ac8 100644 --- a/burn-core/src/nn/rnn/gate_controller.rs +++ b/burn-core/src/nn/rnn/gate_controller.rs @@ -15,73 +15,73 @@ use burn_tensor::backend::Backend; /// the gate's output. #[derive(Module, Debug)] pub struct GateController { - /// Represents the affine transformation applied to input vector - pub(crate) input_transform: Linear, - /// Represents the affine transformation applied to the hidden state - pub(crate) hidden_transform: Linear, + /// Represents the affine transformation applied to input vector + pub(crate) input_transform: Linear, + /// Represents the affine transformation applied to the hidden state + pub(crate) hidden_transform: Linear, } impl GateController { - /// Initialize a new [gate_controller](GateController) module. - pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self { - Self { - input_transform: LinearConfig { - d_input, - d_output, - bias, - initializer: initializer.clone(), - } - .init(), - hidden_transform: LinearConfig { - d_input: d_output, - d_output, - bias, - initializer, - } - .init(), + /// Initialize a new [gate_controller](GateController) module. + pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self { + Self { + input_transform: LinearConfig { + d_input, + d_output, + bias, + initializer: initializer.clone(), + } + .init(), + hidden_transform: LinearConfig { + d_input: d_output, + d_output, + bias, + initializer, + } + .init(), + } } - } - /// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord). - pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord) -> Self { - let l1 = LinearConfig::init_with(linear_config, record.input_transform); - let l2 = LinearConfig::init_with(linear_config, record.hidden_transform); + /// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord). + pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord) -> Self { + let l1 = LinearConfig::init_with(linear_config, record.input_transform); + let l2 = LinearConfig::init_with(linear_config, record.hidden_transform); - Self { - input_transform: l1, - hidden_transform: l2, + Self { + input_transform: l1, + hidden_transform: l2, + } } - } - /// Used to initialize a gate controller with known weight layers, - /// allowing for predictable behavior. Used only for testing in - /// lstm. - #[cfg(test)] - pub fn create_with_weights( - d_input: usize, - d_output: usize, - bias: bool, - initializer: Initializer, - input_record: crate::nn::LinearRecord, - hidden_record: crate::nn::LinearRecord, - ) -> Self { - let l1 = LinearConfig { - d_input, - d_output, - bias, - initializer: initializer.clone(), + /// Used to initialize a gate controller with known weight layers, + /// allowing for predictable behavior. Used only for testing in + /// lstm. + #[cfg(test)] + pub fn create_with_weights( + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + input_record: crate::nn::LinearRecord, + hidden_record: crate::nn::LinearRecord, + ) -> Self { + let l1 = LinearConfig { + d_input, + d_output, + bias, + initializer: initializer.clone(), + } + .init_with(input_record); + let l2 = LinearConfig { + d_input, + d_output, + bias, + initializer, + } + .init_with(hidden_record); + Self { + input_transform: l1, + hidden_transform: l2, + } } - .init_with(input_record); - let l2 = LinearConfig { - d_input, - d_output, - bias, - initializer, - } - .init_with(hidden_record); - Self { - input_transform: l1, - hidden_transform: l2, - } - } } diff --git a/burn-core/src/nn/rnn/gru.rs b/burn-core/src/nn/rnn/gru.rs index fa05c77229..bd2c4649f0 100644 --- a/burn-core/src/nn/rnn/gru.rs +++ b/burn-core/src/nn/rnn/gru.rs @@ -14,256 +14,266 @@ use super::gate_controller::GateController; /// The configuration for a [gru](Gru) module. #[derive(Config)] pub struct GruConfig { - /// The size of the input features. - pub d_input: usize, - /// The size of the hidden state. - pub d_hidden: usize, - /// If a bias should be applied during the Gru transformation. - pub bias: bool, - /// Gru initializer - #[config(default = "Initializer::XavierNormal{gain:1.0}")] - pub initializer: Initializer, + /// The size of the input features. + pub d_input: usize, + /// The size of the hidden state. + pub d_hidden: usize, + /// If a bias should be applied during the Gru transformation. + pub bias: bool, + /// Gru initializer + #[config(default = "Initializer::XavierNormal{gain:1.0}")] + pub initializer: Initializer, } /// The Gru module. This implementation is for a unidirectional, stateless, Gru. #[derive(Module, Debug)] pub struct Gru { - update_gate: GateController, - reset_gate: GateController, - new_gate: GateController, - d_hidden: usize, + update_gate: GateController, + reset_gate: GateController, + new_gate: GateController, + d_hidden: usize, } impl GruConfig { - /// Initialize a new [gru](Gru) module. - pub fn init(&self) -> Gru { - let d_output = self.d_hidden; + /// Initialize a new [gru](Gru) module. + pub fn init(&self) -> Gru { + let d_output = self.d_hidden; - let update_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let reset_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let new_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); + let update_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let reset_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let new_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); - Gru { - update_gate, - reset_gate, - new_gate, - d_hidden: self.d_hidden, + Gru { + update_gate, + reset_gate, + new_gate, + d_hidden: self.d_hidden, + } } - } - /// Initialize a new [gru](Gru) module. - pub fn init_with(self, record: GruRecord) -> Gru { - let linear_config = LinearConfig { - d_input: self.d_input, - d_output: self.d_hidden, - bias: self.bias, - initializer: self.initializer.clone(), - }; + /// Initialize a new [gru](Gru) module. + pub fn init_with(self, record: GruRecord) -> Gru { + let linear_config = LinearConfig { + d_input: self.d_input, + d_output: self.d_hidden, + bias: self.bias, + initializer: self.initializer.clone(), + }; - Gru { - update_gate: gate_controller::GateController::new_with(&linear_config, record.update_gate), - reset_gate: gate_controller::GateController::new_with(&linear_config, record.reset_gate), - new_gate: gate_controller::GateController::new_with(&linear_config, record.new_gate), - d_hidden: self.d_hidden, + Gru { + update_gate: gate_controller::GateController::new_with( + &linear_config, + record.update_gate, + ), + reset_gate: gate_controller::GateController::new_with( + &linear_config, + record.reset_gate, + ), + new_gate: gate_controller::GateController::new_with(&linear_config, record.new_gate), + d_hidden: self.d_hidden, + } } - } } impl Gru { - /// Applies the forward pass on the input tensor. This GRU implementation - /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. - /// - /// Parameters: - /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. - /// state: An optional tensor representing an initial cell state with the same dimensions - /// as batched_input. If none is provided, one will be generated. - /// - /// Returns: - /// The resulting state tensor, with shape [batch_size, sequence_length, hidden_size]. - pub fn forward(&self, batched_input: Tensor, state: Option>) -> Tensor { - let [batch_size, seq_length, _] = batched_input.shape().dims; + /// Applies the forward pass on the input tensor. This GRU implementation + /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. + /// + /// Parameters: + /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. + /// state: An optional tensor representing an initial cell state with the same dimensions + /// as batched_input. If none is provided, one will be generated. + /// + /// Returns: + /// The resulting state tensor, with shape [batch_size, sequence_length, hidden_size]. + pub fn forward( + &self, + batched_input: Tensor, + state: Option>, + ) -> Tensor { + let [batch_size, seq_length, _] = batched_input.shape().dims; - let mut hidden_state = match state { - Some(state) => state, - None => Tensor::zeros([batch_size, seq_length, self.d_hidden]), - }; + let mut hidden_state = match state { + Some(state) => state, + None => Tensor::zeros([batch_size, seq_length, self.d_hidden]), + }; - for (t, (input_t, hidden_t)) in batched_input - .iter_dim(1) - .zip(hidden_state.clone().iter_dim(1)) - .enumerate() - { - let input_t = input_t.squeeze(1); - let hidden_t = hidden_t.squeeze(1); - // u(pdate)g(ate) tensors - let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); - let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) + for (t, (input_t, hidden_t)) in batched_input + .iter_dim(1) + .zip(hidden_state.clone().iter_dim(1)) + .enumerate() + { + let input_t = input_t.squeeze(1); + let hidden_t = hidden_t.squeeze(1); + // u(pdate)g(ate) tensors + let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); + let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) - // r(eset)g(ate) tensors - let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); - let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) - let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate + // r(eset)g(ate) tensors + let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); + let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) + let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate - // n(ew)g(ate) tensor - let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); - let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) + // n(ew)g(ate) tensor + let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); + let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) - // calculate linear interpolation between previous hidden state and candidate state: - // g(t) * (1 - z(t)) + z(t) * hidden_t - let state_vector = candidate_state + // calculate linear interpolation between previous hidden state and candidate state: + // g(t) * (1 - z(t)) + z(t) * hidden_t + let state_vector = candidate_state .clone() .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1) + update_values.clone().mul(hidden_t); - let current_shape = state_vector.shape().dims; - let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; - let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); - hidden_state = hidden_state.slice_assign( - [0..batch_size, t..(t + 1), 0..self.d_hidden], - reshaped_state_vector, - ); - } + let current_shape = state_vector.shape().dims; + let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; + let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); + hidden_state = hidden_state.slice_assign( + [0..batch_size, t..(t + 1), 0..self.d_hidden], + reshaped_state_vector, + ); + } - hidden_state - } + hidden_state + } - /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. - /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: - /// Wx = weight matrix for the connection to input vector X - /// Wh = weight matrix for the connection to hidden state H - /// X = input vector - /// H = hidden state - /// b = bias terms - fn gate_product( - &self, - input: &Tensor, - hidden: &Tensor, - gate: &GateController, - ) -> Tensor { - let input_product = input.clone().matmul(gate.input_transform.weight.val()); - let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); + /// Helper function for performing weighted matrix product for a gate and adds + /// bias, if any. + /// + /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Wx = weight matrix for the connection to input vector X + /// Wh = weight matrix for the connection to hidden state H + /// X = input vector + /// H = hidden state + /// b = bias terms + fn gate_product( + &self, + input: &Tensor, + hidden: &Tensor, + gate: &GateController, + ) -> Tensor { + let input_product = input.clone().matmul(gate.input_transform.weight.val()); + let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); - let input_bias = gate - .input_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - let hidden_bias = gate - .hidden_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); + let input_bias = gate + .input_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + let hidden_bias = gate + .hidden_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { - input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() - } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, + match (input_bias, hidden_bias) { + (Some(input_bias), Some(hidden_bias)) => { + input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() + } + (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, + (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), + (None, None) => input_product + hidden_product, + } } - } } #[cfg(test)] mod tests { - use super::*; - use crate::{module::Param, nn::LinearRecord, TestBackend}; - use burn_tensor::{Data, Distribution}; + use super::*; + use crate::{module::Param, nn::LinearRecord, TestBackend}; + use burn_tensor::{Data, Distribution}; - /// Test forward pass with simple input vector. - /// - /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 - /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 - /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 - /// - /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 - #[test] - fn tests_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = GruConfig::new(1, 1, false); - let mut gru = config.init::(); + /// Test forward pass with simple input vector. + /// + /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 + /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 + /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 + /// + /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 + #[test] + fn tests_forward_single_input_single_feature() { + TestBackend::seed(0); + let config = GruConfig::new(1, 1, false); + let mut gru = config.init::(); - fn create_gate_controller( - weights: f32, - biases: f32, - d_input: usize, - d_output: usize, - bias: bool, - initializer: Initializer, - ) -> GateController { - let record = LinearRecord { - weight: Param::from(Tensor::from_data(Data::from([[weights]]))), - bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), - }; - gate_controller::GateController::create_with_weights( - d_input, - d_output, - bias, - initializer, - record.clone(), - record, - ) - } + fn create_gate_controller( + weights: f32, + biases: f32, + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + ) -> GateController { + let record = LinearRecord { + weight: Param::from(Tensor::from_data(Data::from([[weights]]))), + bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), + }; + gate_controller::GateController::create_with_weights( + d_input, + d_output, + bias, + initializer, + record.clone(), + record, + ) + } - gru.update_gate = create_gate_controller( - 0.5, - 0.0, - 1, - 1, - false, - Initializer::XavierNormal { gain: 1.0 }, - ); - gru.reset_gate = create_gate_controller( - 0.6, - 0.0, - 1, - 1, - false, - Initializer::XavierNormal { gain: 1.0 }, - ); - gru.new_gate = create_gate_controller( - 0.7, - 0.0, - 1, - 1, - false, - Initializer::XavierNormal { gain: 1.0 }, - ); + gru.update_gate = create_gate_controller( + 0.5, + 0.0, + 1, + 1, + false, + Initializer::XavierNormal { gain: 1.0 }, + ); + gru.reset_gate = create_gate_controller( + 0.6, + 0.0, + 1, + 1, + false, + Initializer::XavierNormal { gain: 1.0 }, + ); + gru.new_gate = create_gate_controller( + 0.7, + 0.0, + 1, + 1, + false, + Initializer::XavierNormal { gain: 1.0 }, + ); - let input = Tensor::::from_data(Data::from([[[0.1]]])); + let input = Tensor::::from_data(Data::from([[[0.1]]])); - let state = gru.forward(input, None); + let state = gru.forward(input, None); - let output = state.select(0, Tensor::arange(0..1)).squeeze(0); + let output = state.select(0, Tensor::arange(0..1)).squeeze(0); - output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3); - } + output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3); + } - #[test] - fn test_batched_forward_pass() { - let gru = GruConfig::new(64, 1024, true).init::(); - let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); + #[test] + fn test_batched_forward_pass() { + let gru = GruConfig::new(64, 1024, true).init::(); + let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); - let hidden_state = gru.forward(batched_input, None); + let hidden_state = gru.forward(batched_input, None); - assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); - } + assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); + } } diff --git a/burn-core/src/nn/rnn/lstm.rs b/burn-core/src/nn/rnn/lstm.rs index 00b4b51e6f..b49db35642 100644 --- a/burn-core/src/nn/rnn/lstm.rs +++ b/burn-core/src/nn/rnn/lstm.rs @@ -14,314 +14,323 @@ use super::gate_controller::GateController; /// The configuration for a [lstm](Lstm) module. #[derive(Config)] pub struct LstmConfig { - /// The size of the input features. - pub d_input: usize, - /// The size of the hidden state. - pub d_hidden: usize, - /// If a bias should be applied during the Lstm transformation. - pub bias: bool, - /// Lstm initializer - #[config(default = "Initializer::XavierNormal{gain:1.0}")] - pub initializer: Initializer, + /// The size of the input features. + pub d_input: usize, + /// The size of the hidden state. + pub d_hidden: usize, + /// If a bias should be applied during the Lstm transformation. + pub bias: bool, + /// Lstm initializer + #[config(default = "Initializer::XavierNormal{gain:1.0}")] + pub initializer: Initializer, } /// The Lstm module. This implementation is for a unidirectional, stateless, Lstm. #[derive(Module, Debug)] pub struct Lstm { - input_gate: GateController, - forget_gate: GateController, - output_gate: GateController, - cell_gate: GateController, - d_hidden: usize, + input_gate: GateController, + forget_gate: GateController, + output_gate: GateController, + cell_gate: GateController, + d_hidden: usize, } impl LstmConfig { - /// Initialize a new [lstm](Lstm) module. - pub fn init(&self) -> Lstm { - let d_output = self.d_hidden; - - let input_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let forget_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let output_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - let cell_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - ); - - Lstm { - input_gate, - forget_gate, - output_gate, - cell_gate, - d_hidden: self.d_hidden, + /// Initialize a new [lstm](Lstm) module. + pub fn init(&self) -> Lstm { + let d_output = self.d_hidden; + + let input_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let forget_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let output_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let cell_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + + Lstm { + input_gate, + forget_gate, + output_gate, + cell_gate, + d_hidden: self.d_hidden, + } } - } - - /// Initialize a new [lstm](Lstm) module with a [record](LstmRecord). - pub fn init_with(&self, record: LstmRecord) -> Lstm { - let linear_config = LinearConfig { - d_input: self.d_input, - d_output: self.d_hidden, - bias: self.bias, - initializer: self.initializer.clone(), - }; - - Lstm { - input_gate: gate_controller::GateController::new_with(&linear_config, record.input_gate), - forget_gate: gate_controller::GateController::new_with(&linear_config, record.forget_gate), - output_gate: gate_controller::GateController::new_with(&linear_config, record.output_gate), - cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate), - d_hidden: self.d_hidden, + + /// Initialize a new [lstm](Lstm) module with a [record](LstmRecord). + pub fn init_with(&self, record: LstmRecord) -> Lstm { + let linear_config = LinearConfig { + d_input: self.d_input, + d_output: self.d_hidden, + bias: self.bias, + initializer: self.initializer.clone(), + }; + + Lstm { + input_gate: gate_controller::GateController::new_with( + &linear_config, + record.input_gate, + ), + forget_gate: gate_controller::GateController::new_with( + &linear_config, + record.forget_gate, + ), + output_gate: gate_controller::GateController::new_with( + &linear_config, + record.output_gate, + ), + cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate), + d_hidden: self.d_hidden, + } } - } } impl Lstm { - /// Applies the forward pass on the input tensor. This LSTM implementation - /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), - /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size]. - /// - /// Parameters: - /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. - /// state: An optional tuple of tensors representing the initial cell state and hidden state. - /// Each state tensor has shape [batch_size, hidden_size]. - /// If no initial state is provided, these tensors are initialized to zeros. - /// - /// Returns: - /// A tuple of tensors, where the first tensor represents the cell states and - /// the second tensor represents the hidden states for each sequence element. - /// Both output tensors have the shape [batch_size, sequence_length, hidden_size]. - pub fn forward( - &self, - batched_input: Tensor, - state: Option<(Tensor, Tensor)>, - ) -> (Tensor, Tensor) { - let [batch_size, seq_length, _] = batched_input.shape().dims; - let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); - let mut batched_hidden_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); - - let (mut cell_state, mut hidden_state) = match state { - Some((cell_state, hidden_state)) => (cell_state, hidden_state), - None => ( - Tensor::zeros([batch_size, self.d_hidden]), - Tensor::zeros([batch_size, self.d_hidden]), - ), - }; - - for (t, input_t) in batched_input.iter_dim(1).enumerate() { - let input_t = input_t.squeeze(1); - // f(orget)g(ate) tensors - let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); - let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state - - // i(nput)g(ate) tensors - let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate); - let add_values = activation::sigmoid(biased_ig_input_sum); - - // o(output)g(ate) tensors - let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate); - let output_values = activation::sigmoid(biased_og_input_sum); - - // c(ell)g(ate) tensors - let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate); - let candidate_cell_values = biased_cg_input_sum.tanh(); - - cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; - hidden_state = output_values * cell_state.clone().tanh(); - - let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]]; - - let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape); - let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape); - - // store the state for this timestep - batched_cell_state = batched_cell_state.slice_assign( - [0..batch_size, t..(t + 1), 0..self.d_hidden], - unsqueezed_cell_state.clone(), - ); - batched_hidden_state = batched_hidden_state.slice_assign( - [0..batch_size, t..(t + 1), 0..self.d_hidden], - unsqueezed_hidden_state.clone(), - ); + /// Applies the forward pass on the input tensor. This LSTM implementation + /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), + /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size]. + /// + /// Parameters: + /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. + /// state: An optional tuple of tensors representing the initial cell state and hidden state. + /// Each state tensor has shape [batch_size, hidden_size]. + /// If no initial state is provided, these tensors are initialized to zeros. + /// + /// Returns: + /// A tuple of tensors, where the first tensor represents the cell states and + /// the second tensor represents the hidden states for each sequence element. + /// Both output tensors have the shape [batch_size, sequence_length, hidden_size]. + pub fn forward( + &self, + batched_input: Tensor, + state: Option<(Tensor, Tensor)>, + ) -> (Tensor, Tensor) { + let [batch_size, seq_length, _] = batched_input.shape().dims; + let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); + let mut batched_hidden_state = Tensor::zeros([batch_size, seq_length, self.d_hidden]); + + let (mut cell_state, mut hidden_state) = match state { + Some((cell_state, hidden_state)) => (cell_state, hidden_state), + None => ( + Tensor::zeros([batch_size, self.d_hidden]), + Tensor::zeros([batch_size, self.d_hidden]), + ), + }; + + for (t, input_t) in batched_input.iter_dim(1).enumerate() { + let input_t = input_t.squeeze(1); + // f(orget)g(ate) tensors + let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); + let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state + + // i(nput)g(ate) tensors + let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate); + let add_values = activation::sigmoid(biased_ig_input_sum); + + // o(output)g(ate) tensors + let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate); + let output_values = activation::sigmoid(biased_og_input_sum); + + // c(ell)g(ate) tensors + let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate); + let candidate_cell_values = biased_cg_input_sum.tanh(); + + cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; + hidden_state = output_values * cell_state.clone().tanh(); + + let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]]; + + let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape); + let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape); + + // store the state for this timestep + batched_cell_state = batched_cell_state.slice_assign( + [0..batch_size, t..(t + 1), 0..self.d_hidden], + unsqueezed_cell_state.clone(), + ); + batched_hidden_state = batched_hidden_state.slice_assign( + [0..batch_size, t..(t + 1), 0..self.d_hidden], + unsqueezed_hidden_state.clone(), + ); + } + + (batched_cell_state, batched_hidden_state) } - (batched_cell_state, batched_hidden_state) - } - - /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. - /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: - /// Wx = weight matrix for the connection to input vector X - /// Wh = weight matrix for the connection to hidden state H - /// X = input vector - /// H = hidden state - /// b = bias terms - fn gate_product( - &self, - input: &Tensor, - hidden: &Tensor, - gate: &GateController, - ) -> Tensor { - let input_product = input.clone().matmul(gate.input_transform.weight.val()); - let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); - - let input_bias = gate - .input_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - let hidden_bias = gate - .hidden_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { - input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() - } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, + /// Helper function for performing weighted matrix product for a gate and adds + /// bias, if any. + /// + /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Wx = weight matrix for the connection to input vector X + /// Wh = weight matrix for the connection to hidden state H + /// X = input vector + /// H = hidden state + /// b = bias terms + fn gate_product( + &self, + input: &Tensor, + hidden: &Tensor, + gate: &GateController, + ) -> Tensor { + let input_product = input.clone().matmul(gate.input_transform.weight.val()); + let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); + + let input_bias = gate + .input_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + let hidden_bias = gate + .hidden_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + + match (input_bias, hidden_bias) { + (Some(input_bias), Some(hidden_bias)) => { + input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() + } + (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, + (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), + (None, None) => input_product + hidden_product, + } } - } } #[cfg(test)] mod tests { - use super::*; - use crate::{module::Param, nn::LinearRecord, TestBackend}; - use burn_tensor::{Data, Distribution}; - - #[test] - fn test_with_uniform_initializer() { - TestBackend::seed(0); - - let config = - LstmConfig::new(5, 5, false).with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 }); - let lstm = config.init::(); - - let gate_to_data = - |gate: GateController| gate.input_transform.weight.val().to_data(); - - gate_to_data(lstm.input_gate).assert_within_range(0..1); - gate_to_data(lstm.forget_gate).assert_within_range(0..1); - gate_to_data(lstm.output_gate).assert_within_range(0..1); - gate_to_data(lstm.cell_gate).assert_within_range(0..1); - } - - /// Test forward pass with simple input vector. - /// - /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928 - /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725 - /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723 - /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937 - - /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 - /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 - #[test] - fn test_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = LstmConfig::new(1, 1, false); - let mut lstm = config.init::(); - - fn create_gate_controller( - weights: f32, - biases: f32, - d_input: usize, - d_output: usize, - bias: bool, - initializer: Initializer, - ) -> GateController { - let record = LinearRecord { - weight: Param::from(Tensor::from_data(Data::from([[weights]]))), - bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), - }; - gate_controller::GateController::create_with_weights( - d_input, - d_output, - bias, - initializer, - record.clone(), - record, - ) + use super::*; + use crate::{module::Param, nn::LinearRecord, TestBackend}; + use burn_tensor::{Data, Distribution}; + + #[test] + fn test_with_uniform_initializer() { + TestBackend::seed(0); + + let config = LstmConfig::new(5, 5, false) + .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 }); + let lstm = config.init::(); + + let gate_to_data = + |gate: GateController| gate.input_transform.weight.val().to_data(); + + gate_to_data(lstm.input_gate).assert_within_range(0..1); + gate_to_data(lstm.forget_gate).assert_within_range(0..1); + gate_to_data(lstm.output_gate).assert_within_range(0..1); + gate_to_data(lstm.cell_gate).assert_within_range(0..1); } - lstm.input_gate = create_gate_controller( - 0.5, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - lstm.forget_gate = create_gate_controller( - 0.7, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - lstm.cell_gate = create_gate_controller( - 0.9, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - lstm.output_gate = create_gate_controller( - 1.1, - 0.0, - 1, - 1, - false, - Initializer::XavierUniform { gain: 1.0 }, - ); - - // single timestep with single feature - let input = Tensor::::from_data(Data::from([[[0.1]]])); - - let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); - let cell_state = cell_state_batch.select(0, Tensor::arange(0..1)).squeeze(0); - let hidden_state = hidden_state_batch - .select(0, Tensor::arange(0..1)) - .squeeze(0); - cell_state - .to_data() - .assert_approx_eq(&Data::from([[0.046]]), 3); - hidden_state - .to_data() - .assert_approx_eq(&Data::from([[0.024]]), 3) - } - - #[test] - fn test_batched_forward_pass() { - let lstm = LstmConfig::new(64, 1024, true).init::(); - let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); - - let (cell_state, hidden_state) = lstm.forward(batched_input, None); - - assert_eq!(cell_state.shape().dims, [8, 10, 1024]); - assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); - } + /// Test forward pass with simple input vector. + /// + /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928 + /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725 + /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723 + /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937 + + /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 + /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 + #[test] + fn test_forward_single_input_single_feature() { + TestBackend::seed(0); + let config = LstmConfig::new(1, 1, false); + let mut lstm = config.init::(); + + fn create_gate_controller( + weights: f32, + biases: f32, + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + ) -> GateController { + let record = LinearRecord { + weight: Param::from(Tensor::from_data(Data::from([[weights]]))), + bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), + }; + gate_controller::GateController::create_with_weights( + d_input, + d_output, + bias, + initializer, + record.clone(), + record, + ) + } + + lstm.input_gate = create_gate_controller( + 0.5, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + lstm.forget_gate = create_gate_controller( + 0.7, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + lstm.cell_gate = create_gate_controller( + 0.9, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + lstm.output_gate = create_gate_controller( + 1.1, + 0.0, + 1, + 1, + false, + Initializer::XavierUniform { gain: 1.0 }, + ); + + // single timestep with single feature + let input = Tensor::::from_data(Data::from([[[0.1]]])); + + let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); + let cell_state = cell_state_batch.select(0, Tensor::arange(0..1)).squeeze(0); + let hidden_state = hidden_state_batch + .select(0, Tensor::arange(0..1)) + .squeeze(0); + cell_state + .to_data() + .assert_approx_eq(&Data::from([[0.046]]), 3); + hidden_state + .to_data() + .assert_approx_eq(&Data::from([[0.024]]), 3) + } + + #[test] + fn test_batched_forward_pass() { + let lstm = LstmConfig::new(64, 1024, true).init::(); + let batched_input = Tensor::::random([8, 10, 64], Distribution::Default); + + let (cell_state, hidden_state) = lstm.forward(batched_input, None); + + assert_eq!(cell_state.shape().dims, [8, 10, 1024]); + assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); + } } diff --git a/burn-core/src/nn/transformer/decoder.rs b/burn-core/src/nn/transformer/decoder.rs index 3030fecf75..5033a86692 100644 --- a/burn-core/src/nn/transformer/decoder.rs +++ b/burn-core/src/nn/transformer/decoder.rs @@ -2,49 +2,51 @@ use alloc::vec::Vec; use burn_tensor::Bool; use crate::{ - self as burn, - nn::{attention::MhaCache, cache::TensorCache, Initializer}, + self as burn, + nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ - config::Config, - module::Module, - nn::{ - attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, - Dropout, DropoutConfig, LayerNorm, LayerNormConfig, - }, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{ + attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, + Dropout, DropoutConfig, LayerNorm, LayerNormConfig, + }, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [Transformer Decoder](TransformerDecoder) layer. #[derive(Config)] pub struct TransformerDecoderConfig { - /// The size of the model. - pub d_model: usize, - /// The size of the position-wise feed-forward network. - pub d_ff: usize, - /// The number of attention heads. - pub n_heads: usize, - /// The number of layers. - pub n_layers: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - pub dropout: f64, - /// Layer norm will be applied first instead of after the other modules. - #[config(default = false)] - pub norm_first: bool, - /// Use "quiet softmax" instead of regular softmax. - /// - /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). - /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. - /// - /// Reference: - #[config(default = false)] - pub quiet_softmax: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] - pub initializer: Initializer, + /// The size of the model. + pub d_model: usize, + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + /// The number of attention heads. + pub n_heads: usize, + /// The number of layers. + pub n_layers: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + pub dropout: f64, + /// Layer norm will be applied first instead of after the other modules. + #[config(default = false)] + pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, + /// The type of function used to initialize neural network parameters + #[config( + default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" + )] + pub initializer: Initializer, } /// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -54,400 +56,408 @@ pub struct TransformerDecoderConfig { /// - layers: transformer decoder layers with `d_model` input and output features. #[derive(Module, Debug)] pub struct TransformerDecoder { - layers: Vec>, + layers: Vec>, } impl TransformerDecoderConfig { - /// Initialize a new [Transformer Decoder](TransformerDecoder) module. - pub fn init(&self) -> TransformerDecoder { - let layers = (0..self.n_layers) - .map(|_| TransformerDecoderLayer::new(self)) - .collect::>(); - - TransformerDecoder { layers } - } - - /// Initialize a new [Transformer Decoder](TransformerDecoder) module with a record. - /// - /// # Params - /// - /// - record: the record to initialize the module with. - pub fn init_with( - &self, - record: TransformerDecoderRecord, - ) -> TransformerDecoder { - TransformerDecoder { - layers: record - .layers - .into_iter() - .map(|record| TransformerDecoderLayer::new_with(self, record)) - .collect(), + /// Initialize a new [Transformer Decoder](TransformerDecoder) module. + pub fn init(&self) -> TransformerDecoder { + let layers = (0..self.n_layers) + .map(|_| TransformerDecoderLayer::new(self)) + .collect::>(); + + TransformerDecoder { layers } + } + + /// Initialize a new [Transformer Decoder](TransformerDecoder) module with a record. + /// + /// # Params + /// + /// - record: the record to initialize the module with. + pub fn init_with( + &self, + record: TransformerDecoderRecord, + ) -> TransformerDecoder { + TransformerDecoder { + layers: record + .layers + .into_iter() + .map(|record| TransformerDecoderLayer::new_with(self, record)) + .collect(), + } } - } } /// [Transformer Decoder](TransformerDecoder) forward pass input argument. #[derive(Debug)] pub struct TransformerDecoderInput { - target: Tensor, - target_mask_pad: Option>, - target_mask_attn: Option>, - memory: Tensor, - memory_mask_pad: Option>, - memory_mask_attn: Option>, + target: Tensor, + target_mask_pad: Option>, + target_mask_attn: Option>, + memory: Tensor, + memory_mask_pad: Option>, + memory_mask_attn: Option>, } impl TransformerDecoderInput { - /// Create a [transformer decoder](TransformerDecoder) input argument. - pub fn new(target: Tensor, memory: Tensor) -> Self { - Self { - target, - target_mask_pad: None, - target_mask_attn: None, - memory, - memory_mask_pad: None, - memory_mask_attn: None, + /// Create a [transformer decoder](TransformerDecoder) input argument. + pub fn new(target: Tensor, memory: Tensor) -> Self { + Self { + target, + target_mask_pad: None, + target_mask_attn: None, + memory, + memory_mask_pad: None, + memory_mask_attn: None, + } + } + + /// Register the memory padding mask. + pub fn memory_mask_pad(mut self, mask_pad: Tensor) -> Self { + self.memory_mask_pad = Some(mask_pad); + self + } + + /// Register the memory attention mask. + pub fn memory_mask_attn(mut self, mask_attn: Tensor) -> Self { + self.memory_mask_attn = Some(mask_attn); + self + } + + /// Register the target padding mask. + pub fn target_mask_pad(mut self, mask_pad: Tensor) -> Self { + self.target_mask_pad = Some(mask_pad); + self + } + + /// Register the target attention mask. + pub fn target_mask_attn(mut self, mask_attn: Tensor) -> Self { + self.target_mask_attn = Some(mask_attn); + self } - } - - /// Register the memory padding mask. - pub fn memory_mask_pad(mut self, mask_pad: Tensor) -> Self { - self.memory_mask_pad = Some(mask_pad); - self - } - - /// Register the memory attention mask. - pub fn memory_mask_attn(mut self, mask_attn: Tensor) -> Self { - self.memory_mask_attn = Some(mask_attn); - self - } - - /// Register the target padding mask. - pub fn target_mask_pad(mut self, mask_pad: Tensor) -> Self { - self.target_mask_pad = Some(mask_pad); - self - } - - /// Register the target attention mask. - pub fn target_mask_attn(mut self, mask_attn: Tensor) -> Self { - self.target_mask_attn = Some(mask_attn); - self - } } /// [Transformer Decoder](TransformerDecoder) layer module. #[derive(Module, Debug)] pub struct TransformerDecoderLayer { - cross_attn: MultiHeadAttention, - self_attn: MultiHeadAttention, - pwff: PositionWiseFeedForward, - norm_1: LayerNorm, - norm_2: LayerNorm, - norm_3: LayerNorm, - dropout: Dropout, - norm_first: bool, + cross_attn: MultiHeadAttention, + self_attn: MultiHeadAttention, + pwff: PositionWiseFeedForward, + norm_1: LayerNorm, + norm_2: LayerNorm, + norm_3: LayerNorm, + dropout: Dropout, + norm_first: bool, } struct TransformerDecoderLayerAutoregressiveCache { - cross_attn: MhaCache, - self_attn: MhaCache, - pwff: TensorCache, - norm_1: TensorCache, - norm_2: TensorCache, - norm_3: TensorCache, + cross_attn: MhaCache, + self_attn: MhaCache, + pwff: TensorCache, + norm_1: TensorCache, + norm_2: TensorCache, + norm_3: TensorCache, } impl TransformerDecoderLayerAutoregressiveCache { - fn empty() -> Self { - Self { - cross_attn: MhaCache::autoregressive_cross_attention(), - self_attn: MhaCache::autoregressive(), - pwff: TensorCache::empty(), - norm_1: TensorCache::empty(), - norm_2: TensorCache::empty(), - norm_3: TensorCache::empty(), + fn empty() -> Self { + Self { + cross_attn: MhaCache::autoregressive_cross_attention(), + self_attn: MhaCache::autoregressive(), + pwff: TensorCache::empty(), + norm_1: TensorCache::empty(), + norm_2: TensorCache::empty(), + norm_3: TensorCache::empty(), + } } - } } /// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer. /// /// To be used during inference when decoding tokens. pub struct TransformerDecoderAutoregressiveCache { - layers: Vec>, + layers: Vec>, } impl TransformerDecoderAutoregressiveCache { - fn empty(num_layers: usize) -> Self { - Self { - layers: (0..num_layers) - .map(|_| TransformerDecoderLayerAutoregressiveCache::empty()) - .collect(), + fn empty(num_layers: usize) -> Self { + Self { + layers: (0..num_layers) + .map(|_| TransformerDecoderLayerAutoregressiveCache::empty()) + .collect(), + } } - } } impl TransformerDecoderLayer { - fn new(config: &TransformerDecoderConfig) -> Self { - let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .with_quiet_softmax(config.quiet_softmax) - .init(); - - let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .with_quiet_softmax(config.quiet_softmax) - .init(); - let norm_1 = LayerNormConfig::new(config.d_model).init(); - let norm_2 = LayerNormConfig::new(config.d_model).init(); - let norm_3 = LayerNormConfig::new(config.d_model).init(); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_dropout(config.dropout) - .init(); - - Self { - cross_attn, - self_attn, - norm_1, - norm_2, - norm_3, - pwff, - dropout, - norm_first: config.norm_first, - } - } - - fn new_with(config: &TransformerDecoderConfig, record: TransformerDecoderLayerRecord) -> Self { - let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .with_quiet_softmax(config.quiet_softmax) - .init_with(record.self_attn); - let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .with_quiet_softmax(config.quiet_softmax) - .init_with(record.cross_attn); - let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); - let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); - let norm_3 = LayerNormConfig::new(config.d_model).init_with(record.norm_3); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_dropout(config.dropout) - .init_with(record.pwff); - - Self { - cross_attn, - self_attn, - norm_1, - norm_2, - norm_3, - pwff, - dropout, - norm_first: config.norm_first, + fn new(config: &TransformerDecoderConfig) -> Self { + let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init(); + + let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init(); + let norm_1 = LayerNormConfig::new(config.d_model).init(); + let norm_2 = LayerNormConfig::new(config.d_model).init(); + let norm_3 = LayerNormConfig::new(config.d_model).init(); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_dropout(config.dropout) + .init(); + + Self { + cross_attn, + self_attn, + norm_1, + norm_2, + norm_3, + pwff, + dropout, + norm_first: config.norm_first, + } } - } - fn forward(&self, mut input: TransformerDecoderInput) -> TransformerDecoderInput { - let mut x_0 = input.target; - - if self.norm_first { - x_0 = self.norm_3.forward(x_0); + fn new_with( + config: &TransformerDecoderConfig, + record: TransformerDecoderLayerRecord, + ) -> Self { + let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init_with(record.self_attn); + let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init_with(record.cross_attn); + let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); + let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); + let norm_3 = LayerNormConfig::new(config.d_model).init_with(record.norm_3); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_dropout(config.dropout) + .init_with(record.pwff); + + Self { + cross_attn, + self_attn, + norm_1, + norm_2, + norm_3, + pwff, + dropout, + norm_first: config.norm_first, + } } - let mut self_attn_input = MhaInput::self_attn(x_0.clone()); - if let Some(mask_pad) = &input.target_mask_pad { - self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); - } - if let Some(mask_attn) = &input.target_mask_attn { - self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); + fn forward(&self, mut input: TransformerDecoderInput) -> TransformerDecoderInput { + let mut x_0 = input.target; + + if self.norm_first { + x_0 = self.norm_3.forward(x_0); + } + + let mut self_attn_input = MhaInput::self_attn(x_0.clone()); + if let Some(mask_pad) = &input.target_mask_pad { + self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); + } + if let Some(mask_attn) = &input.target_mask_attn { + self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); + } + + let x_1 = self.self_attn.forward(self_attn_input); + let x_1 = self.dropout.forward(x_1.context) + x_0; + let x_1 = self.norm_1.forward(x_1); + + let mut cross_attn_input = + MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); + if let Some(mask_pad) = &input.memory_mask_pad { + cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone()); + } + if let Some(mask_attn) = &input.memory_mask_attn { + cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone()); + } + + let x_2 = self.cross_attn.forward(cross_attn_input); + let x_2 = self.dropout.forward(x_2.context) + x_1; + let x_2 = self.norm_2.forward(x_2); + + let x_3 = self.pwff.forward(x_2.clone()); + let mut x_3 = self.dropout.forward(x_3) + x_2; + + if !self.norm_first { + x_3 = self.norm_3.forward(x_3) + } + + input.target = x_3; + input } - let x_1 = self.self_attn.forward(self_attn_input); - let x_1 = self.dropout.forward(x_1.context) + x_0; - let x_1 = self.norm_1.forward(x_1); - - let mut cross_attn_input = - MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); - if let Some(mask_pad) = &input.memory_mask_pad { - cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone()); - } - if let Some(mask_attn) = &input.memory_mask_attn { - cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone()); + fn forward_autoregressive_inference( + &self, + mut input: TransformerDecoderInput, + cache: &mut TransformerDecoderLayerAutoregressiveCache, + ) -> TransformerDecoderInput { + let mut x_0 = input.target; + + if self.norm_first { + x_0 = cache + .norm_3 + .forward_autoregressive(x_0, 1, |x| self.norm_3.forward(x)); + } + + let mut self_attn_input = MhaInput::self_attn(x_0.clone()); + if let Some(mask_pad) = &input.target_mask_pad { + self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); + } + if let Some(mask_attn) = &input.target_mask_attn { + self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); + } + + let x_1 = self + .self_attn + .forward_cache(self_attn_input, &mut cache.self_attn); + let x_1 = self.dropout.forward(x_1.context) + x_0; + let x_1 = cache + .norm_1 + .forward_autoregressive(x_1, 1, |x| self.norm_1.forward(x)); + + let mut mha_input = MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); + if let Some(mask_pad) = &input.memory_mask_pad { + mha_input = mha_input.mask_pad(mask_pad.clone()); + } + if let Some(mask_attn) = &input.memory_mask_attn { + mha_input = mha_input.mask_attn(mask_attn.clone()); + } + + let x_2 = self + .cross_attn + .forward_cache(mha_input, &mut cache.cross_attn); + let x_2 = self.dropout.forward(x_2.context) + x_1; + let x_2 = cache + .norm_2 + .forward_autoregressive(x_2, 1, |x| self.norm_2.forward(x)); + + let x_3 = cache + .pwff + .forward_autoregressive(x_2.clone(), 1, |x| self.pwff.forward(x)); + let mut x_3 = self.dropout.forward(x_3) + x_2; + + if !self.norm_first { + x_3 = cache + .norm_3 + .forward_autoregressive(x_3, 1, |x| self.norm_3.forward(x)); + } + + input.target = x_3; + input } +} - let x_2 = self.cross_attn.forward(cross_attn_input); - let x_2 = self.dropout.forward(x_2.context) + x_1; - let x_2 = self.norm_2.forward(x_2); - - let x_3 = self.pwff.forward(x_2.clone()); - let mut x_3 = self.dropout.forward(x_3) + x_2; +impl TransformerDecoder { + /// Applies the forward pass. + pub fn forward(&self, mut input: TransformerDecoderInput) -> Tensor { + for layer in self.layers.iter() { + input = layer.forward(input); + } - if !self.norm_first { - x_3 = self.norm_3.forward(x_3) + input.target } - input.target = x_3; - input - } - - fn forward_autoregressive_inference( - &self, - mut input: TransformerDecoderInput, - cache: &mut TransformerDecoderLayerAutoregressiveCache, - ) -> TransformerDecoderInput { - let mut x_0 = input.target; - - if self.norm_first { - x_0 = cache - .norm_3 - .forward_autoregressive(x_0, 1, |x| self.norm_3.forward(x)); - } + /// Applies the forward pass on the input using autoregressive cache. + pub fn forward_autoregressive_inference( + &self, + mut input: TransformerDecoderInput, + cache: &mut TransformerDecoderAutoregressiveCache, + ) -> Tensor { + for i in 0..self.layers.len() { + let layer = self.layers.get(i).unwrap(); + let cache = cache.layers.get_mut(i).unwrap(); - let mut self_attn_input = MhaInput::self_attn(x_0.clone()); - if let Some(mask_pad) = &input.target_mask_pad { - self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); - } - if let Some(mask_attn) = &input.target_mask_attn { - self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); - } + input = layer.forward_autoregressive_inference(input, cache); + } - let x_1 = self - .self_attn - .forward_cache(self_attn_input, &mut cache.self_attn); - let x_1 = self.dropout.forward(x_1.context) + x_0; - let x_1 = cache - .norm_1 - .forward_autoregressive(x_1, 1, |x| self.norm_1.forward(x)); - - let mut mha_input = MhaInput::new(x_1.clone(), input.memory.clone(), input.memory.clone()); - if let Some(mask_pad) = &input.memory_mask_pad { - mha_input = mha_input.mask_pad(mask_pad.clone()); + input.target } - if let Some(mask_attn) = &input.memory_mask_attn { - mha_input = mha_input.mask_attn(mask_attn.clone()); + /// Create an empty autoregressive cache. + pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache { + TransformerDecoderAutoregressiveCache::empty(self.layers.len()) } - - let x_2 = self - .cross_attn - .forward_cache(mha_input, &mut cache.cross_attn); - let x_2 = self.dropout.forward(x_2.context) + x_1; - let x_2 = cache - .norm_2 - .forward_autoregressive(x_2, 1, |x| self.norm_2.forward(x)); - - let x_3 = cache - .pwff - .forward_autoregressive(x_2.clone(), 1, |x| self.pwff.forward(x)); - let mut x_3 = self.dropout.forward(x_3) + x_2; - - if !self.norm_first { - x_3 = cache - .norm_3 - .forward_autoregressive(x_3, 1, |x| self.norm_3.forward(x)); - } - - input.target = x_3; - input - } } -impl TransformerDecoder { - /// Applies the forward pass. - pub fn forward(&self, mut input: TransformerDecoderInput) -> Tensor { - for layer in self.layers.iter() { - input = layer.forward(input); +#[cfg(test)] +mod tests { + use super::*; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use burn_tensor::Distribution; + + #[test] + fn test_autoregressive_norm_last() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + TestBackend::seed(0); + + test_autoregressive( + TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers) + .with_norm_first(false), + ) } - input.target - } - - /// Applies the forward pass on the input using autoregressive cache. - pub fn forward_autoregressive_inference( - &self, - mut input: TransformerDecoderInput, - cache: &mut TransformerDecoderAutoregressiveCache, - ) -> Tensor { - for i in 0..self.layers.len() { - let layer = self.layers.get(i).unwrap(); - let cache = cache.layers.get_mut(i).unwrap(); + #[test] + fn test_autoregressive_norm_first() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + TestBackend::seed(0); - input = layer.forward_autoregressive_inference(input, cache); + test_autoregressive( + TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), + ) } - input.target - } - /// Create an empty autoregressive cache. - pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache { - TransformerDecoderAutoregressiveCache::empty(self.layers.len()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; - use burn_tensor::Distribution; - - #[test] - fn test_autoregressive_norm_last() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - TestBackend::seed(0); - - test_autoregressive( - TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(false), - ) - } - - #[test] - fn test_autoregressive_norm_first() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - TestBackend::seed(0); - - test_autoregressive( - TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), - ) - } - - fn test_autoregressive(config: TransformerDecoderConfig) { - let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; - let transformer = config.init(); - - let memory = - Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); - let target = - Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); - let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device()); - let input = - TransformerDecoderInput::new(target.clone(), memory.clone()).target_mask_attn(mask_attn); - - // Normal forward using masking. - let output_1 = transformer.forward(input); - - // Forward using the autoregressive cache. - let mut output_2 = Vec::new(); - let mut cache = transformer.new_autoregressive_cache(); - - for i in 1..seq_length + 1 { - let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]); - - let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device()); - let input = - TransformerDecoderInput::new(target.clone(), memory.clone()).target_mask_attn(mask_attn); - let next_tok = transformer // Greedy sampling - .forward_autoregressive_inference(input, &mut cache) - .slice([0..batch_size, i - 1..i, 0..d_model]); - output_2.push(next_tok); + fn test_autoregressive(config: TransformerDecoderConfig) { + let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; + let transformer = config.init(); + + let memory = Tensor::::random( + [batch_size, seq_length, d_model], + Distribution::Default, + ); + let target = Tensor::::random( + [batch_size, seq_length, d_model], + Distribution::Default, + ); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device()); + let input = TransformerDecoderInput::new(target.clone(), memory.clone()) + .target_mask_attn(mask_attn); + + // Normal forward using masking. + let output_1 = transformer.forward(input); + + // Forward using the autoregressive cache. + let mut output_2 = Vec::new(); + let mut cache = transformer.new_autoregressive_cache(); + + for i in 1..seq_length + 1 { + let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]); + + let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device()); + let input = TransformerDecoderInput::new(target.clone(), memory.clone()) + .target_mask_attn(mask_attn); + let next_tok = transformer // Greedy sampling + .forward_autoregressive_inference(input, &mut cache) + .slice([0..batch_size, i - 1..i, 0..d_model]); + output_2.push(next_tok); + } + + let output_2 = Tensor::cat(output_2, 1); + + // Should produce the same tokens. + output_1 + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); } - - let output_2 = Tensor::cat(output_2, 1); - - // Should produce the same tokens. - output_1 - .into_data() - .assert_approx_eq(&output_2.into_data(), 3); - } } diff --git a/burn-core/src/nn/transformer/encoder.rs b/burn-core/src/nn/transformer/encoder.rs index 82bef97fb5..7e7b4fa7e3 100644 --- a/burn-core/src/nn/transformer/encoder.rs +++ b/burn-core/src/nn/transformer/encoder.rs @@ -2,49 +2,51 @@ use alloc::vec::Vec; use burn_tensor::Bool; use crate::{ - self as burn, - nn::{attention::MhaCache, cache::TensorCache, Initializer}, + self as burn, + nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ - config::Config, - module::Module, - nn::{ - attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, - Dropout, DropoutConfig, LayerNorm, LayerNormConfig, - }, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{ + attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, + Dropout, DropoutConfig, LayerNorm, LayerNormConfig, + }, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [Transformer Encoder](TransformerEncoder) layer. #[derive(Config)] pub struct TransformerEncoderConfig { - /// The size of the model. - pub d_model: usize, - /// The size of the position-wise feed-forward network. - pub d_ff: usize, - /// The number of attention heads. - pub n_heads: usize, - /// The number of layers. - pub n_layers: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - pub dropout: f64, - /// Layer norm will be applied first instead of after the other modules. - #[config(default = false)] - pub norm_first: bool, - /// Use "quiet softmax" instead of regular softmax. - /// - /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). - /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. - /// - /// Reference: - #[config(default = false)] - pub quiet_softmax: bool, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] - pub initializer: Initializer, + /// The size of the model. + pub d_model: usize, + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + /// The number of attention heads. + pub n_heads: usize, + /// The number of layers. + pub n_layers: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + pub dropout: f64, + /// Layer norm will be applied first instead of after the other modules. + #[config(default = false)] + pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, + /// The type of function used to initialize neural network parameters + #[config( + default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" + )] + pub initializer: Initializer, } /// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -54,334 +56,340 @@ pub struct TransformerEncoderConfig { /// - layers: transformer encoder layers with `d_model` input and output features. #[derive(Module, Debug)] pub struct TransformerEncoder { - layers: Vec>, + layers: Vec>, } /// [Transformer Encoder](TransformerEncoder) forward pass input argument. #[derive(Debug)] pub struct TransformerEncoderInput { - tensor: Tensor, - mask_pad: Option>, - mask_attn: Option>, + tensor: Tensor, + mask_pad: Option>, + mask_attn: Option>, } impl TransformerEncoderInput { - /// Create a [transformer encoder](TransformerEncoder) input argument. - pub fn new(tensor: Tensor) -> Self { - Self { - tensor, - mask_pad: None, - mask_attn: None, + /// Create a [transformer encoder](TransformerEncoder) input argument. + pub fn new(tensor: Tensor) -> Self { + Self { + tensor, + mask_pad: None, + mask_attn: None, + } + } + + /// Register the padding mask. + pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { + self.mask_pad = Some(mask_pad); + self + } + + /// Register the attention mask. + pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { + self.mask_attn = Some(mask_attn); + self } - } - - /// Register the padding mask. - pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { - self.mask_pad = Some(mask_pad); - self - } - - /// Register the attention mask. - pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { - self.mask_attn = Some(mask_attn); - self - } } impl TransformerEncoderConfig { - /// Initialize a new [transformer encoder](TransformerEncoder) module. - pub fn init(&self) -> TransformerEncoder { - let layers = (0..self.n_layers) - .map(|_| TransformerEncoderLayer::new(self)) - .collect::>(); - - TransformerEncoder { layers } - } - /// Initialize a new [transformer encoder](TransformerEncoder) module with a - /// [record](TransformerEncoderRecord). - pub fn init_with( - &self, - record: TransformerEncoderRecord, - ) -> TransformerEncoder { - TransformerEncoder { - layers: record - .layers - .into_iter() - .map(|record| TransformerEncoderLayer::new_with(self, record)) - .collect(), + /// Initialize a new [transformer encoder](TransformerEncoder) module. + pub fn init(&self) -> TransformerEncoder { + let layers = (0..self.n_layers) + .map(|_| TransformerEncoderLayer::new(self)) + .collect::>(); + + TransformerEncoder { layers } + } + /// Initialize a new [transformer encoder](TransformerEncoder) module with a + /// [record](TransformerEncoderRecord). + pub fn init_with( + &self, + record: TransformerEncoderRecord, + ) -> TransformerEncoder { + TransformerEncoder { + layers: record + .layers + .into_iter() + .map(|record| TransformerEncoderLayer::new_with(self, record)) + .collect(), + } } - } } impl TransformerEncoder { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - tensor: `[batch_size, seq_length, d_model]` - /// - output: `[batch_size, seq_length, d_model]` - pub fn forward(&self, input: TransformerEncoderInput) -> Tensor { - let mut x = input.tensor; - - for layer in self.layers.iter() { - x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone()); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - tensor: `[batch_size, seq_length, d_model]` + /// - output: `[batch_size, seq_length, d_model]` + pub fn forward(&self, input: TransformerEncoderInput) -> Tensor { + let mut x = input.tensor; + + for layer in self.layers.iter() { + x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone()); + } + + x } - - x - } - /// Applies the forward pass on the input tensor using autoregressive cache. - /// - /// # Shapes - /// - /// - tensor: `[batch_size, seq_length, d_model]` - /// - output: `[batch_size, seq_length, d_model]` - pub fn forward_autoregressive_inference( - &self, - input: TransformerEncoderInput, - cache: &mut TransformerEncoderAutoregressiveCache, - ) -> Tensor { - let mut x = input.tensor; - - for i in 0..self.layers.len() { - let layer = self.layers.get(i).unwrap(); - let cache = cache.layers.get_mut(i).unwrap(); - - x = layer.forward_autoregressive_inference( - x, - input.mask_pad.clone(), - input.mask_attn.clone(), - cache, - ); + /// Applies the forward pass on the input tensor using autoregressive cache. + /// + /// # Shapes + /// + /// - tensor: `[batch_size, seq_length, d_model]` + /// - output: `[batch_size, seq_length, d_model]` + pub fn forward_autoregressive_inference( + &self, + input: TransformerEncoderInput, + cache: &mut TransformerEncoderAutoregressiveCache, + ) -> Tensor { + let mut x = input.tensor; + + for i in 0..self.layers.len() { + let layer = self.layers.get(i).unwrap(); + let cache = cache.layers.get_mut(i).unwrap(); + + x = layer.forward_autoregressive_inference( + x, + input.mask_pad.clone(), + input.mask_attn.clone(), + cache, + ); + } + + x } - x - } - - /// Create an empty autoregressive cache. - pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache { - TransformerEncoderAutoregressiveCache::empty(self.layers.len()) - } + /// Create an empty autoregressive cache. + pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache { + TransformerEncoderAutoregressiveCache::empty(self.layers.len()) + } } /// Transformer encoder layer module. #[derive(Module, Debug)] pub struct TransformerEncoderLayer { - mha: MultiHeadAttention, - pwff: PositionWiseFeedForward, - norm_1: LayerNorm, - norm_2: LayerNorm, - dropout: Dropout, - norm_first: bool, + mha: MultiHeadAttention, + pwff: PositionWiseFeedForward, + norm_1: LayerNorm, + norm_2: LayerNorm, + dropout: Dropout, + norm_first: bool, } impl TransformerEncoderLayer { - fn new_with(config: &TransformerEncoderConfig, record: TransformerEncoderLayerRecord) -> Self { - let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .with_quiet_softmax(config.quiet_softmax) - .init_with(record.mha); - let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); - let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init_with(record.pwff); - - Self { - mha, - norm_1, - norm_2, - pwff, - dropout, - norm_first: config.norm_first, - } - } - fn new(config: &TransformerEncoderConfig) -> Self { - let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .with_quiet_softmax(config.quiet_softmax) - .init(); - let norm_1 = LayerNormConfig::new(config.d_model).init(); - let norm_2 = LayerNormConfig::new(config.d_model).init(); - let dropout = DropoutConfig::new(config.dropout).init(); - let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) - .with_initializer(config.initializer.clone()) - .with_dropout(config.dropout) - .init(); - - Self { - mha, - norm_1, - norm_2, - pwff, - dropout, - norm_first: config.norm_first, - } - } - - fn forward( - &self, - mut input: Tensor, - mask_pad: Option>, - mask_attn: Option>, - ) -> Tensor { - if self.norm_first { - input = self.norm_2.forward(input) + fn new_with( + config: &TransformerEncoderConfig, + record: TransformerEncoderLayerRecord, + ) -> Self { + let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init_with(record.mha); + let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); + let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .init_with(record.pwff); + + Self { + mha, + norm_1, + norm_2, + pwff, + dropout, + norm_first: config.norm_first, + } } - - let mut input_mhs = MhaInput::self_attn(input.clone()); - - if let Some(mask_pad) = mask_pad { - input_mhs = input_mhs.mask_pad(mask_pad); + fn new(config: &TransformerEncoderConfig) -> Self { + let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) + .init(); + let norm_1 = LayerNormConfig::new(config.d_model).init(); + let norm_2 = LayerNormConfig::new(config.d_model).init(); + let dropout = DropoutConfig::new(config.dropout).init(); + let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) + .with_initializer(config.initializer.clone()) + .with_dropout(config.dropout) + .init(); + + Self { + mha, + norm_1, + norm_2, + pwff, + dropout, + norm_first: config.norm_first, + } } - if let Some(mask_attn) = mask_attn { - input_mhs = input_mhs.mask_attn(mask_attn); - } - - let x_1 = self.mha.forward(input_mhs); - let x_1 = self.dropout.forward(x_1.context) + input; - let x_1 = self.norm_1.forward(x_1); + fn forward( + &self, + mut input: Tensor, + mask_pad: Option>, + mask_attn: Option>, + ) -> Tensor { + if self.norm_first { + input = self.norm_2.forward(input) + } - let x_2 = self.pwff.forward(x_1.clone()); - let mut x_2 = self.dropout.forward(x_2) + x_1; + let mut input_mhs = MhaInput::self_attn(input.clone()); - if !self.norm_first { - x_2 = self.norm_2.forward(x_2) - } + if let Some(mask_pad) = mask_pad { + input_mhs = input_mhs.mask_pad(mask_pad); + } - x_2 - } + if let Some(mask_attn) = mask_attn { + input_mhs = input_mhs.mask_attn(mask_attn); + } - fn forward_autoregressive_inference( - &self, - mut input: Tensor, - mask_pad: Option>, - mask_attn: Option>, - cache: &mut TransformerEncoderLayerAutoregressiveCache, - ) -> Tensor { - if self.norm_first { - input = cache - .norm_2 - .forward_autoregressive(input, 1, |input| self.norm_2.forward(input)); - } + let x_1 = self.mha.forward(input_mhs); + let x_1 = self.dropout.forward(x_1.context) + input; + let x_1 = self.norm_1.forward(x_1); - let mut input_mhs = MhaInput::self_attn(input.clone()); + let x_2 = self.pwff.forward(x_1.clone()); + let mut x_2 = self.dropout.forward(x_2) + x_1; - if let Some(mask_pad) = mask_pad { - input_mhs = input_mhs.mask_pad(mask_pad); - } + if !self.norm_first { + x_2 = self.norm_2.forward(x_2) + } - if let Some(mask_attn) = mask_attn { - input_mhs = input_mhs.mask_attn(mask_attn); + x_2 } - let x_1 = self.mha.forward_cache(input_mhs, &mut cache.mha); - let x_1 = self.dropout.forward(x_1.context) + input; - let x_1 = cache - .norm_1 - .forward_autoregressive(x_1, 1, |x_1| self.norm_1.forward(x_1)); - - let x_2 = cache - .pwff - .forward_autoregressive(x_1.clone(), 1, |x_1| self.pwff.forward(x_1)); - let mut x_2 = self.dropout.forward(x_2) + x_1; - - if !self.norm_first { - x_2 = cache - .norm_2 - .forward_autoregressive(x_2, 1, |x_2| self.norm_2.forward(x_2)); + fn forward_autoregressive_inference( + &self, + mut input: Tensor, + mask_pad: Option>, + mask_attn: Option>, + cache: &mut TransformerEncoderLayerAutoregressiveCache, + ) -> Tensor { + if self.norm_first { + input = cache + .norm_2 + .forward_autoregressive(input, 1, |input| self.norm_2.forward(input)); + } + + let mut input_mhs = MhaInput::self_attn(input.clone()); + + if let Some(mask_pad) = mask_pad { + input_mhs = input_mhs.mask_pad(mask_pad); + } + + if let Some(mask_attn) = mask_attn { + input_mhs = input_mhs.mask_attn(mask_attn); + } + + let x_1 = self.mha.forward_cache(input_mhs, &mut cache.mha); + let x_1 = self.dropout.forward(x_1.context) + input; + let x_1 = cache + .norm_1 + .forward_autoregressive(x_1, 1, |x_1| self.norm_1.forward(x_1)); + + let x_2 = cache + .pwff + .forward_autoregressive(x_1.clone(), 1, |x_1| self.pwff.forward(x_1)); + let mut x_2 = self.dropout.forward(x_2) + x_1; + + if !self.norm_first { + x_2 = cache + .norm_2 + .forward_autoregressive(x_2, 1, |x_2| self.norm_2.forward(x_2)); + } + + x_2 } - - x_2 - } } struct TransformerEncoderLayerAutoregressiveCache { - mha: MhaCache, - pwff: TensorCache, - norm_1: TensorCache, - norm_2: TensorCache, + mha: MhaCache, + pwff: TensorCache, + norm_1: TensorCache, + norm_2: TensorCache, } impl TransformerEncoderLayerAutoregressiveCache { - fn empty() -> Self { - Self { - mha: MhaCache::autoregressive(), - pwff: TensorCache::empty(), - norm_1: TensorCache::empty(), - norm_2: TensorCache::empty(), + fn empty() -> Self { + Self { + mha: MhaCache::autoregressive(), + pwff: TensorCache::empty(), + norm_1: TensorCache::empty(), + norm_2: TensorCache::empty(), + } } - } } /// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer. /// /// To be used during inference when decoding tokens. pub struct TransformerEncoderAutoregressiveCache { - layers: Vec>, + layers: Vec>, } impl TransformerEncoderAutoregressiveCache { - fn empty(num_layers: usize) -> Self { - Self { - layers: (0..num_layers) - .map(|_| TransformerEncoderLayerAutoregressiveCache::empty()) - .collect(), + fn empty(num_layers: usize) -> Self { + Self { + layers: (0..num_layers) + .map(|_| TransformerEncoderLayerAutoregressiveCache::empty()) + .collect(), + } } - } } #[cfg(test)] mod tests { - use super::*; - use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; - use burn_tensor::Distribution; - - #[test] - fn test_autoregressive_norm_last() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - test_autoregressive( - TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(false), - ) - } - - #[test] - fn test_autoregressive_norm_first() { - let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; - test_autoregressive( - TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), - ) - } - - fn test_autoregressive(config: TransformerEncoderConfig) { - let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; - let transformer = config.init(); - - let tensor = - Tensor::::random([batch_size, seq_length, d_model], Distribution::Default); - let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); - let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn); - - let output_1 = transformer.forward(input); - let mut output_2 = Vec::new(); - let mut cache = transformer.new_autoregressive_cache(); - - for i in 1..seq_length + 1 { - let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); - let input = TransformerEncoderInput::new(tensor.clone()); - let next_tok = transformer - .forward_autoregressive_inference(input, &mut cache) - .slice([0..batch_size, i - 1..i, 0..d_model]); - output_2.push(next_tok); + use super::*; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use burn_tensor::Distribution; + + #[test] + fn test_autoregressive_norm_last() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + test_autoregressive( + TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers) + .with_norm_first(false), + ) } - let output_2 = Tensor::cat(output_2, 1); + #[test] + fn test_autoregressive_norm_first() { + let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; + test_autoregressive( + TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), + ) + } - output_1 - .into_data() - .assert_approx_eq(&output_2.into_data(), 3); - } + fn test_autoregressive(config: TransformerEncoderConfig) { + let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; + let transformer = config.init(); + + let tensor = Tensor::::random( + [batch_size, seq_length, d_model], + Distribution::Default, + ); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); + let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn); + + let output_1 = transformer.forward(input); + let mut output_2 = Vec::new(); + let mut cache = transformer.new_autoregressive_cache(); + + for i in 1..seq_length + 1 { + let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); + let input = TransformerEncoderInput::new(tensor.clone()); + let next_tok = transformer + .forward_autoregressive_inference(input, &mut cache) + .slice([0..batch_size, i - 1..i, 0..d_model]); + output_2.push(next_tok); + } + + let output_2 = Tensor::cat(output_2, 1); + + output_1 + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); + } } diff --git a/burn-core/src/nn/transformer/pwff.rs b/burn-core/src/nn/transformer/pwff.rs index 207db3910f..1307ba8b80 100644 --- a/burn-core/src/nn/transformer/pwff.rs +++ b/burn-core/src/nn/transformer/pwff.rs @@ -2,25 +2,27 @@ use crate as burn; use crate::nn::Initializer; use crate::{ - config::Config, - module::Module, - nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU}, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU}, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer. #[derive(Config)] pub struct PositionWiseFeedForwardConfig { - /// The size of the input and output features. - pub d_model: usize, - /// The size of the hidden inner features. - pub d_ff: usize, - /// The dropout rate. Default: 0.1 - #[config(default = 0.1)] - pub dropout: f64, - /// The type of function used to initialize neural network parameters - #[config(default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}")] - pub initializer: Initializer, + /// The size of the input and output features. + pub d_model: usize, + /// The size of the hidden inner features. + pub d_ff: usize, + /// The dropout rate. Default: 0.1 + #[config(default = 0.1)] + pub dropout: f64, + /// The type of function used to initialize neural network parameters + #[config( + default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" + )] + pub initializer: Initializer, } /// Applies the position-wise feed-forward network to the input tensor. @@ -31,53 +33,53 @@ pub struct PositionWiseFeedForwardConfig { /// - linear outer: Linear layer with `d_ff` input features and `d_model` output features. #[derive(Module, Debug)] pub struct PositionWiseFeedForward { - linear_inner: Linear, - linear_outer: Linear, - dropout: Dropout, - gelu: GELU, + linear_inner: Linear, + linear_outer: Linear, + dropout: Dropout, + gelu: GELU, } impl PositionWiseFeedForwardConfig { - /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module. - pub fn init(&self) -> PositionWiseFeedForward { - PositionWiseFeedForward { - linear_inner: LinearConfig::new(self.d_model, self.d_ff) - .with_initializer(self.initializer.clone()) - .init(), - linear_outer: LinearConfig::new(self.d_ff, self.d_model) - .with_initializer(self.initializer.clone()) - .init(), - dropout: DropoutConfig::new(self.dropout).init(), - gelu: GELU::new(), + /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module. + pub fn init(&self) -> PositionWiseFeedForward { + PositionWiseFeedForward { + linear_inner: LinearConfig::new(self.d_model, self.d_ff) + .with_initializer(self.initializer.clone()) + .init(), + linear_outer: LinearConfig::new(self.d_ff, self.d_model) + .with_initializer(self.initializer.clone()) + .init(), + dropout: DropoutConfig::new(self.dropout).init(), + gelu: GELU::new(), + } } - } - /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module with a - /// [record](PositionWiseFeedForwardRecord). - pub fn init_with( - &self, - record: PositionWiseFeedForwardRecord, - ) -> PositionWiseFeedForward { - PositionWiseFeedForward { - linear_inner: LinearConfig::new(self.d_model, self.d_ff).init_with(record.linear_inner), - linear_outer: LinearConfig::new(self.d_ff, self.d_model).init_with(record.linear_outer), - dropout: DropoutConfig::new(self.dropout).init(), - gelu: GELU::new(), + /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module with a + /// [record](PositionWiseFeedForwardRecord). + pub fn init_with( + &self, + record: PositionWiseFeedForwardRecord, + ) -> PositionWiseFeedForward { + PositionWiseFeedForward { + linear_inner: LinearConfig::new(self.d_model, self.d_ff).init_with(record.linear_inner), + linear_outer: LinearConfig::new(self.d_ff, self.d_model).init_with(record.linear_outer), + dropout: DropoutConfig::new(self.dropout).init(), + gelu: GELU::new(), + } } - } } impl PositionWiseFeedForward { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - tensor: `[batch_size, seq_length, d_model]` - /// - output: `[batch_size, seq_length, d_model]` - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.linear_inner.forward(input); - let x = self.gelu.forward(x); - let x = self.dropout.forward(x); + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - tensor: `[batch_size, seq_length, d_model]` + /// - output: `[batch_size, seq_length, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.linear_inner.forward(input); + let x = self.gelu.forward(x); + let x = self.dropout.forward(x); - self.linear_outer.forward(x) - } + self.linear_outer.forward(x) + } } diff --git a/burn-core/src/nn/unfold.rs b/burn-core/src/nn/unfold.rs index 3ad588b86a..26711622e3 100644 --- a/burn-core/src/nn/unfold.rs +++ b/burn-core/src/nn/unfold.rs @@ -10,50 +10,50 @@ use burn_tensor::Tensor; /// Configuration to create an [unfold 4D](Unfold4d) layer. #[derive(Config, Debug)] pub struct Unfold4dConfig { - /// The size of the kernel. - pub kernel_size: [usize; 2], - /// The stride of the convolution. - #[config(default = "[1, 1]")] - pub stride: [usize; 2], - /// Spacing between kernel elements. - #[config(default = "[1, 1]")] - pub dilation: [usize; 2], - /// The padding configuration. - #[config(default = "[0, 0]")] - pub padding: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + #[config(default = "[1, 1]")] + pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], + /// The padding configuration. + #[config(default = "[0, 0]")] + pub padding: [usize; 2], } /// Four-dimensional unfolding. #[derive(Module, Clone, Debug)] pub struct Unfold4d { - config: Unfold4dConfig, + config: Unfold4dConfig, } impl Unfold4dConfig { - /// Initialize a new [unfold 4k](Unfold4d) module. - pub fn init(&self) -> Unfold4d { - Unfold4d { - config: self.clone(), + /// Initialize a new [unfold 4k](Unfold4d) module. + pub fn init(&self) -> Unfold4d { + Unfold4d { + config: self.clone(), + } } - } } impl Unfold4d { - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// input: `[batch_size, channels_in, height, width]`, - /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, - pub fn forward(&self, input: Tensor) -> Tensor { - unfold4d( - input, - self.config.kernel_size, - UnfoldOptions::new( - self.config.stride, - self.config.padding, - self.config.dilation, - ), - ) - } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// input: `[batch_size, channels_in, height, width]`, + /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, + pub fn forward(&self, input: Tensor) -> Tensor { + unfold4d( + input, + self.config.kernel_size, + UnfoldOptions::new( + self.config.stride, + self.config.padding, + self.config.dilation, + ), + ) + } } diff --git a/burn-core/src/optim/adagrad.rs b/burn-core/src/optim/adagrad.rs index 52ee2c0c1b..bdb8dd4a0a 100644 --- a/burn-core/src/optim/adagrad.rs +++ b/burn-core/src/optim/adagrad.rs @@ -1,11 +1,11 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use super::{ - decay::{WeightDecay, WeightDecayConfig}, - Optimizer, SimpleOptimizer, + decay::{WeightDecay, WeightDecayConfig}, + Optimizer, SimpleOptimizer, }; use crate::config::Config; use crate::optim::adaptor::OptimizerAdaptor; @@ -15,263 +15,263 @@ use burn_tensor::backend::Backend; /// AdaGrad configuration. #[derive(Config)] pub struct AdaGradConfig { - #[config(default = 0.)] - lr_decay: f64, - #[config(default = 1e-5)] - epsilon: f32, - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + #[config(default = 0.)] + lr_decay: f64, + #[config(default = 1e-5)] + epsilon: f32, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } /// AdaGrad optimizer pub struct AdaGrad { - lr_decay: LRDecay, - weight_decay: Option>, + lr_decay: LRDecay, + weight_decay: Option>, } /// AdaGrad state. #[derive(Record, Clone, new)] pub struct AdaGradState { - lr_decay: LRDecayState, + lr_decay: LRDecayState, } impl SimpleOptimizer for AdaGrad { - type State = AdaGradState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - let mut state_lr_decay = None; - - if let Some(state) = state { - state_lr_decay = Some(state.lr_decay); - } + type State = AdaGradState; - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); - } + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + let mut state_lr_decay = None; - let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay); + if let Some(state) = state { + state_lr_decay = Some(state.lr_decay); + } - let state = AdaGradState::new(state_lr_decay); + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); + } - (tensor - grad, Some(state)) - } + let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay); - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.lr_decay = state.lr_decay.to_device(device); - state - } + let state = AdaGradState::new(state_lr_decay); + + (tensor - grad, Some(state)) + } + + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.lr_decay = state.lr_decay.to_device(device); + state + } } impl AdaGradConfig { - /// Initialize AdaGrad optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>(&self) -> impl Optimizer { - let optim = AdaGrad { - lr_decay: LRDecay { - lr_decay: self.lr_decay, - epsilon: self.epsilon, - }, - weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), - }; - - let mut optim = OptimizerAdaptor::from(optim); - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); + /// Initialize AdaGrad optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>(&self) -> impl Optimizer { + let optim = AdaGrad { + lr_decay: LRDecay { + lr_decay: self.lr_decay, + epsilon: self.epsilon, + }, + weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), + }; + + let mut optim = OptimizerAdaptor::from(optim); + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); + } + optim } - optim - } } /// Learning rate decay state (also includes sum state). #[derive(Record, new, Clone)] pub struct LRDecayState { - time: usize, - sum: Tensor, + time: usize, + sum: Tensor, } struct LRDecay { - lr_decay: f64, - epsilon: f32, + lr_decay: f64, + epsilon: f32, } impl LRDecay { - pub fn transform( - &self, - grad: Tensor, - lr: LearningRate, - lr_decay_state: Option>, - ) -> (Tensor, LRDecayState) { - let state = if let Some(mut state) = lr_decay_state { - state.sum = state.sum.add(grad.clone().powf(2.)); - state.time += 1; - state - } else { - LRDecayState::new(1, grad.clone().powf(2.)) - }; - - let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); - - let grad = grad - .div(state.sum.clone().sqrt().add_scalar(self.epsilon)) - .mul_scalar(new_lr); - - (grad, state) - } + pub fn transform( + &self, + grad: Tensor, + lr: LearningRate, + lr_decay_state: Option>, + ) -> (Tensor, LRDecayState) { + let state = if let Some(mut state) = lr_decay_state { + state.sum = state.sum.add(grad.clone().powf(2.)); + state.time += 1; + state + } else { + LRDecayState::new(1, grad.clone().powf(2.)) + }; + + let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); + + let grad = grad + .div(state.sum.clone().sqrt().add_scalar(self.epsilon)) + .mul_scalar(new_lr); + + (grad, state) + } } impl LRDecayState { - /// Move state to device. - /// - /// # Arguments - /// - /// * `device` - Device to move state to. - /// - /// # Returns - /// - /// Returns state moved to device. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.sum = self.sum.to_device(device); - self - } + /// Move state to device. + /// + /// # Arguments + /// + /// * `device` - Device to move state to. + /// + /// # Returns + /// + /// Returns state moved to device. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.sum = self.sum.to_device(device); + self + } } #[cfg(test)] mod tests { - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - - const LEARNING_RATE: LearningRate = 0.01; - - #[test] - fn test_adagrad_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_adagrad(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - BinFileRecorder::::default() - .record(optimizer.to_record(), "/tmp/test_optim".into()) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_adagrad(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - const ASSERT_PRECISION: usize = 6; - - #[test] - fn test_adagrad_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdaGradConfig::new() - .with_epsilon(1e-8) - .with_lr_decay(0.5) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711], - [ - 0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756, - ], - [ - -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538, - ], - [ - -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964, - ], - [ - 0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504, - ], - [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895], - ]); - let bias_expected = Data::from([ - -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - fn create_adagrad( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { - let config = AdaGradConfig::new(); - AdaGrad { - lr_decay: LRDecay { - lr_decay: config.lr_decay, - epsilon: config.epsilon, - }, - weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + + const LEARNING_RATE: LearningRate = 0.01; + + #[test] + fn test_adagrad_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_adagrad(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + BinFileRecorder::::default() + .record(optimizer.to_record(), "/tmp/test_optim".into()) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_adagrad(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + const ASSERT_PRECISION: usize = 6; + + #[test] + fn test_adagrad_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdaGradConfig::new() + .with_epsilon(1e-8) + .with_lr_decay(0.5) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711], + [ + 0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756, + ], + [ + -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538, + ], + [ + -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964, + ], + [ + 0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504, + ], + [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895], + ]); + let bias_expected = Data::from([ + -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + fn create_adagrad( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + { + let config = AdaGradConfig::new(); + AdaGrad { + lr_decay: LRDecay { + lr_decay: config.lr_decay, + epsilon: config.epsilon, + }, + weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), + } + .into() } - .into() - } } diff --git a/burn-core/src/optim/adam.rs b/burn-core/src/optim/adam.rs index 6a02ca0b56..e43bd077fd 100644 --- a/burn-core/src/optim/adam.rs +++ b/burn-core/src/optim/adam.rs @@ -1,11 +1,11 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use super::{ - decay::{WeightDecay, WeightDecayConfig}, - Optimizer, SimpleOptimizer, + decay::{WeightDecay, WeightDecayConfig}, + Optimizer, SimpleOptimizer, }; use crate::config::Config; use crate::optim::adaptor::OptimizerAdaptor; @@ -15,337 +15,338 @@ use burn_tensor::{backend::Backend, ElementConversion}; /// Adam configuration. #[derive(Config)] pub struct AdamConfig { - /// Parameter for Adam. - #[config(default = 0.9)] - beta_1: f32, - /// Parameter for Adam. - #[config(default = 0.999)] - beta_2: f32, - /// A value required for numerical stability. - #[config(default = 1e-5)] - epsilon: f32, - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + /// Parameter for Adam. + #[config(default = 0.9)] + beta_1: f32, + /// Parameter for Adam. + #[config(default = 0.999)] + beta_2: f32, + /// A value required for numerical stability. + #[config(default = 1e-5)] + epsilon: f32, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } /// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf). pub struct Adam { - momentum: AdaptiveMomentum, - weight_decay: Option>, + momentum: AdaptiveMomentum, + weight_decay: Option>, } /// Adam state. #[derive(Record, Clone, new)] pub struct AdamState { - momentum: AdaptiveMomentumState, + momentum: AdaptiveMomentumState, } impl SimpleOptimizer for Adam { - type State = AdamState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - let mut state_momentum = None; - - if let Some(state) = state { - state_momentum = Some(state.momentum); - } + type State = AdamState; - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); - } + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + let mut state_momentum = None; + + if let Some(state) = state { + state_momentum = Some(state.momentum); + } + + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); + } - let (grad, state_momentum) = self.momentum.transform(grad, state_momentum); + let (grad, state_momentum) = self.momentum.transform(grad, state_momentum); - let state = AdamState::new(state_momentum); - let delta = grad.mul_scalar(lr); + let state = AdamState::new(state_momentum); + let delta = grad.mul_scalar(lr); - (tensor - delta, Some(state)) - } + (tensor - delta, Some(state)) + } - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.momentum = state.momentum.to_device(device); - state - } + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.momentum = state.momentum.to_device(device); + state + } } impl AdamConfig { - /// Initialize Adam optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>(&self) -> impl Optimizer { - let optim = Adam { - momentum: AdaptiveMomentum { - beta_1: self.beta_1, - beta_2: self.beta_2, - epsilon: self.epsilon, - }, - weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), - }; - - let mut optim = OptimizerAdaptor::from(optim); - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); + /// Initialize Adam optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>(&self) -> impl Optimizer { + let optim = Adam { + momentum: AdaptiveMomentum { + beta_1: self.beta_1, + beta_2: self.beta_2, + epsilon: self.epsilon, + }, + weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), + }; + + let mut optim = OptimizerAdaptor::from(optim); + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); + } + optim } - optim - } } /// Adaptive momentum state. #[derive(Record, new, Clone)] pub struct AdaptiveMomentumState { - time: usize, - moment_1: Tensor, - moment_2: Tensor, + time: usize, + moment_1: Tensor, + moment_2: Tensor, } struct AdaptiveMomentum { - beta_1: f32, - beta_2: f32, - epsilon: f32, + beta_1: f32, + beta_2: f32, + epsilon: f32, } impl AdaptiveMomentum { - pub fn transform( - &self, - grad: Tensor, - momentum_state: Option>, - ) -> (Tensor, AdaptiveMomentumState) { - let state = if let Some(mut state) = momentum_state { - let factor = 1.0 - self.beta_1; - state.moment_1 = state - .moment_1 - .mul_scalar(self.beta_1) - .add(grad.clone().mul_scalar(factor)); - - let factor = 1.0 - self.beta_2; - state.moment_2 = state - .moment_2 - .mul_scalar(self.beta_2) - .add(grad.powf(2.0).mul_scalar(factor)); - - state.time += 1; - - state - } else { - let factor = 1.0 - self.beta_1; - let moment_1 = grad.clone().mul_scalar(factor); - - let factor = 1.0 - self.beta_2; - let moment_2 = grad.powf(2.0).mul_scalar(factor); - - AdaptiveMomentumState::new(1, moment_1, moment_2) - }; - - let time = (state.time as i32).elem(); - let moment_1_corrected = state - .moment_1 - .clone() - .div_scalar(1f32 - self.beta_1.powi(time)); - let moment_2_corrected = state - .moment_2 - .clone() - .div_scalar(1f32 - self.beta_2.powi(time)); - - let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); - - (grad, state) - } + pub fn transform( + &self, + grad: Tensor, + momentum_state: Option>, + ) -> (Tensor, AdaptiveMomentumState) { + let state = if let Some(mut state) = momentum_state { + let factor = 1.0 - self.beta_1; + state.moment_1 = state + .moment_1 + .mul_scalar(self.beta_1) + .add(grad.clone().mul_scalar(factor)); + + let factor = 1.0 - self.beta_2; + state.moment_2 = state + .moment_2 + .mul_scalar(self.beta_2) + .add(grad.powf(2.0).mul_scalar(factor)); + + state.time += 1; + + state + } else { + let factor = 1.0 - self.beta_1; + let moment_1 = grad.clone().mul_scalar(factor); + + let factor = 1.0 - self.beta_2; + let moment_2 = grad.powf(2.0).mul_scalar(factor); + + AdaptiveMomentumState::new(1, moment_1, moment_2) + }; + + let time = (state.time as i32).elem(); + let moment_1_corrected = state + .moment_1 + .clone() + .div_scalar(1f32 - self.beta_1.powi(time)); + let moment_2_corrected = state + .moment_2 + .clone() + .div_scalar(1f32 - self.beta_2.powi(time)); + + let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); + + (grad, state) + } } impl AdaptiveMomentumState { - /// Move state to device. - /// - /// # Arguments - /// - /// * `device` - Device to move state to. - /// - /// # Returns - /// - /// Returns state moved to device. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.moment_1 = self.moment_1.to_device(device); - self.moment_2 = self.moment_2.to_device(device); - self - } + /// Move state to device. + /// + /// # Arguments + /// + /// * `device` - Device to move state to. + /// + /// # Returns + /// + /// Returns state moved to device. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.moment_1 = self.moment_1.to_device(device); + self.moment_2 = self.moment_2.to_device(device); + self + } } #[cfg(test)] mod tests { - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - - const LEARNING_RATE: LearningRate = 0.01; - - #[test] - fn test_adam_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_adam(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - BinFileRecorder::::default() - .record(optimizer.to_record(), "/tmp/test_optim".into()) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_adam(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - const ASSERT_PRECISION: usize = 2; - - #[test] - fn test_adam_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(Some(WeightDecayConfig::new(0.5))) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154], - [ - 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133, - ], - [ - -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047, - ], - [ - -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651, - ], - [ - 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343, - ], - [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346], - ]); - let bias_expected = Data::from([ - -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - #[test] - fn test_adam_optimizer_no_nan() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - - let x = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(Some(WeightDecayConfig::new(0.5))) - .init(); - - let grads = linear.forward(x.clone()).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - assert!(!state_updated.weight.to_data().value[0].is_nan()); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - fn create_adam( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> { - let config = AdamConfig::new(); - Adam { - momentum: AdaptiveMomentum { - beta_1: config.beta_1, - beta_2: config.beta_2, - epsilon: config.epsilon, - }, - weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + + const LEARNING_RATE: LearningRate = 0.01; + + #[test] + fn test_adam_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_adam(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + BinFileRecorder::::default() + .record(optimizer.to_record(), "/tmp/test_optim".into()) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_adam(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + const ASSERT_PRECISION: usize = 2; + + #[test] + fn test_adam_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(Some(WeightDecayConfig::new(0.5))) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154], + [ + 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133, + ], + [ + -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047, + ], + [ + -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651, + ], + [ + 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343, + ], + [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346], + ]); + let bias_expected = Data::from([ + -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + #[test] + fn test_adam_optimizer_no_nan() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + + let x = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(Some(WeightDecayConfig::new(0.5))) + .init(); + + let grads = linear.forward(x.clone()).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + assert!(!state_updated.weight.to_data().value[0].is_nan()); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + fn create_adam( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + { + let config = AdamConfig::new(); + Adam { + momentum: AdaptiveMomentum { + beta_1: config.beta_1, + beta_2: config.beta_2, + epsilon: config.epsilon, + }, + weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), + } + .into() } - .into() - } } diff --git a/burn-core/src/optim/adamw.rs b/burn-core/src/optim/adamw.rs index 8f8441b489..befbeb88cb 100644 --- a/burn-core/src/optim/adamw.rs +++ b/burn-core/src/optim/adamw.rs @@ -1,6 +1,6 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use std::marker::PhantomData; @@ -13,355 +13,356 @@ use burn_tensor::{backend::Backend, ElementConversion}; /// AdamW configuration. #[derive(Config)] pub struct AdamWConfig { - /// Parameter for AdamW. - #[config(default = 0.9)] - beta_1: f32, - /// Parameter for AdamW. - #[config(default = 0.999)] - beta_2: f32, - /// A value required for numerical stability. - #[config(default = 1e-5)] - epsilon: f32, - /// Weight decay config. - #[config(default = 1e-4)] - weight_decay: f32, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + /// Parameter for AdamW. + #[config(default = 0.9)] + beta_1: f32, + /// Parameter for AdamW. + #[config(default = 0.999)] + beta_2: f32, + /// A value required for numerical stability. + #[config(default = 1e-5)] + epsilon: f32, + /// Weight decay config. + #[config(default = 1e-4)] + weight_decay: f32, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } /// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101). pub struct AdamW { - momentum: AdaptiveMomentumW, - weight_decay: f32, - _phantom: PhantomData, + momentum: AdaptiveMomentumW, + weight_decay: f32, + _phantom: PhantomData, } /// AdamW state. #[derive(Record, Clone, new)] pub struct AdamWState { - momentum: AdaptiveMomentumWState, + momentum: AdaptiveMomentumWState, } impl SimpleOptimizer for AdamW { - type State = AdamWState; - - /// A single optimization step for any tensor that represents the parameters of a model. - fn step( - &self, - // Learning rate. - lr: LearningRate, - // Any tensor that represents the parameters of a model. - tensor: Tensor, - // Gradient of the loss w.r.t. the parameters. - grad: Tensor, - // State of the optimizer. - state: Option>, - ) -> (Tensor, Option>) { - let tensor_updated = tensor.clone() - tensor.mul_scalar(lr).mul_scalar(self.weight_decay); - - let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum)); - - let state = AdamWState { - momentum: momentum_state, - }; - - (tensor_updated - raw_delta.mul_scalar(lr), Some(state)) - } - - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.momentum = state.momentum.to_device(device); - state - } + type State = AdamWState; + + /// A single optimization step for any tensor that represents the parameters of a model. + fn step( + &self, + // Learning rate. + lr: LearningRate, + // Any tensor that represents the parameters of a model. + tensor: Tensor, + // Gradient of the loss w.r.t. the parameters. + grad: Tensor, + // State of the optimizer. + state: Option>, + ) -> (Tensor, Option>) { + let tensor_updated = tensor.clone() - tensor.mul_scalar(lr).mul_scalar(self.weight_decay); + + let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum)); + + let state = AdamWState { + momentum: momentum_state, + }; + + (tensor_updated - raw_delta.mul_scalar(lr), Some(state)) + } + + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.momentum = state.momentum.to_device(device); + state + } } impl AdamWConfig { - /// Initialize AdamW optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>(&self) -> impl Optimizer { - let optim = AdamW { - momentum: AdaptiveMomentumW { - beta_1: self.beta_1, - beta_2: self.beta_2, - epsilon: self.epsilon, - }, - weight_decay: self.weight_decay, - _phantom: Default::default(), - }; - - let mut optim = OptimizerAdaptor::from(optim); - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); + /// Initialize AdamW optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>(&self) -> impl Optimizer { + let optim = AdamW { + momentum: AdaptiveMomentumW { + beta_1: self.beta_1, + beta_2: self.beta_2, + epsilon: self.epsilon, + }, + weight_decay: self.weight_decay, + _phantom: Default::default(), + }; + + let mut optim = OptimizerAdaptor::from(optim); + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); + } + optim } - optim - } } /// Adaptive momentum state. #[derive(Record, new, Clone)] pub struct AdaptiveMomentumWState { - time: usize, - moment_1: Tensor, - moment_2: Tensor, + time: usize, + moment_1: Tensor, + moment_2: Tensor, } struct AdaptiveMomentumW { - beta_1: f32, - beta_2: f32, - epsilon: f32, + beta_1: f32, + beta_2: f32, + epsilon: f32, } impl AdaptiveMomentumW { - pub fn transform( - &self, - grad: Tensor, - state: Option>, - ) -> (Tensor, AdaptiveMomentumWState) { - let state = if let Some(mut state) = state { - // Update first moment estimate. - let factor = 1.0 - self.beta_1; - state.moment_1 = state - .moment_1 - .mul_scalar(self.beta_1) - .add(grad.clone().mul_scalar(factor)); - - // Update second moment estimate. - let factor = 1.0 - self.beta_2; - state.moment_2 = state - .moment_2 - .mul_scalar(self.beta_2) - .add(grad.powf(2.0).mul_scalar(factor)); - - // Update time. - state.time += 1; - - state - } else { - // Initialize first moment estimate. - let factor = 1.0 - self.beta_1; - let moment_1 = grad.clone().mul_scalar(factor); - - // Initialize second moment estimate. - let factor = 1.0 - self.beta_2; - let moment_2 = grad.powf(2.0).mul_scalar(factor); - - AdaptiveMomentumWState::new(1, moment_1, moment_2) - }; - - let time: i32 = (state.time as i32).elem(); - - // Compute bias-corrected first and second moment estimates. - let moment_1_corrected = state - .moment_1 - .clone() - .div_scalar(1f32 - self.beta_1.powi(time)); - - let moment_2_corrected = state - .moment_2 - .clone() - .div_scalar(1f32 - self.beta_2.powi(time)); - - // Compute update delta. This still needs to be scaled by the learning rate. - let update_delta = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); - - ( - update_delta, - AdaptiveMomentumWState::new(state.time, state.moment_1, state.moment_2), - ) - } + pub fn transform( + &self, + grad: Tensor, + state: Option>, + ) -> (Tensor, AdaptiveMomentumWState) { + let state = if let Some(mut state) = state { + // Update first moment estimate. + let factor = 1.0 - self.beta_1; + state.moment_1 = state + .moment_1 + .mul_scalar(self.beta_1) + .add(grad.clone().mul_scalar(factor)); + + // Update second moment estimate. + let factor = 1.0 - self.beta_2; + state.moment_2 = state + .moment_2 + .mul_scalar(self.beta_2) + .add(grad.powf(2.0).mul_scalar(factor)); + + // Update time. + state.time += 1; + + state + } else { + // Initialize first moment estimate. + let factor = 1.0 - self.beta_1; + let moment_1 = grad.clone().mul_scalar(factor); + + // Initialize second moment estimate. + let factor = 1.0 - self.beta_2; + let moment_2 = grad.powf(2.0).mul_scalar(factor); + + AdaptiveMomentumWState::new(1, moment_1, moment_2) + }; + + let time: i32 = (state.time as i32).elem(); + + // Compute bias-corrected first and second moment estimates. + let moment_1_corrected = state + .moment_1 + .clone() + .div_scalar(1f32 - self.beta_1.powi(time)); + + let moment_2_corrected = state + .moment_2 + .clone() + .div_scalar(1f32 - self.beta_2.powi(time)); + + // Compute update delta. This still needs to be scaled by the learning rate. + let update_delta = + moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); + + ( + update_delta, + AdaptiveMomentumWState::new(state.time, state.moment_1, state.moment_2), + ) + } } impl AdaptiveMomentumWState { - /// Move state to device. - /// - /// # Arguments - /// - /// * `device` - Device to move state to. - /// - /// # Returns - /// - /// Returns state moved to device. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.moment_1 = self.moment_1.to_device(device); - self.moment_2 = self.moment_2.to_device(device); - self - } + /// Move state to device. + /// + /// # Arguments + /// + /// * `device` - Device to move state to. + /// + /// # Returns + /// + /// Returns state moved to device. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.moment_1 = self.moment_1.to_device(device); + self.moment_2 = self.moment_2.to_device(device); + self + } } #[cfg(test)] mod tests { - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - use tempfile::TempDir; - - const LEARNING_RATE: LearningRate = 0.01; - - #[test] - fn test_adamw_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_adamw(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - let temp_dir = TempDir::new().unwrap(); - BinFileRecorder::::default() - .record(optimizer.to_record(), temp_dir.path().join("test_optim")) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_adamw(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - - const ASSERT_PRECISION: usize = 2; - - #[test] - fn test_adamw_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamWConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(0.5) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534], - [ - 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182, - ], - [ - -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981, - ], - [ - -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081, - ], - [ - 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993, - ], - [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580], - ]); - let bias_expected = Data::from([ - -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - #[test] - fn test_adam_optimizer_no_nan() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - - let x = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = AdamWConfig::new() - .with_epsilon(1e-8) - .with_beta_1(0.9) - .with_beta_2(0.999) - .with_weight_decay(0.5) - .init(); - - let grads = linear.forward(x.clone()).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - assert!(!state_updated.weight.to_data().value[0].is_nan()); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - fn create_adamw( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { - let config = AdamWConfig::new(); - AdamW { - momentum: AdaptiveMomentumW { - beta_1: config.beta_1, - beta_2: config.beta_2, - epsilon: config.epsilon, - }, - weight_decay: config.weight_decay, - _phantom: Default::default(), + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + use tempfile::TempDir; + + const LEARNING_RATE: LearningRate = 0.01; + + #[test] + fn test_adamw_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_adamw(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + let temp_dir = TempDir::new().unwrap(); + BinFileRecorder::::default() + .record(optimizer.to_record(), temp_dir.path().join("test_optim")) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_adamw(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + + const ASSERT_PRECISION: usize = 2; + + #[test] + fn test_adamw_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamWConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(0.5) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534], + [ + 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182, + ], + [ + -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981, + ], + [ + -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081, + ], + [ + 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993, + ], + [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580], + ]); + let bias_expected = Data::from([ + -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + #[test] + fn test_adam_optimizer_no_nan() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + + let x = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = AdamWConfig::new() + .with_epsilon(1e-8) + .with_beta_1(0.9) + .with_beta_2(0.999) + .with_weight_decay(0.5) + .init(); + + let grads = linear.forward(x.clone()).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + assert!(!state_updated.weight.to_data().value[0].is_nan()); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + fn create_adamw( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + { + let config = AdamWConfig::new(); + AdamW { + momentum: AdaptiveMomentumW { + beta_1: config.beta_1, + beta_2: config.beta_2, + epsilon: config.epsilon, + }, + weight_decay: config.weight_decay, + _phantom: Default::default(), + } + .into() } - .into() - } } diff --git a/burn-core/src/optim/base.rs b/burn-core/src/optim/base.rs index 3602efb57e..5fc54fede3 100644 --- a/burn-core/src/optim/base.rs +++ b/burn-core/src/optim/base.rs @@ -7,19 +7,19 @@ use crate::LearningRate; /// General trait to optimize [module](AutodiffModule). pub trait Optimizer: Send + Sync where - M: AutodiffModule, - B: AutodiffBackend, + M: AutodiffModule, + B: AutodiffBackend, { - /// Optimizer associative type to be used when saving and loading the state. - type Record: Record; + /// Optimizer associative type to be used when saving and loading the state. + type Record: Record; - /// Perform the optimizer step using the given learning rate and gradients. - /// The updated module is returned. - fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M; + /// Perform the optimizer step using the given learning rate and gradients. + /// The updated module is returned. + fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M; - /// Get the current state of the optimizer as a [record](Record). - fn to_record(&self) -> Self::Record; + /// Get the current state of the optimizer as a [record](Record). + fn to_record(&self) -> Self::Record; - /// Load the state of the optimizer as a [record](Record). - fn load_record(self, record: Self::Record) -> Self; + /// Load the state of the optimizer as a [record](Record). + fn load_record(self, record: Self::Record) -> Self; } diff --git a/burn-core/src/optim/decay.rs b/burn-core/src/optim/decay.rs index a7ca4d8dcb..eb1653990d 100644 --- a/burn-core/src/optim/decay.rs +++ b/burn-core/src/optim/decay.rs @@ -9,60 +9,60 @@ use crate::tensor::{ElementConversion, Tensor}; /// Configuration to create [weight decay](WeightDecay). #[derive(Config)] pub struct WeightDecayConfig { - /// L2 penalty. - pub penalty: f64, + /// L2 penalty. + pub penalty: f64, } /// State of [weight decay](WeightDecay). #[derive(Record, Clone, new)] pub struct WeightDecayState { - pub(crate) grad_last_step: Tensor, + pub(crate) grad_last_step: Tensor, } /// Weight decay implementation that transforms gradients. pub struct WeightDecay { - penalty: B::FloatElem, + penalty: B::FloatElem, } impl WeightDecay { - /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig). - pub fn new(config: &WeightDecayConfig) -> Self { - Self { - penalty: config.penalty.elem(), + /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig). + pub fn new(config: &WeightDecayConfig) -> Self { + Self { + penalty: config.penalty.elem(), + } } - } - /// Transforms a gradient. - /// - /// # Arguments - /// - /// * `grad` - Gradient to transform. - /// * `tensor` - Tensor param of the last iteration. - /// - /// # Returns - /// - /// * `grad` - Transformed gradient. - pub fn transform( - &self, - grad: Tensor, - tensor: Tensor, - ) -> Tensor { - tensor.mul_scalar(self.penalty).add(grad) - } + /// Transforms a gradient. + /// + /// # Arguments + /// + /// * `grad` - Gradient to transform. + /// * `tensor` - Tensor param of the last iteration. + /// + /// # Returns + /// + /// * `grad` - Transformed gradient. + pub fn transform( + &self, + grad: Tensor, + tensor: Tensor, + ) -> Tensor { + tensor.mul_scalar(self.penalty).add(grad) + } } impl WeightDecayState { - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.grad_last_step = self.grad_last_step.to_device(device); - self - } + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.grad_last_step = self.grad_last_step.to_device(device); + self + } } diff --git a/burn-core/src/optim/grad_accum.rs b/burn-core/src/optim/grad_accum.rs index 6655525b4f..e0e455bdf7 100644 --- a/burn-core/src/optim/grad_accum.rs +++ b/burn-core/src/optim/grad_accum.rs @@ -8,116 +8,115 @@ use super::GradientsParams; /// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object. pub struct GradientsAccumulator { - grads: GradientsParams, - phantom: PhantomData, + grads: GradientsParams, + phantom: PhantomData, } impl Default for GradientsAccumulator { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl GradientsAccumulator { - /// Create a new gradients accumulator. - pub fn new() -> Self { - Self { - grads: GradientsParams::new(), - phantom: PhantomData, + /// Create a new gradients accumulator. + pub fn new() -> Self { + Self { + grads: GradientsParams::new(), + phantom: PhantomData, + } } - } } impl GradientsAccumulator { - /// Accumulate the given gradients for each parameter in the given module. - pub fn accumulate(&mut self, module: &M, grads: GradientsParams) - where - M: AutodiffModule, - { - let mut visitor = ModuleGradsAccumulator::::new(&mut self.grads, grads); - module.visit(&mut visitor); - } - - /// Return the accumulated gradients and reset the accumulator state. - pub fn grads(&mut self) -> GradientsParams { - let mut grads = GradientsParams::new(); - core::mem::swap(&mut self.grads, &mut grads); - - grads - } + /// Accumulate the given gradients for each parameter in the given module. + pub fn accumulate(&mut self, module: &M, grads: GradientsParams) + where + M: AutodiffModule, + { + let mut visitor = ModuleGradsAccumulator::::new(&mut self.grads, grads); + module.visit(&mut visitor); + } + + /// Return the accumulated gradients and reset the accumulator state. + pub fn grads(&mut self) -> GradientsParams { + let mut grads = GradientsParams::new(); + core::mem::swap(&mut self.grads, &mut grads); + + grads + } } #[derive(new)] struct ModuleGradsAccumulator<'a, M> { - grads: &'a mut GradientsParams, - grads_new: GradientsParams, - phantom: PhantomData, + grads: &'a mut GradientsParams, + grads_new: GradientsParams, + phantom: PhantomData, } impl<'a, B: AutodiffBackend, M: AutodiffModule> ModuleVisitor - for ModuleGradsAccumulator<'a, M> + for ModuleGradsAccumulator<'a, M> { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { - let grad_updated = match self.grads_new.remove::(id) { - Some(new) => match self.grads.remove::(id) { - Some(grad) => grad.add(new), - None => new, - }, - None => match self.grads.remove::(id) { - Some(grad) => grad, - None => return, - }, - }; - - self - .grads - .register::(id.clone(), grad_updated); - } + fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + let grad_updated = match self.grads_new.remove::(id) { + Some(new) => match self.grads.remove::(id) { + Some(grad) => grad.add(new), + None => new, + }, + None => match self.grads.remove::(id) { + Some(grad) => grad, + None => return, + }, + }; + + self.grads + .register::(id.clone(), grad_updated); + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - nn::{Linear, LinearConfig}, - TestAutodiffBackend, - }; - use burn_tensor::Distribution; - - #[test] - fn test_accumulate_gradients_one_step() { - let mut accumulator = GradientsAccumulator::new(); - let layer = layer(); - let loss = layer.forward(random_tensor()); - let grads = GradientsParams::from_grads(loss.backward(), &layer); - - accumulator.accumulate(&layer, grads); - - let grads = accumulator.grads(); - assert!(!grads.is_empty()) - } - - #[test] - fn test_accumulate_gradients_two_steps() { - let mut accumulator = GradientsAccumulator::new(); - let layer = layer(); - let loss_1 = layer.forward(random_tensor()); - let loss_2 = layer.forward(random_tensor()); - let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer); - let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer); - - accumulator.accumulate(&layer, grads_1); - accumulator.accumulate(&layer, grads_2); - - let grads = accumulator.grads(); - assert_eq!(grads.len(), 2) - } - - fn layer() -> Linear { - LinearConfig::new(20, 20).with_bias(true).init() - } - - fn random_tensor() -> Tensor { - Tensor::::random([2, 20], Distribution::Default) - } + use super::*; + use crate::{ + nn::{Linear, LinearConfig}, + TestAutodiffBackend, + }; + use burn_tensor::Distribution; + + #[test] + fn test_accumulate_gradients_one_step() { + let mut accumulator = GradientsAccumulator::new(); + let layer = layer(); + let loss = layer.forward(random_tensor()); + let grads = GradientsParams::from_grads(loss.backward(), &layer); + + accumulator.accumulate(&layer, grads); + + let grads = accumulator.grads(); + assert!(!grads.is_empty()) + } + + #[test] + fn test_accumulate_gradients_two_steps() { + let mut accumulator = GradientsAccumulator::new(); + let layer = layer(); + let loss_1 = layer.forward(random_tensor()); + let loss_2 = layer.forward(random_tensor()); + let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer); + let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer); + + accumulator.accumulate(&layer, grads_1); + accumulator.accumulate(&layer, grads_2); + + let grads = accumulator.grads(); + assert_eq!(grads.len(), 2) + } + + fn layer() -> Linear { + LinearConfig::new(20, 20).with_bias(true).init() + } + + fn random_tensor() -> Tensor { + Tensor::::random([2, 20], Distribution::Default) + } } diff --git a/burn-core/src/optim/grads.rs b/burn-core/src/optim/grads.rs index 79b05ae1d1..81c3eb265c 100644 --- a/burn-core/src/optim/grads.rs +++ b/burn-core/src/optim/grads.rs @@ -1,7 +1,7 @@ use burn_tensor::{ - backend::{AutodiffBackend, Backend}, - container::TensorContainer, - Tensor, + backend::{AutodiffBackend, Backend}, + container::TensorContainer, + Tensor, }; use crate::module::{AutodiffModule, ParamId}; @@ -11,115 +11,115 @@ use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter}; /// Data type that contains gradients for parameters. #[derive(Default)] pub struct GradientsParams { - container: TensorContainer, + container: TensorContainer, } impl GradientsParams { - /// Creates a new [GradientsParams](GradientsParams). - pub fn new() -> Self { - Self::default() - } - - /// Get the gradients for the given [parameter id](ParamId). - /// - /// # Notes - /// - /// You should use [remove](GradientsParams::remove) if you want to get the gradients - /// only one time. - pub fn get(&self, id: &ParamId) -> Option> - where - B: Backend, - { - self.container.get(id) - } - - /// Remove the gradients for the given [parameter id](ParamId). - pub fn remove(&mut self, id: &ParamId) -> Option> - where - B: Backend, - { - self.container.remove(id) - } - - /// Register a gradients tensor for the given [parameter id](ParamId). - /// - /// # Notes - /// - /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced. - pub fn register(&mut self, id: ParamId, value: Tensor) - where - B: Backend, - { - self.container.register(id, value) - } - - /// The number of gradients tensors registered. - pub fn len(&self) -> usize { - self.container.len() - } - - /// If any tensor is contained. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Change the device of each tensor gradients registered for the given [module](AutodiffModule). - pub fn to_device>( - mut self, - device: &B::Device, - module: &M, - ) -> Self { - let mut visitor = GradientsParamsChangeDevice::::new(device, &mut self); - module.visit(&mut visitor); - self - } - - /// Extract each tensor gradients for the given [module](AutodiffModule). - pub fn from_grads>( - grads: B::Gradients, - module: &M, - ) -> Self { - let mut grads_params = GradientsParams::new(); - let mut visitor = GradientsParamsConverter::::new(grads, &mut grads_params); - - module.visit(&mut visitor); - grads_params - } + /// Creates a new [GradientsParams](GradientsParams). + pub fn new() -> Self { + Self::default() + } + + /// Get the gradients for the given [parameter id](ParamId). + /// + /// # Notes + /// + /// You should use [remove](GradientsParams::remove) if you want to get the gradients + /// only one time. + pub fn get(&self, id: &ParamId) -> Option> + where + B: Backend, + { + self.container.get(id) + } + + /// Remove the gradients for the given [parameter id](ParamId). + pub fn remove(&mut self, id: &ParamId) -> Option> + where + B: Backend, + { + self.container.remove(id) + } + + /// Register a gradients tensor for the given [parameter id](ParamId). + /// + /// # Notes + /// + /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced. + pub fn register(&mut self, id: ParamId, value: Tensor) + where + B: Backend, + { + self.container.register(id, value) + } + + /// The number of gradients tensors registered. + pub fn len(&self) -> usize { + self.container.len() + } + + /// If any tensor is contained. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Change the device of each tensor gradients registered for the given [module](AutodiffModule). + pub fn to_device>( + mut self, + device: &B::Device, + module: &M, + ) -> Self { + let mut visitor = GradientsParamsChangeDevice::::new(device, &mut self); + module.visit(&mut visitor); + self + } + + /// Extract each tensor gradients for the given [module](AutodiffModule). + pub fn from_grads>( + grads: B::Gradients, + module: &M, + ) -> Self { + let mut grads_params = GradientsParams::new(); + let mut visitor = GradientsParamsConverter::::new(grads, &mut grads_params); + + module.visit(&mut visitor); + grads_params + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - module::{list_param_ids, Module}, - nn::{Linear, LinearConfig}, - TestAutodiffBackend, - }; - use burn_tensor::{backend::Backend, Distribution}; - - #[test] - fn test_convert_grads() { - let layer_1 = layer(); - let mut layer_2 = layer_1.clone(); - layer_2 = layer_2.fork(&::Device::default()); - let loss_1 = layer_1.forward(random_tensor()); - let loss_2 = layer_2.forward(random_tensor()); - let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1); - let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2); - - let param_ids_1 = list_param_ids(&layer_1); - let param_ids_2 = list_param_ids(&layer_2); - - assert_eq!(param_ids_1, param_ids_2); - assert_eq!(grads_1.len(), param_ids_1.len()); - assert_eq!(grads_2.len(), param_ids_2.len()); - } - - fn layer() -> Linear { - LinearConfig::new(20, 20).with_bias(true).init() - } - - fn random_tensor() -> Tensor { - Tensor::::random([2, 20], Distribution::Default) - } + use super::*; + use crate::{ + module::{list_param_ids, Module}, + nn::{Linear, LinearConfig}, + TestAutodiffBackend, + }; + use burn_tensor::{backend::Backend, Distribution}; + + #[test] + fn test_convert_grads() { + let layer_1 = layer(); + let mut layer_2 = layer_1.clone(); + layer_2 = layer_2.fork(&::Device::default()); + let loss_1 = layer_1.forward(random_tensor()); + let loss_2 = layer_2.forward(random_tensor()); + let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1); + let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2); + + let param_ids_1 = list_param_ids(&layer_1); + let param_ids_2 = list_param_ids(&layer_2); + + assert_eq!(param_ids_1, param_ids_2); + assert_eq!(grads_1.len(), param_ids_1.len()); + assert_eq!(grads_2.len(), param_ids_2.len()); + } + + fn layer() -> Linear { + LinearConfig::new(20, 20).with_bias(true).init() + } + + fn random_tensor() -> Tensor { + Tensor::::random([2, 20], Distribution::Default) + } } diff --git a/burn-core/src/optim/momentum.rs b/burn-core/src/optim/momentum.rs index 95f695ab77..ef2cb174f8 100644 --- a/burn-core/src/optim/momentum.rs +++ b/burn-core/src/optim/momentum.rs @@ -8,87 +8,86 @@ use burn_tensor::backend::Backend; /// Configuration to create [momentum](Momentum). #[derive(Config)] pub struct MomentumConfig { - /// Momemtum factor - #[config(default = 0.9)] - pub momentum: f64, - /// Dampening factor. - #[config(default = 0.1)] - pub dampening: f64, - /// Enables Nesterov momentum, see [On the importance of initialization and - /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf). - #[config(default = false)] - pub nesterov: bool, + /// Momemtum factor + #[config(default = 0.9)] + pub momentum: f64, + /// Dampening factor. + #[config(default = 0.1)] + pub dampening: f64, + /// Enables Nesterov momentum, see [On the importance of initialization and + /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf). + #[config(default = false)] + pub nesterov: bool, } /// State of [momentum](Momentum). #[derive(Record, Clone, new)] pub struct MomentumState { - velocity: Tensor, + velocity: Tensor, } /// Momemtum implementation that transforms gradients. pub struct Momentum { - momentum: B::FloatElem, - dampening: f64, - nesterov: bool, + momentum: B::FloatElem, + dampening: f64, + nesterov: bool, } impl Momentum { - /// Creates a new [momentum](Momentum) from a [config](MomentumConfig). - pub fn new(config: &MomentumConfig) -> Self { - Self { - momentum: config.momentum.elem(), - dampening: config.dampening, - nesterov: config.nesterov, + /// Creates a new [momentum](Momentum) from a [config](MomentumConfig). + pub fn new(config: &MomentumConfig) -> Self { + Self { + momentum: config.momentum.elem(), + dampening: config.dampening, + nesterov: config.nesterov, + } } - } - /// Transforms a gradient. - /// - /// # Arguments - /// - /// * `grad` - Gradient to transform. - /// * `state` - State of the optimizer. - /// - /// # Returns - /// - /// * `grad` - Transformed gradient. - /// * `state` - State of the optimizer. - pub fn transform( - &self, - grad: Tensor, - state: Option>, - ) -> (Tensor, MomentumState) { - let velocity = if let Some(state) = state { - grad - .clone() - .mul_scalar(1.0 - self.dampening) - .add(state.velocity.mul_scalar(self.momentum)) - } else { - grad.clone() - }; + /// Transforms a gradient. + /// + /// # Arguments + /// + /// * `grad` - Gradient to transform. + /// * `state` - State of the optimizer. + /// + /// # Returns + /// + /// * `grad` - Transformed gradient. + /// * `state` - State of the optimizer. + pub fn transform( + &self, + grad: Tensor, + state: Option>, + ) -> (Tensor, MomentumState) { + let velocity = if let Some(state) = state { + grad.clone() + .mul_scalar(1.0 - self.dampening) + .add(state.velocity.mul_scalar(self.momentum)) + } else { + grad.clone() + }; - let grad = match self.nesterov { - true => velocity.clone().mul_scalar(self.momentum).add(grad), - false => velocity.clone(), - }; + let grad = match self.nesterov { + true => velocity.clone().mul_scalar(self.momentum).add(grad), + false => velocity.clone(), + }; - (grad, MomentumState::new(velocity)) - } + (grad, MomentumState::new(velocity)) + } } impl MomentumState { - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.velocity = self.velocity.to_device(device); - self - } + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.velocity = self.velocity.to_device(device); + self + } } diff --git a/burn-core/src/optim/rmsprop.rs b/burn-core/src/optim/rmsprop.rs index 72ea2b2e80..ffe683db34 100644 --- a/burn-core/src/optim/rmsprop.rs +++ b/burn-core/src/optim/rmsprop.rs @@ -1,11 +1,11 @@ use crate::{ - self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, - LearningRate, + self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, + LearningRate, }; use super::{ - decay::{WeightDecay, WeightDecayConfig}, - SimpleOptimizer, + decay::{WeightDecay, WeightDecayConfig}, + SimpleOptimizer, }; use crate::config::Config; use crate::optim::adaptor::OptimizerAdaptor; @@ -15,509 +15,510 @@ use burn_tensor::backend::Backend; /// Configuration to create the [RMSProp](RMSProp) optimizer. #[derive(Config)] pub struct RMSPropConfig { - /// Smoothing constant. - #[config(default = 0.99)] - alpha: f32, - /// momentum for RMSProp. - #[config(default = 0.9)] - momentum: f32, - /// A value required for numerical stability. - #[config(default = 1e-5)] - epsilon: f32, - /// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance - #[config(default = false)] - centered: bool, - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - grad_clipping: Option, + /// Smoothing constant. + #[config(default = 0.99)] + alpha: f32, + /// momentum for RMSProp. + #[config(default = 0.9)] + momentum: f32, + /// A value required for numerical stability. + #[config(default = 1e-5)] + epsilon: f32, + /// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance + #[config(default = false)] + centered: bool, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + grad_clipping: Option, } impl RMSPropConfig { - /// Initialize RMSProp optimizer. - /// - /// # Returns - /// - /// Returns an optimizer that can be used to optimize a module. - pub fn init>( - &self, - ) -> OptimizerAdaptor, M, B> { - let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); - - let mut optim = OptimizerAdaptor::from(RMSProp { - alpha: self.alpha, - centered: self.centered, - weight_decay, - momentum: RMSPropMomentum { - momentum: self.momentum, - epsilon: self.epsilon, - }, - }); - - if let Some(config) = &self.grad_clipping { - optim = optim.with_grad_clipping(config.init()); + /// Initialize RMSProp optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. + pub fn init>( + &self, + ) -> OptimizerAdaptor, M, B> { + let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); + + let mut optim = OptimizerAdaptor::from(RMSProp { + alpha: self.alpha, + centered: self.centered, + weight_decay, + momentum: RMSPropMomentum { + momentum: self.momentum, + epsilon: self.epsilon, + }, + }); + + if let Some(config) = &self.grad_clipping { + optim = optim.with_grad_clipping(config.init()); + } + + optim } - - optim - } } /// Optimizer that implements stochastic gradient descent with momentum. /// The optimizer can be configured with [RMSPropConfig](RMSPropConfig). pub struct RMSProp { - alpha: f32, - // epsilon: f32, - centered: bool, - // momentum: Option>, - momentum: RMSPropMomentum, - weight_decay: Option>, + alpha: f32, + // epsilon: f32, + centered: bool, + // momentum: Option>, + momentum: RMSPropMomentum, + weight_decay: Option>, } impl SimpleOptimizer for RMSProp { - type State = RMSPropState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - // fetch state for params - let mut state_square_avg = None; - let mut state_centered = None; - let mut state_momentum = None; - if let Some(state) = state { - state_square_avg = Some(state.square_avg); - state_centered = Some(state.centered); - state_momentum = state.momentum; + type State = RMSPropState; + + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + // fetch state for params + let mut state_square_avg = None; + let mut state_centered = None; + let mut state_momentum = None; + if let Some(state) = state { + state_square_avg = Some(state.square_avg); + state_centered = Some(state.centered); + state_momentum = state.momentum; + } + + // weight_decay transform + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); + } + + // square_avg transform + let (grad, state_square_avg) = + SquareAvgState::transform(self.alpha, grad, state_square_avg); + + // centered transform + let (grad, state_square_avg, state_centered) = CenteredState::transform( + self.alpha, + self.centered, + grad, + state_square_avg, + state_centered, + ); + + // momentum transform + let (grad, state_centered, state_momentum) = + self.momentum + .transform(grad, state_centered, state_momentum); + + // transition state + let state = RMSPropState::new(state_square_avg, state_centered, state_momentum); + + // tensor param transform + let delta = grad.mul_scalar(lr); + (tensor - delta, Some(state)) } - // weight_decay transform - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); + fn to_device( + mut state: Self::State, + device: &::Device, + ) -> Self::State { + state.square_avg = state.square_avg.to_device(device); + state.centered = state.centered.to_device(device); + state.momentum = state.momentum.map(|momentum| momentum.to_device(device)); + state } - - // square_avg transform - let (grad, state_square_avg) = SquareAvgState::transform(self.alpha, grad, state_square_avg); - - // centered transform - let (grad, state_square_avg, state_centered) = CenteredState::transform( - self.alpha, - self.centered, - grad, - state_square_avg, - state_centered, - ); - - // momentum transform - let (grad, state_centered, state_momentum) = - self - .momentum - .transform(grad, state_centered, state_momentum); - - // transition state - let state = RMSPropState::new(state_square_avg, state_centered, state_momentum); - - // tensor param transform - let delta = grad.mul_scalar(lr); - (tensor - delta, Some(state)) - } - - fn to_device( - mut state: Self::State, - device: &::Device, - ) -> Self::State { - state.square_avg = state.square_avg.to_device(device); - state.centered = state.centered.to_device(device); - state.momentum = state.momentum.map(|momentum| momentum.to_device(device)); - state - } } /// State of [RMSProp](RMSProp) #[derive(Record, Clone, new)] pub struct RMSPropState { - square_avg: SquareAvgState, - centered: CenteredState, - momentum: Option>, + square_avg: SquareAvgState, + centered: CenteredState, + momentum: Option>, } /// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct SquareAvgState { - square_avg: Tensor, + square_avg: Tensor, } impl SquareAvgState { - /// transform [SquareAvgState] to the next step - fn transform(alpha: f32, grad: Tensor, state: Option) -> (Tensor, Self) { - match state { - Some(state) => { - let square_avg = state - .square_avg - .mul_scalar(alpha) - .add(grad.clone().powf(2.).mul_scalar(1. - alpha)); - (grad, Self { square_avg }) - } - _ => { - let square_avg = grad.clone().powf(2.).mul_scalar(1. - alpha); - (grad, Self { square_avg }) - } + /// transform [SquareAvgState] to the next step + fn transform(alpha: f32, grad: Tensor, state: Option) -> (Tensor, Self) { + match state { + Some(state) => { + let square_avg = state + .square_avg + .mul_scalar(alpha) + .add(grad.clone().powf(2.).mul_scalar(1. - alpha)); + (grad, Self { square_avg }) + } + _ => { + let square_avg = grad.clone().powf(2.).mul_scalar(1. - alpha); + (grad, Self { square_avg }) + } + } + } + + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.square_avg = self.square_avg.to_device(device); + self } - } - - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.square_avg = self.square_avg.to_device(device); - self - } } /// [CenteredState](CenteredState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct CenteredState { - grad_avg: Option>, - avg: Tensor, + grad_avg: Option>, + avg: Tensor, } impl CenteredState { - /// transform [CenteredState] to the next step - fn transform( - alpha: f32, - centered: bool, - grad: Tensor, - square_avg_state: SquareAvgState, - centered_state: Option, - ) -> (Tensor, SquareAvgState, Self) { - if centered { - let grad_avg_constant = grad.clone().mul_scalar(1. - alpha); - let grad_avg = match centered_state { - Some(state) => state - .grad_avg - .map_or(grad_avg_constant.clone(), move |grad_avg| { - grad_avg.mul_scalar(alpha).add(grad_avg_constant) - }), - _ => grad_avg_constant, - }; - let avg = square_avg_state - .square_avg - .clone() - .sub(grad_avg.clone().powf(2.)); - - ( - grad, - square_avg_state, - Self { - grad_avg: Some(grad_avg), - avg, - }, - ) - } else { - ( - grad, - square_avg_state.clone(), - Self { - grad_avg: None, - avg: square_avg_state.square_avg, - }, - ) + /// transform [CenteredState] to the next step + fn transform( + alpha: f32, + centered: bool, + grad: Tensor, + square_avg_state: SquareAvgState, + centered_state: Option, + ) -> (Tensor, SquareAvgState, Self) { + if centered { + let grad_avg_constant = grad.clone().mul_scalar(1. - alpha); + let grad_avg = match centered_state { + Some(state) => state + .grad_avg + .map_or(grad_avg_constant.clone(), move |grad_avg| { + grad_avg.mul_scalar(alpha).add(grad_avg_constant) + }), + _ => grad_avg_constant, + }; + let avg = square_avg_state + .square_avg + .clone() + .sub(grad_avg.clone().powf(2.)); + + ( + grad, + square_avg_state, + Self { + grad_avg: Some(grad_avg), + avg, + }, + ) + } else { + ( + grad, + square_avg_state.clone(), + Self { + grad_avg: None, + avg: square_avg_state.square_avg, + }, + ) + } + } + + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device)); + self.avg = self.avg.to_device(device); + self } - } - - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device)); - self.avg = self.avg.to_device(device); - self - } } /// [RMSPropMomentum](RMSPropMomentum) is to store config status for optimizer. /// (, which is stored in [optimizer](RMSProp) itself and not passed in during `step()` calculation) pub struct RMSPropMomentum { - momentum: f32, - epsilon: f32, + momentum: f32, + epsilon: f32, } impl RMSPropMomentum { - /// transform [grad](Tensor) and [RMSPropMomentumState] to the next step - fn transform( - &self, - grad: Tensor, - centered_state: CenteredState, - momentum_state: Option>, - ) -> ( - Tensor, - CenteredState, - Option>, - ) { - let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); - - if self.momentum > 0. { - let buf = match momentum_state { - Some(state) => state.buf.mul_scalar(self.momentum).add(grad), - _ => grad, - }; - ( - buf.clone(), - centered_state, - Some(RMSPropMomentumState { buf }), - ) - } else { - (grad, centered_state, None) + /// transform [grad](Tensor) and [RMSPropMomentumState] to the next step + fn transform( + &self, + grad: Tensor, + centered_state: CenteredState, + momentum_state: Option>, + ) -> ( + Tensor, + CenteredState, + Option>, + ) { + let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); + + if self.momentum > 0. { + let buf = match momentum_state { + Some(state) => state.buf.mul_scalar(self.momentum).add(grad), + _ => grad, + }; + ( + buf.clone(), + centered_state, + Some(RMSPropMomentumState { buf }), + ) + } else { + (grad, centered_state, None) + } } - } } /// [RMSPropMomentumState](RMSPropMomentumState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct RMSPropMomentumState { - buf: Tensor, + buf: Tensor, } impl RMSPropMomentumState { - /// Moves the state to a device. - /// - /// # Arguments - /// - /// * `device` - Device to move the state to. - /// - /// # Returns - /// - /// * `self` - Moved state. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.buf = self.buf.to_device(device); - self - } + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. + pub fn to_device(mut self, device: &B::Device) -> Self { + self.buf = self.buf.to_device(device); + self + } } #[cfg(test)] mod tests { - use burn_tensor::Shape; - - use super::*; - use crate::module::{Module, Param}; - use crate::optim::{GradientsParams, Optimizer}; - use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; - use crate::tensor::{Data, Distribution, Tensor}; - use crate::{nn, TestAutodiffBackend, TestBackend}; - use tempfile::TempDir; - - const LEARNING_RATE: LearningRate = 0.01; - const ASSERT_PRECISION: usize = 6; - - #[test] - fn test_rmsprop_optimizer_save_load_state() { - let linear = nn::LinearConfig::new(6, 6).init(); - let x = Tensor::::random([2, 6], Distribution::Default); - let mut optimizer = create_rmsprop(); - let grads = linear.forward(x).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let _linear = optimizer.step(LEARNING_RATE, linear, grads); - let temp_dir = TempDir::new().unwrap(); - BinFileRecorder::::default() - .record(optimizer.to_record(), temp_dir.path().join("test_optim")) - .unwrap(); - - let state_optim_before = optimizer.to_record(); - let state_optim_before_copy = optimizer.to_record(); - let optimizer = create_rmsprop(); - let optimizer = optimizer.load_record(state_optim_before_copy); - let state_optim_after = optimizer.to_record(); - - assert_eq!(state_optim_before.len(), state_optim_after.len()); - } - - /// used for test differences and debug - #[test] - fn test_rmsprop_optimizer_with_numbers_basic() { - let linear = given_linear_layer( - Data::from([ - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - ]), - Data::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = RMSPropConfig::new() - .with_alpha(0.99) - .with_epsilon(1e-8) - .with_weight_decay(WeightDecayConfig::new(0.05).into()) - .with_momentum(0.9) - .with_centered(false) - .init(); - - // println!("linear is {:?}", linear); - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - // println!("linear is {:?}", linear); - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - // println!("linear is {:?}", linear); - let state_updated = linear.into_record(); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - // println!("\nweight_updated\n{:?}", weight_updated); - // println!("\nbias_updated\n{:?}", bias_updated); - - let weights_expected = Data::from([ - [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937], - [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809], - [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881], - [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366], - [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005], - [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710], - ]); - let bias_expected = Data::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - #[test] - fn test_rmsprop_optimizer_with_numbers() { - let linear = given_linear_layer( - Data::from([ - [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], - [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], - [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], - [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], - [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], - [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], - ]), - Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), - ); - let x_1 = Tensor::from_floats([ - [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], - [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], - ]) - .require_grad(); - let x_2 = Tensor::from_floats([ - [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], - [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], - ]) - .require_grad(); - - let mut optimizer = RMSPropConfig::new() - .with_alpha(0.99) - .with_epsilon(1e-8) - .with_weight_decay(WeightDecayConfig::new(0.05).into()) - .with_momentum(0.9) - .with_centered(false) - .init(); - - let grads = linear.forward(x_1).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let grads = linear.forward(x_2).backward(); - let grads = GradientsParams::from_grads(grads, &linear); - let linear = optimizer.step(LEARNING_RATE, linear, grads); - - let state_updated = linear.into_record(); - let weights_expected = Data::from([ - [ - -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779, - ], - [ - -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207, - ], - [ - -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967, - ], - [ - -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997, - ], - [ - 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912, - ], - [ - -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126, - ], - ]); - let bias_expected = Data::from([ - -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800, - ]); - - let (weight_updated, bias_updated) = ( - state_updated.weight.to_data(), - state_updated.bias.unwrap().to_data(), - ); - - // println!("\nweight_updated\n{:?}", weight_updated); - // println!("\nbias_updated\n{:?}", bias_updated); - - bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); - weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); - } - - fn given_linear_layer( - weight: Data, - bias: Data, - ) -> nn::Linear { - let record = nn::LinearRecord { - weight: Param::from(Tensor::from_data(weight)), - bias: Some(Param::from(Tensor::from_data(bias))), - }; - - nn::LinearConfig::new(6, 6).init_with(record) - } - - #[allow(dead_code)] - fn create_random_tensor() -> Tensor { - Tensor::::random(Shape::new([2, 20]), Distribution::Default) - } - - fn create_rmsprop( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { - RMSPropConfig { - alpha: 0.99, - epsilon: 1e-9, - centered: false, - weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), - momentum: 0.9, - grad_clipping: None, + use burn_tensor::Shape; + + use super::*; + use crate::module::{Module, Param}; + use crate::optim::{GradientsParams, Optimizer}; + use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; + use crate::tensor::{Data, Distribution, Tensor}; + use crate::{nn, TestAutodiffBackend, TestBackend}; + use tempfile::TempDir; + + const LEARNING_RATE: LearningRate = 0.01; + const ASSERT_PRECISION: usize = 6; + + #[test] + fn test_rmsprop_optimizer_save_load_state() { + let linear = nn::LinearConfig::new(6, 6).init(); + let x = Tensor::::random([2, 6], Distribution::Default); + let mut optimizer = create_rmsprop(); + let grads = linear.forward(x).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let _linear = optimizer.step(LEARNING_RATE, linear, grads); + let temp_dir = TempDir::new().unwrap(); + BinFileRecorder::::default() + .record(optimizer.to_record(), temp_dir.path().join("test_optim")) + .unwrap(); + + let state_optim_before = optimizer.to_record(); + let state_optim_before_copy = optimizer.to_record(); + let optimizer = create_rmsprop(); + let optimizer = optimizer.load_record(state_optim_before_copy); + let state_optim_after = optimizer.to_record(); + + assert_eq!(state_optim_before.len(), state_optim_after.len()); + } + + /// used for test differences and debug + #[test] + fn test_rmsprop_optimizer_with_numbers_basic() { + let linear = given_linear_layer( + Data::from([ + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + ]), + Data::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = RMSPropConfig::new() + .with_alpha(0.99) + .with_epsilon(1e-8) + .with_weight_decay(WeightDecayConfig::new(0.05).into()) + .with_momentum(0.9) + .with_centered(false) + .init(); + + // println!("linear is {:?}", linear); + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + // println!("linear is {:?}", linear); + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + // println!("linear is {:?}", linear); + let state_updated = linear.into_record(); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + // println!("\nweight_updated\n{:?}", weight_updated); + // println!("\nbias_updated\n{:?}", bias_updated); + + let weights_expected = Data::from([ + [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937], + [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809], + [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881], + [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366], + [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005], + [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710], + ]); + let bias_expected = + Data::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + #[test] + fn test_rmsprop_optimizer_with_numbers() { + let linear = given_linear_layer( + Data::from([ + [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], + [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], + [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], + [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], + [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], + [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], + ]), + Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), + ); + let x_1 = Tensor::from_floats([ + [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], + [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], + ]) + .require_grad(); + let x_2 = Tensor::from_floats([ + [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], + [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], + ]) + .require_grad(); + + let mut optimizer = RMSPropConfig::new() + .with_alpha(0.99) + .with_epsilon(1e-8) + .with_weight_decay(WeightDecayConfig::new(0.05).into()) + .with_momentum(0.9) + .with_centered(false) + .init(); + + let grads = linear.forward(x_1).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let grads = linear.forward(x_2).backward(); + let grads = GradientsParams::from_grads(grads, &linear); + let linear = optimizer.step(LEARNING_RATE, linear, grads); + + let state_updated = linear.into_record(); + let weights_expected = Data::from([ + [ + -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779, + ], + [ + -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207, + ], + [ + -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967, + ], + [ + -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997, + ], + [ + 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912, + ], + [ + -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126, + ], + ]); + let bias_expected = Data::from([ + -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800, + ]); + + let (weight_updated, bias_updated) = ( + state_updated.weight.to_data(), + state_updated.bias.unwrap().to_data(), + ); + + // println!("\nweight_updated\n{:?}", weight_updated); + // println!("\nbias_updated\n{:?}", bias_updated); + + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); + } + + fn given_linear_layer( + weight: Data, + bias: Data, + ) -> nn::Linear { + let record = nn::LinearRecord { + weight: Param::from(Tensor::from_data(weight)), + bias: Some(Param::from(Tensor::from_data(bias))), + }; + + nn::LinearConfig::new(6, 6).init_with(record) + } + + #[allow(dead_code)] + fn create_random_tensor() -> Tensor { + Tensor::::random(Shape::new([2, 20]), Distribution::Default) + } + + fn create_rmsprop( + ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + { + RMSPropConfig { + alpha: 0.99, + epsilon: 1e-9, + centered: false, + weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), + momentum: 0.9, + grad_clipping: None, + } + .init() } - .init() - } } diff --git a/burn-core/src/optim/sgd.rs b/burn-core/src/optim/sgd.rs index b5d1526264..b2ed4a4b75 100644 --- a/burn-core/src/optim/sgd.rs +++ b/burn-core/src/optim/sgd.rs @@ -14,163 +14,163 @@ use burn_tensor::backend::{AutodiffBackend, Backend}; /// Configuration to create the [Sgd](Sgd) optimizer. #[derive(Config)] pub struct SgdConfig { - /// [Weight decay](WeightDecayConfig) config. - weight_decay: Option, - /// [Momentum](MomentumConfig) config. - momentum: Option, - /// [Gradient Clipping](GradientClippingConfig) config. - gradient_clipping: Option, + /// [Weight decay](WeightDecayConfig) config. + weight_decay: Option, + /// [Momentum](MomentumConfig) config. + momentum: Option, + /// [Gradient Clipping](GradientClippingConfig) config. + gradient_clipping: Option, } /// Optimizer that implements stochastic gradient descent with momentum. /// /// The optimizer can be configured with [SgdConfig](SgdConfig). pub struct Sgd { - momentum: Option>, - weight_decay: Option>, + momentum: Option>, + weight_decay: Option>, } /// State of [Sgd](Sgd). #[derive(Record, Clone, new)] pub struct SgdState { - momentum: Option>, + momentum: Option>, } impl SgdConfig { - /// Creates a new [SgdConfig](SgdConfig) with default values. - pub fn init>( - &self, - ) -> OptimizerAdaptor, M, B> { - let momentum = self.momentum.as_ref().map(Momentum::new); - let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); - - let mut optim = OptimizerAdaptor::from(Sgd { - momentum, - weight_decay, - }); - if let Some(config) = &self.gradient_clipping { - optim = optim.with_grad_clipping(config.init()); + /// Creates a new [SgdConfig](SgdConfig) with default values. + pub fn init>( + &self, + ) -> OptimizerAdaptor, M, B> { + let momentum = self.momentum.as_ref().map(Momentum::new); + let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); + + let mut optim = OptimizerAdaptor::from(Sgd { + momentum, + weight_decay, + }); + if let Some(config) = &self.gradient_clipping { + optim = optim.with_grad_clipping(config.init()); + } + optim } - optim - } } impl SimpleOptimizer for Sgd { - type State = SgdState; - - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - mut grad: Tensor, - state: Option>, - ) -> (Tensor, Option>) { - let mut state_momemtum = None; - - if let Some(state) = state { - state_momemtum = state.momentum; + type State = SgdState; + + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + mut grad: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + let mut state_momemtum = None; + + if let Some(state) = state { + state_momemtum = state.momentum; + } + + if let Some(weight_decay) = &self.weight_decay { + grad = weight_decay.transform(grad, tensor.clone()); + } + + if let Some(momentum) = &self.momentum { + let (grad_out, state) = momentum.transform(grad, state_momemtum); + state_momemtum = Some(state); + grad = grad_out; + } + + let state = SgdState::new(state_momemtum); + let delta = grad.mul_scalar(lr); + + (tensor - delta, Some(state)) } - if let Some(weight_decay) = &self.weight_decay { - grad = weight_decay.transform(grad, tensor.clone()); + fn to_device(mut state: Self::State, device: &B::Device) -> Self::State { + state.momentum = state.momentum.map(|state| state.to_device(device)); + state } +} - if let Some(momentum) = &self.momentum { - let (grad_out, state) = momentum.transform(grad, state_momemtum); - state_momemtum = Some(state); - grad = grad_out; +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + grad_clipping::GradientClipping, + nn::{Linear, LinearConfig}, + optim::{GradientsParams, Optimizer}, + tensor::{Distribution, Shape}, + TestAutodiffBackend, TestBackend, + }; + + const LEARNING_RATE: LearningRate = 0.02; + + #[test] + fn with_updated_params_should_have_state() { + let layer = layer(); + let mut optim = sgd_with_all(); + let loss = layer.forward(random_tensor()); + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &layer); + let _layer = optim.step(LEARNING_RATE, layer, grads); + + let record = optim.to_record(); + + assert!(!record.is_empty()); } - let state = SgdState::new(state_momemtum); - let delta = grad.mul_scalar(lr); + #[test] + fn without_updated_params_should_not_have_state() { + let optim = sgd_with_all(); + let record = optim.to_record(); + assert!(record.is_empty()); + } - (tensor - delta, Some(state)) - } + #[test] + fn can_attach_gradient_clipping() { + let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5)); + assert!(optim.has_gradient_clipping()); + } - fn to_device(mut state: Self::State, device: &B::Device) -> Self::State { - state.momentum = state.momentum.map(|state| state.to_device(device)); - state - } -} + #[test] + fn should_load_state() { + let layer = layer(); + let mut optim = sgd_with_all(); + let loss = layer.forward(random_tensor()); + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &layer); + let _layer = optim.step(LEARNING_RATE, layer, grads); + + let record = optim.to_record(); + let optim_new = sgd_with_all(); + let record_new = optim_new.to_record(); + let optim_new = optim_new.load_record(record.clone()); + let state_restored = optim_new.to_record(); + + assert_ne!(record.len(), record_new.len()); + assert_eq!(record.len(), state_restored.len()); + } -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - grad_clipping::GradientClipping, - nn::{Linear, LinearConfig}, - optim::{GradientsParams, Optimizer}, - tensor::{Distribution, Shape}, - TestAutodiffBackend, TestBackend, - }; - - const LEARNING_RATE: LearningRate = 0.02; - - #[test] - fn with_updated_params_should_have_state() { - let layer = layer(); - let mut optim = sgd_with_all(); - let loss = layer.forward(random_tensor()); - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &layer); - let _layer = optim.step(LEARNING_RATE, layer, grads); - - let record = optim.to_record(); - - assert!(!record.is_empty()); - } - - #[test] - fn without_updated_params_should_not_have_state() { - let optim = sgd_with_all(); - let record = optim.to_record(); - assert!(record.is_empty()); - } - - #[test] - fn can_attach_gradient_clipping() { - let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5)); - assert!(optim.has_gradient_clipping()); - } - - #[test] - fn should_load_state() { - let layer = layer(); - let mut optim = sgd_with_all(); - let loss = layer.forward(random_tensor()); - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &layer); - let _layer = optim.step(LEARNING_RATE, layer, grads); - - let record = optim.to_record(); - let optim_new = sgd_with_all(); - let record_new = optim_new.to_record(); - let optim_new = optim_new.load_record(record.clone()); - let state_restored = optim_new.to_record(); - - assert_ne!(record.len(), record_new.len()); - assert_eq!(record.len(), state_restored.len()); - } - - fn random_tensor() -> Tensor { - Tensor::::random(Shape::new([2, 20]), Distribution::Default) - } - - fn layer() -> Linear { - LinearConfig::new(20, 20).with_bias(true).init() - } - - fn sgd_with_all( - ) -> OptimizerAdaptor, Linear, TestAutodiffBackend> { - SgdConfig { - weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), - momentum: Some(MomentumConfig { - momentum: 0.9, - dampening: 0.1, - nesterov: true, - }), - gradient_clipping: None, + fn random_tensor() -> Tensor { + Tensor::::random(Shape::new([2, 20]), Distribution::Default) + } + + fn layer() -> Linear { + LinearConfig::new(20, 20).with_bias(true).init() + } + + fn sgd_with_all( + ) -> OptimizerAdaptor, Linear, TestAutodiffBackend> { + SgdConfig { + weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), + momentum: Some(MomentumConfig { + momentum: 0.9, + dampening: 0.1, + nesterov: true, + }), + gradient_clipping: None, + } + .init() } - .init() - } } diff --git a/burn-core/src/optim/simple/adaptor.rs b/burn-core/src/optim/simple/adaptor.rs index 996d9b6d27..0b44c84183 100644 --- a/burn-core/src/optim/simple/adaptor.rs +++ b/burn-core/src/optim/simple/adaptor.rs @@ -1,9 +1,9 @@ use super::{record::AdaptorRecord, SimpleOptimizer}; use crate::{ - grad_clipping::GradientClipping, - module::{AutodiffModule, ModuleMapper, ParamId}, - optim::{GradientsParams, Optimizer}, - LearningRate, + grad_clipping::GradientClipping, + module::{AutodiffModule, ModuleMapper, ParamId}, + optim::{GradientsParams, Optimizer}, + LearningRate, }; use burn_tensor::{backend::AutodiffBackend, Tensor}; use core::marker::PhantomData; @@ -13,143 +13,143 @@ use hashbrown::HashMap; /// an [optimizer](Optimizer). pub struct OptimizerAdaptor where - O: SimpleOptimizer, - M: AutodiffModule, - B: AutodiffBackend, + O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, { - optim: O, - records: HashMap>, - module: PhantomData, - grad_clipping: Option, + optim: O, + records: HashMap>, + module: PhantomData, + grad_clipping: Option, } impl From for OptimizerAdaptor where - B: AutodiffBackend, - M: AutodiffModule, - O: SimpleOptimizer, + B: AutodiffBackend, + M: AutodiffModule, + O: SimpleOptimizer, { - fn from(optim: O) -> Self { - Self { - optim, - records: HashMap::new(), - module: PhantomData, - grad_clipping: None, + fn from(optim: O) -> Self { + Self { + optim, + records: HashMap::new(), + module: PhantomData, + grad_clipping: None, + } } - } } impl OptimizerAdaptor where - O: SimpleOptimizer, - M: AutodiffModule, - B: AutodiffBackend, + O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, { - /// Sets the gradient clipping. - /// - /// # Arguments - /// - /// * `gradient_clipping` - The gradient clipping. - /// - /// # Returns - /// - /// The optimizer. - pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self { - self.grad_clipping = Some(gradient_clipping); - self - } - - #[cfg(test)] - pub(crate) fn has_gradient_clipping(&self) -> bool { - self.grad_clipping.is_some() - } + /// Sets the gradient clipping. + /// + /// # Arguments + /// + /// * `gradient_clipping` - The gradient clipping. + /// + /// # Returns + /// + /// The optimizer. + pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self { + self.grad_clipping = Some(gradient_clipping); + self + } + + #[cfg(test)] + pub(crate) fn has_gradient_clipping(&self) -> bool { + self.grad_clipping.is_some() + } } impl Optimizer for OptimizerAdaptor where - B: AutodiffBackend, - M: AutodiffModule, - O: SimpleOptimizer, + B: AutodiffBackend, + M: AutodiffModule, + O: SimpleOptimizer, { - type Record = HashMap>; - - fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M { - let mut mapper = SimpleOptimizerMapper::::new( - &self.optim, - &mut self.records, - &mut grads, - lr, - self.grad_clipping.as_ref(), - ); - module.map(&mut mapper) - } - - fn to_record(&self) -> Self::Record { - self.records.clone() - } - - fn load_record(mut self, record: Self::Record) -> Self { - self.records = record; - self - } + type Record = HashMap>; + + fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M { + let mut mapper = SimpleOptimizerMapper::::new( + &self.optim, + &mut self.records, + &mut grads, + lr, + self.grad_clipping.as_ref(), + ); + module.map(&mut mapper) + } + + fn to_record(&self) -> Self::Record { + self.records.clone() + } + + fn load_record(mut self, record: Self::Record) -> Self { + self.records = record; + self + } } #[derive(new)] struct SimpleOptimizerMapper<'a, M, B, O> where - M: AutodiffModule, - B: AutodiffBackend, - O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, + O: SimpleOptimizer, { - optimizer: &'a O, - records: &'a mut HashMap>, - grads: &'a mut GradientsParams, - lr: LearningRate, - phantom: PhantomData, - grad_clipping: Option<&'a GradientClipping>, + optimizer: &'a O, + records: &'a mut HashMap>, + grads: &'a mut GradientsParams, + lr: LearningRate, + phantom: PhantomData, + grad_clipping: Option<&'a GradientClipping>, } impl<'a, M, B, O> ModuleMapper for SimpleOptimizerMapper<'a, M, B, O> where - M: AutodiffModule, - B: AutodiffBackend, - O: SimpleOptimizer, + M: AutodiffModule, + B: AutodiffBackend, + O: SimpleOptimizer, { - fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor { - let grad = self.grads.remove(id); - - if let Some(grad) = grad { - let device = grad.device(); - let is_require_grad = tensor.is_require_grad(); - let (key, record) = self.records.remove_entry(id).unzip(); - - let clipped_grad = if let Some(g_clipping) = self.grad_clipping { - g_clipping.clip_gradient(grad) - } else { - grad - }; - - let (tensor, state) = self.optimizer.step( - self.lr, - tensor.inner(), - clipped_grad, - record.map(|record| O::to_device(record.into_state(), &device)), - ); - - if let Some(state) = state { - self.records.insert( - key.unwrap_or_else(|| id.clone()), - AdaptorRecord::from_state(state), - ); - } - - let mut tensor = Tensor::from_inner(tensor); - if is_require_grad { - tensor = tensor.require_grad(); - } - return tensor; + fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor { + let grad = self.grads.remove(id); + + if let Some(grad) = grad { + let device = grad.device(); + let is_require_grad = tensor.is_require_grad(); + let (key, record) = self.records.remove_entry(id).unzip(); + + let clipped_grad = if let Some(g_clipping) = self.grad_clipping { + g_clipping.clip_gradient(grad) + } else { + grad + }; + + let (tensor, state) = self.optimizer.step( + self.lr, + tensor.inner(), + clipped_grad, + record.map(|record| O::to_device(record.into_state(), &device)), + ); + + if let Some(state) = state { + self.records.insert( + key.unwrap_or_else(|| id.clone()), + AdaptorRecord::from_state(state), + ); + } + + let mut tensor = Tensor::from_inner(tensor); + if is_require_grad { + tensor = tensor.require_grad(); + } + return tensor; + } + + tensor } - - tensor - } } diff --git a/burn-core/src/optim/simple/base.rs b/burn-core/src/optim/simple/base.rs index 97d2fb961d..5737960ad9 100644 --- a/burn-core/src/optim/simple/base.rs +++ b/burn-core/src/optim/simple/base.rs @@ -8,26 +8,26 @@ use burn_tensor::{backend::Backend, Tensor}; /// module parameter structure, handle tracked and untracked tensors, and the likes. pub trait SimpleOptimizer: Send + Sync where - B: Backend, + B: Backend, { - /// The state of the optimizer. It also implements [record](Record), so that it can be saved. - type State: Record + Clone + 'static; + /// The state of the optimizer. It also implements [record](Record), so that it can be saved. + type State: Record + Clone + 'static; - /// The optimizer step is performed for one tensor at a time with its gradient and state. - /// - /// Note that the state is passed as parameter, so implementations don't have to handle - /// the saving and loading of recorded states. - fn step( - &self, - lr: LearningRate, - tensor: Tensor, - grad: Tensor, - state: Option>, - ) -> (Tensor, Option>); + /// The optimizer step is performed for one tensor at a time with its gradient and state. + /// + /// Note that the state is passed as parameter, so implementations don't have to handle + /// the saving and loading of recorded states. + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + state: Option>, + ) -> (Tensor, Option>); - /// Change the device of the state. - /// - /// This function will be called accordindly to have the state on the same device as the - /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called. - fn to_device(state: Self::State, device: &B::Device) -> Self::State; + /// Change the device of the state. + /// + /// This function will be called accordindly to have the state on the same device as the + /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called. + fn to_device(state: Self::State, device: &B::Device) -> Self::State; } diff --git a/burn-core/src/optim/simple/record/base.rs b/burn-core/src/optim/simple/record/base.rs index 75b2196fa9..e0bc9199d5 100644 --- a/burn-core/src/optim/simple/record/base.rs +++ b/burn-core/src/optim/simple/record/base.rs @@ -1,7 +1,7 @@ use super::{AdaptorRecordItemV1, AdaptorRecordV1}; use crate::{ - optim::SimpleOptimizer, - record::{PrecisionSettings, Record}, + optim::SimpleOptimizer, + record::{PrecisionSettings, Record}, }; use burn_tensor::backend::Backend; use serde::{Deserialize, Serialize}; @@ -10,76 +10,76 @@ use serde::{Deserialize, Serialize}; /// /// Records are versioned for backward compatibility, so old records can be loaded. pub enum AdaptorRecord, B: Backend> { - /// Version 1. - V1(AdaptorRecordV1), + /// Version 1. + V1(AdaptorRecordV1), } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub enum AdaptorRecordItem, B: Backend, S: PrecisionSettings> { - /// Version 1. - V1(AdaptorRecordItemV1), + /// Version 1. + V1(AdaptorRecordItemV1), } impl Record for AdaptorRecord where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - type Item = AdaptorRecordItem; + type Item = AdaptorRecordItem; - fn into_item(self) -> Self::Item { - match self { - AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()), + fn into_item(self) -> Self::Item { + match self { + AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()), + } } - } - fn from_item(item: Self::Item) -> Self { - match item { - AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item)), + fn from_item(item: Self::Item) -> Self { + match item { + AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item)), + } } - } } impl Clone for AdaptorRecord where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - fn clone(&self) -> Self { - match self { - AdaptorRecord::V1(record) => Self::V1(record.clone()), + fn clone(&self) -> Self { + match self { + AdaptorRecord::V1(record) => Self::V1(record.clone()), + } } - } } impl AdaptorRecord where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - /// Converts the record into the optimizer state. - /// - /// # Returns - /// - /// The optimizer state. - pub fn into_state(self) -> O::State { - match self { - AdaptorRecord::V1(record) => record.into_state(), + /// Converts the record into the optimizer state. + /// + /// # Returns + /// + /// The optimizer state. + pub fn into_state(self) -> O::State { + match self { + AdaptorRecord::V1(record) => record.into_state(), + } } - } - /// Converts the optimizer state into the record. - /// - /// # Arguments - /// - /// * `state`: The optimizer state. - /// - /// # Returns - /// - /// The record. - pub fn from_state(state: O::State) -> Self { - Self::V1(AdaptorRecordV1::from_state(state)) - } + /// Converts the optimizer state into the record. + /// + /// # Arguments + /// + /// * `state`: The optimizer state. + /// + /// # Returns + /// + /// The record. + pub fn from_state(state: O::State) -> Self { + Self::V1(AdaptorRecordV1::from_state(state)) + } } diff --git a/burn-core/src/optim/simple/record/v1.rs b/burn-core/src/optim/simple/record/v1.rs index 7c721f6347..9c47403473 100644 --- a/burn-core/src/optim/simple/record/v1.rs +++ b/burn-core/src/optim/simple/record/v1.rs @@ -1,6 +1,6 @@ use crate::{ - optim::SimpleOptimizer, - record::{PrecisionSettings, Record}, + optim::SimpleOptimizer, + record::{PrecisionSettings, Record}, }; use burn_tensor::backend::Backend; use core::any::Any; @@ -8,178 +8,178 @@ use serde::{Deserialize, Serialize}; /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. pub enum AdaptorRecordV1, B: Backend> { - /// Rank 1. - Rank1(O::State<1>), + /// Rank 1. + Rank1(O::State<1>), - /// Rank 2. - Rank2(O::State<2>), + /// Rank 2. + Rank2(O::State<2>), - /// Rank 3. - Rank3(O::State<3>), + /// Rank 3. + Rank3(O::State<3>), - /// Rank 4. - Rank4(O::State<4>), + /// Rank 4. + Rank4(O::State<4>), - /// Rank 5. - Rank5(O::State<5>), + /// Rank 5. + Rank5(O::State<5>), - /// Rank 6. - Rank6(O::State<6>), + /// Rank 6. + Rank6(O::State<6>), - /// Rank 7. - Rank7(O::State<7>), + /// Rank 7. + Rank7(O::State<7>), - /// Rank 8. - Rank8(O::State<8>), + /// Rank 8. + Rank8(O::State<8>), } impl, B: Backend> Clone for AdaptorRecordV1 { - fn clone(&self) -> Self { - match self { - AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()), - AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()), - AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()), - AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()), - AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()), - AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()), - AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()), - AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()), + fn clone(&self) -> Self { + match self { + AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()), + AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()), + AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()), + AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()), + AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()), + AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()), + AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()), + AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()), + } } - } } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub enum AdaptorRecordItemV1, B: Backend, S: PrecisionSettings> { - /// Rank 1. - Rank1( as Record>::Item), + /// Rank 1. + Rank1( as Record>::Item), - /// Rank 2. - Rank2( as Record>::Item), + /// Rank 2. + Rank2( as Record>::Item), - /// Rank 3. - Rank3( as Record>::Item), + /// Rank 3. + Rank3( as Record>::Item), - /// Rank 4. - Rank4( as Record>::Item), + /// Rank 4. + Rank4( as Record>::Item), - /// Rank 5. - Rank5( as Record>::Item), + /// Rank 5. + Rank5( as Record>::Item), - /// Rank 6. - Rank6( as Record>::Item), + /// Rank 6. + Rank6( as Record>::Item), - /// Rank 7. - Rank7( as Record>::Item), + /// Rank 7. + Rank7( as Record>::Item), - /// Rank 8. - Rank8( as Record>::Item), + /// Rank 8. + Rank8( as Record>::Item), } impl AdaptorRecordV1 where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - /// Convert the record into the state. - /// - /// # Returns - /// - /// The state. - /// - /// # Panics - /// - /// Panics if the state dimension is not supported. - pub fn into_state(self) -> O::State { - let boxed_state: Box = match self { - AdaptorRecordV1::Rank1(s) => Box::new(s), - AdaptorRecordV1::Rank2(s) => Box::new(s), - AdaptorRecordV1::Rank3(s) => Box::new(s), - AdaptorRecordV1::Rank4(s) => Box::new(s), - AdaptorRecordV1::Rank5(s) => Box::new(s), - AdaptorRecordV1::Rank6(s) => Box::new(s), - AdaptorRecordV1::Rank7(s) => Box::new(s), - AdaptorRecordV1::Rank8(s) => Box::new(s), - }; - let state = boxed_state - .downcast::>() - .expect("Unsupported state dimension, dimension up to 8 are supported."); - *state - } - - /// Convert the state into the record. - /// - /// # Arguments - /// - /// * `state`: The state. - /// - /// # Returns - /// - /// The record. - pub fn from_state(state: O::State) -> Self { - let state: Box = Box::new(state); - - match D { - 1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()), - 2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()), - 3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()), - 4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()), - 5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()), - 6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()), - 7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()), - 8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()), - _ => panic!("Unsupported state dimension, dimension up to 8 are supported."), + /// Convert the record into the state. + /// + /// # Returns + /// + /// The state. + /// + /// # Panics + /// + /// Panics if the state dimension is not supported. + pub fn into_state(self) -> O::State { + let boxed_state: Box = match self { + AdaptorRecordV1::Rank1(s) => Box::new(s), + AdaptorRecordV1::Rank2(s) => Box::new(s), + AdaptorRecordV1::Rank3(s) => Box::new(s), + AdaptorRecordV1::Rank4(s) => Box::new(s), + AdaptorRecordV1::Rank5(s) => Box::new(s), + AdaptorRecordV1::Rank6(s) => Box::new(s), + AdaptorRecordV1::Rank7(s) => Box::new(s), + AdaptorRecordV1::Rank8(s) => Box::new(s), + }; + let state = boxed_state + .downcast::>() + .expect("Unsupported state dimension, dimension up to 8 are supported."); + *state + } + + /// Convert the state into the record. + /// + /// # Arguments + /// + /// * `state`: The state. + /// + /// # Returns + /// + /// The record. + pub fn from_state(state: O::State) -> Self { + let state: Box = Box::new(state); + + match D { + 1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()), + 2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()), + 3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()), + 4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()), + 5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()), + 6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()), + 7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()), + 8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()), + _ => panic!("Unsupported state dimension, dimension up to 8 are supported."), + } } - } } impl Record for AdaptorRecordV1 where - O: SimpleOptimizer, - B: Backend, + O: SimpleOptimizer, + B: Backend, { - type Item = AdaptorRecordItemV1; - - fn into_item(self) -> Self::Item { - match self { - AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()), - AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()), - AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()), - AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()), - AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()), - AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()), - AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()), - AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()), + type Item = AdaptorRecordItemV1; + + fn into_item(self) -> Self::Item { + match self { + AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()), + AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()), + AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()), + AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()), + AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()), + AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()), + AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()), + AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()), + } } - } - - fn from_item(item: Self::Item) -> Self { - match item { - AdaptorRecordItemV1::Rank1(item) => { - AdaptorRecordV1::Rank1( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank2(item) => { - AdaptorRecordV1::Rank2( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank3(item) => { - AdaptorRecordV1::Rank3( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank4(item) => { - AdaptorRecordV1::Rank4( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank5(item) => { - AdaptorRecordV1::Rank5( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank6(item) => { - AdaptorRecordV1::Rank6( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank7(item) => { - AdaptorRecordV1::Rank7( as Record>::from_item(item)) - } - AdaptorRecordItemV1::Rank8(item) => { - AdaptorRecordV1::Rank8( as Record>::from_item(item)) - } + + fn from_item(item: Self::Item) -> Self { + match item { + AdaptorRecordItemV1::Rank1(item) => { + AdaptorRecordV1::Rank1( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank2(item) => { + AdaptorRecordV1::Rank2( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank3(item) => { + AdaptorRecordV1::Rank3( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank4(item) => { + AdaptorRecordV1::Rank4( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank5(item) => { + AdaptorRecordV1::Rank5( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank6(item) => { + AdaptorRecordV1::Rank6( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank7(item) => { + AdaptorRecordV1::Rank7( as Record>::from_item(item)) + } + AdaptorRecordItemV1::Rank8(item) => { + AdaptorRecordV1::Rank8( as Record>::from_item(item)) + } + } } - } } diff --git a/burn-core/src/optim/visitor.rs b/burn-core/src/optim/visitor.rs index 54c864fcbb..1631fcf74e 100644 --- a/burn-core/src/optim/visitor.rs +++ b/burn-core/src/optim/visitor.rs @@ -5,42 +5,40 @@ use core::marker::PhantomData; #[derive(new)] pub struct GradientsParamsConverter<'a, M: AutodiffModule, B: AutodiffBackend> { - grads: B::Gradients, - grads_params: &'a mut GradientsParams, - phatom: PhantomData, + grads: B::Gradients, + grads_params: &'a mut GradientsParams, + phatom: PhantomData, } #[derive(new)] pub struct GradientsParamsChangeDevice<'a, M: AutodiffModule, B: AutodiffBackend> { - device: &'a B::Device, - grads: &'a mut GradientsParams, - phatom: PhantomData, + device: &'a B::Device, + grads: &'a mut GradientsParams, + phatom: PhantomData, } impl<'a, B, M> ModuleVisitor for GradientsParamsConverter<'a, M, B> where - B: AutodiffBackend, - M: AutodiffModule, + B: AutodiffBackend, + M: AutodiffModule, { - fn visit(&mut self, id: &ParamId, tensor: &Tensor) { - if let Some(grad) = tensor.grad_remove(&mut self.grads) { - self - .grads_params - .register::(id.clone(), grad); + fn visit(&mut self, id: &ParamId, tensor: &Tensor) { + if let Some(grad) = tensor.grad_remove(&mut self.grads) { + self.grads_params + .register::(id.clone(), grad); + } } - } } impl<'a, B, M> ModuleVisitor for GradientsParamsChangeDevice<'a, M, B> where - B: AutodiffBackend, - M: AutodiffModule, + B: AutodiffBackend, + M: AutodiffModule, { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { - if let Some(grad) = self.grads.remove::(id) { - self - .grads - .register::(id.clone(), grad.to_device(self.device)); + fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + if let Some(grad) = self.grads.remove::(id) { + self.grads + .register::(id.clone(), grad.to_device(self.device)); + } } - } } diff --git a/burn-core/src/record/base.rs b/burn-core/src/record/base.rs index 522075d44f..0c1633e8ec 100644 --- a/burn-core/src/record/base.rs +++ b/burn-core/src/record/base.rs @@ -5,12 +5,12 @@ use serde::{de::DeserializeOwned, Serialize}; /// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings). pub trait Record: Send + Sync { - /// Type of the item that can be serialized and deserialized. - type Item: Serialize + DeserializeOwned; + /// Type of the item that can be serialized and deserialized. + type Item: Serialize + DeserializeOwned; - /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings). - fn into_item(self) -> Self::Item; + /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings). + fn into_item(self) -> Self::Item; - /// Convert the given item into a record. - fn from_item(item: Self::Item) -> Self; + /// Convert the given item into a record. + fn from_item(item: Self::Item) -> Self; } diff --git a/burn-core/src/record/file.rs b/burn-core/src/record/file.rs index 9cbbb16dcc..00c9b0192e 100644 --- a/burn-core/src/record/file.rs +++ b/burn-core/src/record/file.rs @@ -7,10 +7,10 @@ use std::{fs::File, path::PathBuf}; /// Recorder trait specialized to save and load data to and from files. pub trait FileRecorder: - Recorder + Recorder { - /// File extension of the format used by the recorder. - fn file_extension() -> &'static str; + /// File extension of the format used by the recorder. + fn file_extension() -> &'static str; } /// Default [file recorder](FileRecorder). @@ -19,360 +19,361 @@ pub type DefaultFileRecorder = NamedMpkGzFileRecorder; /// File recorder using the [bincode format](bincode). #[derive(new, Debug, Default, Clone)] pub struct BinFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [bincode format](bincode) compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct BinGzFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [json format](serde_json) compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct JsonGzFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using [pretty json format](serde_json) for easy readability. #[derive(new, Debug, Default, Clone)] pub struct PrettyJsonFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [named msgpack](rmp_serde) format compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct NamedMpkGzFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } /// File recorder using the [named msgpack](rmp_serde) format. #[derive(new, Debug, Default, Clone)] pub struct NamedMpkFileRecorder { - _settings: PhantomData, + _settings: PhantomData, } impl FileRecorder for BinGzFileRecorder { - fn file_extension() -> &'static str { - "bin.gz" - } + fn file_extension() -> &'static str { + "bin.gz" + } } impl FileRecorder for BinFileRecorder { - fn file_extension() -> &'static str { - "bin" - } + fn file_extension() -> &'static str { + "bin" + } } impl FileRecorder for JsonGzFileRecorder { - fn file_extension() -> &'static str { - "json.gz" - } + fn file_extension() -> &'static str { + "json.gz" + } } impl FileRecorder for PrettyJsonFileRecorder { - fn file_extension() -> &'static str { - "json" - } + fn file_extension() -> &'static str { + "json" + } } impl FileRecorder for NamedMpkGzFileRecorder { - fn file_extension() -> &'static str { - "mpk.gz" - } + fn file_extension() -> &'static str { + "mpk.gz" + } } impl FileRecorder for NamedMpkFileRecorder { - fn file_extension() -> &'static str { - "mpk" - } + fn file_extension() -> &'static str { + "mpk" + } } macro_rules! str2reader { - ( + ( $file:expr ) => {{ - $file.set_extension(::file_extension()); - let path = $file.as_path(); - - File::open(path) - .map_err(|err| match err.kind() { - std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), - _ => RecorderError::Unknown(err.to_string()), - }) - .map(|file| BufReader::new(file)) - }}; + $file.set_extension(::file_extension()); + let path = $file.as_path(); + + File::open(path) + .map_err(|err| match err.kind() { + std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), + _ => RecorderError::Unknown(err.to_string()), + }) + .map(|file| BufReader::new(file)) + }}; } macro_rules! str2writer { - ( + ( $file:expr ) => {{ - $file.set_extension(::file_extension()); - let path = $file.as_path(); + $file.set_extension(::file_extension()); + let path = $file.as_path(); + + if path.exists() { + log::info!("File exists, replacing"); + std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?; + } + + File::create(path) + .map_err(|err| match err.kind() { + std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), + _ => RecorderError::Unknown(err.to_string()), + }) + .map(|file| BufWriter::new(file)) + }}; +} - if path.exists() { - log::info!("File exists, replacing"); - std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?; +impl Recorder for BinGzFileRecorder { + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let config = bin_config(); + let writer = str2writer!(file)?; + let mut writer = GzEncoder::new(writer, Compression::default()); + + bincode::serde::encode_into_std_write(&item, &mut writer, config) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) } - File::create(path) - .map_err(|err| match err.kind() { - std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), - _ => RecorderError::Unknown(err.to_string()), - }) - .map(|file| BufWriter::new(file)) - }}; -} + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let mut reader = GzDecoder::new(reader); + let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; -impl Recorder for BinGzFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let config = bin_config(); - let writer = str2writer!(file)?; - let mut writer = GzEncoder::new(writer, Compression::default()); - - bincode::serde::encode_into_std_write(&item, &mut writer, config) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let mut reader = GzDecoder::new(reader); - let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + Ok(state) + } } impl Recorder for BinFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let config = bin_config(); - let mut writer = str2writer!(file)?; - bincode::serde::encode_into_std_write(&item, &mut writer, config) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let mut reader = str2reader!(file)?; - let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let config = bin_config(); + let mut writer = str2writer!(file)?; + bincode::serde::encode_into_std_write(&item, &mut writer, config) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let mut reader = str2reader!(file)?; + let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(state) + } } impl Recorder for JsonGzFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let writer = str2writer!(file)?; - let writer = GzEncoder::new(writer, Compression::default()); - serde_json::to_writer(writer, &item).map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let reader = GzDecoder::new(reader); - let state = - serde_json::from_reader(reader).map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let writer = str2writer!(file)?; + let writer = GzEncoder::new(writer, Compression::default()); + serde_json::to_writer(writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let reader = GzDecoder::new(reader); + let state = serde_json::from_reader(reader) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } impl Recorder for PrettyJsonFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let writer = str2writer!(file)?; - serde_json::to_writer_pretty(writer, &item) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let state = - serde_json::from_reader(reader).map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let writer = str2writer!(file)?; + serde_json::to_writer_pretty(writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let state = serde_json::from_reader(reader) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } impl Recorder for NamedMpkGzFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let writer = str2writer!(file)?; - let mut writer = GzEncoder::new(writer, Compression::default()); - rmp_serde::encode::write_named(&mut writer, &item) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let reader = GzDecoder::new(reader); - let state = rmp_serde::decode::from_read(reader) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let writer = str2writer!(file)?; + let mut writer = GzEncoder::new(writer, Compression::default()); + rmp_serde::encode::write_named(&mut writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let reader = GzDecoder::new(reader); + let state = rmp_serde::decode::from_read(reader) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } impl Recorder for NamedMpkFileRecorder { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let mut writer = str2writer!(file)?; - - rmp_serde::encode::write_named(&mut writer, &item) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let reader = str2reader!(file)?; - let state = rmp_serde::decode::from_read(reader) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - - Ok(state) - } + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let mut writer = str2writer!(file)?; + + rmp_serde::encode::write_named(&mut writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let reader = str2reader!(file)?; + let state = rmp_serde::decode::from_read(reader) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + + Ok(state) + } } #[cfg(test)] mod tests { - use burn_tensor::backend::Backend; - - use super::*; - use crate::{ - module::Module, - nn::{ - conv::{Conv2d, Conv2dConfig}, - Linear, LinearConfig, - }, - record::{BinBytesRecorder, FullPrecisionSettings}, - TestBackend, - }; - - use crate as burn; - - static FILE_PATH: &str = "/tmp/burn_test_file_recorder"; - - #[test] - fn test_can_save_and_load_jsongz_format() { - test_can_save_and_load(JsonGzFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_bin_format() { - test_can_save_and_load(BinFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_bingz_format() { - test_can_save_and_load(BinGzFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_pretty_json_format() { - test_can_save_and_load(PrettyJsonFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_mpkgz_format() { - test_can_save_and_load(NamedMpkGzFileRecorder::::default()) - } - - #[test] - fn test_can_save_and_load_mpk_format() { - test_can_save_and_load(NamedMpkFileRecorder::::default()) - } - - fn test_can_save_and_load(recorder: Recorder) { - let model_before = create_model(); - recorder - .record(model_before.clone().into_record(), FILE_PATH.into()) - .unwrap(); - - let model_after = create_model().load_record(recorder.load(FILE_PATH.into()).unwrap()); - - let byte_recorder = BinBytesRecorder::::default(); - let model_bytes_before = byte_recorder - .record(model_before.into_record(), ()) - .unwrap(); - let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap(); - - assert_eq!(model_bytes_after, model_bytes_before); - } - - #[derive(Module, Debug)] - pub struct Model { - conv2d1: Conv2d, - linear1: Linear, - phantom: core::marker::PhantomData, - } - - pub fn create_model() -> Model { - let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(); - - let linear1 = LinearConfig::new(32, 32).with_bias(true).init(); - - Model { - conv2d1, - linear1, - phantom: core::marker::PhantomData, + use burn_tensor::backend::Backend; + + use super::*; + use crate::{ + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + Linear, LinearConfig, + }, + record::{BinBytesRecorder, FullPrecisionSettings}, + TestBackend, + }; + + use crate as burn; + + static FILE_PATH: &str = "/tmp/burn_test_file_recorder"; + + #[test] + fn test_can_save_and_load_jsongz_format() { + test_can_save_and_load(JsonGzFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_bin_format() { + test_can_save_and_load(BinFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_bingz_format() { + test_can_save_and_load(BinGzFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_pretty_json_format() { + test_can_save_and_load(PrettyJsonFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_mpkgz_format() { + test_can_save_and_load(NamedMpkGzFileRecorder::::default()) + } + + #[test] + fn test_can_save_and_load_mpk_format() { + test_can_save_and_load(NamedMpkFileRecorder::::default()) + } + + fn test_can_save_and_load(recorder: Recorder) { + let model_before = create_model(); + recorder + .record(model_before.clone().into_record(), FILE_PATH.into()) + .unwrap(); + + let model_after = create_model().load_record(recorder.load(FILE_PATH.into()).unwrap()); + + let byte_recorder = BinBytesRecorder::::default(); + let model_bytes_before = byte_recorder + .record(model_before.into_record(), ()) + .unwrap(); + let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap(); + + assert_eq!(model_bytes_after, model_bytes_before); + } + + #[derive(Module, Debug)] + pub struct Model { + conv2d1: Conv2d, + linear1: Linear, + phantom: core::marker::PhantomData, + } + + pub fn create_model() -> Model { + let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]).init(); + + let linear1 = LinearConfig::new(32, 32).with_bias(true).init(); + + Model { + conv2d1, + linear1, + phantom: core::marker::PhantomData, + } } - } } diff --git a/burn-core/src/record/memory.rs b/burn-core/src/record/memory.rs index f96f1f8b17..545ad87e64 100644 --- a/burn-core/src/record/memory.rs +++ b/burn-core/src/record/memory.rs @@ -9,42 +9,42 @@ use serde::{de::DeserializeOwned, Serialize}; /// This is especially useful in no_std environment where weights are stored directly in /// compiled binaries. pub trait BytesRecorder: - Recorder, LoadArgs = Vec> + Recorder, LoadArgs = Vec> { } /// In memory recorder using the [bincode format](bincode). #[derive(new, Debug, Default, Clone)] pub struct BinBytesRecorder { - _settings: core::marker::PhantomData, + _settings: core::marker::PhantomData, } impl BytesRecorder for BinBytesRecorder {} impl Recorder for BinBytesRecorder { - type Settings = S; - type RecordArgs = (); - type RecordOutput = Vec; - type LoadArgs = Vec; - - fn save_item( - &self, - item: I, - _args: Self::RecordArgs, - ) -> Result { - Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap()) - } - fn load_item(&self, args: Self::LoadArgs) -> Result { - let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap(); - Ok(state) - } + type Settings = S; + type RecordArgs = (); + type RecordOutput = Vec; + type LoadArgs = Vec; + + fn save_item( + &self, + item: I, + _args: Self::RecordArgs, + ) -> Result { + Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap()) + } + fn load_item(&self, args: Self::LoadArgs) -> Result { + let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap(); + Ok(state) + } } #[cfg(feature = "std")] /// In memory recorder using the [Named MessagePack](rmp_serde). #[derive(new, Debug, Default, Clone)] pub struct NamedMpkBytesRecorder { - _settings: core::marker::PhantomData, + _settings: core::marker::PhantomData, } #[cfg(feature = "std")] @@ -52,53 +52,53 @@ impl BytesRecorder for NamedMpkBytesRecorder {} #[cfg(feature = "std")] impl Recorder for NamedMpkBytesRecorder { - type Settings = S; - type RecordArgs = (); - type RecordOutput = Vec; - type LoadArgs = Vec; - - fn save_item( - &self, - item: I, - _args: Self::RecordArgs, - ) -> Result { - rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string())) - } - fn load_item(&self, args: Self::LoadArgs) -> Result { - rmp_serde::decode::from_slice(&args).map_err(|e| RecorderError::Unknown(e.to_string())) - } + type Settings = S; + type RecordArgs = (); + type RecordOutput = Vec; + type LoadArgs = Vec; + + fn save_item( + &self, + item: I, + _args: Self::RecordArgs, + ) -> Result { + rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string())) + } + fn load_item(&self, args: Self::LoadArgs) -> Result { + rmp_serde::decode::from_slice(&args).map_err(|e| RecorderError::Unknown(e.to_string())) + } } #[cfg(test)] mod tests { - use super::*; - use crate::{module::Module, nn, record::FullPrecisionSettings, TestBackend}; - - #[test] - fn test_can_save_and_load_bin_format() { - test_can_save_and_load(BinBytesRecorder::::default()) - } - - #[cfg(feature = "std")] - #[test] - fn test_can_save_and_load_named_mpk_format() { - test_can_save_and_load(NamedMpkBytesRecorder::::default()) - } - - fn test_can_save_and_load(recorder: Recorder) { - let model1 = create_model(); - let model2 = create_model(); - let bytes1 = recorder.record(model1.into_record(), ()).unwrap(); - let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap(); - - let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap()); - let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap(); - - assert_ne!(bytes1, bytes2); - assert_eq!(bytes1, bytes2_after); - } - - pub fn create_model() -> nn::Linear { - nn::LinearConfig::new(32, 32).with_bias(true).init() - } + use super::*; + use crate::{module::Module, nn, record::FullPrecisionSettings, TestBackend}; + + #[test] + fn test_can_save_and_load_bin_format() { + test_can_save_and_load(BinBytesRecorder::::default()) + } + + #[cfg(feature = "std")] + #[test] + fn test_can_save_and_load_named_mpk_format() { + test_can_save_and_load(NamedMpkBytesRecorder::::default()) + } + + fn test_can_save_and_load(recorder: Recorder) { + let model1 = create_model(); + let model2 = create_model(); + let bytes1 = recorder.record(model1.into_record(), ()).unwrap(); + let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap(); + + let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap()); + let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap(); + + assert_ne!(bytes1, bytes2); + assert_eq!(bytes1, bytes2_after); + } + + pub fn create_model() -> nn::Linear { + nn::LinearConfig::new(32, 32).with_bias(true).init() + } } diff --git a/burn-core/src/record/primitive.rs b/burn-core/src/record/primitive.rs index c24cd1046b..507635b205 100644 --- a/burn-core/src/record/primitive.rs +++ b/burn-core/src/record/primitive.rs @@ -13,124 +13,123 @@ use burn_tensor::{DataSerialize, Element}; use hashbrown::HashMap; impl Record for () { - type Item = (); + type Item = (); - fn into_item(self) -> Self::Item {} + fn into_item(self) -> Self::Item {} - fn from_item(_item: Self::Item) -> Self {} + fn from_item(_item: Self::Item) -> Self {} } impl Record for Vec { - type Item = Vec>; + type Item = Vec>; - fn into_item(self) -> Self::Item { - self.into_iter().map(Record::into_item).collect() - } + fn into_item(self) -> Self::Item { + self.into_iter().map(Record::into_item).collect() + } - fn from_item(item: Self::Item) -> Self { - item.into_iter().map(Record::from_item).collect() - } + fn from_item(item: Self::Item) -> Self { + item.into_iter().map(Record::from_item).collect() + } } impl Record for Option { - type Item = Option>; + type Item = Option>; - fn into_item(self) -> Self::Item { - self.map(Record::into_item) - } + fn into_item(self) -> Self::Item { + self.map(Record::into_item) + } - fn from_item(item: Self::Item) -> Self { - item.map(Record::from_item) - } + fn from_item(item: Self::Item) -> Self { + item.map(Record::from_item) + } } impl Record for [T; N] { - type Item = Vec>; - - fn into_item(self) -> Self::Item { - self.map(Record::into_item).into_iter().collect() - } - - fn from_item(item: Self::Item) -> Self { - item - .into_iter() - .map(Record::from_item) - .collect::>() - .try_into() - .unwrap_or_else(|_| panic!("An arrar of size {N}")) - } + type Item = Vec>; + + fn into_item(self) -> Self::Item { + self.map(Record::into_item).into_iter().collect() + } + + fn from_item(item: Self::Item) -> Self { + item.into_iter() + .map(Record::from_item) + .collect::>() + .try_into() + .unwrap_or_else(|_| panic!("An arrar of size {N}")) + } } impl Record for HashMap { - type Item = HashMap>; - - fn into_item(self) -> Self::Item { - let mut items = HashMap::with_capacity(self.len()); - self.into_iter().for_each(|(id, record)| { - items.insert(id.to_string(), record.into_item()); - }); - items - } - - fn from_item(item: Self::Item) -> Self { - let mut record = HashMap::with_capacity(item.len()); - item.into_iter().for_each(|(id, item)| { - record.insert(ParamId::from(id), T::from_item(item)); - }); - record - } + type Item = HashMap>; + + fn into_item(self) -> Self::Item { + let mut items = HashMap::with_capacity(self.len()); + self.into_iter().for_each(|(id, record)| { + items.insert(id.to_string(), record.into_item()); + }); + items + } + + fn from_item(item: Self::Item) -> Self { + let mut record = HashMap::with_capacity(item.len()); + item.into_iter().for_each(|(id, item)| { + record.insert(ParamId::from(id), T::from_item(item)); + }); + record + } } impl Record for DataSerialize { - type Item = DataSerialize; + type Item = DataSerialize; - fn into_item(self) -> Self::Item { - self.convert() - } + fn into_item(self) -> Self::Item { + self.convert() + } - fn from_item(item: Self::Item) -> Self { - item.convert() - } + fn from_item(item: Self::Item) -> Self { + item.convert() + } } /// (De)serialize parameters into a clean format. #[derive(new, Debug, Clone, Serialize, Deserialize)] pub struct ParamSerde { - id: String, - param: T, + id: String, + param: T, } impl Record for Param> { - type Item = ParamSerde>; - - fn into_item(self) -> Self::Item { - ParamSerde::new(self.id.into_string(), self.value.into_item()) - } - - fn from_item(item: Self::Item) -> Self { - Param::new( - ParamId::from(item.id), - Tensor::from_item(item.param).require_grad(), // Same behavior as when we create a new - // Param from a tensor. - ) - } + type Item = ParamSerde>; + + fn into_item(self) -> Self::Item { + ParamSerde::new(self.id.into_string(), self.value.into_item()) + } + + fn from_item(item: Self::Item) -> Self { + Param::new( + ParamId::from(item.id), + Tensor::from_item(item.param).require_grad(), // Same behavior as when we create a new + // Param from a tensor. + ) + } } // Type that can be serialized as is without any conversion. macro_rules! primitive { - ($type:ty) => { - impl Record for $type { - type Item = $type; - - fn into_item(self) -> Self::Item { - self - } - - fn from_item(item: Self::Item) -> Self { - item - } - } - }; + ($type:ty) => { + impl Record for $type { + type Item = $type; + + fn into_item(self) -> Self::Item { + self + } + + fn from_item(item: Self::Item) -> Self { + item + } + } + }; } // General Types diff --git a/burn-core/src/record/recorder.rs b/burn-core/src/record/recorder.rs index d278881e85..8c76199ff3 100644 --- a/burn-core/src/record/recorder.rs +++ b/burn-core/src/record/recorder.rs @@ -8,148 +8,148 @@ use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record}; #[cfg(feature = "std")] use super::{ - BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings, - PrettyJsonFileRecorder, + BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings, + PrettyJsonFileRecorder, }; /// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned). pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone { - /// Type of the settings used by the recorder. - type Settings: PrecisionSettings; - - /// Arguments used to record objects. - type RecordArgs: Clone; - - /// Record output type. - type RecordOutput; - - /// Arguments used to load recorded objects. - type LoadArgs: Clone; - - /// Records an item. - /// - /// # Arguments - /// - /// * `record` - The item to record. - /// * `args` - Arguments used to record the item. - /// - /// # Returns - /// - /// The output of the recording. - fn record( - &self, - record: R, - args: Self::RecordArgs, - ) -> Result { - let item = record.into_item::(); - let item = BurnRecord::new::(item); - - self.save_item(item, args) - } - - /// Load an item from the given arguments. - fn load(&self, args: Self::LoadArgs) -> Result { - let item: BurnRecord> = - self.load_item(args.clone()).map_err(|err| { - if let Ok(record) = self.load_item::(args.clone()) { - let mut message = "Unable to load record.".to_string(); - let metadata = recorder_metadata::(); - if metadata.float != record.metadata.float { - message += format!( - "\nMetadata has a different float type: Actual {:?}, Expected {:?}", - record.metadata.float, metadata.float - ) - .as_str(); - } - if metadata.int != record.metadata.int { - message += format!( - "\nMetadata has a different int type: Actual {:?}, Expected {:?}", - record.metadata.int, metadata.int - ) - .as_str(); - } - if metadata.format != record.metadata.format { - message += format!( - "\nMetadata has a different format: Actual {:?}, Expected {:?}", - record.metadata.format, metadata.format - ) - .as_str(); - } - if metadata.version != record.metadata.version { - message += format!( - "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}", - record.metadata.version, metadata.version - ) - .as_str(); - } - - message += format!("\nError: {:?}", err).as_str(); - - return RecorderError::Unknown(message); - } + /// Type of the settings used by the recorder. + type Settings: PrecisionSettings; + + /// Arguments used to record objects. + type RecordArgs: Clone; + + /// Record output type. + type RecordOutput; + + /// Arguments used to load recorded objects. + type LoadArgs: Clone; + + /// Records an item. + /// + /// # Arguments + /// + /// * `record` - The item to record. + /// * `args` - Arguments used to record the item. + /// + /// # Returns + /// + /// The output of the recording. + fn record( + &self, + record: R, + args: Self::RecordArgs, + ) -> Result { + let item = record.into_item::(); + let item = BurnRecord::new::(item); + + self.save_item(item, args) + } + + /// Load an item from the given arguments. + fn load(&self, args: Self::LoadArgs) -> Result { + let item: BurnRecord> = + self.load_item(args.clone()).map_err(|err| { + if let Ok(record) = self.load_item::(args.clone()) { + let mut message = "Unable to load record.".to_string(); + let metadata = recorder_metadata::(); + if metadata.float != record.metadata.float { + message += format!( + "\nMetadata has a different float type: Actual {:?}, Expected {:?}", + record.metadata.float, metadata.float + ) + .as_str(); + } + if metadata.int != record.metadata.int { + message += format!( + "\nMetadata has a different int type: Actual {:?}, Expected {:?}", + record.metadata.int, metadata.int + ) + .as_str(); + } + if metadata.format != record.metadata.format { + message += format!( + "\nMetadata has a different format: Actual {:?}, Expected {:?}", + record.metadata.format, metadata.format + ) + .as_str(); + } + if metadata.version != record.metadata.version { + message += format!( + "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}", + record.metadata.version, metadata.version + ) + .as_str(); + } + + message += format!("\nError: {:?}", err).as_str(); + + return RecorderError::Unknown(message); + } + + err + })?; + + Ok(R::from_item(item.item)) + } - err - })?; - - Ok(R::from_item(item.item)) - } - - /// Saves an item. - /// - /// This method is used by [record](Recorder::record) to save the item. - /// - /// # Arguments - /// - /// * `item` - Item to save. - /// * `args` - Arguments to use to save the item. - /// - /// # Returns - /// - /// The output of the save operation. - fn save_item( - &self, - item: I, - args: Self::RecordArgs, - ) -> Result; - - /// Loads an item. - /// - /// This method is used by [load](Recorder::load) to load the item. - /// - /// # Arguments - /// - /// * `args` - Arguments to use to load the item. - /// - /// # Returns - /// - /// The loaded item. - fn load_item(&self, args: Self::LoadArgs) -> Result; + /// Saves an item. + /// + /// This method is used by [record](Recorder::record) to save the item. + /// + /// # Arguments + /// + /// * `item` - Item to save. + /// * `args` - Arguments to use to save the item. + /// + /// # Returns + /// + /// The output of the save operation. + fn save_item( + &self, + item: I, + args: Self::RecordArgs, + ) -> Result; + + /// Loads an item. + /// + /// This method is used by [load](Recorder::load) to load the item. + /// + /// # Arguments + /// + /// * `args` - Arguments to use to load the item. + /// + /// # Returns + /// + /// The loaded item. + fn load_item(&self, args: Self::LoadArgs) -> Result; } fn recorder_metadata() -> BurnMetadata { - BurnMetadata::new( - type_name::<::FloatElem>().to_string(), - type_name::<::IntElem>().to_string(), - type_name::().to_string(), - env!("CARGO_PKG_VERSION").to_string(), - format!("{:?}", R::Settings::default()), - ) + BurnMetadata::new( + type_name::<::FloatElem>().to_string(), + type_name::<::IntElem>().to_string(), + type_name::().to_string(), + env!("CARGO_PKG_VERSION").to_string(), + format!("{:?}", R::Settings::default()), + ) } /// Error that can occur when using a [Recorder](Recorder). #[derive(Debug)] pub enum RecorderError { - /// File not found. - FileNotFound(String), + /// File not found. + FileNotFound(String), - /// Other error. - Unknown(String), + /// Other error. + Unknown(String), } impl core::fmt::Display for RecorderError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("{self:?}").as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(format!("{self:?}").as_str()) + } } // TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765) @@ -157,60 +157,60 @@ impl core::fmt::Display for RecorderError { impl std::error::Error for RecorderError {} pub(crate) fn bin_config() -> bincode::config::Configuration { - bincode::config::standard() + bincode::config::standard() } /// Metadata of a record. #[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct BurnMetadata { - /// Float type used to record the item. - pub float: String, + /// Float type used to record the item. + pub float: String, - /// Int type used to record the item. - pub int: String, + /// Int type used to record the item. + pub int: String, - /// Format used to record the item. - pub format: String, + /// Format used to record the item. + pub format: String, - /// Burn record version used to record the item. - pub version: String, + /// Burn record version used to record the item. + pub version: String, - /// Settings used to record the item. - pub settings: String, + /// Settings used to record the item. + pub settings: String, } /// Record that can be saved by a [Recorder](Recorder). #[derive(Serialize, Deserialize, Debug)] pub struct BurnRecord { - /// Metadata of the record. - pub metadata: BurnMetadata, + /// Metadata of the record. + pub metadata: BurnMetadata, - /// Item to record. - pub item: I, + /// Item to record. + pub item: I, } impl BurnRecord { - /// Creates a new record. - /// - /// # Arguments - /// - /// * `item` - Item to record. - /// - /// # Returns - /// - /// The new record. - pub fn new(item: I) -> Self { - let metadata = recorder_metadata::(); - - Self { metadata, item } - } + /// Creates a new record. + /// + /// # Arguments + /// + /// * `item` - Item to record. + /// + /// # Returns + /// + /// The new record. + pub fn new(item: I) -> Self { + let metadata = recorder_metadata::(); + + Self { metadata, item } + } } /// Record that can be saved by a [Recorder](Recorder) without the item. #[derive(new, Debug, Serialize, Deserialize)] pub struct BurnRecordNoItem { - /// Metadata of the record. - pub metadata: BurnMetadata, + /// Metadata of the record. + pub metadata: BurnMetadata, } /// Default recorder. @@ -252,45 +252,45 @@ pub type DebugRecordSettings = PrettyJsonFileRecorder; #[cfg(all(test, feature = "std"))] mod tests { - static FILE_PATH: &str = "/tmp/burn_test_record"; + static FILE_PATH: &str = "/tmp/burn_test_record"; - use super::*; - use burn_tensor::ElementConversion; + use super::*; + use burn_tensor::ElementConversion; - #[test] - #[should_panic] - fn err_when_invalid_item() { - #[derive(new, Serialize, Deserialize)] - struct Item { - value: S::FloatElem, - } + #[test] + #[should_panic] + fn err_when_invalid_item() { + #[derive(new, Serialize, Deserialize)] + struct Item { + value: S::FloatElem, + } - impl Record for Item { - type Item = Item; + impl Record for Item { + type Item = Item; - fn into_item(self) -> Self::Item { - Item { - value: self.value.elem(), - } - } + fn into_item(self) -> Self::Item { + Item { + value: self.value.elem(), + } + } - fn from_item(item: Self::Item) -> Self { - Item { - value: item.value.elem(), + fn from_item(item: Self::Item) -> Self { + Item { + value: item.value.elem(), + } + } } - } - } - let item = Item::::new(16.elem()); + let item = Item::::new(16.elem()); - // Serialize in f32. - let recorder = DefaultFileRecorder::::new(); - recorder.record(item, FILE_PATH.into()).unwrap(); + // Serialize in f32. + let recorder = DefaultFileRecorder::::new(); + recorder.record(item, FILE_PATH.into()).unwrap(); - // Can't deserialize f32 into f16. - let recorder = DefaultFileRecorder::::new(); - recorder - .load::>(FILE_PATH.into()) - .unwrap(); - } + // Can't deserialize f32 into f16. + let recorder = DefaultFileRecorder::::new(); + recorder + .load::>(FILE_PATH.into()) + .unwrap(); + } } diff --git a/burn-core/src/record/settings.rs b/burn-core/src/record/settings.rs index 202a5fb183..a59c6ec331 100644 --- a/burn-core/src/record/settings.rs +++ b/burn-core/src/record/settings.rs @@ -3,13 +3,13 @@ use serde::{de::DeserializeOwned, Serialize}; /// Settings allowing to control the precision when (de)serializing items. pub trait PrecisionSettings: - Send + Sync + core::fmt::Debug + core::default::Default + Clone + Send + Sync + core::fmt::Debug + core::default::Default + Clone { - /// Float element type. - type FloatElem: Element + Serialize + DeserializeOwned; + /// Float element type. + type FloatElem: Element + Serialize + DeserializeOwned; - /// Integer element type. - type IntElem: Element + Serialize + DeserializeOwned; + /// Integer element type. + type IntElem: Element + Serialize + DeserializeOwned; } /// Default precision settings. @@ -25,16 +25,16 @@ pub struct HalfPrecisionSettings; pub struct DoublePrecisionSettings; impl PrecisionSettings for FullPrecisionSettings { - type FloatElem = f32; - type IntElem = f32; + type FloatElem = f32; + type IntElem = f32; } impl PrecisionSettings for DoublePrecisionSettings { - type FloatElem = f64; - type IntElem = i64; + type FloatElem = f64; + type IntElem = i64; } impl PrecisionSettings for HalfPrecisionSettings { - type FloatElem = half::f16; - type IntElem = i16; + type FloatElem = half::f16; + type IntElem = i16; } diff --git a/burn-core/src/record/tensor.rs b/burn-core/src/record/tensor.rs index e60897fa68..70badf2169 100644 --- a/burn-core/src/record/tensor.rs +++ b/burn-core/src/record/tensor.rs @@ -6,129 +6,129 @@ use serde::{Deserialize, Serialize}; /// using the given [record settings](RecordSettings). #[derive(new, Clone, Debug)] pub struct FloatTensorSerde { - data: DataSerialize, + data: DataSerialize, } /// This struct implements serde to lazily serialize and deserialize an int tensor /// using the given [record settings](RecordSettings). #[derive(new, Clone, Debug)] pub struct IntTensorSerde { - data: DataSerialize, + data: DataSerialize, } /// This struct implements serde to lazily serialize and deserialize an bool tensor. #[derive(new, Clone, Debug)] pub struct BoolTensorSerde { - data: DataSerialize, + data: DataSerialize, } // --- SERDE IMPLEMENTATIONS --- // impl Serialize for FloatTensorSerde { - fn serialize(&self, serializer: Se) -> Result - where - Se: serde::Serializer, - { - self.data.serialize(serializer) - } + fn serialize(&self, serializer: Se) -> Result + where + Se: serde::Serializer, + { + self.data.serialize(serializer) + } } impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde { - fn deserialize(deserializer: De) -> Result - where - De: serde::Deserializer<'de>, - { - let data = DataSerialize::::deserialize(deserializer)?; - - Ok(Self::new(data)) - } + fn deserialize(deserializer: De) -> Result + where + De: serde::Deserializer<'de>, + { + let data = DataSerialize::::deserialize(deserializer)?; + + Ok(Self::new(data)) + } } impl Serialize for IntTensorSerde { - fn serialize(&self, serializer: Se) -> Result - where - Se: serde::Serializer, - { - self.data.serialize(serializer) - } + fn serialize(&self, serializer: Se) -> Result + where + Se: serde::Serializer, + { + self.data.serialize(serializer) + } } impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde { - fn deserialize(deserializer: De) -> Result - where - De: serde::Deserializer<'de>, - { - let data = DataSerialize::::deserialize(deserializer)?; - Ok(Self::new(data)) - } + fn deserialize(deserializer: De) -> Result + where + De: serde::Deserializer<'de>, + { + let data = DataSerialize::::deserialize(deserializer)?; + Ok(Self::new(data)) + } } impl Serialize for BoolTensorSerde { - fn serialize(&self, serializer: Se) -> Result - where - Se: serde::Serializer, - { - self.data.serialize(serializer) - } + fn serialize(&self, serializer: Se) -> Result + where + Se: serde::Serializer, + { + self.data.serialize(serializer) + } } impl<'de> Deserialize<'de> for BoolTensorSerde { - fn deserialize(deserializer: De) -> Result - where - De: serde::Deserializer<'de>, - { - let data = DataSerialize::::deserialize(deserializer)?; - - Ok(Self::new(data)) - } + fn deserialize(deserializer: De) -> Result + where + De: serde::Deserializer<'de>, + { + let data = DataSerialize::::deserialize(deserializer)?; + + Ok(Self::new(data)) + } } // --- RECORD IMPLEMENTATIONS --- // impl Record for Tensor { - type Item = FloatTensorSerde; + type Item = FloatTensorSerde; - fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording float tensors isn't yet supported on wasm."); + fn into_item(self) -> Self::Item { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + todo!("Recording float tensors isn't yet supported on wasm."); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - FloatTensorSerde::new(self.into_data().convert().serialize()) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + FloatTensorSerde::new(self.into_data().convert().serialize()) + } - fn from_item(item: Self::Item) -> Self { - Tensor::from_data(item.data.convert::()) - } + fn from_item(item: Self::Item) -> Self { + Tensor::from_data(item.data.convert::()) + } } impl Record for Tensor { - type Item = IntTensorSerde; + type Item = IntTensorSerde; - fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording int tensors isn't yet supported on wasm."); + fn into_item(self) -> Self::Item { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + todo!("Recording int tensors isn't yet supported on wasm."); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - IntTensorSerde::new(self.into_data().convert().serialize()) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + IntTensorSerde::new(self.into_data().convert().serialize()) + } - fn from_item(item: Self::Item) -> Self { - Tensor::from_data(item.data.convert()) - } + fn from_item(item: Self::Item) -> Self { + Tensor::from_data(item.data.convert()) + } } impl Record for Tensor { - type Item = BoolTensorSerde; + type Item = BoolTensorSerde; - fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording bool tensors isn't yet supported on wasm."); + fn into_item(self) -> Self::Item { + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + todo!("Recording bool tensors isn't yet supported on wasm."); - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - BoolTensorSerde::new(self.into_data().serialize()) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + BoolTensorSerde::new(self.into_data().serialize()) + } - fn from_item(item: Self::Item) -> Self { - Tensor::from_data(item.data) - } + fn from_item(item: Self::Item) -> Self { + Tensor::from_data(item.data) + } } diff --git a/burn-core/tests/derive_config.rs b/burn-core/tests/derive_config.rs index 227d85f336..dec636ee91 100644 --- a/burn-core/tests/derive_config.rs +++ b/burn-core/tests/derive_config.rs @@ -6,102 +6,102 @@ pub struct TestEmptyStructConfig {} #[derive(Config, Debug, PartialEq)] pub struct TestStructConfig { - int: i32, - #[config(default = 2)] - int_default: i32, - float: f32, - #[config(default = 2.0)] - float_default: f32, - string: String, - other_config: TestEmptyStructConfig, + int: i32, + #[config(default = 2)] + int_default: i32, + float: f32, + #[config(default = 2.0)] + float_default: f32, + string: String, + other_config: TestEmptyStructConfig, } #[derive(Config, Debug, PartialEq)] pub enum TestEnumConfig { - None, - Single(f32), - Multiple(f32, String), - Named { first: f32, second: String }, + None, + Single(f32), + Multiple(f32, String), + Named { first: f32, second: String }, } #[cfg(feature = "std")] #[test] fn struct_config_should_impl_serde() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - let file_path = "/tmp/test_struct_config.json"; + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + let file_path = "/tmp/test_struct_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestStructConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestStructConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[test] fn struct_config_should_impl_clone() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - assert_eq!(config, config.clone()); + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + assert_eq!(config, config.clone()); } #[test] fn struct_config_should_impl_display() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - assert_eq!(burn::config::config_to_json(&config), config.to_string()); + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + assert_eq!(burn::config::config_to_json(&config), config.to_string()); } #[cfg(feature = "std")] #[test] fn enum_config_no_value_should_impl_serde() { - let config = TestEnumConfig::None; - let file_path = "/tmp/test_enum_no_value_config.json"; + let config = TestEnumConfig::None; + let file_path = "/tmp/test_enum_no_value_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestEnumConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestEnumConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[cfg(feature = "std")] #[test] fn enum_config_one_value_should_impl_serde() { - let config = TestEnumConfig::Single(42.0); - let file_path = "/tmp/test_enum_one_value_config.json"; + let config = TestEnumConfig::Single(42.0); + let file_path = "/tmp/test_enum_one_value_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestEnumConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestEnumConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[cfg(feature = "std")] #[test] fn enum_config_multiple_values_should_impl_serde() { - let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); - let file_path = "/tmp/test_enum_multiple_values_config.json"; + let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); + let file_path = "/tmp/test_enum_multiple_values_config.json"; - config.save(file_path).unwrap(); + config.save(file_path).unwrap(); - let config_loaded = TestEnumConfig::load(file_path).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestEnumConfig::load(file_path).unwrap(); + assert_eq!(config, config_loaded); } #[test] fn enum_config_should_impl_clone() { - let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); - assert_eq!(config, config.clone()); + let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); + assert_eq!(config, config.clone()); } #[test] fn enum_config_should_impl_display() { - let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); - assert_eq!(burn::config::config_to_json(&config), config.to_string()); + let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); + assert_eq!(burn::config::config_to_json(&config), config.to_string()); } #[test] fn struct_config_can_load_binary() { - let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); + let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); - let binary = config_to_json(&config).as_bytes().to_vec(); + let binary = config_to_json(&config).as_bytes().to_vec(); - let config_loaded = TestStructConfig::load_binary(&binary).unwrap(); - assert_eq!(config, config_loaded); + let config_loaded = TestStructConfig::load_binary(&binary).unwrap(); + assert_eq!(config, config_loaded); } diff --git a/burn-core/tests/derive_module.rs b/burn-core/tests/derive_module.rs index 427bd668b3..87beafc422 100644 --- a/burn-core/tests/derive_module.rs +++ b/burn-core/tests/derive_module.rs @@ -9,139 +9,139 @@ pub type TestAutodiffBackend = burn_autodiff::Autodiff; #[derive(Module, Debug)] pub struct ModuleBasic { - weight_basic: Param>, + weight_basic: Param>, } impl ModuleBasic { - fn new() -> Self { - let weight_basic = Tensor::random(Shape::new([20, 20]), Distribution::Default); - Self { - weight_basic: Param::from(weight_basic), + fn new() -> Self { + let weight_basic = Tensor::random(Shape::new([20, 20]), Distribution::Default); + Self { + weight_basic: Param::from(weight_basic), + } } - } } #[derive(Module, Debug)] pub struct ModuleComposed { - weight: Param>, - basic: ModuleBasic, + weight: Param>, + basic: ModuleBasic, } impl ModuleComposed { - fn new() -> Self { - let weight = Tensor::random(Shape::new([20, 20]), Distribution::Default); - Self { - weight: Param::from(weight), - basic: ModuleBasic::new(), + fn new() -> Self { + let weight = Tensor::random(Shape::new([20, 20]), Distribution::Default); + Self { + weight: Param::from(weight), + basic: ModuleBasic::new(), + } } - } } mod state { - use super::*; - - #[test] - fn should_load_from_record_basic() { - let module_1 = ModuleBasic::::new(); - let mut module_2 = ModuleBasic::::new(); - let state_1 = module_1.clone().into_record(); - - assert_ne!( - module_1.weight_basic.to_data(), - module_2.weight_basic.to_data() - ); - - module_2 = module_2.load_record(state_1); - - assert_eq!( - module_1.weight_basic.to_data(), - module_2.weight_basic.to_data() - ); - } - - #[test] - fn should_load_from_record_compose() { - let module_1 = ModuleComposed::::new(); - let mut module_2 = ModuleComposed::::new(); - assert_ne!(module_1.weight.to_data(), module_2.weight.to_data()); - assert_ne!( - module_1.basic.weight_basic.to_data(), - module_2.basic.weight_basic.to_data() - ); - - let state_1 = module_1.clone().into_record(); - module_2 = module_2.load_record(state_1); - - assert_eq!(module_1.weight.to_data(), module_2.weight.to_data()); - assert_eq!( - module_1.basic.weight_basic.to_data(), - module_2.basic.weight_basic.to_data() - ); - } + use super::*; + + #[test] + fn should_load_from_record_basic() { + let module_1 = ModuleBasic::::new(); + let mut module_2 = ModuleBasic::::new(); + let state_1 = module_1.clone().into_record(); + + assert_ne!( + module_1.weight_basic.to_data(), + module_2.weight_basic.to_data() + ); + + module_2 = module_2.load_record(state_1); + + assert_eq!( + module_1.weight_basic.to_data(), + module_2.weight_basic.to_data() + ); + } + + #[test] + fn should_load_from_record_compose() { + let module_1 = ModuleComposed::::new(); + let mut module_2 = ModuleComposed::::new(); + assert_ne!(module_1.weight.to_data(), module_2.weight.to_data()); + assert_ne!( + module_1.basic.weight_basic.to_data(), + module_2.basic.weight_basic.to_data() + ); + + let state_1 = module_1.clone().into_record(); + module_2 = module_2.load_record(state_1); + + assert_eq!(module_1.weight.to_data(), module_2.weight.to_data()); + assert_eq!( + module_1.basic.weight_basic.to_data(), + module_2.basic.weight_basic.to_data() + ); + } } mod num_params { - use super::*; - - #[test] - fn should_calculate_num_params_basic() { - let module = ModuleBasic::::new(); - assert_eq!(20 * 20, module.num_params()); - } - - #[test] - fn should_output_state_composed() { - let module = ModuleComposed::::new(); - assert_eq!(2 * 20 * 20, module.num_params()); - } + use super::*; + + #[test] + fn should_calculate_num_params_basic() { + let module = ModuleBasic::::new(); + assert_eq!(20 * 20, module.num_params()); + } + + #[test] + fn should_output_state_composed() { + let module = ModuleComposed::::new(); + assert_eq!(2 * 20 * 20, module.num_params()); + } } #[cfg(feature = "std")] mod require_grad { - use burn_tensor::backend::AutodiffBackend; + use burn_tensor::backend::AutodiffBackend; - use super::*; + use super::*; - #[test] - fn should_have_grad_by_default() { - let module = ModuleBasic::::new(); - let mut grads = calculate_grads(&module); + #[test] + fn should_have_grad_by_default() { + let module = ModuleBasic::::new(); + let mut grads = calculate_grads(&module); - let grad_x = module.weight_basic.grad_remove(&mut grads); + let grad_x = module.weight_basic.grad_remove(&mut grads); - assert!(grad_x.is_some()); - } + assert!(grad_x.is_some()); + } - #[test] - fn should_have_no_grad_after_no_grad() { - let module = ModuleBasic::::new().no_grad(); - let mut grads = calculate_grads(&module); + #[test] + fn should_have_no_grad_after_no_grad() { + let module = ModuleBasic::::new().no_grad(); + let mut grads = calculate_grads(&module); - let grad_x = module.weight_basic.grad_remove(&mut grads); + let grad_x = module.weight_basic.grad_remove(&mut grads); - assert!(grad_x.is_none()); - } + assert!(grad_x.is_none()); + } - #[test] - fn should_have_grad_when_from_record() { - let module = ModuleBasic::::new(); - let record = ModuleBasicRecord { - weight_basic: module.weight_basic.clone(), // Even when param is no_grad, - }; - let module = module.load_record(record); - let mut grads = calculate_grads(&module); + #[test] + fn should_have_grad_when_from_record() { + let module = ModuleBasic::::new(); + let record = ModuleBasicRecord { + weight_basic: module.weight_basic.clone(), // Even when param is no_grad, + }; + let module = module.load_record(record); + let mut grads = calculate_grads(&module); - let grad_x = module.weight_basic.grad_remove(&mut grads); + let grad_x = module.weight_basic.grad_remove(&mut grads); - assert!(grad_x.is_some()); - } + assert!(grad_x.is_some()); + } - fn calculate_grads( - module: &ModuleBasic, - ) -> ::Gradients { - let x = Tensor::ones([20, 20]).require_grad(); - let y = module.weight_basic.val().matmul(x); + fn calculate_grads( + module: &ModuleBasic, + ) -> ::Gradients { + let x = Tensor::ones([20, 20]).require_grad(); + let y = module.weight_basic.val().matmul(x); - y.backward() - } + y.backward() + } } diff --git a/burn-core/tests/derive_record.rs b/burn-core/tests/derive_record.rs index d11bd58181..c0a6653731 100644 --- a/burn-core/tests/derive_record.rs +++ b/burn-core/tests/derive_record.rs @@ -7,11 +7,11 @@ use burn_tensor::Tensor; // It compiles #[derive(Record)] pub struct TestWithBackendRecord { - tensor: Tensor, + tensor: Tensor, } // It compiles #[derive(Record)] pub struct TestWithoutBackendRecord { - tensor: usize, + tensor: usize, } diff --git a/burn-core/tests/record_resilience.rs b/burn-core/tests/record_resilience.rs index b021dc5a72..9bf4d651cb 100644 --- a/burn-core/tests/record_resilience.rs +++ b/burn-core/tests/record_resilience.rs @@ -1,290 +1,298 @@ #[cfg(feature = "std")] mod tests { - use burn::{ - module::Module, - nn, - record::{ - BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings, - PrettyJsonFileRecorder, RecorderError, - }, - }; - use burn_core as burn; - use burn_tensor::backend::Backend; - use std::path::PathBuf; - - type TestBackend = burn_ndarray::NdArray; - - #[derive(Module, Debug)] - pub struct Model { - single_const: f32, - linear1: nn::Linear, - array_const: [usize; 2], - linear2: nn::Linear, - } - - #[derive(Module, Debug)] - pub struct ModelNewOptionalField { - single_const: f32, - linear1: nn::Linear, - array_const: [usize; 2], - linear2: nn::Linear, - new_field: Option, - } - - #[derive(Module, Debug)] - pub struct ModelNewConstantField { - single_const: f32, - linear1: nn::Linear, - array_const: [usize; 2], - linear2: nn::Linear, - new_field: usize, - } - - #[derive(Module, Debug)] - pub struct ModelNewFieldOrders { - array_const: [usize; 2], - linear2: nn::Linear, - single_const: f32, - linear1: nn::Linear, - } - - #[test] - fn deserialize_with_new_optional_field_works_with_default_file_recorder() { - deserialize_with_new_optional_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_optional_field_works_with_default_file_recorder() { - deserialize_with_removed_optional_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_constant_field_works_with_default_file_recorder() { - deserialize_with_new_constant_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_constant_field_works_with_default_file_recorder() { - deserialize_with_removed_constant_field( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_field_order_works_with_default_file_recorder() { - deserialize_with_new_field_order( - "default", - DefaultFileRecorder::::new(), - ) - .unwrap(); - } - #[test] - fn deserialize_with_new_optional_field_works_with_pretty_json() { - deserialize_with_new_optional_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_optional_field_works_with_pretty_json() { - deserialize_with_removed_optional_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_constant_field_works_with_pretty_json() { - deserialize_with_new_constant_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_constant_field_works_with_pretty_json() { - deserialize_with_removed_constant_field( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - fn deserialize_with_new_field_order_works_with_pretty_json() { - deserialize_with_new_field_order( - "pretty-json", - PrettyJsonFileRecorder::::new(), - ) - .unwrap(); - } - - #[test] - #[should_panic] - fn deserialize_with_new_optional_field_doesnt_works_with_bin_file_recorder() { - deserialize_with_new_optional_field("bin", BinFileRecorder::::new()) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() { - deserialize_with_removed_optional_field("bin", BinFileRecorder::::new()) - .unwrap(); - } - - #[test] - fn deserialize_with_new_constant_field_works_with_bin_file_recorder() { - deserialize_with_new_constant_field("bin", BinFileRecorder::::new()) - .unwrap(); - } - - #[test] - fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() { - deserialize_with_removed_constant_field("bin", BinFileRecorder::::new()) - .unwrap(); - } - - #[test] - #[should_panic] - fn deserialize_with_new_field_order_works_with_bin_file_recorder() { - deserialize_with_new_field_order("bin", BinFileRecorder::::new()) - .unwrap(); - } - - fn deserialize_with_new_optional_field(name: &str, recorder: R) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_new_optional_field-{name}").into(); - let model = Model { - single_const: 32.0, - linear1: nn::LinearConfig::new(20, 20).init::(), - array_const: [2, 2], - linear2: nn::LinearConfig::new(20, 20).init::(), + use burn::{ + module::Module, + nn, + record::{ + BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings, + PrettyJsonFileRecorder, RecorderError, + }, }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_removed_optional_field( - name: &str, - recorder: R, - ) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_removed_optional_field-{name}").into(); - let model = ModelNewOptionalField { - single_const: 32.0, - linear1: nn::LinearConfig::new(20, 20).init::(), - array_const: [2, 2], - linear2: nn::LinearConfig::new(20, 20).init::(), - new_field: None, - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_new_constant_field(name: &str, recorder: R) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_new_constant_field-{name}").into(); - let model = Model { - single_const: 32.0, - array_const: [2, 2], - linear1: nn::LinearConfig::new(20, 20).init::(), - linear2: nn::LinearConfig::new(20, 20).init::(), - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_removed_constant_field( - name: &str, - recorder: R, - ) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_removed_constant_field-{name}").into(); - let model = ModelNewConstantField { - single_const: 32.0, - array_const: [2, 2], - linear1: nn::LinearConfig::new(20, 20).init::(), - linear2: nn::LinearConfig::new(20, 20).init::(), - new_field: 0, - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } - - fn deserialize_with_new_field_order(name: &str, recorder: R) -> Result<(), RecorderError> - where - R: FileRecorder, - { - let file_path: PathBuf = format!("/tmp/deserialize_with_new_field_order-{name}").into(); - let model = Model { - array_const: [2, 2], - single_const: 32.0, - linear1: nn::LinearConfig::new(20, 20).init::(), - linear2: nn::LinearConfig::new(20, 20).init::(), - }; - - recorder - .record(model.into_record(), file_path.clone()) - .unwrap(); - - let result = recorder.load::>(file_path.clone()); - std::fs::remove_file(file_path).ok(); - - result?; - Ok(()) - } + use burn_core as burn; + use burn_tensor::backend::Backend; + use std::path::PathBuf; + + type TestBackend = burn_ndarray::NdArray; + + #[derive(Module, Debug)] + pub struct Model { + single_const: f32, + linear1: nn::Linear, + array_const: [usize; 2], + linear2: nn::Linear, + } + + #[derive(Module, Debug)] + pub struct ModelNewOptionalField { + single_const: f32, + linear1: nn::Linear, + array_const: [usize; 2], + linear2: nn::Linear, + new_field: Option, + } + + #[derive(Module, Debug)] + pub struct ModelNewConstantField { + single_const: f32, + linear1: nn::Linear, + array_const: [usize; 2], + linear2: nn::Linear, + new_field: usize, + } + + #[derive(Module, Debug)] + pub struct ModelNewFieldOrders { + array_const: [usize; 2], + linear2: nn::Linear, + single_const: f32, + linear1: nn::Linear, + } + + #[test] + fn deserialize_with_new_optional_field_works_with_default_file_recorder() { + deserialize_with_new_optional_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_optional_field_works_with_default_file_recorder() { + deserialize_with_removed_optional_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_constant_field_works_with_default_file_recorder() { + deserialize_with_new_constant_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_constant_field_works_with_default_file_recorder() { + deserialize_with_removed_constant_field( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_field_order_works_with_default_file_recorder() { + deserialize_with_new_field_order( + "default", + DefaultFileRecorder::::new(), + ) + .unwrap(); + } + #[test] + fn deserialize_with_new_optional_field_works_with_pretty_json() { + deserialize_with_new_optional_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_optional_field_works_with_pretty_json() { + deserialize_with_removed_optional_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_constant_field_works_with_pretty_json() { + deserialize_with_new_constant_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_constant_field_works_with_pretty_json() { + deserialize_with_removed_constant_field( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_field_order_works_with_pretty_json() { + deserialize_with_new_field_order( + "pretty-json", + PrettyJsonFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + #[should_panic] + fn deserialize_with_new_optional_field_doesnt_works_with_bin_file_recorder() { + deserialize_with_new_optional_field("bin", BinFileRecorder::::new()) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() { + deserialize_with_removed_optional_field( + "bin", + BinFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + fn deserialize_with_new_constant_field_works_with_bin_file_recorder() { + deserialize_with_new_constant_field("bin", BinFileRecorder::::new()) + .unwrap(); + } + + #[test] + fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() { + deserialize_with_removed_constant_field( + "bin", + BinFileRecorder::::new(), + ) + .unwrap(); + } + + #[test] + #[should_panic] + fn deserialize_with_new_field_order_works_with_bin_file_recorder() { + deserialize_with_new_field_order("bin", BinFileRecorder::::new()) + .unwrap(); + } + + fn deserialize_with_new_optional_field(name: &str, recorder: R) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_new_optional_field-{name}").into(); + let model = Model { + single_const: 32.0, + linear1: nn::LinearConfig::new(20, 20).init::(), + array_const: [2, 2], + linear2: nn::LinearConfig::new(20, 20).init::(), + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_removed_optional_field( + name: &str, + recorder: R, + ) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = + format!("/tmp/deserialize_with_removed_optional_field-{name}").into(); + let model = ModelNewOptionalField { + single_const: 32.0, + linear1: nn::LinearConfig::new(20, 20).init::(), + array_const: [2, 2], + linear2: nn::LinearConfig::new(20, 20).init::(), + new_field: None, + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_new_constant_field(name: &str, recorder: R) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_new_constant_field-{name}").into(); + let model = Model { + single_const: 32.0, + array_const: [2, 2], + linear1: nn::LinearConfig::new(20, 20).init::(), + linear2: nn::LinearConfig::new(20, 20).init::(), + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_removed_constant_field( + name: &str, + recorder: R, + ) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = + format!("/tmp/deserialize_with_removed_constant_field-{name}").into(); + let model = ModelNewConstantField { + single_const: 32.0, + array_const: [2, 2], + linear1: nn::LinearConfig::new(20, 20).init::(), + linear2: nn::LinearConfig::new(20, 20).init::(), + new_field: 0, + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } + + fn deserialize_with_new_field_order(name: &str, recorder: R) -> Result<(), RecorderError> + where + R: FileRecorder, + { + let file_path: PathBuf = format!("/tmp/deserialize_with_new_field_order-{name}").into(); + let model = Model { + array_const: [2, 2], + single_const: 32.0, + linear1: nn::LinearConfig::new(20, 20).init::(), + linear2: nn::LinearConfig::new(20, 20).init::(), + }; + + recorder + .record(model.into_record(), file_path.clone()) + .unwrap(); + + let result = recorder.load::>(file_path.clone()); + std::fs::remove_file(file_path).ok(); + + result?; + Ok(()) + } } diff --git a/burn-dataset/examples/speech_commands.rs b/burn-dataset/examples/speech_commands.rs index 5b4ff7791d..cce7f131e1 100644 --- a/burn-dataset/examples/speech_commands.rs +++ b/burn-dataset/examples/speech_commands.rs @@ -3,21 +3,21 @@ use burn_dataset::{audio::SpeechCommandsDataset, Dataset}; #[cfg(feature = "audio")] fn speech_command() { - let index: usize = 4835; - let test = SpeechCommandsDataset::test(); - let item = test.get(index).unwrap(); + let index: usize = 4835; + let test = SpeechCommandsDataset::test(); + let item = test.get(index).unwrap(); - println!("Item: {:?}", item); - println!("Item Length: {:?}", item.audio_samples.len()); - println!("Label: {}", item.label.to_string()); + println!("Item: {:?}", item); + println!("Item Length: {:?}", item.audio_samples.len()); + println!("Label: {}", item.label.to_string()); - assert_eq!(test.len(), 4890); - assert_eq!(item.label.to_string(), "Yes"); - assert_eq!(item.sample_rate, 16000); - assert_eq!(item.audio_samples.len(), 16000); + assert_eq!(test.len(), 4890); + assert_eq!(item.label.to_string(), "Yes"); + assert_eq!(item.sample_rate, 16000); + assert_eq!(item.audio_samples.len(), 16000); } fn main() { - #[cfg(feature = "audio")] - speech_command() + #[cfg(feature = "audio")] + speech_command() } diff --git a/burn-dataset/src/audio/speech_commands.rs b/burn-dataset/src/audio/speech_commands.rs index 401f8965ac..28c2d34f20 100644 --- a/burn-dataset/src/audio/speech_commands.rs +++ b/burn-dataset/src/audio/speech_commands.rs @@ -1,6 +1,6 @@ use crate::{ - transform::{Mapper, MapperDataset}, - Dataset, HuggingfaceDatasetLoader, SqliteDataset, + transform::{Mapper, MapperDataset}, + Dataset, HuggingfaceDatasetLoader, SqliteDataset, }; use hound::WavReader; @@ -17,65 +17,65 @@ type MappedDataset = MapperDataset, ConvertSamples, #[allow(missing_docs)] #[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)] pub enum SpeechCommandClass { - // Target command words - Yes = 0, - No = 1, - Up = 2, - Down = 3, - Left = 4, - Right = 5, - On = 6, - Off = 7, - Stop = 8, - Go = 9, - Zero = 10, - One = 11, - Two = 12, - Three = 13, - Four = 14, - Five = 15, - Six = 16, - Seven = 17, - Eight = 18, - Nine = 19, - - // Non-target words that can be grouped into "Other" - Bed = 20, - Bird = 21, - Cat = 22, - Dog = 23, - Happy = 24, - House = 25, - Marvin = 26, - Sheila = 27, - Tree = 28, - Wow = 29, - - // Commands from v2 dataset, that can be grouped into "Other" - Backward = 30, - Forward = 31, - Follow = 32, - Learn = 33, - Visual = 34, - - // Background noise - Silence = 35, - - // Other miscellaneous words - Other = 36, + // Target command words + Yes = 0, + No = 1, + Up = 2, + Down = 3, + Left = 4, + Right = 5, + On = 6, + Off = 7, + Stop = 8, + Go = 9, + Zero = 10, + One = 11, + Two = 12, + Three = 13, + Four = 14, + Five = 15, + Six = 16, + Seven = 17, + Eight = 18, + Nine = 19, + + // Non-target words that can be grouped into "Other" + Bed = 20, + Bird = 21, + Cat = 22, + Dog = 23, + Happy = 24, + House = 25, + Marvin = 26, + Sheila = 27, + Tree = 28, + Wow = 29, + + // Commands from v2 dataset, that can be grouped into "Other" + Backward = 30, + Forward = 31, + Follow = 32, + Learn = 33, + Visual = 34, + + // Background noise + Silence = 35, + + // Other miscellaneous words + Other = 36, } /// Struct containing raw speech data returned from a database. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SpeechItemRaw { - /// Audio file bytes. - pub audio_bytes: Vec, + /// Audio file bytes. + pub audio_bytes: Vec, - /// Label index. - pub label: usize, + /// Label index. + pub label: usize, - /// Indicates if the label is unknown. - pub is_unknown: bool, + /// Indicates if the label is unknown. + pub is_unknown: bool, } /// Speech item with audio samples and label. @@ -88,14 +88,14 @@ pub struct SpeechItemRaw { /// The original label is also stored in the `label_original` field for debugging and remapping if needed. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SpeechItem { - /// Audio samples in the range [-1.0, 1.0]. - pub audio_samples: Vec, + /// Audio samples in the range [-1.0, 1.0]. + pub audio_samples: Vec, - /// The sample rate of the audio. - pub sample_rate: usize, + /// The sample rate of the audio. + pub sample_rate: usize, - /// The label of the audio. - pub label: SpeechCommandClass, + /// The label of the audio. + pub label: SpeechCommandClass, } /// Speech Commands dataset from Huggingface v0.02. @@ -114,95 +114,96 @@ pub struct SpeechItem { /// - test: 4,890 audio files /// - validation: 9,982 audio files pub struct SpeechCommandsDataset { - dataset: MappedDataset, + dataset: MappedDataset, } impl SpeechCommandsDataset { - /// Create a new dataset with the given split. - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("speech_commands") - .with_subset("v0.02") - .dataset(split) - .unwrap(); - let dataset = MapperDataset::new(dataset, ConvertSamples); - Self { dataset } - } - - /// Create a new dataset with the train split. - pub fn train() -> Self { - Self::new("train") - } - - /// Create a new dataset with the test split. - pub fn test() -> Self { - Self::new("test") - } - - /// Create a new dataset with the validation split. - pub fn validation() -> Self { - Self::new("validation") - } - - /// Returns the number of classes in the dataset - pub fn num_classes() -> usize { - SpeechCommandClass::COUNT - } + /// Create a new dataset with the given split. + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = + HuggingfaceDatasetLoader::new("speech_commands") + .with_subset("v0.02") + .dataset(split) + .unwrap(); + let dataset = MapperDataset::new(dataset, ConvertSamples); + Self { dataset } + } + + /// Create a new dataset with the train split. + pub fn train() -> Self { + Self::new("train") + } + + /// Create a new dataset with the test split. + pub fn test() -> Self { + Self::new("test") + } + + /// Create a new dataset with the validation split. + pub fn validation() -> Self { + Self::new("validation") + } + + /// Returns the number of classes in the dataset + pub fn num_classes() -> usize { + SpeechCommandClass::COUNT + } } impl Dataset for SpeechCommandsDataset { - fn get(&self, index: usize) -> Option { - self.dataset.get(index) - } + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } /// Mapper converting audio bytes into audio samples and the label to enum class. struct ConvertSamples; impl ConvertSamples { - /// Convert label to enum class. - fn to_speechcommandclass(label: usize) -> SpeechCommandClass { - SpeechCommandClass::from_repr(label).unwrap() - } - - /// Convert audio bytes into samples of floats [-1.0, 1.0]. - fn to_audiosamples(bytes: &Vec) -> (Vec, usize) { - let reader = WavReader::new(bytes.as_slice()).unwrap(); - let spec = reader.spec(); - - // Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample). - let max_value = (1 << (spec.bits_per_sample - 1)) as f32; - - // The sample rate of the audio. - let sample_rate = spec.sample_rate as usize; - - // Convert the audio samples to floats [-1.0, 1.0]. - let audio_samples: Vec = reader - .into_samples::() - .filter_map(Result::ok) - .map(|sample| sample as f32 / max_value) - .collect(); - - (audio_samples, sample_rate) - } + /// Convert label to enum class. + fn to_speechcommandclass(label: usize) -> SpeechCommandClass { + SpeechCommandClass::from_repr(label).unwrap() + } + + /// Convert audio bytes into samples of floats [-1.0, 1.0]. + fn to_audiosamples(bytes: &Vec) -> (Vec, usize) { + let reader = WavReader::new(bytes.as_slice()).unwrap(); + let spec = reader.spec(); + + // Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample). + let max_value = (1 << (spec.bits_per_sample - 1)) as f32; + + // The sample rate of the audio. + let sample_rate = spec.sample_rate as usize; + + // Convert the audio samples to floats [-1.0, 1.0]. + let audio_samples: Vec = reader + .into_samples::() + .filter_map(Result::ok) + .map(|sample| sample as f32 / max_value) + .collect(); + + (audio_samples, sample_rate) + } } impl Mapper for ConvertSamples { - /// Convert audio bytes into samples of floats [-1.0, 1.0] - /// and the label to enum class with the target word, other and silence classes. - fn map(&self, item: &SpeechItemRaw) -> SpeechItem { - let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes); - - // Convert the label to enum class, with the target words, other and silence classes. - let label = Self::to_speechcommandclass(item.label); - - SpeechItem { - audio_samples, - sample_rate, - label, + /// Convert audio bytes into samples of floats [-1.0, 1.0] + /// and the label to enum class with the target word, other and silence classes. + fn map(&self, item: &SpeechItemRaw) -> SpeechItem { + let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes); + + // Convert the label to enum class, with the target words, other and silence classes. + let label = Self::to_speechcommandclass(item.label); + + SpeechItem { + audio_samples, + sample_rate, + label, + } } - } } diff --git a/burn-dataset/src/dataset/base.rs b/burn-dataset/src/dataset/base.rs index 6f4caead7d..eb53980c94 100644 --- a/burn-dataset/src/dataset/base.rs +++ b/burn-dataset/src/dataset/base.rs @@ -4,68 +4,68 @@ use crate::DatasetIterator; /// The dataset trait defines a basic collection of items with a predefined size. pub trait Dataset: Send + Sync { - /// Gets the item at the given index. - fn get(&self, index: usize) -> Option; + /// Gets the item at the given index. + fn get(&self, index: usize) -> Option; - /// Gets the number of items in the dataset. - fn len(&self) -> usize; + /// Gets the number of items in the dataset. + fn len(&self) -> usize; - /// Checks if the dataset is empty. - fn is_empty(&self) -> bool { - self.len() == 0 - } + /// Checks if the dataset is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } - /// Returns an iterator over the dataset. - fn iter(&self) -> DatasetIterator<'_, I> - where - Self: Sized, - { - DatasetIterator::new(self) - } + /// Returns an iterator over the dataset. + fn iter(&self) -> DatasetIterator<'_, I> + where + Self: Sized, + { + DatasetIterator::new(self) + } } impl Dataset for Arc where - D: Dataset, + D: Dataset, { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } impl Dataset for Arc> { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } impl Dataset for Box where - D: Dataset, + D: Dataset, { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } impl Dataset for Box> { - fn get(&self, index: usize) -> Option { - self.as_ref().get(index) - } + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } - fn len(&self) -> usize { - self.as_ref().len() - } + fn len(&self) -> usize { + self.as_ref().len() + } } diff --git a/burn-dataset/src/dataset/fake.rs b/burn-dataset/src/dataset/fake.rs index af8762d5e0..c27f8cf0d3 100644 --- a/burn-dataset/src/dataset/fake.rs +++ b/burn-dataset/src/dataset/fake.rs @@ -3,36 +3,36 @@ use fake::{Dummy, Fake, Faker}; /// Dataset filled with fake items generated from the [fake](fake) crate. pub struct FakeDataset { - dataset: InMemDataset, + dataset: InMemDataset, } impl> FakeDataset { - /// Create a new fake dataset with the given size. - pub fn new(size: usize) -> Self { - let mut items = Vec::with_capacity(size); - for _ in 0..size { - items.push(Faker.fake()); - } - let dataset = InMemDataset::new(items); + /// Create a new fake dataset with the given size. + pub fn new(size: usize) -> Self { + let mut items = Vec::with_capacity(size); + for _ in 0..size { + items.push(Faker.fake()); + } + let dataset = InMemDataset::new(items); - Self { dataset } - } + Self { dataset } + } } impl Dataset for FakeDataset { - fn iter(&self) -> DatasetIterator<'_, I> { - DatasetIterator::new(self) - } + fn iter(&self) -> DatasetIterator<'_, I> { + DatasetIterator::new(self) + } - fn get(&self, index: usize) -> Option { - self.dataset.get(index) - } + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } - fn is_empty(&self) -> bool { - self.dataset.is_empty() - } + fn is_empty(&self) -> bool { + self.dataset.is_empty() + } } diff --git a/burn-dataset/src/dataset/in_memory.rs b/burn-dataset/src/dataset/in_memory.rs index 1091e8f08c..a3b167f0c7 100644 --- a/burn-dataset/src/dataset/in_memory.rs +++ b/burn-dataset/src/dataset/in_memory.rs @@ -1,7 +1,7 @@ use std::{ - fs::File, - io::{BufRead, BufReader}, - path::Path, + fs::File, + io::{BufRead, BufReader}, + path::Path, }; use serde::de::DeserializeOwned; @@ -10,162 +10,162 @@ use crate::Dataset; /// Dataset where all items are stored in ram. pub struct InMemDataset { - items: Vec, + items: Vec, } impl InMemDataset { - /// Creates a new in memory dataset from the given items. - pub fn new(items: Vec) -> Self { - InMemDataset { items } - } + /// Creates a new in memory dataset from the given items. + pub fn new(items: Vec) -> Self { + InMemDataset { items } + } } impl Dataset for InMemDataset where - I: Clone + Send + Sync, + I: Clone + Send + Sync, { - fn get(&self, index: usize) -> Option { - self.items.get(index).cloned() - } - fn len(&self) -> usize { - self.items.len() - } + fn get(&self, index: usize) -> Option { + self.items.get(index).cloned() + } + fn len(&self) -> usize { + self.items.len() + } } impl InMemDataset where - I: Clone + DeserializeOwned, + I: Clone + DeserializeOwned, { - /// Create from a dataset. All items are loaded in memory. - pub fn from_dataset(dataset: &impl Dataset) -> Self { - let items: Vec = dataset.iter().collect(); - Self::new(items) - } - - /// Create from a json rows file (one json per line). - /// - /// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html) - pub fn from_json_rows>(path: P) -> Result { - let file = File::open(path)?; - let reader = BufReader::new(file); - let mut items = Vec::new(); - - for line in reader.lines() { - let item = serde_json::from_str(line.unwrap().as_str()).unwrap(); - items.push(item); + /// Create from a dataset. All items are loaded in memory. + pub fn from_dataset(dataset: &impl Dataset) -> Self { + let items: Vec = dataset.iter().collect(); + Self::new(items) } - let dataset = Self::new(items); + /// Create from a json rows file (one json per line). + /// + /// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html) + pub fn from_json_rows>(path: P) -> Result { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut items = Vec::new(); - Ok(dataset) - } + for line in reader.lines() { + let item = serde_json::from_str(line.unwrap().as_str()).unwrap(); + items.push(item); + } - /// Create from a csv file. - /// - /// The first line of the csv file must be the header. The header must contain the name of the fields in the struct. - /// - /// The supported field types are: String, integer, float, and bool. - /// - /// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde) - pub fn from_csv>(path: P) -> Result { - let file = File::open(path)?; - let reader = BufReader::new(file); - let mut rdr = csv::Reader::from_reader(reader); + let dataset = Self::new(items); - let mut items = Vec::new(); - - for result in rdr.deserialize() { - let item: I = result?; - items.push(item); + Ok(dataset) } - let dataset = Self::new(items); + /// Create from a csv file. + /// + /// The first line of the csv file must be the header. The header must contain the name of the fields in the struct. + /// + /// The supported field types are: String, integer, float, and bool. + /// + /// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde) + pub fn from_csv>(path: P) -> Result { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut rdr = csv::Reader::from_reader(reader); + + let mut items = Vec::new(); - Ok(dataset) - } + for result in rdr.deserialize() { + let item: I = result?; + items.push(item); + } + + let dataset = Self::new(items); + + Ok(dataset) + } } #[cfg(test)] mod tests { - use super::*; - use crate::{test_data, SqliteDataset}; - - use rstest::{fixture, rstest}; - use serde::{Deserialize, Serialize}; - - const DB_FILE: &str = "tests/data/sqlite-dataset.db"; - const JSON_FILE: &str = "tests/data/dataset.json"; - const CSV_FILE: &str = "tests/data/dataset.csv"; - - type SqlDs = SqliteDataset; - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct Sample { - column_str: String, - column_bytes: Vec, - column_int: i64, - column_bool: bool, - column_float: f64, - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct SampleCvs { - column_str: String, - column_int: i64, - column_bool: bool, - column_float: f64, - } - - #[fixture] - fn train_dataset() -> SqlDs { - SqliteDataset::from_db_file(DB_FILE, "train").unwrap() - } - - #[rstest] - pub fn from_dataset(train_dataset: SqlDs) { - let dataset = InMemDataset::from_dataset(&train_dataset); - - let non_existing_record_index: usize = 10; - let record_index: usize = 0; - - assert_eq!(train_dataset.get(non_existing_record_index), None); - assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1"); - } - - #[test] - pub fn from_json_rows() { - let dataset = InMemDataset::::from_json_rows(JSON_FILE).unwrap(); - - let non_existing_record_index: usize = 10; - let record_index: usize = 1; - - assert_eq!(dataset.get(non_existing_record_index), None); - assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); - assert!(!dataset.get(record_index).unwrap().column_bool); - } - - #[test] - pub fn from_csv_rows() { - let dataset = InMemDataset::::from_csv(CSV_FILE).unwrap(); - - let non_existing_record_index: usize = 10; - let record_index: usize = 1; - - assert_eq!(dataset.get(non_existing_record_index), None); - assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); - assert_eq!(dataset.get(record_index).unwrap().column_int, 1); - assert!(!dataset.get(record_index).unwrap().column_bool); - assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0); - } - - #[test] - pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() { - let items_original = test_data::string_items(); - let dataset = InMemDataset::new(items_original.clone()); - - let items: Vec = dataset.iter().collect(); - - assert_eq!(items_original, items); - } + use super::*; + use crate::{test_data, SqliteDataset}; + + use rstest::{fixture, rstest}; + use serde::{Deserialize, Serialize}; + + const DB_FILE: &str = "tests/data/sqlite-dataset.db"; + const JSON_FILE: &str = "tests/data/dataset.json"; + const CSV_FILE: &str = "tests/data/dataset.csv"; + + type SqlDs = SqliteDataset; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct Sample { + column_str: String, + column_bytes: Vec, + column_int: i64, + column_bool: bool, + column_float: f64, + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct SampleCvs { + column_str: String, + column_int: i64, + column_bool: bool, + column_float: f64, + } + + #[fixture] + fn train_dataset() -> SqlDs { + SqliteDataset::from_db_file(DB_FILE, "train").unwrap() + } + + #[rstest] + pub fn from_dataset(train_dataset: SqlDs) { + let dataset = InMemDataset::from_dataset(&train_dataset); + + let non_existing_record_index: usize = 10; + let record_index: usize = 0; + + assert_eq!(train_dataset.get(non_existing_record_index), None); + assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1"); + } + + #[test] + pub fn from_json_rows() { + let dataset = InMemDataset::::from_json_rows(JSON_FILE).unwrap(); + + let non_existing_record_index: usize = 10; + let record_index: usize = 1; + + assert_eq!(dataset.get(non_existing_record_index), None); + assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); + assert!(!dataset.get(record_index).unwrap().column_bool); + } + + #[test] + pub fn from_csv_rows() { + let dataset = InMemDataset::::from_csv(CSV_FILE).unwrap(); + + let non_existing_record_index: usize = 10; + let record_index: usize = 1; + + assert_eq!(dataset.get(non_existing_record_index), None); + assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); + assert_eq!(dataset.get(record_index).unwrap().column_int, 1); + assert!(!dataset.get(record_index).unwrap().column_bool); + assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0); + } + + #[test] + pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() { + let items_original = test_data::string_items(); + let dataset = InMemDataset::new(items_original.clone()); + + let items: Vec = dataset.iter().collect(); + + assert_eq!(items_original, items); + } } diff --git a/burn-dataset/src/dataset/iterator.rs b/burn-dataset/src/dataset/iterator.rs index 4d4045d16a..e513e1a08c 100644 --- a/burn-dataset/src/dataset/iterator.rs +++ b/burn-dataset/src/dataset/iterator.rs @@ -3,29 +3,29 @@ use std::iter::Iterator; /// Dataset iterator. pub struct DatasetIterator<'a, I> { - current: usize, - dataset: &'a dyn Dataset, + current: usize, + dataset: &'a dyn Dataset, } impl<'a, I> DatasetIterator<'a, I> { - /// Creates a new dataset iterator. - pub fn new(dataset: &'a D) -> Self - where - D: Dataset, - { - DatasetIterator { - current: 0, - dataset, + /// Creates a new dataset iterator. + pub fn new(dataset: &'a D) -> Self + where + D: Dataset, + { + DatasetIterator { + current: 0, + dataset, + } } - } } impl<'a, I> Iterator for DatasetIterator<'a, I> { - type Item = I; + type Item = I; - fn next(&mut self) -> Option { - let item = self.dataset.get(self.current); - self.current += 1; - item - } + fn next(&mut self) -> Option { + let item = self.dataset.get(self.current); + self.current += 1; + item + } } diff --git a/burn-dataset/src/dataset/sqlite.rs b/burn-dataset/src/dataset/sqlite.rs index 1f6aa55877..3428642982 100644 --- a/burn-dataset/src/dataset/sqlite.rs +++ b/burn-dataset/src/dataset/sqlite.rs @@ -1,21 +1,21 @@ use std::{ - collections::HashSet, - fs, io, - marker::PhantomData, - path::{Path, PathBuf}, - sync::{Arc, RwLock}, + collections::HashSet, + fs, io, + marker::PhantomData, + path::{Path, PathBuf}, + sync::{Arc, RwLock}, }; use crate::Dataset; use gix_tempfile::{ - handle::{persist, Writable}, - AutoRemove, ContainingDirectory, Handle, + handle::{persist, Writable}, + AutoRemove, ContainingDirectory, Handle, }; use r2d2::{Pool, PooledConnection}; use r2d2_sqlite::{ - rusqlite::{OpenFlags, OptionalExtension}, - SqliteConnectionManager, + rusqlite::{OpenFlags, OptionalExtension}, + SqliteConnectionManager, }; use sanitize_filename::sanitize; use serde::{de::DeserializeOwned, Serialize}; @@ -27,39 +27,39 @@ pub type Result = core::result::Result; /// Sqlite dataset error. #[derive(thiserror::Error, Debug)] pub enum SqliteDatasetError { - /// IO related error. - #[error("IO error: {0}")] - Io(#[from] io::Error), + /// IO related error. + #[error("IO error: {0}")] + Io(#[from] io::Error), - /// Sql related error. - #[error("Sql error: {0}")] - Sql(#[from] serde_rusqlite::rusqlite::Error), + /// Sql related error. + #[error("Sql error: {0}")] + Sql(#[from] serde_rusqlite::rusqlite::Error), - /// Serde related error. - #[error("Serde error: {0}")] - Serde(#[from] rmp_serde::encode::Error), + /// Serde related error. + #[error("Serde error: {0}")] + Serde(#[from] rmp_serde::encode::Error), - /// The database file already exists error. - #[error("Overwrite flag is set to false and the database file already exists: {0}")] - FileExists(PathBuf), + /// The database file already exists error. + #[error("Overwrite flag is set to false and the database file already exists: {0}")] + FileExists(PathBuf), - /// Error when creating the connection pool. - #[error("Failed to create connection pool: {0}")] - ConnectionPool(#[from] r2d2::Error), + /// Error when creating the connection pool. + #[error("Failed to create connection pool: {0}")] + ConnectionPool(#[from] r2d2::Error), - /// Error when persisting the temporary database file. - #[error("Could not persist the temporary database file: {0}")] - PersistDbFile(#[from] persist::Error), + /// Error when persisting the temporary database file. + #[error("Could not persist the temporary database file: {0}")] + PersistDbFile(#[from] persist::Error), - /// Any other error. - #[error("{0}")] - Other(&'static str), + /// Any other error. + #[error("{0}")] + Other(&'static str), } impl From<&'static str> for SqliteDatasetError { - fn from(s: &'static str) -> Self { - SqliteDatasetError::Other(s) - } + fn from(s: &'static str) -> Self { + SqliteDatasetError::Other(s) + } } /// This struct represents a dataset where all items are stored in an SQLite database. @@ -89,322 +89,323 @@ impl From<&'static str> for SqliteDatasetError { /// method to read the data from the table. #[derive(Debug)] pub struct SqliteDataset { - db_file: PathBuf, - split: String, - conn_pool: Pool, - columns: Vec, - len: usize, - select_statement: String, - row_serialized: bool, - phantom: PhantomData, + db_file: PathBuf, + split: String, + conn_pool: Pool, + columns: Vec, + len: usize, + select_statement: String, + row_serialized: bool, + phantom: PhantomData, } impl SqliteDataset { - /// Initializes a `SqliteDataset` from a SQLite database file and a split name. - pub fn from_db_file>(db_file: P, split: &str) -> Result { - // Create a connection pool - let conn_pool = create_conn_pool(&db_file, false)?; - - // Determine how the table is stored - let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?; - - // Create a select statement and save it - let select_statement = if row_serialized { - format!("select item from {split} where row_id = ?") - } else { - format!("select * from {split} where row_id = ?") - }; - - // Save the column names and the number of rows - let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?; - - Ok(SqliteDataset { - db_file: db_file.as_ref().to_path_buf(), - split: split.to_string(), - conn_pool, - columns, - len, - select_statement, - row_serialized, - phantom: PhantomData, - }) - } - - /// Returns true if table has two columns: row_id (integer) and item (blob). - /// - /// This is used to determine if the table is row serialized or not. - fn check_if_row_serialized( - conn_pool: &Pool, - split: &str, - ) -> Result { - // This struct is used to store the column name and type - struct Column { - name: String, - ty: String, - } - - const COLUMN_NAME: usize = 1; - const COLUMN_TYPE: usize = 2; - - let sql_statement = format!("PRAGMA table_info({split})"); - - let conn = conn_pool.get()?; - - let mut stmt = conn.prepare(sql_statement.as_str())?; - let column_iter = stmt.query_map([], |row| { - Ok(Column { - name: row - .get::(COLUMN_NAME) - .unwrap() - .to_lowercase(), - ty: row - .get::(COLUMN_TYPE) - .unwrap() - .to_lowercase(), - }) - })?; + /// Initializes a `SqliteDataset` from a SQLite database file and a split name. + pub fn from_db_file>(db_file: P, split: &str) -> Result { + // Create a connection pool + let conn_pool = create_conn_pool(&db_file, false)?; + + // Determine how the table is stored + let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?; + + // Create a select statement and save it + let select_statement = if row_serialized { + format!("select item from {split} where row_id = ?") + } else { + format!("select * from {split} where row_id = ?") + }; + + // Save the column names and the number of rows + let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?; + + Ok(SqliteDataset { + db_file: db_file.as_ref().to_path_buf(), + split: split.to_string(), + conn_pool, + columns, + len, + select_statement, + row_serialized, + phantom: PhantomData, + }) + } - let mut columns: Vec = vec![]; + /// Returns true if table has two columns: row_id (integer) and item (blob). + /// + /// This is used to determine if the table is row serialized or not. + fn check_if_row_serialized( + conn_pool: &Pool, + split: &str, + ) -> Result { + // This struct is used to store the column name and type + struct Column { + name: String, + ty: String, + } + + const COLUMN_NAME: usize = 1; + const COLUMN_TYPE: usize = 2; + + let sql_statement = format!("PRAGMA table_info({split})"); + + let conn = conn_pool.get()?; + + let mut stmt = conn.prepare(sql_statement.as_str())?; + let column_iter = stmt.query_map([], |row| { + Ok(Column { + name: row + .get::(COLUMN_NAME) + .unwrap() + .to_lowercase(), + ty: row + .get::(COLUMN_TYPE) + .unwrap() + .to_lowercase(), + }) + })?; + + let mut columns: Vec = vec![]; + + for column in column_iter { + columns.push(column?); + } + + if columns.len() != 2 { + Ok(false) + } else { + // Check if the column names and types match the expected values + Ok(columns[0].name == "row_id" + && columns[0].ty == "integer" + && columns[1].name == "item" + && columns[1].ty == "blob") + } + } - for column in column_iter { - columns.push(column?); + /// Get the database file name. + pub fn db_file(&self) -> PathBuf { + self.db_file.clone() } - if columns.len() != 2 { - Ok(false) - } else { - // Check if the column names and types match the expected values - Ok( - columns[0].name == "row_id" - && columns[0].ty == "integer" - && columns[1].name == "item" - && columns[1].ty == "blob", - ) - } - } - - /// Get the database file name. - pub fn db_file(&self) -> PathBuf { - self.db_file.clone() - } - - /// Get the split name. - pub fn split(&self) -> &str { - self.split.as_str() - } + /// Get the split name. + pub fn split(&self) -> &str { + self.split.as_str() + } } impl Dataset for SqliteDataset where - I: Clone + Send + Sync + DeserializeOwned, + I: Clone + Send + Sync + DeserializeOwned, { - /// Get an item from the dataset. - fn get(&self, index: usize) -> Option { - // Row ids start with 1 (one) and index starts with 0 (zero) - let row_id = index + 1; - - // Get a connection from the pool - let connection = self.conn_pool.get().unwrap(); - let mut statement = connection.prepare(self.select_statement.as_str()).unwrap(); - - if self.row_serialized { - // Fetch with a single column `item` and deserialize it with MessagePack - statement - .query_row([row_id], |row| { - // Deserialize item (blob) with MessagePack (rmp-serde) - Ok(rmp_serde::from_slice::(row.get_ref(0).unwrap().as_blob().unwrap()).unwrap()) - }) - .optional() //Converts Error (not found) to None - .unwrap() - } else { - // Fetch a row with multiple columns and deserialize it serde_rusqlite - statement - .query_row([row_id], |row| { - // Deserialize the row with serde_rusqlite - Ok(from_row_with_columns::(row, &self.columns).unwrap()) - }) - .optional() //Converts Error (not found) to None - .unwrap() + /// Get an item from the dataset. + fn get(&self, index: usize) -> Option { + // Row ids start with 1 (one) and index starts with 0 (zero) + let row_id = index + 1; + + // Get a connection from the pool + let connection = self.conn_pool.get().unwrap(); + let mut statement = connection.prepare(self.select_statement.as_str()).unwrap(); + + if self.row_serialized { + // Fetch with a single column `item` and deserialize it with MessagePack + statement + .query_row([row_id], |row| { + // Deserialize item (blob) with MessagePack (rmp-serde) + Ok( + rmp_serde::from_slice::(row.get_ref(0).unwrap().as_blob().unwrap()) + .unwrap(), + ) + }) + .optional() //Converts Error (not found) to None + .unwrap() + } else { + // Fetch a row with multiple columns and deserialize it serde_rusqlite + statement + .query_row([row_id], |row| { + // Deserialize the row with serde_rusqlite + Ok(from_row_with_columns::(row, &self.columns).unwrap()) + }) + .optional() //Converts Error (not found) to None + .unwrap() + } } - } - /// Return the number of rows in the dataset. - fn len(&self) -> usize { - self.len - } + /// Return the number of rows in the dataset. + fn len(&self) -> usize { + self.len + } } /// Fetch the column names and the number of rows from the database. fn fetch_columns_and_len( - conn_pool: &Pool, - select_statement: &str, - split: &str, + conn_pool: &Pool, + select_statement: &str, + split: &str, ) -> Result<(Vec, usize)> { - // Save the column names - let connection = conn_pool.get()?; - let statement = connection.prepare(select_statement)?; - let columns = columns_from_statement(&statement); - - // Count the number of rows and save it as len - // - // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables. - // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id, - // which corresponds to the number of rows in the table. - // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps. - // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index. - let mut statement = - connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?; - - let len = statement.query_row([], |row| { - let len: usize = row.get(0)?; - Ok(len) - })?; - Ok((columns, len)) + // Save the column names + let connection = conn_pool.get()?; + let statement = connection.prepare(select_statement)?; + let columns = columns_from_statement(&statement); + + // Count the number of rows and save it as len + // + // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables. + // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id, + // which corresponds to the number of rows in the table. + // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps. + // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index. + let mut statement = + connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?; + + let len = statement.query_row([], |row| { + let len: usize = row.get(0)?; + Ok(len) + })?; + Ok((columns, len)) } /// Helper function to create a connection pool fn create_conn_pool>( - db_file: P, - write: bool, + db_file: P, + write: bool, ) -> Result> { - let sqlite_flags = if write { - OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE - } else { - OpenFlags::SQLITE_OPEN_READ_ONLY - }; + let sqlite_flags = if write { + OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE + } else { + OpenFlags::SQLITE_OPEN_READ_ONLY + }; - // Create a connection pool and make sure the connections are read only - let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags); + // Create a connection pool and make sure the connections are read only + let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags); - Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool) + Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool) } /// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets. /// It consists of an optional name, a database file path, and a base directory for storage. #[derive(Clone, Debug)] pub struct SqliteDatasetStorage { - name: Option, - db_file: Option, - base_dir: Option, + name: Option, + db_file: Option, + base_dir: Option, } impl SqliteDatasetStorage { - /// Creates a new instance of `SqliteDatasetStorage` using a dataset name. - /// - /// # Arguments - /// - /// * `name` - A string slice that holds the name of the dataset. - pub fn from_name(name: &str) -> Self { - SqliteDatasetStorage { - name: Some(name.to_string()), - db_file: None, - base_dir: None, - } - } - - /// Creates a new instance of `SqliteDatasetStorage` using a database file path. - /// - /// # Arguments - /// - /// * `db_file` - A reference to the Path that represents the database file path. - pub fn from_file>(db_file: P) -> Self { - SqliteDatasetStorage { - name: None, - db_file: Some(db_file.as_ref().to_path_buf()), - base_dir: None, - } - } - - /// Sets the base directory for storing the dataset. - /// - /// # Arguments - /// - /// * `base_dir` - A string slice that represents the base directory. - pub fn with_base_dir>(mut self, base_dir: P) -> Self { - self.base_dir = Some(base_dir.as_ref().to_path_buf()); - self - } - - /// Checks if the database file exists in the given path. - /// - /// # Returns - /// - /// * A boolean value indicating whether the file exists or not. - pub fn exists(&self) -> bool { - self.db_file().exists() - } - - /// Fetches the database file path. - /// - /// # Returns - /// - /// * A `PathBuf` instance representing the file path. - pub fn db_file(&self) -> PathBuf { - let db_file = match &self.db_file { - Some(db_file) => db_file.clone(), - None => { - let name = sanitize(self.name.as_ref().expect("Name is not set")); - Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db")) - } - }; - db_file - } - - /// Determines the base directory for storing the dataset. - /// - /// # Arguments - /// - /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory. - /// - /// # Returns - /// - /// * A `PathBuf` instance representing the base directory. - pub fn base_dir(base_dir: Option) -> PathBuf { - match base_dir { - Some(base_dir) => base_dir, - None => { - let home_dir = dirs::home_dir().expect("Could not get home directory"); - - home_dir.join(".cache").join("burn-dataset") - } - } - } - - /// Provides a writer instance for the SQLite dataset. - /// - /// # Arguments - /// - /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. - pub fn writer(&self, overwrite: bool) -> Result> - where - I: Clone + Send + Sync + Serialize + DeserializeOwned, - { - SqliteDatasetWriter::new(self.db_file(), overwrite) - } - - /// Provides a reader instance for the SQLite dataset. - /// - /// # Arguments - /// - /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test"). - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise. - pub fn reader(&self, split: &str) -> Result> - where - I: Clone + Send + Sync + Serialize + DeserializeOwned, - { - if !self.exists() { - panic!("The database file does not exist"); + /// Creates a new instance of `SqliteDatasetStorage` using a dataset name. + /// + /// # Arguments + /// + /// * `name` - A string slice that holds the name of the dataset. + pub fn from_name(name: &str) -> Self { + SqliteDatasetStorage { + name: Some(name.to_string()), + db_file: None, + base_dir: None, + } + } + + /// Creates a new instance of `SqliteDatasetStorage` using a database file path. + /// + /// # Arguments + /// + /// * `db_file` - A reference to the Path that represents the database file path. + pub fn from_file>(db_file: P) -> Self { + SqliteDatasetStorage { + name: None, + db_file: Some(db_file.as_ref().to_path_buf()), + base_dir: None, + } + } + + /// Sets the base directory for storing the dataset. + /// + /// # Arguments + /// + /// * `base_dir` - A string slice that represents the base directory. + pub fn with_base_dir>(mut self, base_dir: P) -> Self { + self.base_dir = Some(base_dir.as_ref().to_path_buf()); + self } - SqliteDataset::from_db_file(self.db_file(), split) - } + /// Checks if the database file exists in the given path. + /// + /// # Returns + /// + /// * A boolean value indicating whether the file exists or not. + pub fn exists(&self) -> bool { + self.db_file().exists() + } + + /// Fetches the database file path. + /// + /// # Returns + /// + /// * A `PathBuf` instance representing the file path. + pub fn db_file(&self) -> PathBuf { + let db_file = match &self.db_file { + Some(db_file) => db_file.clone(), + None => { + let name = sanitize(self.name.as_ref().expect("Name is not set")); + Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db")) + } + }; + db_file + } + + /// Determines the base directory for storing the dataset. + /// + /// # Arguments + /// + /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory. + /// + /// # Returns + /// + /// * A `PathBuf` instance representing the base directory. + pub fn base_dir(base_dir: Option) -> PathBuf { + match base_dir { + Some(base_dir) => base_dir, + None => { + let home_dir = dirs::home_dir().expect("Could not get home directory"); + + home_dir.join(".cache").join("burn-dataset") + } + } + } + + /// Provides a writer instance for the SQLite dataset. + /// + /// # Arguments + /// + /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. + pub fn writer(&self, overwrite: bool) -> Result> + where + I: Clone + Send + Sync + Serialize + DeserializeOwned, + { + SqliteDatasetWriter::new(self.db_file(), overwrite) + } + + /// Provides a reader instance for the SQLite dataset. + /// + /// # Arguments + /// + /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test"). + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise. + pub fn reader(&self, split: &str) -> Result> + where + I: Clone + Send + Sync + Serialize + DeserializeOwned, + { + if !self.exists() { + panic!("The database file does not exist"); + } + + SqliteDataset::from_db_file(self.db_file(), split) + } } /// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets. @@ -419,190 +420,190 @@ impl SqliteDatasetStorage { /// - Enlargement of a dataset's item count post preprocessing #[derive(Debug)] pub struct SqliteDatasetWriter { - db_file: PathBuf, - db_file_tmp: Option>, - splits: Arc>>, - overwrite: bool, - conn_pool: Option>, - is_completed: Arc>, - phantom: PhantomData, + db_file: PathBuf, + db_file_tmp: Option>, + splits: Arc>>, + overwrite: bool, + conn_pool: Option>, + is_completed: Arc>, + phantom: PhantomData, } impl SqliteDatasetWriter where - I: Clone + Send + Sync + Serialize + DeserializeOwned, + I: Clone + Send + Sync + Serialize + DeserializeOwned, { - /// Creates a new instance of `SqliteDatasetWriter`. - /// - /// # Arguments - /// - /// * `db_file` - A reference to the Path that represents the database file path. - /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. - pub fn new>(db_file: P, overwrite: bool) -> Result { - let writer = Self { - db_file: db_file.as_ref().to_path_buf(), - db_file_tmp: None, - splits: Arc::new(RwLock::new(HashSet::new())), - overwrite, - conn_pool: None, - is_completed: Arc::new(RwLock::new(false)), - phantom: PhantomData, - }; + /// Creates a new instance of `SqliteDatasetWriter`. + /// + /// # Arguments + /// + /// * `db_file` - A reference to the Path that represents the database file path. + /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. + pub fn new>(db_file: P, overwrite: bool) -> Result { + let writer = Self { + db_file: db_file.as_ref().to_path_buf(), + db_file_tmp: None, + splits: Arc::new(RwLock::new(HashSet::new())), + overwrite, + conn_pool: None, + is_completed: Arc::new(RwLock::new(false)), + phantom: PhantomData, + }; + + writer.init() + } - writer.init() - } - - /// Initializes the dataset writer by creating the database file, tables, and connection pool. - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise. - fn init(mut self) -> Result { - // Remove the db file if it already exists - if self.db_file.exists() { - if self.overwrite { - fs::remove_file(&self.db_file)?; - } else { - return Err(SqliteDatasetError::FileExists(self.db_file)); - } - } - - // Create the database file directory if it does not exist - let db_file_dir = self - .db_file - .parent() - .ok_or("Unable to get parent directory")?; - - if !db_file_dir.exists() { - fs::create_dir_all(db_file_dir)?; - } - - // Create a temp database file name as {base_dir}/{name}.db.tmp - let mut db_file_tmp = self.db_file.clone(); - db_file_tmp.set_extension("db.tmp"); - if db_file_tmp.exists() { - fs::remove_file(&db_file_tmp)?; - } - - // Create the temp database file and wrap it with a gix_tempfile::Handle - // This will ensure that the temp file is deleted when the writer is dropped - // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this) - gix_tempfile::signal::setup(Default::default()); - self.db_file_tmp = Some(gix_tempfile::writable_at( - &db_file_tmp, - ContainingDirectory::Exists, - AutoRemove::Tempfile, - )?); - - let conn_pool = create_conn_pool(db_file_tmp, true)?; - self.conn_pool = Some(conn_pool); - - Ok(self) - } - - /// Serializes and writes an item to the database. The item is written to the table for the - /// specified split. If the table does not exist, it is created. If the table exists, the item - /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/) - /// - /// # Arguments - /// - /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test"). - /// * `item` - A reference to the item to be written to the database. - /// - /// # Returns - /// - /// * A `Result` containing the index of the inserted row if successful, an error otherwise. - pub fn write(&self, split: &str, item: &I) -> Result { - // Acquire the read lock (wont't block other reads) - let is_completed = self.is_completed.read().unwrap(); - - // If the writer is completed, return an error - if *is_completed { - return Err(SqliteDatasetError::Other( - "Cannot save to a completed dataset writer", - )); - } - - // create the table for the split if it does not exist - if !self.splits.read().unwrap().contains(split) { - self.create_table(split)?; - } - - // Get a connection from the pool - let conn_pool = self.conn_pool.as_ref().unwrap(); - let conn = conn_pool.get()?; - - // Serialize the item using MessagePack - let serialized_item = rmp_serde::to_vec(item)?; - - // Turn off the synchronous and journal mode for speed up - // We are sacrificing durability for speed but it's okay because - // we always recreate the dataset if it is not completed. - pragma_update_with_error_handling(&conn, "synchronous", "OFF")?; - pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?; - - // Insert the serialized item into the database - let insert_statement = format!("insert into {split} (item) values (?)", split = split); - conn.execute(insert_statement.as_str(), [serialized_item])?; - - // Get the primary key of the last inserted row and convert to index (row_id-1) - let index = (conn.last_insert_rowid() - 1) as usize; - - Ok(index) - } - - /// Marks the dataset as completed and persists the temporary database file. - pub fn set_completed(&mut self) -> Result<()> { - let mut is_completed = self.is_completed.write().unwrap(); - - // Rename the database file from tmp to db - let _file_result = self - .db_file_tmp - .take() // take ownership of the temporary file and set to None - .unwrap() // unwrap the temporary file - .persist(&self.db_file)? - .ok_or("Unable to persist the database file")?; - - *is_completed = true; - Ok(()) - } - - /// Creates table for the data split. - /// - /// Note: call is idempotent and thread-safe. - /// - /// # Arguments - /// - /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test"). - /// - /// # Returns - /// - /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise. - /// - /// TODO (@antimora): add support creating a table with columns corresponding to the item fields - fn create_table(&self, split: &str) -> Result<()> { - // Check if the split already exists - if self.splits.read().unwrap().contains(split) { - return Ok(()); - } - - let conn_pool = self.conn_pool.as_ref().unwrap(); - let connection = conn_pool.get()?; - let create_table_statement = format!( + /// Initializes the dataset writer by creating the database file, tables, and connection pool. + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise. + fn init(mut self) -> Result { + // Remove the db file if it already exists + if self.db_file.exists() { + if self.overwrite { + fs::remove_file(&self.db_file)?; + } else { + return Err(SqliteDatasetError::FileExists(self.db_file)); + } + } + + // Create the database file directory if it does not exist + let db_file_dir = self + .db_file + .parent() + .ok_or("Unable to get parent directory")?; + + if !db_file_dir.exists() { + fs::create_dir_all(db_file_dir)?; + } + + // Create a temp database file name as {base_dir}/{name}.db.tmp + let mut db_file_tmp = self.db_file.clone(); + db_file_tmp.set_extension("db.tmp"); + if db_file_tmp.exists() { + fs::remove_file(&db_file_tmp)?; + } + + // Create the temp database file and wrap it with a gix_tempfile::Handle + // This will ensure that the temp file is deleted when the writer is dropped + // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this) + gix_tempfile::signal::setup(Default::default()); + self.db_file_tmp = Some(gix_tempfile::writable_at( + &db_file_tmp, + ContainingDirectory::Exists, + AutoRemove::Tempfile, + )?); + + let conn_pool = create_conn_pool(db_file_tmp, true)?; + self.conn_pool = Some(conn_pool); + + Ok(self) + } + + /// Serializes and writes an item to the database. The item is written to the table for the + /// specified split. If the table does not exist, it is created. If the table exists, the item + /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/) + /// + /// # Arguments + /// + /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test"). + /// * `item` - A reference to the item to be written to the database. + /// + /// # Returns + /// + /// * A `Result` containing the index of the inserted row if successful, an error otherwise. + pub fn write(&self, split: &str, item: &I) -> Result { + // Acquire the read lock (wont't block other reads) + let is_completed = self.is_completed.read().unwrap(); + + // If the writer is completed, return an error + if *is_completed { + return Err(SqliteDatasetError::Other( + "Cannot save to a completed dataset writer", + )); + } + + // create the table for the split if it does not exist + if !self.splits.read().unwrap().contains(split) { + self.create_table(split)?; + } + + // Get a connection from the pool + let conn_pool = self.conn_pool.as_ref().unwrap(); + let conn = conn_pool.get()?; + + // Serialize the item using MessagePack + let serialized_item = rmp_serde::to_vec(item)?; + + // Turn off the synchronous and journal mode for speed up + // We are sacrificing durability for speed but it's okay because + // we always recreate the dataset if it is not completed. + pragma_update_with_error_handling(&conn, "synchronous", "OFF")?; + pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?; + + // Insert the serialized item into the database + let insert_statement = format!("insert into {split} (item) values (?)", split = split); + conn.execute(insert_statement.as_str(), [serialized_item])?; + + // Get the primary key of the last inserted row and convert to index (row_id-1) + let index = (conn.last_insert_rowid() - 1) as usize; + + Ok(index) + } + + /// Marks the dataset as completed and persists the temporary database file. + pub fn set_completed(&mut self) -> Result<()> { + let mut is_completed = self.is_completed.write().unwrap(); + + // Rename the database file from tmp to db + let _file_result = self + .db_file_tmp + .take() // take ownership of the temporary file and set to None + .unwrap() // unwrap the temporary file + .persist(&self.db_file)? + .ok_or("Unable to persist the database file")?; + + *is_completed = true; + Ok(()) + } + + /// Creates table for the data split. + /// + /// Note: call is idempotent and thread-safe. + /// + /// # Arguments + /// + /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test"). + /// + /// # Returns + /// + /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise. + /// + /// TODO (@antimora): add support creating a table with columns corresponding to the item fields + fn create_table(&self, split: &str) -> Result<()> { + // Check if the split already exists + if self.splits.read().unwrap().contains(split) { + return Ok(()); + } + + let conn_pool = self.conn_pool.as_ref().unwrap(); + let connection = conn_pool.get()?; + let create_table_statement = format!( "create table if not exists {split} (row_id integer primary key autoincrement not null, item blob not null)" ); - connection.execute(create_table_statement.as_str(), [])?; + connection.execute(create_table_statement.as_str(), [])?; - // Add the split to the splits - self.splits.write().unwrap().insert(split.to_string()); + // Add the split to the splits + self.splits.write().unwrap().insert(split.to_string()); - Ok(()) - } + Ok(()) + } } /// Runs a pragma update and ignores the `ExecuteReturnedResults` error. @@ -611,235 +612,237 @@ where /// and can be ignored. This function runs the pragma update and ignores the error if it is /// `ExecuteReturnedResults`. fn pragma_update_with_error_handling( - conn: &PooledConnection, - setting: &str, - value: &str, + conn: &PooledConnection, + setting: &str, + value: &str, ) -> Result<()> { - let result = conn.pragma_update(None, setting, value); - if let Err(error) = result { - if error != rusqlite::Error::ExecuteReturnedResults { - return Err(SqliteDatasetError::Sql(error)); + let result = conn.pragma_update(None, setting, value); + if let Err(error) = result { + if error != rusqlite::Error::ExecuteReturnedResults { + return Err(SqliteDatasetError::Sql(error)); + } } - } - Ok(()) + Ok(()) } #[cfg(test)] mod tests { - use rayon::prelude::*; - use rstest::{fixture, rstest}; - use serde::{Deserialize, Serialize}; - use tempfile::{tempdir, NamedTempFile, TempDir}; - - use super::*; - - type SqlDs = SqliteDataset; - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct Sample { - column_str: String, - column_bytes: Vec, - column_int: i64, - column_bool: bool, - column_float: f64, - } - - #[fixture] - fn train_dataset() -> SqlDs { - SqliteDataset::::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap() - } - - #[rstest] - pub fn len(train_dataset: SqlDs) { - assert_eq!(train_dataset.len(), 2); - } - - #[rstest] - pub fn get_some(train_dataset: SqlDs) { - let item = train_dataset.get(0).unwrap(); - assert_eq!(item.column_str, "HI1"); - assert_eq!(item.column_bytes, vec![55, 231, 159]); - assert_eq!(item.column_int, 1); - assert!(item.column_bool); - assert_eq!(item.column_float, 1.0); - } - - #[rstest] - pub fn get_none(train_dataset: SqlDs) { - assert_eq!(train_dataset.get(10), None); - } - - #[rstest] - pub fn multi_thread(train_dataset: SqlDs) { - let indices: Vec = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1]; - let results: Vec> = indices.par_iter().map(|&i| train_dataset.get(i)).collect(); - - let mut match_count = 0; - for (_index, result) in indices.iter().zip(results.iter()) { - match result { - Some(_val) => match_count += 1, - None => (), - } - } - - assert_eq!(match_count, 5); - } - - #[test] - fn sqlite_dataset_storage() { - // Test with non-existing file - let storage = SqliteDatasetStorage::from_file("non-existing.db"); - assert!(!storage.exists()); - - // Test with non-existing name - let storage = SqliteDatasetStorage::from_name("non-existing.db"); - assert!(!storage.exists()); - - // Test with existing file - let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db"); - assert!(storage.exists()); - let result = storage.reader::("train"); - assert!(result.is_ok()); - let train = result.unwrap(); - assert_eq!(train.len(), 2); - - // Test get writer - let temp_file = NamedTempFile::new().unwrap(); - let storage = SqliteDatasetStorage::from_file(temp_file.path()); - assert!(storage.exists()); - let result = storage.writer::(true); - assert!(result.is_ok()); - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct Complex { - column_str: String, - column_bytes: Vec, - column_int: i64, - column_bool: bool, - column_float: f64, - column_complex: Vec>>, - } - - /// Create a temporary directory. - #[fixture] - fn tmp_dir() -> TempDir { - // Create a TempDir. This object will be automatically - // deleted when it goes out of scope. - tempdir().unwrap() - } - type Writer = SqliteDatasetWriter; - - /// Create a SqliteDatasetWriter with a temporary directory. - /// Make sure to return the temporary directory so that it is not deleted. - #[fixture] - fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) { - let temp_dir_str = tmp_dir.path(); - let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str); - let overwrite = true; - let result = storage.writer::(overwrite); - assert!(result.is_ok()); - let writer = result.unwrap(); - (writer, tmp_dir) - } - - #[test] - fn test_new() { - // Test that the constructor works with overwrite = true - let test_path = NamedTempFile::new().unwrap(); - let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); - assert!(!test_path.path().exists()); - - // Test that the constructor works with overwrite = false - let test_path = NamedTempFile::new().unwrap(); - let result = SqliteDatasetWriter::::new(&test_path, false); - assert!(result.is_err()); - - // Test that the constructor works with no existing file - let temp = NamedTempFile::new().unwrap(); - let test_path = temp.path().to_path_buf(); - assert!(temp.close().is_ok()); - assert!(!test_path.exists()); - let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); - assert!(!test_path.exists()); - } - - #[rstest] - pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) { - // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) - let (writer, _tmp_dir) = writer_fixture; - - assert!(writer.overwrite); - assert!(!writer.db_file.exists()); - - let new_item = Complex { - column_str: "HI1".to_string(), - column_bytes: vec![1_u8, 2, 3], - column_int: 0, - column_bool: true, - column_float: 1.0, - column_complex: vec![vec![vec![[1, 23_u8, 3]]]], - }; + use rayon::prelude::*; + use rstest::{fixture, rstest}; + use serde::{Deserialize, Serialize}; + use tempfile::{tempdir, NamedTempFile, TempDir}; + + use super::*; + + type SqlDs = SqliteDataset; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct Sample { + column_str: String, + column_bytes: Vec, + column_int: i64, + column_bool: bool, + column_float: f64, + } - let index = writer.write("train", &new_item).unwrap(); - assert_eq!(index, 0); + #[fixture] + fn train_dataset() -> SqlDs { + SqliteDataset::::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap() + } - let mut writer = writer; + #[rstest] + pub fn len(train_dataset: SqlDs) { + assert_eq!(train_dataset.len(), 2); + } - writer.set_completed().expect("Failed to set completed"); + #[rstest] + pub fn get_some(train_dataset: SqlDs) { + let item = train_dataset.get(0).unwrap(); + assert_eq!(item.column_str, "HI1"); + assert_eq!(item.column_bytes, vec![55, 231, 159]); + assert_eq!(item.column_int, 1); + assert!(item.column_bool); + assert_eq!(item.column_float, 1.0); + } - assert!(writer.db_file.exists()); - assert!(writer.db_file_tmp.is_none()); + #[rstest] + pub fn get_none(train_dataset: SqlDs) { + assert_eq!(train_dataset.get(10), None); + } - let result = writer.write("train", &new_item); + #[rstest] + pub fn multi_thread(train_dataset: SqlDs) { + let indices: Vec = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1]; + let results: Vec> = + indices.par_iter().map(|&i| train_dataset.get(i)).collect(); + + let mut match_count = 0; + for (_index, result) in indices.iter().zip(results.iter()) { + match result { + Some(_val) => match_count += 1, + None => (), + } + } + + assert_eq!(match_count, 5); + } + + #[test] + fn sqlite_dataset_storage() { + // Test with non-existing file + let storage = SqliteDatasetStorage::from_file("non-existing.db"); + assert!(!storage.exists()); + + // Test with non-existing name + let storage = SqliteDatasetStorage::from_name("non-existing.db"); + assert!(!storage.exists()); + + // Test with existing file + let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db"); + assert!(storage.exists()); + let result = storage.reader::("train"); + assert!(result.is_ok()); + let train = result.unwrap(); + assert_eq!(train.len(), 2); + + // Test get writer + let temp_file = NamedTempFile::new().unwrap(); + let storage = SqliteDatasetStorage::from_file(temp_file.path()); + assert!(storage.exists()); + let result = storage.writer::(true); + assert!(result.is_ok()); + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct Complex { + column_str: String, + column_bytes: Vec, + column_int: i64, + column_bool: bool, + column_float: f64, + column_complex: Vec>>, + } + + /// Create a temporary directory. + #[fixture] + fn tmp_dir() -> TempDir { + // Create a TempDir. This object will be automatically + // deleted when it goes out of scope. + tempdir().unwrap() + } + type Writer = SqliteDatasetWriter; + + /// Create a SqliteDatasetWriter with a temporary directory. + /// Make sure to return the temporary directory so that it is not deleted. + #[fixture] + fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) { + let temp_dir_str = tmp_dir.path(); + let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str); + let overwrite = true; + let result = storage.writer::(overwrite); + assert!(result.is_ok()); + let writer = result.unwrap(); + (writer, tmp_dir) + } + + #[test] + fn test_new() { + // Test that the constructor works with overwrite = true + let test_path = NamedTempFile::new().unwrap(); + let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); + assert!(!test_path.path().exists()); + + // Test that the constructor works with overwrite = false + let test_path = NamedTempFile::new().unwrap(); + let result = SqliteDatasetWriter::::new(&test_path, false); + assert!(result.is_err()); + + // Test that the constructor works with no existing file + let temp = NamedTempFile::new().unwrap(); + let test_path = temp.path().to_path_buf(); + assert!(temp.close().is_ok()); + assert!(!test_path.exists()); + let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); + assert!(!test_path.exists()); + } - // Should fail because the writer is completed - assert!(result.is_err()); + #[rstest] + pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) { + // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) + let (writer, _tmp_dir) = writer_fixture; - let dataset = SqliteDataset::::from_db_file(writer.db_file, "train").unwrap(); + assert!(writer.overwrite); + assert!(!writer.db_file.exists()); - let fetched_item = dataset.get(0).unwrap(); - assert_eq!(fetched_item, new_item); - assert_eq!(dataset.len(), 1); - } + let new_item = Complex { + column_str: "HI1".to_string(), + column_bytes: vec![1_u8, 2, 3], + column_int: 0, + column_bool: true, + column_float: 1.0, + column_complex: vec![vec![vec![[1, 23_u8, 3]]]], + }; - #[rstest] - pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) { - // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) - let (writer, _tmp_dir) = writer_fixture; + let index = writer.write("train", &new_item).unwrap(); + assert_eq!(index, 0); - let writer = Arc::new(writer); - let record_count = 20; + let mut writer = writer; - let splits = ["train", "test"]; + writer.set_completed().expect("Failed to set completed"); - (0..record_count).into_par_iter().for_each(|index: i64| { - let thread_id: std::thread::ThreadId = std::thread::current().id(); - let sample = Complex { - column_str: format!("test_{:?}_{}", thread_id, index), - column_bytes: vec![index as u8, 2, 3], - column_int: index, - column_bool: true, - column_float: 1.0, - column_complex: vec![vec![vec![[1, index as u8, 3]]]], - }; + assert!(writer.db_file.exists()); + assert!(writer.db_file_tmp.is_none()); - // half for train and half for test - let split = splits[index as usize % 2]; + let result = writer.write("train", &new_item); - let _index = writer.write(split, &sample).unwrap(); - }); + // Should fail because the writer is completed + assert!(result.is_err()); - let mut writer = Arc::try_unwrap(writer).unwrap(); + let dataset = SqliteDataset::::from_db_file(writer.db_file, "train").unwrap(); - writer - .set_completed() - .expect("Should set completed successfully"); + let fetched_item = dataset.get(0).unwrap(); + assert_eq!(fetched_item, new_item); + assert_eq!(dataset.len(), 1); + } + + #[rstest] + pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) { + // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) + let (writer, _tmp_dir) = writer_fixture; + + let writer = Arc::new(writer); + let record_count = 20; + + let splits = ["train", "test"]; + + (0..record_count).into_par_iter().for_each(|index: i64| { + let thread_id: std::thread::ThreadId = std::thread::current().id(); + let sample = Complex { + column_str: format!("test_{:?}_{}", thread_id, index), + column_bytes: vec![index as u8, 2, 3], + column_int: index, + column_bool: true, + column_float: 1.0, + column_complex: vec![vec![vec![[1, index as u8, 3]]]], + }; - let train = SqliteDataset::::from_db_file(writer.db_file.clone(), "train").unwrap(); - let test = SqliteDataset::::from_db_file(writer.db_file, "test").unwrap(); + // half for train and half for test + let split = splits[index as usize % 2]; - assert_eq!(train.len(), record_count as usize / 2); - assert_eq!(test.len(), record_count as usize / 2); - } + let _index = writer.write(split, &sample).unwrap(); + }); + + let mut writer = Arc::try_unwrap(writer).unwrap(); + + writer + .set_completed() + .expect("Should set completed successfully"); + + let train = + SqliteDataset::::from_db_file(writer.db_file.clone(), "train").unwrap(); + let test = SqliteDataset::::from_db_file(writer.db_file, "test").unwrap(); + + assert_eq!(train.len(), record_count as usize / 2); + assert_eq!(test.len(), record_count as usize / 2); + } } diff --git a/burn-dataset/src/lib.rs b/burn-dataset/src/lib.rs index 1719a87de2..1d8045878a 100644 --- a/burn-dataset/src/lib.rs +++ b/burn-dataset/src/lib.rs @@ -26,12 +26,12 @@ pub use source::huggingface::downloader::*; #[cfg(test)] mod test_data { - pub fn string_items() -> Vec { - vec![ - "1 Item".to_string(), - "2 Items".to_string(), - "3 Items".to_string(), - "4 Items".to_string(), - ] - } + pub fn string_items() -> Vec { + vec![ + "1 Item".to_string(), + "2 Items".to_string(), + "3 Items".to_string(), + "4 Items".to_string(), + ] + } } diff --git a/burn-dataset/src/source/huggingface/downloader.rs b/burn-dataset/src/source/huggingface/downloader.rs index b6a8121b20..9d3ef277dc 100644 --- a/burn-dataset/src/source/huggingface/downloader.rs +++ b/burn-dataset/src/source/huggingface/downloader.rs @@ -17,25 +17,25 @@ const VENV_BIN_PYTHON: &str = "Scripts\\python"; /// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader). #[derive(Error, Debug)] pub enum ImporterError { - /// Unknown error. - #[error("unknown: `{0}`")] - Unknown(String), + /// Unknown error. + #[error("unknown: `{0}`")] + Unknown(String), - /// Fail to download python dependencies. - #[error("fail to download python dependencies: `{0}`")] - FailToDownloadPythonDependencies(String), + /// Fail to download python dependencies. + #[error("fail to download python dependencies: `{0}`")] + FailToDownloadPythonDependencies(String), - /// Fail to create sqlite dataset. - #[error("sqlite dataset: `{0}`")] - SqliteDataset(#[from] SqliteDatasetError), + /// Fail to create sqlite dataset. + #[error("sqlite dataset: `{0}`")] + SqliteDataset(#[from] SqliteDatasetError), - /// python3 is not installed. - #[error("python3 is not installed")] - PythonNotInstalled, + /// python3 is not installed. + #[error("python3 is not installed")] + PythonNotInstalled, - /// venv environment is not initialized. - #[error("venv environment is not initialized")] - VenvNotInitialized, + /// venv environment is not initialized. + #[error("venv environment is not initialized")] + VenvNotInitialized, } /// Load a dataset from [huggingface datasets](https://huggingface.co/datasets). @@ -58,237 +58,237 @@ pub enum ImporterError { /// .dataset("train") /// .unwrap(); pub struct HuggingfaceDatasetLoader { - name: String, - subset: Option, - base_dir: Option, - huggingface_token: Option, - huggingface_cache_dir: Option, + name: String, + subset: Option, + base_dir: Option, + huggingface_token: Option, + huggingface_cache_dir: Option, } impl HuggingfaceDatasetLoader { - /// Create a huggingface dataset loader. - pub fn new(name: &str) -> Self { - Self { - name: name.to_string(), - subset: None, - base_dir: None, - huggingface_token: None, - huggingface_cache_dir: None, + /// Create a huggingface dataset loader. + pub fn new(name: &str) -> Self { + Self { + name: name.to_string(), + subset: None, + base_dir: None, + huggingface_token: None, + huggingface_cache_dir: None, + } + } + + /// Create a huggingface dataset loader for a subset of the dataset. + /// + /// The subset name must be one of the subsets listed in the dataset page. + /// + /// If no subset names are listed, then do not use this method. + pub fn with_subset(mut self, subset: &str) -> Self { + self.subset = Some(subset.to_string()); + self + } + + /// Specify a base directory to store the dataset. + /// + /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`. + pub fn with_base_dir(mut self, base_dir: &str) -> Self { + self.base_dir = Some(base_dir.into()); + self + } + + /// Specify a huggingface token to download datasets behind authentication. + /// + /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens) + pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self { + self.huggingface_token = Some(huggingface_token.to_string()); + self } - } - - /// Create a huggingface dataset loader for a subset of the dataset. - /// - /// The subset name must be one of the subsets listed in the dataset page. - /// - /// If no subset names are listed, then do not use this method. - pub fn with_subset(mut self, subset: &str) -> Self { - self.subset = Some(subset.to_string()); - self - } - - /// Specify a base directory to store the dataset. - /// - /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`. - pub fn with_base_dir(mut self, base_dir: &str) -> Self { - self.base_dir = Some(base_dir.into()); - self - } - - /// Specify a huggingface token to download datasets behind authentication. - /// - /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens) - pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self { - self.huggingface_token = Some(huggingface_token.to_string()); - self - } - - /// Specify a huggingface cache directory to store the downloaded datasets. - /// - /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`. - pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self { - self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string()); - self - } - - /// Load the dataset. - pub fn dataset( - self, - split: &str, - ) -> Result, ImporterError> { - let db_file = self.db_file()?; - let dataset = SqliteDataset::from_db_file(db_file, split)?; - Ok(dataset) - } - - /// Get the path to the sqlite database file. - /// - /// If the database file does not exist, it will be downloaded and imported. - pub fn db_file(self) -> Result { - // determine (and create if needed) the base directory - let base_dir = SqliteDatasetStorage::base_dir(self.base_dir); - - if !base_dir.exists() { - create_dir_all(&base_dir).expect("Failed to create base directory"); + + /// Specify a huggingface cache directory to store the downloaded datasets. + /// + /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`. + pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self { + self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string()); + self } - //sanitize the name and subset - let name = sanitize(self.name.as_str()); - - // create the db file path - let db_file_name = if let Some(subset) = self.subset.clone() { - format!("{}-{}.db", name, sanitize(subset.as_str())) - } else { - format!("{}.db", name) - }; - - let db_file = base_dir.join(db_file_name); - - // import the dataset if needed - if !Path::new(&db_file).exists() { - import( - self.name, - self.subset, - db_file.clone(), - base_dir, - self.huggingface_token, - self.huggingface_cache_dir, - )?; + /// Load the dataset. + pub fn dataset( + self, + split: &str, + ) -> Result, ImporterError> { + let db_file = self.db_file()?; + let dataset = SqliteDataset::from_db_file(db_file, split)?; + Ok(dataset) } - Ok(db_file) - } + /// Get the path to the sqlite database file. + /// + /// If the database file does not exist, it will be downloaded and imported. + pub fn db_file(self) -> Result { + // determine (and create if needed) the base directory + let base_dir = SqliteDatasetStorage::base_dir(self.base_dir); + + if !base_dir.exists() { + create_dir_all(&base_dir).expect("Failed to create base directory"); + } + + //sanitize the name and subset + let name = sanitize(self.name.as_str()); + + // create the db file path + let db_file_name = if let Some(subset) = self.subset.clone() { + format!("{}-{}.db", name, sanitize(subset.as_str())) + } else { + format!("{}.db", name) + }; + + let db_file = base_dir.join(db_file_name); + + // import the dataset if needed + if !Path::new(&db_file).exists() { + import( + self.name, + self.subset, + db_file.clone(), + base_dir, + self.huggingface_token, + self.huggingface_cache_dir, + )?; + } + + Ok(db_file) + } } /// Import a dataset from huggingface. The transformed dataset is stored as sqlite database. fn import( - name: String, - subset: Option, - base_file: PathBuf, - base_dir: PathBuf, - huggingface_token: Option, - huggingface_cache_dir: Option, + name: String, + subset: Option, + base_file: PathBuf, + base_dir: PathBuf, + huggingface_token: Option, + huggingface_cache_dir: Option, ) -> Result<(), ImporterError> { - let venv_python_path = install_python_deps(&base_dir)?; + let venv_python_path = install_python_deps(&base_dir)?; - let mut command = Command::new(venv_python_path); + let mut command = Command::new(venv_python_path); - command.arg(importer_script_path(&base_dir)); + command.arg(importer_script_path(&base_dir)); - command.arg("--name"); - command.arg(name); + command.arg("--name"); + command.arg(name); - command.arg("--file"); - command.arg(base_file); + command.arg("--file"); + command.arg(base_file); - if let Some(subset) = subset { - command.arg("--subset"); - command.arg(subset); - } + if let Some(subset) = subset { + command.arg("--subset"); + command.arg(subset); + } - if let Some(huggingface_token) = huggingface_token { - command.arg("--token"); - command.arg(huggingface_token); - } + if let Some(huggingface_token) = huggingface_token { + command.arg("--token"); + command.arg(huggingface_token); + } - if let Some(huggingface_cache_dir) = huggingface_cache_dir { - command.arg("--cache_dir"); - command.arg(huggingface_cache_dir); - } + if let Some(huggingface_cache_dir) = huggingface_cache_dir { + command.arg("--cache_dir"); + command.arg(huggingface_cache_dir); + } - let mut handle = command.spawn().unwrap(); - handle - .wait() - .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?; + let mut handle = command.spawn().unwrap(); + handle + .wait() + .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?; - Ok(()) + Ok(()) } /// check python --version output is `Python 3.x.x` fn check_python_version_is_3(python: &str) -> bool { - let output = Command::new(python).arg("--version").output(); - match output { - Ok(output) => { - if output.status.success() { - let version_string = String::from_utf8_lossy(&output.stdout); - if let Some(index) = version_string.find(' ') { - let version = &version_string[index + 1..]; - version.starts_with("3.") - } else { - false + let output = Command::new(python).arg("--version").output(); + match output { + Ok(output) => { + if output.status.success() { + let version_string = String::from_utf8_lossy(&output.stdout); + if let Some(index) = version_string.find(' ') { + let version = &version_string[index + 1..]; + version.starts_with("3.") + } else { + false + } + } else { + false + } } - } else { - false - } + Err(_error) => false, } - Err(_error) => false, - } } /// get python3 name `python` `python3` or `py` fn get_python_name() -> Result<&'static str, ImporterError> { - let python_name_list = ["python3", "python", "py"]; - for python_name in python_name_list.iter() { - if check_python_version_is_3(python_name) { - return Ok(python_name); + let python_name_list = ["python3", "python", "py"]; + for python_name in python_name_list.iter() { + if check_python_version_is_3(python_name) { + return Ok(python_name); + } } - } - Err(ImporterError::PythonNotInstalled) + Err(ImporterError::PythonNotInstalled) } fn importer_script_path(base_dir: &Path) -> PathBuf { - let path_file = base_dir.join("importer.py"); + let path_file = base_dir.join("importer.py"); - fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader"); - path_file + fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader"); + path_file } fn install_python_deps(base_dir: &Path) -> Result { - let venv_dir = base_dir.join("venv"); - let venv_python_path = venv_dir.join(VENV_BIN_PYTHON); - // If the venv environment is already initialized, skip the initialization. - if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { - let python_name = get_python_name()?; - let mut command = Command::new(python_name); + let venv_dir = base_dir.join("venv"); + let venv_python_path = venv_dir.join(VENV_BIN_PYTHON); + // If the venv environment is already initialized, skip the initialization. + if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { + let python_name = get_python_name()?; + let mut command = Command::new(python_name); + command.args([ + "-m", + "venv", + venv_dir + .as_os_str() + .to_str() + .expect("Path utf8 conversion should not fail"), + ]); + + // Spawn the venv creation process and wait for it to complete. + let mut handle = command.spawn().unwrap(); + + handle.wait().map_err(|err| { + ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)) + })?; + // Check if the venv environment can be used successfully." + if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { + return Err(ImporterError::VenvNotInitialized); + } + } + + let mut command = Command::new(&venv_python_path); command.args([ - "-m", - "venv", - venv_dir - .as_os_str() - .to_str() - .expect("Path utf8 conversion should not fail"), + "-m", + "pip", + "--quiet", + "install", + "pyarrow", + "sqlalchemy", + "Pillow", + "soundfile", + "datasets", ]); - // Spawn the venv creation process and wait for it to complete. + // Spawn the pip install process and wait for it to complete. let mut handle = command.spawn().unwrap(); + handle.wait().map_err(|err| { + ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)) + })?; - handle - .wait() - .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)))?; - // Check if the venv environment can be used successfully." - if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { - return Err(ImporterError::VenvNotInitialized); - } - } - - let mut command = Command::new(&venv_python_path); - command.args([ - "-m", - "pip", - "--quiet", - "install", - "pyarrow", - "sqlalchemy", - "Pillow", - "soundfile", - "datasets", - ]); - - // Spawn the pip install process and wait for it to complete. - let mut handle = command.spawn().unwrap(); - handle - .wait() - .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)))?; - - Ok(venv_python_path) + Ok(venv_python_path) } diff --git a/burn-dataset/src/source/huggingface/mnist.rs b/burn-dataset/src/source/huggingface/mnist.rs index 6126141dbc..88b37180ac 100644 --- a/burn-dataset/src/source/huggingface/mnist.rs +++ b/burn-dataset/src/source/huggingface/mnist.rs @@ -11,43 +11,43 @@ const HEIGHT: usize = 28; /// MNIST item. #[derive(Deserialize, Serialize, Debug, Clone)] pub struct MNISTItem { - /// Image as a 2D array of floats. - pub image: [[f32; WIDTH]; HEIGHT], + /// Image as a 2D array of floats. + pub image: [[f32; WIDTH]; HEIGHT], - /// Label of the image. - pub label: usize, + /// Label of the image. + pub label: usize, } #[derive(Deserialize, Debug, Clone)] struct MNISTItemRaw { - pub image_bytes: Vec, - pub label: usize, + pub image_bytes: Vec, + pub label: usize, } struct BytesToImage; impl Mapper for BytesToImage { - /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). - fn map(&self, item: &MNISTItemRaw) -> MNISTItem { - let image = image::load_from_memory(&item.image_bytes).unwrap(); - let image = image.as_luma8().unwrap(); - - // Ensure the image dimensions are correct. - debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32)); - - // Convert the image to a 2D array of floats. - let mut image_array = [[0f32; WIDTH]; HEIGHT]; - for (i, pixel) in image.as_raw().iter().enumerate() { - let x = i % WIDTH; - let y = i / HEIGHT; - image_array[y][x] = *pixel as f32; + /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). + fn map(&self, item: &MNISTItemRaw) -> MNISTItem { + let image = image::load_from_memory(&item.image_bytes).unwrap(); + let image = image.as_luma8().unwrap(); + + // Ensure the image dimensions are correct. + debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32)); + + // Convert the image to a 2D array of floats. + let mut image_array = [[0f32; WIDTH]; HEIGHT]; + for (i, pixel) in image.as_raw().iter().enumerate() { + let x = i % WIDTH; + let y = i / HEIGHT; + image_array[y][x] = *pixel as f32; + } + + MNISTItem { + image: image_array, + label: item.label, + } } - - MNISTItem { - image: image_array, - label: item.label, - } - } } type MappedDataset = MapperDataset, BytesToImage, MNISTItemRaw>; @@ -56,37 +56,37 @@ type MappedDataset = MapperDataset, BytesToImage, MN /// /// The data is downloaded from Huggingface and stored in a SQLite database. pub struct MNISTDataset { - dataset: MappedDataset, + dataset: MappedDataset, } impl Dataset for MNISTDataset { - fn get(&self, index: usize) -> Option { - self.dataset.get(index) - } + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } impl MNISTDataset { - /// Creates a new train dataset. - pub fn train() -> Self { - Self::new("train") - } + /// Creates a new train dataset. + pub fn train() -> Self { + Self::new("train") + } - /// Creates a new test dataset. - pub fn test() -> Self { - Self::new("test") - } + /// Creates a new test dataset. + pub fn test() -> Self { + Self::new("test") + } - fn new(split: &str) -> Self { - let dataset = HuggingfaceDatasetLoader::new("mnist") - .dataset(split) - .unwrap(); + fn new(split: &str) -> Self { + let dataset = HuggingfaceDatasetLoader::new("mnist") + .dataset(split) + .unwrap(); - let dataset = MapperDataset::new(dataset, BytesToImage); + let dataset = MapperDataset::new(dataset, BytesToImage); - Self { dataset } - } + Self { dataset } + } } diff --git a/burn-dataset/src/transform/composed.rs b/burn-dataset/src/transform/composed.rs index 0903e7059e..8f26bd5976 100644 --- a/burn-dataset/src/transform/composed.rs +++ b/burn-dataset/src/transform/composed.rs @@ -3,29 +3,29 @@ use crate::Dataset; /// Compose multiple datasets together to create a bigger one. #[derive(new)] pub struct ComposedDataset { - datasets: Vec, + datasets: Vec, } impl Dataset for ComposedDataset where - D: Dataset, - I: Clone, + D: Dataset, + I: Clone, { - fn get(&self, index: usize) -> Option { - let mut current_index = 0; - for dataset in self.datasets.iter() { - if index < dataset.len() + current_index { - return dataset.get(index - current_index); - } - current_index += dataset.len(); + fn get(&self, index: usize) -> Option { + let mut current_index = 0; + for dataset in self.datasets.iter() { + if index < dataset.len() + current_index { + return dataset.get(index - current_index); + } + current_index += dataset.len(); + } + None } - None - } - fn len(&self) -> usize { - let mut total = 0; - for dataset in self.datasets.iter() { - total += dataset.len(); + fn len(&self) -> usize { + let mut total = 0; + for dataset in self.datasets.iter() { + total += dataset.len(); + } + total } - total - } } diff --git a/burn-dataset/src/transform/mapper.rs b/burn-dataset/src/transform/mapper.rs index 1cf39ea2aa..b089a375ed 100644 --- a/burn-dataset/src/transform/mapper.rs +++ b/burn-dataset/src/transform/mapper.rs @@ -3,58 +3,58 @@ use std::marker::PhantomData; /// Basic mapper trait to be used with the [mapper dataset](MapperDataset). pub trait Mapper: Send + Sync { - /// Maps an item of type I to an item of type O. - fn map(&self, item: &I) -> O; + /// Maps an item of type I to an item of type O. + fn map(&self, item: &I) -> O; } /// Dataset mapping each element in an inner dataset to another element type lazily. #[derive(new)] pub struct MapperDataset { - dataset: D, - mapper: M, - input: PhantomData, + dataset: D, + mapper: M, + input: PhantomData, } impl Dataset for MapperDataset where - D: Dataset, - M: Mapper + Send + Sync, - I: Send + Sync, - O: Send + Sync, + D: Dataset, + M: Mapper + Send + Sync, + I: Send + Sync, + O: Send + Sync, { - fn get(&self, index: usize) -> Option { - let item = self.dataset.get(index); - item.map(|item| self.mapper.map(&item)) - } - - fn len(&self) -> usize { - self.dataset.len() - } + fn get(&self, index: usize) -> Option { + let item = self.dataset.get(index); + item.map(|item| self.mapper.map(&item)) + } + + fn len(&self) -> usize { + self.dataset.len() + } } #[cfg(test)] mod tests { - use super::*; - use crate::{test_data, InMemDataset}; - - #[test] - pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() { - struct StringToFirstChar; - - impl Mapper for StringToFirstChar { - fn map(&self, item: &String) -> String { - let mut item = item.clone(); - item.truncate(1); - item - } - } + use super::*; + use crate::{test_data, InMemDataset}; - let items_original = test_data::string_items(); - let dataset = InMemDataset::new(items_original); - let dataset = MapperDataset::new(dataset, StringToFirstChar); + #[test] + pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() { + struct StringToFirstChar; - let items: Vec = dataset.iter().collect(); + impl Mapper for StringToFirstChar { + fn map(&self, item: &String) -> String { + let mut item = item.clone(); + item.truncate(1); + item + } + } - assert_eq!(vec!["1", "2", "3", "4"], items); - } + let items_original = test_data::string_items(); + let dataset = InMemDataset::new(items_original); + let dataset = MapperDataset::new(dataset, StringToFirstChar); + + let items: Vec = dataset.iter().collect(); + + assert_eq!(vec!["1", "2", "3", "4"], items); + } } diff --git a/burn-dataset/src/transform/partial.rs b/burn-dataset/src/transform/partial.rs index 2cc018b3ee..c8bd53f08b 100644 --- a/burn-dataset/src/transform/partial.rs +++ b/burn-dataset/src/transform/partial.rs @@ -4,136 +4,136 @@ use std::{marker::PhantomData, sync::Arc}; /// Only use a fraction of an existing dataset lazily. #[derive(new)] pub struct PartialDataset { - dataset: D, - start_index: usize, - end_index: usize, - input: PhantomData, + dataset: D, + start_index: usize, + end_index: usize, + input: PhantomData, } impl PartialDataset where - D: Dataset, + D: Dataset, { - /// Splits a dataset into multiple partial datasets. - pub fn split(dataset: D, num: usize) -> Vec, I>> { - let dataset = Arc::new(dataset); // cheap cloning. + /// Splits a dataset into multiple partial datasets. + pub fn split(dataset: D, num: usize) -> Vec, I>> { + let dataset = Arc::new(dataset); // cheap cloning. - let mut current = 0; - let mut datasets = Vec::with_capacity(num); + let mut current = 0; + let mut datasets = Vec::with_capacity(num); - let batch_size = dataset.len() / num; + let batch_size = dataset.len() / num; - for i in 0..num { - let start = current; - let mut end = current + batch_size; + for i in 0..num { + let start = current; + let mut end = current + batch_size; - if i == (num - 1) { - end = dataset.len(); - } + if i == (num - 1) { + end = dataset.len(); + } - let dataset = PartialDataset::new(dataset.clone(), start, end); + let dataset = PartialDataset::new(dataset.clone(), start, end); - current += batch_size; - datasets.push(dataset); - } + current += batch_size; + datasets.push(dataset); + } - datasets - } + datasets + } } impl Dataset for PartialDataset where - D: Dataset, - I: Clone + Send + Sync, + D: Dataset, + I: Clone + Send + Sync, { - fn get(&self, index: usize) -> Option { - let index = index + self.start_index; - if index < self.start_index || index >= self.end_index { - return None; + fn get(&self, index: usize) -> Option { + let index = index + self.start_index; + if index < self.start_index || index >= self.end_index { + return None; + } + self.dataset.get(index) } - self.dataset.get(index) - } - fn len(&self) -> usize { - usize::min(self.end_index - self.start_index, self.dataset.len()) - } + fn len(&self) -> usize { + usize::min(self.end_index - self.start_index, self.dataset.len()) + } } #[cfg(test)] mod tests { - use super::*; - use crate::FakeDataset; - use std::collections::HashSet; - - #[test] - fn test_start_from_beginning() { - let dataset_original = FakeDataset::::new(27); - let mut items_original_1 = HashSet::new(); - let mut items_original_2 = HashSet::new(); - let mut items_partial = HashSet::new(); - dataset_original.iter().enumerate().for_each(|(i, item)| { - match i >= 10 { - true => items_original_2.insert(item), - false => items_original_1.insert(item), - }; - }); - - let dataset_partial = PartialDataset::new(dataset_original, 0, 10); - - for item in dataset_partial.iter() { - items_partial.insert(item); + use super::*; + use crate::FakeDataset; + use std::collections::HashSet; + + #[test] + fn test_start_from_beginning() { + let dataset_original = FakeDataset::::new(27); + let mut items_original_1 = HashSet::new(); + let mut items_original_2 = HashSet::new(); + let mut items_partial = HashSet::new(); + dataset_original.iter().enumerate().for_each(|(i, item)| { + match i >= 10 { + true => items_original_2.insert(item), + false => items_original_1.insert(item), + }; + }); + + let dataset_partial = PartialDataset::new(dataset_original, 0, 10); + + for item in dataset_partial.iter() { + items_partial.insert(item); + } + + assert_eq!(dataset_partial.len(), 10); + assert_eq!(items_original_1, items_partial); + for item in items_original_2 { + assert!(!items_partial.contains(&item)); + } } - assert_eq!(dataset_partial.len(), 10); - assert_eq!(items_original_1, items_partial); - for item in items_original_2 { - assert!(!items_partial.contains(&item)); - } - } - - #[test] - fn test_start_inside() { - let dataset_original = FakeDataset::::new(27); - let mut items_original_1 = HashSet::new(); - let mut items_original_2 = HashSet::new(); - let mut items_partial = HashSet::new(); - - dataset_original.iter().enumerate().for_each(|(i, item)| { - match !(10..20).contains(&i) { - true => items_original_2.insert(item), - false => items_original_1.insert(item), - }; - }); - - let dataset_partial = PartialDataset::new(dataset_original, 10, 20); - for item in dataset_partial.iter() { - items_partial.insert(item); + #[test] + fn test_start_inside() { + let dataset_original = FakeDataset::::new(27); + let mut items_original_1 = HashSet::new(); + let mut items_original_2 = HashSet::new(); + let mut items_partial = HashSet::new(); + + dataset_original.iter().enumerate().for_each(|(i, item)| { + match !(10..20).contains(&i) { + true => items_original_2.insert(item), + false => items_original_1.insert(item), + }; + }); + + let dataset_partial = PartialDataset::new(dataset_original, 10, 20); + for item in dataset_partial.iter() { + items_partial.insert(item); + } + + assert_eq!(dataset_partial.len(), 10); + assert_eq!(items_original_1, items_partial); + for item in items_original_2 { + assert!(!items_partial.contains(&item)); + } } - assert_eq!(dataset_partial.len(), 10); - assert_eq!(items_original_1, items_partial); - for item in items_original_2 { - assert!(!items_partial.contains(&item)); - } - } - - #[test] - fn test_split_contains_all_items_without_duplicates() { - let dataset_original = FakeDataset::::new(27); - let mut items_original = Vec::new(); - let mut items_partial = Vec::new(); - for item in dataset_original.iter() { - items_original.push(item); - } + #[test] + fn test_split_contains_all_items_without_duplicates() { + let dataset_original = FakeDataset::::new(27); + let mut items_original = Vec::new(); + let mut items_partial = Vec::new(); + for item in dataset_original.iter() { + items_original.push(item); + } - let dataset_partials = PartialDataset::split(dataset_original, 4); + let dataset_partials = PartialDataset::split(dataset_original, 4); - for dataset in dataset_partials { - for item in dataset.iter() { - items_partial.push(item); - } - } + for dataset in dataset_partials { + for item in dataset.iter() { + items_partial.push(item); + } + } - assert_eq!(items_original, items_partial); - } + assert_eq!(items_original, items_partial); + } } diff --git a/burn-dataset/src/transform/random.rs b/burn-dataset/src/transform/random.rs index 0dd79754bc..5b9de9d8d3 100644 --- a/burn-dataset/src/transform/random.rs +++ b/burn-dataset/src/transform/random.rs @@ -5,51 +5,51 @@ use std::marker::PhantomData; /// Shuffled a dataset, consider using [sampler dataset](crate::transform::SamplerDataset) is you /// want a probability distribution that is computed lazily. pub struct ShuffledDataset { - dataset: D, - indices: Vec, - input: PhantomData, + dataset: D, + indices: Vec, + input: PhantomData, } impl ShuffledDataset where - D: Dataset, + D: Dataset, { - /// Creates a new shuffled dataset. - pub fn new(dataset: D, rng: &mut StdRng) -> Self { - let mut indices = Vec::with_capacity(dataset.len()); - for i in 0..dataset.len() { - indices.push(i); - } - indices.shuffle(rng); + /// Creates a new shuffled dataset. + pub fn new(dataset: D, rng: &mut StdRng) -> Self { + let mut indices = Vec::with_capacity(dataset.len()); + for i in 0..dataset.len() { + indices.push(i); + } + indices.shuffle(rng); - Self { - dataset, - indices, - input: PhantomData, + Self { + dataset, + indices, + input: PhantomData, + } } - } - /// Creates a new shuffled dataset with a fixed seed. - pub fn with_seed(dataset: D, seed: u64) -> Self { - let mut rng = StdRng::seed_from_u64(seed); - Self::new(dataset, &mut rng) - } + /// Creates a new shuffled dataset with a fixed seed. + pub fn with_seed(dataset: D, seed: u64) -> Self { + let mut rng = StdRng::seed_from_u64(seed); + Self::new(dataset, &mut rng) + } } impl Dataset for ShuffledDataset where - D: Dataset, - I: Clone + Send + Sync, + D: Dataset, + I: Clone + Send + Sync, { - fn get(&self, index: usize) -> Option { - let index = match self.indices.get(index) { - Some(index) => index, - None => return None, - }; - self.dataset.get(*index) - } + fn get(&self, index: usize) -> Option { + let index = match self.indices.get(index) { + Some(index) => index, + None => return None, + }; + self.dataset.get(*index) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } diff --git a/burn-dataset/src/transform/sampler.rs b/burn-dataset/src/transform/sampler.rs index 69b6c1639e..b3c0077ab8 100644 --- a/burn-dataset/src/transform/sampler.rs +++ b/burn-dataset/src/transform/sampler.rs @@ -15,132 +15,132 @@ use std::{marker::PhantomData, ops::DerefMut, sync::Mutex}; /// set the dataset to an arbitrary size. Once every item has been used, a new cycle is /// created with a new random suffle. pub struct SamplerDataset { - dataset: D, - size: usize, - state: Mutex, - input: PhantomData, + dataset: D, + size: usize, + state: Mutex, + input: PhantomData, } enum SamplerState { - WithReplacement(StdRng), - WithoutReplacement(StdRng, Vec), + WithReplacement(StdRng), + WithoutReplacement(StdRng, Vec), } impl SamplerDataset where - D: Dataset, - I: Send + Sync, + D: Dataset, + I: Send + Sync, { - /// Creates a new sampler dataset with replacement. - pub fn new(dataset: D, size: usize) -> Self { - Self { - dataset, - size, - state: Mutex::new(SamplerState::WithReplacement(StdRng::from_entropy())), - input: PhantomData, - } - } - - /// Creates a new sampler dataset with replacement. - pub fn with_replacement(dataset: D, size: usize) -> Self { - Self::new(dataset, size) - } - - /// Creates a new sampler dataset without replacement. - pub fn without_replacement(dataset: D, size: usize) -> Self { - Self { - dataset, - size, - state: Mutex::new(SamplerState::WithoutReplacement( - StdRng::from_entropy(), - Vec::new(), - )), - input: PhantomData, + /// Creates a new sampler dataset with replacement. + pub fn new(dataset: D, size: usize) -> Self { + Self { + dataset, + size, + state: Mutex::new(SamplerState::WithReplacement(StdRng::from_entropy())), + input: PhantomData, + } } - } - fn index(&self) -> usize { - let mut state = self.state.lock().unwrap(); + /// Creates a new sampler dataset with replacement. + pub fn with_replacement(dataset: D, size: usize) -> Self { + Self::new(dataset, size) + } - match state.deref_mut() { - SamplerState::WithReplacement(rng) => rng.sample(Uniform::new(0, self.dataset.len())), - SamplerState::WithoutReplacement(rng, indices) => { - if indices.is_empty() { - // Refill the state. - *indices = (0..self.dataset.len()).choose_multiple(rng, self.dataset.len()); + /// Creates a new sampler dataset without replacement. + pub fn without_replacement(dataset: D, size: usize) -> Self { + Self { + dataset, + size, + state: Mutex::new(SamplerState::WithoutReplacement( + StdRng::from_entropy(), + Vec::new(), + )), + input: PhantomData, } + } - indices.pop().expect("Indices are refilled when empty.") - } + fn index(&self) -> usize { + let mut state = self.state.lock().unwrap(); + + match state.deref_mut() { + SamplerState::WithReplacement(rng) => rng.sample(Uniform::new(0, self.dataset.len())), + SamplerState::WithoutReplacement(rng, indices) => { + if indices.is_empty() { + // Refill the state. + *indices = (0..self.dataset.len()).choose_multiple(rng, self.dataset.len()); + } + + indices.pop().expect("Indices are refilled when empty.") + } + } } - } } impl Dataset for SamplerDataset where - D: Dataset, - I: Send + Sync, + D: Dataset, + I: Send + Sync, { - fn get(&self, index: usize) -> Option { - if index >= self.size { - return None; - } + fn get(&self, index: usize) -> Option { + if index >= self.size { + return None; + } - self.dataset.get(self.index()) - } + self.dataset.get(self.index()) + } - fn len(&self) -> usize { - self.size - } + fn len(&self) -> usize { + self.size + } } #[cfg(test)] mod tests { - use super::*; - use crate::FakeDataset; - use std::collections::HashMap; - - #[test] - fn sampler_dataset_with_replacement_iter() { - let factor = 3; - let len_original = 10; - let dataset_sampler = SamplerDataset::with_replacement( - FakeDataset::::new(len_original), - len_original * factor, - ); - let mut total = 0; - - for _item in dataset_sampler.iter() { - total += 1; - } + use super::*; + use crate::FakeDataset; + use std::collections::HashMap; + + #[test] + fn sampler_dataset_with_replacement_iter() { + let factor = 3; + let len_original = 10; + let dataset_sampler = SamplerDataset::with_replacement( + FakeDataset::::new(len_original), + len_original * factor, + ); + let mut total = 0; + + for _item in dataset_sampler.iter() { + total += 1; + } - assert_eq!(total, factor * len_original); - } - - #[test] - fn sampler_dataset_without_replacement_bucket_test() { - let factor = 3; - let len_original = 10; - let dataset_sampler = SamplerDataset::without_replacement( - FakeDataset::::new(len_original), - len_original * factor, - ); - let mut buckets = HashMap::new(); - - for item in dataset_sampler.iter() { - let count = match buckets.get(&item) { - Some(count) => count + 1, - None => 1, - }; - - buckets.insert(item, count); + assert_eq!(total, factor * len_original); } - let mut total = 0; - for count in buckets.into_values() { - assert_eq!(count, factor); - total += count; + #[test] + fn sampler_dataset_without_replacement_bucket_test() { + let factor = 3; + let len_original = 10; + let dataset_sampler = SamplerDataset::without_replacement( + FakeDataset::::new(len_original), + len_original * factor, + ); + let mut buckets = HashMap::new(); + + for item in dataset_sampler.iter() { + let count = match buckets.get(&item) { + Some(count) => count + 1, + None => 1, + }; + + buckets.insert(item, count); + } + + let mut total = 0; + for count in buckets.into_values() { + assert_eq!(count, factor); + total += count; + } + assert_eq!(total, factor * len_original); } - assert_eq!(total, factor * len_original); - } } diff --git a/burn-derive/src/config/analyzer.rs b/burn-derive/src/config/analyzer.rs index af55e06b1e..e5e628585c 100644 --- a/burn-derive/src/config/analyzer.rs +++ b/burn-derive/src/config/analyzer.rs @@ -8,80 +8,80 @@ use syn::{Field, Ident}; pub struct ConfigAnalyzerFactory {} pub trait ConfigAnalyzer { - fn gen_new_fn(&self) -> TokenStream { - quote! {} - } - fn gen_builder_fns(&self) -> TokenStream { - quote! {} - } - fn gen_serde_impl(&self) -> TokenStream; - fn gen_clone_impl(&self) -> TokenStream; - fn gen_display_impl(&self) -> TokenStream; - fn gen_config_impl(&self) -> TokenStream; + fn gen_new_fn(&self) -> TokenStream { + quote! {} + } + fn gen_builder_fns(&self) -> TokenStream { + quote! {} + } + fn gen_serde_impl(&self) -> TokenStream; + fn gen_clone_impl(&self) -> TokenStream; + fn gen_display_impl(&self) -> TokenStream; + fn gen_config_impl(&self) -> TokenStream; } impl ConfigAnalyzerFactory { - pub fn new() -> Self { - Self {} - } + pub fn new() -> Self { + Self {} + } - pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box { - let name = item.ident.clone(); - let config_type = parse_asm(item); + pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box { + let name = item.ident.clone(); + let config_type = parse_asm(item); - match config_type { - ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)), - ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)), + match config_type { + ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)), + ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)), + } } - } - fn create_struct_analyzer(&self, name: Ident, fields: Vec) -> ConfigStructAnalyzer { - let fields = fields.into_iter().map(FieldTypeAnalyzer::new); + fn create_struct_analyzer(&self, name: Ident, fields: Vec) -> ConfigStructAnalyzer { + let fields = fields.into_iter().map(FieldTypeAnalyzer::new); - let mut fields_required = Vec::new(); - let mut fields_option = Vec::new(); - let mut fields_default = Vec::new(); + let mut fields_required = Vec::new(); + let mut fields_option = Vec::new(); + let mut fields_default = Vec::new(); - for field in fields { - let attributes: Vec = field - .attributes() - .filter(|attr| attr.has_name("config")) - .map(|attr| attr.item()) - .collect(); + for field in fields { + let attributes: Vec = field + .attributes() + .filter(|attr| attr.has_name("config")) + .map(|attr| attr.item()) + .collect(); - if !attributes.is_empty() { - let item = attributes.first().unwrap().clone(); - fields_default.push((field.clone(), item)); - continue; - } + if !attributes.is_empty() { + let item = attributes.first().unwrap().clone(); + fields_default.push((field.clone(), item)); + continue; + } - if field.is_of_type(&["Option"]) { - fields_option.push(field.clone()); - continue; - } + if field.is_of_type(&["Option"]) { + fields_option.push(field.clone()); + continue; + } - fields_required.push(field.clone()); - } + fields_required.push(field.clone()); + } - ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default) - } + ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default) + } - fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer { - ConfigEnumAnalyzer::new(name, data) - } + fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer { + ConfigEnumAnalyzer::new(name, data) + } } enum ConfigType { - Struct(Vec), - Enum(syn::DataEnum), + Struct(Vec), + Enum(syn::DataEnum), } fn parse_asm(ast: &syn::DeriveInput) -> ConfigType { - match &ast.data { - syn::Data::Struct(struct_data) => { - ConfigType::Struct(struct_data.fields.clone().into_iter().collect()) + match &ast.data { + syn::Data::Struct(struct_data) => { + ConfigType::Struct(struct_data.fields.clone().into_iter().collect()) + } + syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()), + syn::Data::Union(_) => panic!("Only struct and enum can be derived"), } - syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()), - syn::Data::Union(_) => panic!("Only struct and enum can be derived"), - } } diff --git a/burn-derive/src/config/analyzer_enum.rs b/burn-derive/src/config/analyzer_enum.rs index 9c817bc2a4..2f7e2347b1 100644 --- a/burn-derive/src/config/analyzer_enum.rs +++ b/burn-derive/src/config/analyzer_enum.rs @@ -4,174 +4,174 @@ use quote::quote; use syn::{FieldsNamed, Variant}; pub struct ConfigEnumAnalyzer { - name: Ident, - data: syn::DataEnum, + name: Ident, + data: syn::DataEnum, } impl ConfigEnumAnalyzer { - pub fn new(name: Ident, data: syn::DataEnum) -> Self { - Self { name, data } - } - - fn serde_enum_ident(&self) -> Ident { - Ident::new(&format!("{}Serde", self.name), self.name.span()) - } - - fn gen_serde_enum(&self) -> TokenStream { - let enum_name = self.serde_enum_ident(); - let data = &self.data.variants; - - quote! { - #[derive(serde::Serialize, serde::Deserialize)] - enum #enum_name { - #data - } + pub fn new(name: Ident, data: syn::DataEnum) -> Self { + Self { name, data } + } + fn serde_enum_ident(&self) -> Ident { + Ident::new(&format!("{}Serde", self.name), self.name.span()) } - } - fn gen_variant_field(&self, variant: &Variant) -> (TokenStream, TokenStream) { - let gen_fields_unnamed = |num: usize| { - let mut input = Vec::new(); - let mut output = Vec::new(); + fn gen_serde_enum(&self) -> TokenStream { + let enum_name = self.serde_enum_ident(); + let data = &self.data.variants; - for i in 0..num { - let arg_name = Ident::new(&format!("arg_{i}"), self.name.span()); + quote! { + #[derive(serde::Serialize, serde::Deserialize)] + enum #enum_name { + #data + } - input.push(quote! { #arg_name }); - output.push(quote! { #arg_name.clone() }); - } + } + } - (quote! (( #(#input),* )), quote! (( #(#output),* ))) - }; - let gen_fields_named = |fields: &FieldsNamed| { - let mut input = Vec::new(); - let mut output = Vec::new(); + fn gen_variant_field(&self, variant: &Variant) -> (TokenStream, TokenStream) { + let gen_fields_unnamed = |num: usize| { + let mut input = Vec::new(); + let mut output = Vec::new(); - fields.named.iter().for_each(|field| { - let ident = &field.ident; + for i in 0..num { + let arg_name = Ident::new(&format!("arg_{i}"), self.name.span()); - input.push(quote! { - #ident - }); - output.push(quote! { - #ident: #ident.clone() - }); - }); - - (quote! {{ #(#input),* }}, quote! {{ #(#output),* }}) - }; + input.push(quote! { #arg_name }); + output.push(quote! { #arg_name.clone() }); + } - match &variant.fields { - syn::Fields::Named(fields) => gen_fields_named(fields), - syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()), - syn::Fields::Unit => (quote! {}, quote! {}), + (quote! (( #(#input),* )), quote! (( #(#output),* ))) + }; + let gen_fields_named = |fields: &FieldsNamed| { + let mut input = Vec::new(); + let mut output = Vec::new(); + + fields.named.iter().for_each(|field| { + let ident = &field.ident; + + input.push(quote! { + #ident + }); + output.push(quote! { + #ident: #ident.clone() + }); + }); + + (quote! {{ #(#input),* }}, quote! {{ #(#output),* }}) + }; + + match &variant.fields { + syn::Fields::Named(fields) => gen_fields_named(fields), + syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()), + syn::Fields::Unit => (quote! {}, quote! {}), + } } - } - fn gen_serialize_fn(&self) -> TokenStream { - let enum_name = self.serde_enum_ident(); - let variants = self.data.variants.iter().map(|variant| { + fn gen_serialize_fn(&self) -> TokenStream { + let enum_name = self.serde_enum_ident(); + let variants = self.data.variants.iter().map(|variant| { let variant_name = &variant.ident; let (variant_input, variant_output) = self.gen_variant_field(variant); quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output } }); - let name = &self.name; - - quote! { - impl serde::Serialize for #name { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer { - let serde_state = match self { - #(#variants),* - }; - serde_state.serialize(serializer) + let name = &self.name; + + quote! { + impl serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer { + let serde_state = match self { + #(#variants),* + }; + serde_state.serialize(serializer) + } } - } + } } - } - fn gen_deserialize_fn(&self) -> TokenStream { - let enum_name = self.serde_enum_ident(); - let variants = self.data.variants.iter().map(|variant| { + fn gen_deserialize_fn(&self) -> TokenStream { + let enum_name = self.serde_enum_ident(); + let variants = self.data.variants.iter().map(|variant| { let variant_name = &variant.ident; let (variant_input, variant_output) = self.gen_variant_field(variant); quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output } }); - let name = &self.name; - - quote! { - impl<'de> serde::Deserialize<'de> for #name { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de> { - let serde_state = #enum_name::deserialize(deserializer)?; - Ok(match serde_state { - #(#variants),* - }) + let name = &self.name; + + quote! { + impl<'de> serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + let serde_state = #enum_name::deserialize(deserializer)?; + Ok(match serde_state { + #(#variants),* + }) + } } - } + } } - } } impl ConfigAnalyzer for ConfigEnumAnalyzer { - fn gen_serde_impl(&self) -> TokenStream { - let struct_gen = self.gen_serde_enum(); - let serialize_gen = self.gen_serialize_fn(); - let deserialize_gen = self.gen_deserialize_fn(); - - quote! { - #struct_gen - #serialize_gen - #deserialize_gen + fn gen_serde_impl(&self) -> TokenStream { + let struct_gen = self.gen_serde_enum(); + let serialize_gen = self.gen_serialize_fn(); + let deserialize_gen = self.gen_deserialize_fn(); + + quote! { + #struct_gen + #serialize_gen + #deserialize_gen + } } - } - fn gen_clone_impl(&self) -> TokenStream { - let variants = self.data.variants.iter().map(|variant| { - let variant_name = &variant.ident; - let (variant_input, variant_output) = self.gen_variant_field(variant); + fn gen_clone_impl(&self) -> TokenStream { + let variants = self.data.variants.iter().map(|variant| { + let variant_name = &variant.ident; + let (variant_input, variant_output) = self.gen_variant_field(variant); - quote! { Self::#variant_name #variant_input => Self::#variant_name #variant_output } - }); - let name = &self.name; - - quote! { - impl Clone for #name { - fn clone(&self) -> Self { - match self { - #(#variants),* + quote! { Self::#variant_name #variant_input => Self::#variant_name #variant_output } + }); + let name = &self.name; + + quote! { + impl Clone for #name { + fn clone(&self) -> Self { + match self { + #(#variants),* + } } } - } + } } - } - fn gen_display_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_display_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl core::fmt::Display for #name { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(&burn::config::config_to_json(self)) + quote! { + impl core::fmt::Display for #name { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(&burn::config::config_to_json(self)) + } } } } - } - fn gen_config_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_config_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl burn::config::Config for #name { + quote! { + impl burn::config::Config for #name { + } } } - } } diff --git a/burn-derive/src/config/analyzer_struct.rs b/burn-derive/src/config/analyzer_struct.rs index 699bc49661..18ec62c169 100644 --- a/burn-derive/src/config/analyzer_struct.rs +++ b/burn-derive/src/config/analyzer_struct.rs @@ -4,294 +4,294 @@ use proc_macro2::{Ident, TokenStream}; use quote::quote; pub struct ConfigStructAnalyzer { - name: Ident, - fields_required: Vec, - fields_option: Vec, - fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, -} - -impl ConfigStructAnalyzer { - pub fn new( name: Ident, fields_required: Vec, fields_option: Vec, fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, - ) -> Self { - Self { - name, - fields_required, - fields_option, - fields_default, +} + +impl ConfigStructAnalyzer { + pub fn new( + name: Ident, + fields_required: Vec, + fields_option: Vec, + fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, + ) -> Self { + Self { + name, + fields_required, + fields_option, + fields_default, + } } - } - fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream { - let name = &self.name; + fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream { + let name = &self.name; - quote! { - impl #name { - #tokens + quote! { + impl #name { + #tokens + } } } - } - fn names(&self) -> Vec { - let mut names = Vec::new(); + fn names(&self) -> Vec { + let mut names = Vec::new(); - for field in self.fields_required.iter() { - names.push(field.clone()); - } + for field in self.fields_required.iter() { + names.push(field.clone()); + } - for field in self.fields_option.iter() { - names.push(field.clone()); - } + for field in self.fields_option.iter() { + names.push(field.clone()); + } + + for (field, _) in self.fields_default.iter() { + names.push(field.clone()); + } - for (field, _) in self.fields_default.iter() { - names.push(field.clone()); + names } - names - } + fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec { + let mut name_types = Vec::new(); - fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec { - let mut name_types = Vec::new(); + for field in names.iter() { + let name = field.ident(); + let ty = &field.field.ty; - for field in names.iter() { - let name = field.ident(); - let ty = &field.field.ty; + name_types.push(quote! { + #name: #ty + }); + } + + name_types + } - name_types.push(quote! { - #name: #ty - }); + fn serde_struct_ident(&self) -> Ident { + Ident::new(&format!("{}Serde", self.name), self.name.span()) } - name_types - } - - fn serde_struct_ident(&self) -> Ident { - Ident::new(&format!("{}Serde", self.name), self.name.span()) - } - - fn gen_serialize_fn( - &self, - struct_name: &Ident, - struct_gen: &TokenStream, - names: &[FieldTypeAnalyzer], - ) -> TokenStream { - let name = &self.name; - let names = names.iter().map(|name| { - let name = name.ident(); - quote! { #name: self.#name.clone() } - }); - - quote! { - impl serde::Serialize for #name { - - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer { - #[derive(serde::Serialize)] - #struct_gen - - let serde_state = #struct_name { - #(#names),* - }; - serde_state.serialize(serializer) + fn gen_serialize_fn( + &self, + struct_name: &Ident, + struct_gen: &TokenStream, + names: &[FieldTypeAnalyzer], + ) -> TokenStream { + let name = &self.name; + let names = names.iter().map(|name| { + let name = name.ident(); + quote! { #name: self.#name.clone() } + }); + + quote! { + impl serde::Serialize for #name { + + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer { + #[derive(serde::Serialize)] + #struct_gen + + let serde_state = #struct_name { + #(#names),* + }; + serde_state.serialize(serializer) + } } - } + } } - } - - fn gen_deserialize_fn( - &self, - struct_name: &Ident, - struct_gen: &TokenStream, - names: &[FieldTypeAnalyzer], - ) -> TokenStream { - let name = &self.name; - let names = names.iter().map(|name| { - let name = name.ident(); - quote! { #name: serde_state.#name } - }); - - quote! { - impl<'de> serde::Deserialize<'de> for #name { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de> { - #[derive(serde::Deserialize)] - #struct_gen - - let serde_state = #struct_name::deserialize(deserializer)?; - Ok(#name { - #(#names),* - }) + + fn gen_deserialize_fn( + &self, + struct_name: &Ident, + struct_gen: &TokenStream, + names: &[FieldTypeAnalyzer], + ) -> TokenStream { + let name = &self.name; + let names = names.iter().map(|name| { + let name = name.ident(); + quote! { #name: serde_state.#name } + }); + + quote! { + impl<'de> serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + #[derive(serde::Deserialize)] + #struct_gen + + let serde_state = #struct_name::deserialize(deserializer)?; + Ok(#name { + #(#names),* + }) + } } - } + } } - } - fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream { - let struct_name = self.serde_struct_ident(); + fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream { + let struct_name = self.serde_struct_ident(); - quote! { - struct #struct_name { - #(#names),* - } + quote! { + struct #struct_name { + #(#names),* + } + } } - } } impl ConfigAnalyzer for ConfigStructAnalyzer { - fn gen_new_fn(&self) -> TokenStream { - let mut body = quote! {}; - let mut names = Vec::new(); - - for field in self.fields_required.iter() { - let name = field.ident(); - let ty = &field.field.ty; - - body.extend(quote! { - #name: #name, - }); - names.push(quote! { - #name: #ty - }); - } + fn gen_new_fn(&self) -> TokenStream { + let mut body = quote! {}; + let mut names = Vec::new(); + + for field in self.fields_required.iter() { + let name = field.ident(); + let ty = &field.field.ty; + + body.extend(quote! { + #name: #name, + }); + names.push(quote! { + #name: #ty + }); + } - for field in self.fields_option.iter() { - let name = field.ident(); + for field in self.fields_option.iter() { + let name = field.ident(); - body.extend(quote! { - #name: None, - }); - } + body.extend(quote! { + #name: None, + }); + } - for (field, attribute) in self.fields_default.iter() { - let name = field.ident(); - let value = &attribute.value; - match value { - syn::Lit::Str(value) => { - let stream: proc_macro2::TokenStream = value.value().parse().unwrap(); + for (field, attribute) in self.fields_default.iter() { + let name = field.ident(); + let value = &attribute.value; + match value { + syn::Lit::Str(value) => { + let stream: proc_macro2::TokenStream = value.value().parse().unwrap(); - body.extend(quote! { - #name: #stream, - }); - } - _ => { - body.extend(quote! { - #name: #value, - }); + body.extend(quote! { + #name: #stream, + }); + } + _ => { + body.extend(quote! { + #name: #value, + }); + } + }; } - }; + + let body = quote! { + /// Create a new instance of the config. + pub fn new( + #(#names),* + ) -> Self { + Self { #body } + } + }; + self.wrap_impl_block(body) } - let body = quote! { - /// Create a new instance of the config. - pub fn new( - #(#names),* - ) -> Self { - Self { #body } + fn gen_builder_fns(&self) -> TokenStream { + let mut body = quote! {}; + + for (field, _) in self.fields_default.iter() { + let name = field.ident(); + let doc = field.doc().unwrap_or_else(|| { + quote! { + /// Set the default value for the field. + } + }); + let ty = &field.field.ty; + let fn_name = Ident::new(&format!("with_{name}"), name.span()); + + body.extend(quote! { + #doc + pub fn #fn_name(mut self, #name: #ty) -> Self { + self.#name = #name; + self + } + }); } - }; - self.wrap_impl_block(body) - } - fn gen_builder_fns(&self) -> TokenStream { - let mut body = quote! {}; + for field in self.fields_option.iter() { + let name = field.ident(); + let ty = &field.field.ty; + let fn_name = Ident::new(&format!("with_{name}"), name.span()); - for (field, _) in self.fields_default.iter() { - let name = field.ident(); - let doc = field.doc().unwrap_or_else(|| { - quote! { + body.extend(quote! { /// Set the default value for the field. + pub fn #fn_name(mut self, #name: #ty) -> Self { + self.#name = #name; + self + } + }); } - }); - let ty = &field.field.ty; - let fn_name = Ident::new(&format!("with_{name}"), name.span()); - - body.extend(quote! { - #doc - pub fn #fn_name(mut self, #name: #ty) -> Self { - self.#name = #name; - self - } - }); - } - for field in self.fields_option.iter() { - let name = field.ident(); - let ty = &field.field.ty; - let fn_name = Ident::new(&format!("with_{name}"), name.span()); - - body.extend(quote! { - /// Set the default value for the field. - pub fn #fn_name(mut self, #name: #ty) -> Self { - self.#name = #name; - self - } - }); + self.wrap_impl_block(body) } - self.wrap_impl_block(body) - } - - fn gen_serde_impl(&self) -> TokenStream { - let names = self.names(); + fn gen_serde_impl(&self) -> TokenStream { + let names = self.names(); - let struct_name = self.serde_struct_ident(); - let name_types = self.name_types(&names); - let struct_gen = self.gen_serde_struct(&name_types); + let struct_name = self.serde_struct_ident(); + let name_types = self.name_types(&names); + let struct_gen = self.gen_serde_struct(&name_types); - let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names); - let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names); + let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names); + let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names); - quote! { - #serialize_gen - #deserialize_gen + quote! { + #serialize_gen + #deserialize_gen + } } - } - - fn gen_clone_impl(&self) -> TokenStream { - let name = &self.name; - let names = self.names().into_iter().map(|name| { - let name = name.ident(); - quote! { #name: self.#name.clone() } - }); - - quote! { - impl Clone for #name { - fn clone(&self) -> Self { - Self { - #(#names),* + + fn gen_clone_impl(&self) -> TokenStream { + let name = &self.name; + let names = self.names().into_iter().map(|name| { + let name = name.ident(); + quote! { #name: self.#name.clone() } + }); + + quote! { + impl Clone for #name { + fn clone(&self) -> Self { + Self { + #(#names),* + } } } - } + } } - } - fn gen_display_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_display_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl core::fmt::Display for #name { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(&burn::config::config_to_json(self)) + quote! { + impl core::fmt::Display for #name { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(&burn::config::config_to_json(self)) + } } } } - } - fn gen_config_impl(&self) -> TokenStream { - let name = &self.name; + fn gen_config_impl(&self) -> TokenStream { + let name = &self.name; - quote! { - impl burn::config::Config for #name { + quote! { + impl burn::config::Config for #name { + } } } - } } diff --git a/burn-derive/src/config/base.rs b/burn-derive/src/config/base.rs index c1a45e28f4..cca3f0430b 100644 --- a/burn-derive/src/config/base.rs +++ b/burn-derive/src/config/base.rs @@ -2,23 +2,23 @@ use super::ConfigAnalyzerFactory; use quote::quote; pub(crate) fn derive_impl(item: &syn::DeriveInput) -> proc_macro::TokenStream { - let factory = ConfigAnalyzerFactory::new(); - let analyzer = factory.create_analyzer(item); + let factory = ConfigAnalyzerFactory::new(); + let analyzer = factory.create_analyzer(item); - let constructor = analyzer.gen_new_fn(); - let builders = analyzer.gen_builder_fns(); - let serde = analyzer.gen_serde_impl(); - let clone = analyzer.gen_clone_impl(); - let display = analyzer.gen_display_impl(); - let config_impl = analyzer.gen_config_impl(); + let constructor = analyzer.gen_new_fn(); + let builders = analyzer.gen_builder_fns(); + let serde = analyzer.gen_serde_impl(); + let clone = analyzer.gen_clone_impl(); + let display = analyzer.gen_display_impl(); + let config_impl = analyzer.gen_config_impl(); - quote! { - #config_impl - #constructor - #builders - #serde - #clone - #display - } - .into() + quote! { + #config_impl + #constructor + #builders + #serde + #clone + #display + } + .into() } diff --git a/burn-derive/src/lib.rs b/burn-derive/src/lib.rs index 35a5e22f71..91f254b88f 100644 --- a/burn-derive/src/lib.rs +++ b/burn-derive/src/lib.rs @@ -15,20 +15,20 @@ pub(crate) mod shared; /// Derive macro for the module. #[proc_macro_derive(Module)] pub fn module_derive(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - module::derive_impl(&input) + let input = syn::parse(input).unwrap(); + module::derive_impl(&input) } /// Derive macro for the record. #[proc_macro_derive(Record)] pub fn record_derive(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - record::derive_impl(&input) + let input = syn::parse(input).unwrap(); + record::derive_impl(&input) } /// Derive macro for the config. #[proc_macro_derive(Config, attributes(config))] pub fn config_derive(input: TokenStream) -> TokenStream { - let item = syn::parse(input).unwrap(); - config::derive_impl(&item) + let item = syn::parse(input).unwrap(); + config::derive_impl(&item) } diff --git a/burn-derive/src/module/base.rs b/burn-derive/src/module/base.rs index 9536a7adc7..1cf41bb5ed 100644 --- a/burn-derive/src/module/base.rs +++ b/burn-derive/src/module/base.rs @@ -1,6 +1,6 @@ use super::{ - codegen::ModuleCodegen, codegen_struct::StructModuleCodegen, record::ModuleRecordCodegen, - record_struct::StructModuleRecordCodegen, + codegen::ModuleCodegen, codegen_struct::StructModuleCodegen, record::ModuleRecordCodegen, + record_struct::StructModuleRecordCodegen, }; use crate::module::display; use proc_macro::TokenStream; @@ -8,149 +8,149 @@ use quote::quote; use syn::{parse_quote, Ident}; pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; - let has_backend = ast - .generics - .type_params() - .map(|param| param.ident == "B") - .reduce(|accum, is_backend| is_backend || accum) - .unwrap_or(false); - - if !has_backend { - return constant_impl(ast); - } - - let (generics, generics_ty, generics_where) = ast.generics.split_for_impl(); - let backend_trait = fetch_backend_trait(&ast.generics); - - let display_fn = display::display_fn(name); - - let generator = StructModuleCodegen::from_ast(ast); - let num_params_fn = generator.gen_num_params(); - let visit = generator.gen_visit(); - let map_mut = generator.gen_map(); - let valid_fn = generator.gen_valid(); - let into_record_fn = generator.gen_into_record(); - let load_record_fn = generator.gen_load_record(); - let clone_fn = generator.gen_clone(); - let generics_names_except_backend = generics_names_except_backend(&ast.generics); - - let record_name = Ident::new(format!("{}Record", name).as_str(), name.span()); - let record_gen = StructModuleRecordCodegen::new(generator.fields); - let record_struct = record_gen.gen_record_type(&record_name, &ast.generics); - - let gen = quote! { - impl #generics burn::module::Module for #name #generics_ty #generics_where { - type Record = #record_name #generics_ty; - - #load_record_fn - #into_record_fn - - #num_params_fn - - #visit - #map_mut - } - - impl #generics burn::module::AutodiffModule for #name #generics_ty - where - B: burn::tensor::backend::AutodiffBackend, - ::InnerBackend: #backend_trait, - { - type InnerModule=#name; - - #valid_fn - } - - impl #generics core::fmt::Display for #name #generics_ty #generics_where { - #display_fn - } - - impl #generics Clone for #name #generics_ty #generics_where { - #clone_fn - } - - #record_struct - }; - - gen.into() + let name = &ast.ident; + let has_backend = ast + .generics + .type_params() + .map(|param| param.ident == "B") + .reduce(|accum, is_backend| is_backend || accum) + .unwrap_or(false); + + if !has_backend { + return constant_impl(ast); + } + + let (generics, generics_ty, generics_where) = ast.generics.split_for_impl(); + let backend_trait = fetch_backend_trait(&ast.generics); + + let display_fn = display::display_fn(name); + + let generator = StructModuleCodegen::from_ast(ast); + let num_params_fn = generator.gen_num_params(); + let visit = generator.gen_visit(); + let map_mut = generator.gen_map(); + let valid_fn = generator.gen_valid(); + let into_record_fn = generator.gen_into_record(); + let load_record_fn = generator.gen_load_record(); + let clone_fn = generator.gen_clone(); + let generics_names_except_backend = generics_names_except_backend(&ast.generics); + + let record_name = Ident::new(format!("{}Record", name).as_str(), name.span()); + let record_gen = StructModuleRecordCodegen::new(generator.fields); + let record_struct = record_gen.gen_record_type(&record_name, &ast.generics); + + let gen = quote! { + impl #generics burn::module::Module for #name #generics_ty #generics_where { + type Record = #record_name #generics_ty; + + #load_record_fn + #into_record_fn + + #num_params_fn + + #visit + #map_mut + } + + impl #generics burn::module::AutodiffModule for #name #generics_ty + where + B: burn::tensor::backend::AutodiffBackend, + ::InnerBackend: #backend_trait, + { + type InnerModule=#name; + + #valid_fn + } + + impl #generics core::fmt::Display for #name #generics_ty #generics_where { + #display_fn + } + + impl #generics Clone for #name #generics_ty #generics_where { + #clone_fn + } + + #record_struct + }; + + gen.into() } // When there is no backend in the generic parameter, the struct is considered as a constant. fn constant_impl(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; - let (_, generics_ty, generics_where) = ast.generics.split_for_impl(); - - let backend: syn::Generics = parse_quote! { }; - let backend_ad: syn::Generics = parse_quote! { }; - - let mut generics_module = ast.generics.clone(); - let mut generics_module_ad = ast.generics.clone(); - - for param in backend.params.into_iter() { - generics_module.params.push(param); - } - for param in backend_ad.params.into_iter() { - generics_module_ad.params.push(param); - } - let (generics_module, _, _) = generics_module.split_for_impl(); - let (generics_module_ad, _, _) = generics_module_ad.split_for_impl(); - - let gen = quote! { - impl #generics_module burn::module::Module for #name #generics_ty #generics_where { - burn::constant!(module); - } - - impl #generics_module_ad burn::module::AutodiffModule for #name #generics_ty #generics_where { - burn::constant!(ad_module, #name #generics_ty); - } - }; - - gen.into() + let name = &ast.ident; + let (_, generics_ty, generics_where) = ast.generics.split_for_impl(); + + let backend: syn::Generics = parse_quote! { }; + let backend_ad: syn::Generics = parse_quote! { }; + + let mut generics_module = ast.generics.clone(); + let mut generics_module_ad = ast.generics.clone(); + + for param in backend.params.into_iter() { + generics_module.params.push(param); + } + for param in backend_ad.params.into_iter() { + generics_module_ad.params.push(param); + } + let (generics_module, _, _) = generics_module.split_for_impl(); + let (generics_module_ad, _, _) = generics_module_ad.split_for_impl(); + + let gen = quote! { + impl #generics_module burn::module::Module for #name #generics_ty #generics_where { + burn::constant!(module); + } + + impl #generics_module_ad burn::module::AutodiffModule for #name #generics_ty #generics_where { + burn::constant!(ad_module, #name #generics_ty); + } + }; + + gen.into() } fn fetch_backend_trait(generics: &syn::Generics) -> proc_macro2::TokenStream { - static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str = "Modules should be generic over a backend. + static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str = "Modules should be generic over a backend. - The generic argument named `B` should have its first trait bound being a backend trait. - The default backend trait is `burn::tensor::backend::Backend`. - Any backend trait is supported."; - for param in generics.params.iter() { - if let syn::GenericParam::Type(ty) = ¶m { - if ty.ident == "B" { - let bound = ty - .bounds - .first() - .expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG); - - return quote! { - #bound - }; - } + for param in generics.params.iter() { + if let syn::GenericParam::Type(ty) = ¶m { + if ty.ident == "B" { + let bound = ty + .bounds + .first() + .expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG); + + return quote! { + #bound + }; + } + } } - } - panic!("{BACKEND_TRAIT_COMPILATION_ERROR_MSG}"); + panic!("{BACKEND_TRAIT_COMPILATION_ERROR_MSG}"); } fn generics_names_except_backend(generics: &syn::Generics) -> proc_macro2::TokenStream { - let mut named = quote! {}; - - generics.params.iter().for_each(|param| { - match param { - syn::GenericParam::Type(ty) => { - if ty.ident != "B" { - let ident = &ty.ident; - named.extend(quote! { #ident, }); - } - } - syn::GenericParam::Lifetime(_) => panic!("Lifetime not supported in module"), - syn::GenericParam::Const(c) => { - let ident = &c.ident; - named.extend(quote! { #ident, }); - } - }; - }); + let mut named = quote! {}; + + generics.params.iter().for_each(|param| { + match param { + syn::GenericParam::Type(ty) => { + if ty.ident != "B" { + let ident = &ty.ident; + named.extend(quote! { #ident, }); + } + } + syn::GenericParam::Lifetime(_) => panic!("Lifetime not supported in module"), + syn::GenericParam::Const(c) => { + let ident = &c.ident; + named.extend(quote! { #ident, }); + } + }; + }); - named + named } diff --git a/burn-derive/src/module/codegen.rs b/burn-derive/src/module/codegen.rs index 2a7d52cedf..852138018c 100644 --- a/burn-derive/src/module/codegen.rs +++ b/burn-derive/src/module/codegen.rs @@ -2,11 +2,11 @@ use proc_macro2::TokenStream; /// Basic trait to be implemented for Module generation. pub(crate) trait ModuleCodegen { - fn gen_num_params(&self) -> TokenStream; - fn gen_visit(&self) -> TokenStream; - fn gen_map(&self) -> TokenStream; - fn gen_valid(&self) -> TokenStream; - fn gen_into_record(&self) -> TokenStream; - fn gen_load_record(&self) -> TokenStream; - fn gen_clone(&self) -> TokenStream; + fn gen_num_params(&self) -> TokenStream; + fn gen_visit(&self) -> TokenStream; + fn gen_map(&self) -> TokenStream; + fn gen_valid(&self) -> TokenStream; + fn gen_into_record(&self) -> TokenStream; + fn gen_load_record(&self) -> TokenStream; + fn gen_clone(&self) -> TokenStream; } diff --git a/burn-derive/src/module/codegen_struct.rs b/burn-derive/src/module/codegen_struct.rs index 1f163b413a..a6988b2bd4 100644 --- a/burn-derive/src/module/codegen_struct.rs +++ b/burn-derive/src/module/codegen_struct.rs @@ -5,164 +5,164 @@ use quote::quote; use super::codegen::ModuleCodegen; pub(crate) struct StructModuleCodegen { - pub fields: Vec, + pub fields: Vec, } impl ModuleCodegen for StructModuleCodegen { - fn gen_num_params(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - num_params += burn::module::Module::::num_params(&self.#name); - } - }); - - quote! { - fn num_params(&self) -> usize { - let mut num_params = 0; - #body - num_params + fn gen_num_params(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + num_params += burn::module::Module::::num_params(&self.#name); + } + }); + + quote! { + fn num_params(&self) -> usize { + let mut num_params = 0; + #body + num_params + } } } - } - - fn gen_visit(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - burn::module::Module::visit(&self.#name, visitor); - } - }); - - quote! { - fn visit>(&self, visitor: &mut V) { - #body + + fn gen_visit(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + burn::module::Module::visit(&self.#name, visitor); + } + }); + + quote! { + fn visit>(&self, visitor: &mut V) { + #body + } } } - } - fn gen_map(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = burn::module::Module::map(self.#name, mapper); - } - }); + fn gen_map(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = burn::module::Module::map(self.#name, mapper); + } + }); - quote! { - fn map>(self, mapper: &mut M) -> Self { - #body + quote! { + fn map>(self, mapper: &mut M) -> Self { + #body - Self { - #(#names),* + Self { + #(#names),* + } } } } - } - fn gen_valid(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = burn::module::AutodiffModule::::valid(&self.#name); - } - }); + fn gen_valid(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = burn::module::AutodiffModule::::valid(&self.#name); + } + }); - quote! { - fn valid(&self) -> Self::InnerModule { - #body + quote! { + fn valid(&self) -> Self::InnerModule { + #body - Self::InnerModule { - #(#names),* + Self::InnerModule { + #(#names),* + } } } } - } - - fn gen_into_record(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - #name: burn::module::Module::::into_record(self.#name), - } - }); - - quote! { - fn into_record(self) -> Self::Record { - Self::Record { - #body + + fn gen_into_record(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + #name: burn::module::Module::::into_record(self.#name), + } + }); + + quote! { + fn into_record(self) -> Self::Record { + Self::Record { + #body + } } } } - } - - fn gen_load_record(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - #name: burn::module::Module::::load_record(self.#name, record.#name), - } - }); - - quote! { - fn load_record(self, record: Self::Record) -> Self { - Self { - #body + + fn gen_load_record(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + #name: burn::module::Module::::load_record(self.#name, record.#name), + } + }); + + quote! { + fn load_record(self, record: Self::Record) -> Self { + Self { + #body + } } } } - } - fn gen_clone(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = self.#name.clone(); - } - }); + fn gen_clone(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = self.#name.clone(); + } + }); - quote! { - fn clone(&self) -> Self { - #body + quote! { + fn clone(&self) -> Self { + #body - Self { - #(#names),* + Self { + #(#names),* + } } } } - } } impl StructModuleCodegen { - pub fn from_ast(ast: &syn::DeriveInput) -> Self { - Self { - fields: parse_fields(ast) - .into_iter() - .map(FieldTypeAnalyzer::new) - .collect(), + pub fn from_ast(ast: &syn::DeriveInput) -> Self { + Self { + fields: parse_fields(ast) + .into_iter() + .map(FieldTypeAnalyzer::new) + .collect(), + } } - } - fn gen_fields_fn_names(&self, func: F) -> (Vec, TokenStream) - where - F: Fn(Ident) -> TokenStream, - { - let mut body = quote! {}; - let mut names = Vec::new(); + fn gen_fields_fn_names(&self, func: F) -> (Vec, TokenStream) + where + F: Fn(Ident) -> TokenStream, + { + let mut body = quote! {}; + let mut names = Vec::new(); - for field in self.fields.iter() { - let name = field.ident(); + for field in self.fields.iter() { + let name = field.ident(); - names.push(name.clone()); - body.extend(func(field.ident())); + names.push(name.clone()); + body.extend(func(field.ident())); + } + + (names, body) } - (names, body) - } + fn gen_fields_fn(&self, func: F) -> TokenStream + where + F: Fn(Ident) -> TokenStream, + { + let mut body = quote! {}; - fn gen_fields_fn(&self, func: F) -> TokenStream - where - F: Fn(Ident) -> TokenStream, - { - let mut body = quote! {}; + for field in self.fields.iter() { + body.extend(func(field.ident())); + } - for field in self.fields.iter() { - body.extend(func(field.ident())); + body } - - body - } } diff --git a/burn-derive/src/module/display.rs b/burn-derive/src/module/display.rs index 3e15331ebe..f9c024ff49 100644 --- a/burn-derive/src/module/display.rs +++ b/burn-derive/src/module/display.rs @@ -2,10 +2,10 @@ use proc_macro2::Ident; use quote::quote; pub fn display_fn(name: &Ident) -> proc_macro2::TokenStream { - quote! { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}[num_params={}]", stringify!(#name), self.num_params()) + quote! { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}[num_params={}]", stringify!(#name), self.num_params()) - } - } + } + } } diff --git a/burn-derive/src/module/record.rs b/burn-derive/src/module/record.rs index be05e34f82..72e3cd2872 100644 --- a/burn-derive/src/module/record.rs +++ b/burn-derive/src/module/record.rs @@ -3,6 +3,6 @@ use syn::Generics; /// Basic trait to generate a record type based on the Module struct. pub(crate) trait ModuleRecordCodegen { - /// Generate the record type (i.e a struct) - fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream; + /// Generate the record type (i.e a struct) + fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream; } diff --git a/burn-derive/src/module/record_struct.rs b/burn-derive/src/module/record_struct.rs index 2521e52b03..47c330f96e 100644 --- a/burn-derive/src/module/record_struct.rs +++ b/burn-derive/src/module/record_struct.rs @@ -7,30 +7,30 @@ use super::record::ModuleRecordCodegen; #[derive(new)] pub(crate) struct StructModuleRecordCodegen { - fields: Vec, + fields: Vec, } impl ModuleRecordCodegen for StructModuleRecordCodegen { - fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { - let mut fields = quote! {}; + fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { + let mut fields = quote! {}; - for field in self.fields.iter() { - let ty = &field.field.ty; - let name = &field.field.ident; + for field in self.fields.iter() { + let ty = &field.field.ty; + let name = &field.field.ident; - fields.extend(quote! { - /// The module record associative type. - pub #name: <#ty as burn::module::Module>::Record, - }); - } + fields.extend(quote! { + /// The module record associative type. + pub #name: <#ty as burn::module::Module>::Record, + }); + } - quote! { + quote! { - /// The record type for the module. - #[derive(burn::record::Record, Debug, Clone)] - pub struct #record_name #generics { - #fields + /// The record type for the module. + #[derive(burn::record::Record, Debug, Clone)] + pub struct #record_name #generics { + #fields + } } } - } } diff --git a/burn-derive/src/record/base.rs b/burn-derive/src/record/base.rs index deeaf9551b..e446e961fc 100644 --- a/burn-derive/src/record/base.rs +++ b/burn-derive/src/record/base.rs @@ -6,83 +6,83 @@ use super::{codegen::RecordItemCodegen, codegen_struct::StructRecordItemCodegen} use crate::shared::field::{parse_fields, FieldTypeAnalyzer}; pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream { - let record_gen = RecordDeriveCodegen::from_ast(ast); - let item_struct = record_gen.gen_record_type(); - let record_impl = record_gen.gen_impl_record(); - - quote! { - #item_struct - #record_impl - } - .into() + let record_gen = RecordDeriveCodegen::from_ast(ast); + let item_struct = record_gen.gen_record_type(); + let record_impl = record_gen.gen_impl_record(); + + quote! { + #item_struct + #record_impl + } + .into() } struct RecordDeriveCodegen { - name_record: Ident, - name_item: Ident, - gen: StructRecordItemCodegen, - generics: Generics, + name_record: Ident, + name_item: Ident, + gen: StructRecordItemCodegen, + generics: Generics, } impl RecordDeriveCodegen { - pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self { - let name_record = ast.ident.clone(); - let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span()); - - Self { - name_record, - name_item, - gen: StructRecordItemCodegen::new( - parse_fields(ast) - .into_iter() - .map(FieldTypeAnalyzer::new) - .collect(), - ), - generics: ast.generics.clone(), + pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self { + let name_record = ast.ident.clone(); + let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span()); + + Self { + name_record, + name_item, + gen: StructRecordItemCodegen::new( + parse_fields(ast) + .into_iter() + .map(FieldTypeAnalyzer::new) + .collect(), + ), + generics: ast.generics.clone(), + } } - } - /// Generate the record type with the correct generics. - pub(crate) fn gen_record_type(&self) -> TokenStream { - let param: syn::Generics = parse_quote! { }; - let mut generics = self.generics.clone(); + /// Generate the record type with the correct generics. + pub(crate) fn gen_record_type(&self) -> TokenStream { + let param: syn::Generics = parse_quote! { }; + let mut generics = self.generics.clone(); - for param in param.params.into_iter() { - generics.params.push(param); - } + for param in param.params.into_iter() { + generics.params.push(param); + } - self.gen.gen_item_type(&self.name_item, &generics) - } + self.gen.gen_item_type(&self.name_item, &generics) + } - /// Generate the implementation for the Record trait. - pub(crate) fn gen_impl_record(&self) -> TokenStream { - let name = &self.name_record; - let item_generics = self.record_item_generics(); - let (_, ty_generics_item, _) = item_generics.split_for_impl(); - let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + /// Generate the implementation for the Record trait. + pub(crate) fn gen_impl_record(&self) -> TokenStream { + let name = &self.name_record; + let item_generics = self.record_item_generics(); + let (_, ty_generics_item, _) = item_generics.split_for_impl(); + let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); - let name_item = &self.name_item; - let into_item_fn = self.gen.gen_into_item(name_item); - let from_item_fn = self.gen.gen_from_item(); + let name_item = &self.name_item; + let into_item_fn = self.gen.gen_into_item(name_item); + let from_item_fn = self.gen.gen_from_item(); - quote! { - impl #impl_generics burn::record::Record for #name #ty_generics #where_clause { - type Item = #name_item #ty_generics_item; + quote! { + impl #impl_generics burn::record::Record for #name #ty_generics #where_clause { + type Item = #name_item #ty_generics_item; - #into_item_fn - #from_item_fn + #into_item_fn + #from_item_fn + } } } - } - fn record_item_generics(&self) -> Generics { - let param: syn::Generics = parse_quote! { }; - let mut generics = self.generics.clone(); - for param in param.params.into_iter() { - generics.params.push(param); - } + fn record_item_generics(&self) -> Generics { + let param: syn::Generics = parse_quote! { }; + let mut generics = self.generics.clone(); + for param in param.params.into_iter() { + generics.params.push(param); + } - generics - } + generics + } } diff --git a/burn-derive/src/record/codegen.rs b/burn-derive/src/record/codegen.rs index 5af75f4678..aafcead97a 100644 --- a/burn-derive/src/record/codegen.rs +++ b/burn-derive/src/record/codegen.rs @@ -3,10 +3,10 @@ use syn::Generics; /// Basic trait to be implemented for record generation. pub(crate) trait RecordItemCodegen { - /// Generate the record item type (i.e a struct) - fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream; - /// Generate the into_item function. - fn gen_into_item(&self, item_name: &Ident) -> TokenStream; - /// Generate the from item function. - fn gen_from_item(&self) -> TokenStream; + /// Generate the record item type (i.e a struct) + fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream; + /// Generate the into_item function. + fn gen_into_item(&self, item_name: &Ident) -> TokenStream; + /// Generate the from item function. + fn gen_from_item(&self) -> TokenStream; } diff --git a/burn-derive/src/record/codegen_struct.rs b/burn-derive/src/record/codegen_struct.rs index 9fa1e7e1c5..331b38b308 100644 --- a/burn-derive/src/record/codegen_struct.rs +++ b/burn-derive/src/record/codegen_struct.rs @@ -7,76 +7,76 @@ use super::codegen::RecordItemCodegen; #[derive(new)] pub(crate) struct StructRecordItemCodegen { - fields: Vec, + fields: Vec, } impl RecordItemCodegen for StructRecordItemCodegen { - fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream { - let mut fields = quote! {}; - let mut bounds = quote! {}; + fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream { + let mut fields = quote! {}; + let mut bounds = quote! {}; - for field in self.fields.iter() { - let ty = &field.field.ty; - let name = &field.field.ident; + for field in self.fields.iter() { + let ty = &field.field.ty; + let name = &field.field.ident; - fields.extend(quote! { - /// Field to be serialized. - pub #name: <#ty as burn::record::Record>::Item, - }); - bounds.extend(quote! { + fields.extend(quote! { + /// Field to be serialized. + pub #name: <#ty as burn::record::Record>::Item, + }); + bounds.extend(quote! { <#ty as burn::record::Record>::Item: serde::Serialize + serde::de::DeserializeOwned, }); - } - let bound = bounds.to_string(); + } + let bound = bounds.to_string(); - quote! { + quote! { - /// The record item type for the module. - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] - #[serde(bound = #bound)] - pub struct #item_name #generics { - #fields + /// The record item type for the module. + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + #[serde(bound = #bound)] + pub struct #item_name #generics { + #fields + } } } - } - fn gen_into_item(&self, item_name: &Ident) -> TokenStream { - let mut body_into_item = quote! {}; + fn gen_into_item(&self, item_name: &Ident) -> TokenStream { + let mut body_into_item = quote! {}; - for field in self.fields.iter() { - let name = &field.field.ident; + for field in self.fields.iter() { + let name = &field.field.ident; - body_into_item.extend(quote! { - #name: burn::record::Record::into_item::(self.#name), - }); - } + body_into_item.extend(quote! { + #name: burn::record::Record::into_item::(self.#name), + }); + } - quote! { - fn into_item(self) -> Self::Item { - #item_name { - #body_into_item + quote! { + fn into_item(self) -> Self::Item { + #item_name { + #body_into_item + } } } } - } - fn gen_from_item(&self) -> TokenStream { - let mut body_from_item = quote! {}; + fn gen_from_item(&self) -> TokenStream { + let mut body_from_item = quote! {}; - for field in self.fields.iter() { - let name = &field.field.ident; + for field in self.fields.iter() { + let name = &field.field.ident; - body_from_item.extend(quote! { - #name: burn::record::Record::from_item::(item.#name), - }); - } + body_from_item.extend(quote! { + #name: burn::record::Record::from_item::(item.#name), + }); + } - quote! { - fn from_item(item: Self::Item) -> Self { - Self { - #body_from_item + quote! { + fn from_item(item: Self::Item) -> Self { + Self { + #body_from_item + } } } } - } } diff --git a/burn-derive/src/shared/attribute.rs b/burn-derive/src/shared/attribute.rs index e1cd0d6c33..b678b3a7ec 100644 --- a/burn-derive/src/shared/attribute.rs +++ b/burn-derive/src/shared/attribute.rs @@ -1,53 +1,53 @@ use syn::{Attribute, Ident, Meta}; pub struct AttributeAnalyzer { - attr: Attribute, + attr: Attribute, } #[derive(Clone)] pub struct AttributeItem { - pub ident: Ident, - pub value: syn::Lit, + pub ident: Ident, + pub value: syn::Lit, } impl AttributeAnalyzer { - pub fn new(attr: Attribute) -> Self { - Self { attr } - } - - pub fn item(&self) -> AttributeItem { - let value = match &self.attr.meta { - Meta::List(val) => val.parse_args::().unwrap(), - Meta::NameValue(meta) => meta.clone(), - Meta::Path(_) => panic!("Path meta unsupported"), - }; - - let lit = match value.value { - syn::Expr::Lit(lit) => lit.lit, - _ => panic!("Only literal is supported"), - }; - - AttributeItem { - ident: value.path.get_ident().unwrap().clone(), - value: lit, + pub fn new(attr: Attribute) -> Self { + Self { attr } } - } - - pub fn has_name(&self, name: &str) -> bool { - Self::path_syn_name(self.attr.path()) == name - } - - fn path_syn_name(path: &syn::Path) -> String { - let length = path.segments.len(); - let mut name = String::new(); - for (i, segment) in path.segments.iter().enumerate() { - if i == length - 1 { - name += segment.ident.to_string().as_str(); - } else { - let tmp = segment.ident.to_string() + "::"; - name += tmp.as_str(); - } + + pub fn item(&self) -> AttributeItem { + let value = match &self.attr.meta { + Meta::List(val) => val.parse_args::().unwrap(), + Meta::NameValue(meta) => meta.clone(), + Meta::Path(_) => panic!("Path meta unsupported"), + }; + + let lit = match value.value { + syn::Expr::Lit(lit) => lit.lit, + _ => panic!("Only literal is supported"), + }; + + AttributeItem { + ident: value.path.get_ident().unwrap().clone(), + value: lit, + } + } + + pub fn has_name(&self, name: &str) -> bool { + Self::path_syn_name(self.attr.path()) == name + } + + fn path_syn_name(path: &syn::Path) -> String { + let length = path.segments.len(); + let mut name = String::new(); + for (i, segment) in path.segments.iter().enumerate() { + if i == length - 1 { + name += segment.ident.to_string().as_str(); + } else { + let tmp = segment.ident.to_string() + "::"; + name += tmp.as_str(); + } + } + name } - name - } } diff --git a/burn-derive/src/shared/field.rs b/burn-derive/src/shared/field.rs index b48a3c3e2e..274cabcc15 100644 --- a/burn-derive/src/shared/field.rs +++ b/burn-derive/src/shared/field.rs @@ -5,103 +5,101 @@ use syn::{Field, Type, TypePath}; #[derive(Clone)] pub struct FieldTypeAnalyzer { - pub field: Field, + pub field: Field, } impl FieldTypeAnalyzer { - pub fn new(field: Field) -> Self { - FieldTypeAnalyzer { field } - } - - pub fn ident(&self) -> Ident { - self.field.ident.clone().unwrap() - } + pub fn new(field: Field) -> Self { + FieldTypeAnalyzer { field } + } - pub fn is_of_type(&self, paths: &[&str]) -> bool { - match &self.field.ty { - syn::Type::Path(path) => { - let name = Self::path_name(path); - paths.contains(&name.as_str()) - } - _ => false, + pub fn ident(&self) -> Ident { + self.field.ident.clone().unwrap() } - } - #[allow(dead_code)] - pub fn first_generic_field(&self) -> TypePath { - let err = || panic!("Field {} as no generic", self.field.ident.clone().unwrap()); - match &self.field.ty { - syn::Type::Path(path) => Self::path_generic_argument(path), - _ => err(), + pub fn is_of_type(&self, paths: &[&str]) -> bool { + match &self.field.ty { + syn::Type::Path(path) => { + let name = Self::path_name(path); + paths.contains(&name.as_str()) + } + _ => false, + } } - } - pub fn path_generic_argument(path: &TypePath) -> TypePath { - let segment = path.path.segments.last().unwrap(); - let err = || panic!("Path segment {} has no generic", segment.ident.clone(),); - match &segment.arguments { - syn::PathArguments::None => err(), - syn::PathArguments::AngleBracketed(param) => { - let first_param = param.args.first().unwrap(); - if let syn::GenericArgument::Type(Type::Path(path)) = first_param { - path.clone() - } else { - err() + #[allow(dead_code)] + pub fn first_generic_field(&self) -> TypePath { + let err = || panic!("Field {} as no generic", self.field.ident.clone().unwrap()); + match &self.field.ty { + syn::Type::Path(path) => Self::path_generic_argument(path), + _ => err(), } - } - syn::PathArguments::Parenthesized(_) => err(), } - } + pub fn path_generic_argument(path: &TypePath) -> TypePath { + let segment = path.path.segments.last().unwrap(); + let err = || panic!("Path segment {} has no generic", segment.ident.clone(),); + match &segment.arguments { + syn::PathArguments::None => err(), + syn::PathArguments::AngleBracketed(param) => { + let first_param = param.args.first().unwrap(); - fn path_name(path: &TypePath) -> String { - let length = path.path.segments.len(); - let mut name = String::new(); - for (i, segment) in path.path.segments.iter().enumerate() { - if i == length - 1 { - name += segment.ident.to_string().as_str(); - } else { - let tmp = segment.ident.to_string() + "::"; - name += tmp.as_str(); - } + if let syn::GenericArgument::Type(Type::Path(path)) = first_param { + path.clone() + } else { + err() + } + } + syn::PathArguments::Parenthesized(_) => err(), + } } - name - } - /// Returns the doc of the field if present. - pub fn doc(&self) -> Option { - self - .field - .attrs - .iter() - .find(|attr| attr.path().is_ident("doc")) - .map(|doc| { - quote! { - #doc + fn path_name(path: &TypePath) -> String { + let length = path.path.segments.len(); + let mut name = String::new(); + for (i, segment) in path.path.segments.iter().enumerate() { + if i == length - 1 { + name += segment.ident.to_string().as_str(); + } else { + let tmp = segment.ident.to_string() + "::"; + name += tmp.as_str(); + } } - }) - } + name + } - pub fn attributes(&self) -> impl Iterator { - self - .field - .attrs - .clone() - .into_iter() - .map(AttributeAnalyzer::new) - } + /// Returns the doc of the field if present. + pub fn doc(&self) -> Option { + self.field + .attrs + .iter() + .find(|attr| attr.path().is_ident("doc")) + .map(|doc| { + quote! { + #doc + } + }) + } + + pub fn attributes(&self) -> impl Iterator { + self.field + .attrs + .clone() + .into_iter() + .map(AttributeAnalyzer::new) + } } pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec { - let mut fields = Vec::new(); + let mut fields = Vec::new(); - match &ast.data { - syn::Data::Struct(struct_data) => { - for field in struct_data.fields.iter() { - fields.push(field.clone()); - } - } - syn::Data::Enum(_) => panic!("Only struct can be derived"), - syn::Data::Union(_) => panic!("Only struct can be derived"), - }; - fields + match &ast.data { + syn::Data::Struct(struct_data) => { + for field in struct_data.fields.iter() { + fields.push(field.clone()); + } + } + syn::Data::Enum(_) => panic!("Only struct can be derived"), + syn::Data::Union(_) => panic!("Only struct can be derived"), + }; + fields } diff --git a/burn-fusion/src/backend.rs b/burn-fusion/src/backend.rs index 2db49004d3..1a9a30604a 100644 --- a/burn-fusion/src/backend.rs +++ b/burn-fusion/src/backend.rs @@ -1,6 +1,6 @@ use crate::{ - client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor, - HandleContainer, + client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor, + HandleContainer, }; use burn_tensor::{backend::Backend, Device, Shape}; use core::marker::PhantomData; @@ -9,62 +9,62 @@ use std::sync::Arc; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); pub(crate) fn get_client(device: &B::FusionDevice) -> B::FusionClient { - CLIENTS.client(device) + CLIENTS.client(device) } /// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend). #[derive(Clone, Debug, Default)] pub struct Fusion { - _backend: PhantomData, + _backend: PhantomData, } impl Backend for Fusion { - type Device = B::Device; + type Device = B::Device; - // TODO: Find a better way to handle full precision. - type FullPrecisionBackend = Self; - type FullPrecisionElem = B::FloatElem; + // TODO: Find a better way to handle full precision. + type FullPrecisionBackend = Self; + type FullPrecisionElem = B::FloatElem; - type TensorPrimitive = FusionTensor; + type TensorPrimitive = FusionTensor; - type FloatElem = B::FloatElem; + type FloatElem = B::FloatElem; - type IntTensorPrimitive = FusionTensor; + type IntTensorPrimitive = FusionTensor; - type IntElem = B::IntElem; + type IntElem = B::IntElem; - type BoolTensorPrimitive = FusionTensor; + type BoolTensorPrimitive = FusionTensor; - fn name() -> String { - format!("fusion<{}>", B::name()) - } + fn name() -> String { + format!("fusion<{}>", B::name()) + } - fn seed(seed: u64) { - B::seed(seed); - } + fn seed(seed: u64) { + B::seed(seed); + } - fn sync(device: &Self::Device) { - let client = CLIENTS.client::(&device.clone().into()); - client.drain_graph(); - B::sync(device) - } + fn sync(device: &Self::Device) { + let client = CLIENTS.client::(&device.clone().into()); + client.drain_graph(); + B::sync(device) + } } /// The status of a [fusion ops](FusionOps). pub enum FusionStatus { - /// No more operations can be fused. - Closed(FusionProperties), - /// More operations can be fused. - Open(FusionProperties), + /// No more operations can be fused. + Closed(FusionProperties), + /// More operations can be fused. + Open(FusionProperties), } /// The properties of a [fusion ops](FusionOps). #[derive(Debug, Clone, Copy, Default)] pub struct FusionProperties { - /// The score of the optimization, higher is better. - pub score: u64, - /// If the operation is ready to be executed. - pub ready: bool, + /// The score of the optimization, higher is better. + pub score: u64, + /// If the operation is ready to be executed. + pub ready: bool, } /// The fusion operation abstraction allows implementations to fuse many @@ -80,77 +80,77 @@ pub struct FusionProperties { /// Also, it is important to return (FusionStatus::Closed) when no more registered operation can /// improve the performance. pub trait FusionOps: Send { - /// Register a new [tensor operation](TensorOpsDescription). - /// - /// The return value should be either [closed](FusionStatus::Closed) or - /// [open](FusionStatus::Open). - /// - /// When [closed](FusionStatus::Closed), it's assumed that no more operation can be added - /// to the current fusion operation. No [tensor operation](TensorOpsDescription) can be - /// ignored, they are either accepted or rejected, and the [status](FusionStatus) describes it. - fn register(&mut self, ops: Arc>) -> FusionStatus; - /// Execute the operation. - fn execute(&mut self, handles: &mut HandleContainer); - /// Reset the state. - fn reset(&mut self); - /// The size of operations fused. - fn len(&self) -> usize; - /// If the current operation is empty. - fn is_empty(&self) -> bool { - self.len() == 0 - } + /// Register a new [tensor operation](TensorOpsDescription). + /// + /// The return value should be either [closed](FusionStatus::Closed) or + /// [open](FusionStatus::Open). + /// + /// When [closed](FusionStatus::Closed), it's assumed that no more operation can be added + /// to the current fusion operation. No [tensor operation](TensorOpsDescription) can be + /// ignored, they are either accepted or rejected, and the [status](FusionStatus) describes it. + fn register(&mut self, ops: Arc>) -> FusionStatus; + /// Execute the operation. + fn execute(&mut self, handles: &mut HandleContainer); + /// Reset the state. + fn reset(&mut self); + /// The size of operations fused. + fn len(&self) -> usize; + /// If the current operation is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } } /// The device id. #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)] pub struct DeviceId { - /// The type id identifies the type of the device. - pub type_id: u16, - /// The index id identifies the device number. - pub index_id: u32, + /// The type id identifies the type of the device. + pub type_id: u16, + /// The index id identifies the device number. + pub index_id: u32, } /// The handle device trait allows to get an id for a backend device. pub trait FusionDevice: Clone + Send + Sync + PartialEq { - /// Return the [device id](DeviceId). - fn id(&self) -> DeviceId; + /// Return the [device id](DeviceId). + fn id(&self) -> DeviceId; } /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [fusion operation](crate::FusionOps). pub trait FusionBackend: Backend { - /// The device type that can return an ID. - /// - /// It can be the same as (Backend::Device), but must implement (FusionDevice). - type FusionDevice: FusionDevice + From + Into + core::fmt::Debug; - /// The type that can be used to point to a tensor of any kind. - type Handle: Sync + Send + Clone; - /// What kind of client should be used. - type FusionClient: FusionClient; - - /// The list of operations that will be used to optimize the computational graph. - fn operations(device: &Device) -> Vec>>; - - /// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive). - fn float_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::TensorPrimitive; - /// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive). - fn int_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::IntTensorPrimitive; - /// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). - fn bool_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::BoolTensorPrimitive; - - /// Convert a [float tensor](Backend::TensorPrimitive) to a [handle](FusionBackend::Handle). - fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle; - /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle). - fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle; - /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle). - fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle; + /// The device type that can return an ID. + /// + /// It can be the same as (Backend::Device), but must implement (FusionDevice). + type FusionDevice: FusionDevice + From + Into + core::fmt::Debug; + /// The type that can be used to point to a tensor of any kind. + type Handle: Sync + Send + Clone; + /// What kind of client should be used. + type FusionClient: FusionClient; + + /// The list of operations that will be used to optimize the computational graph. + fn operations(device: &Device) -> Vec>>; + + /// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive). + fn float_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::TensorPrimitive; + /// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive). + fn int_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::IntTensorPrimitive; + /// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). + fn bool_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::BoolTensorPrimitive; + + /// Convert a [float tensor](Backend::TensorPrimitive) to a [handle](FusionBackend::Handle). + fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle; + /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle). + fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle; + /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle). + fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle; } diff --git a/burn-fusion/src/client/base.rs b/burn-fusion/src/client/base.rs index e8329ff5cf..778c71030f 100644 --- a/burn-fusion/src/client/base.rs +++ b/burn-fusion/src/client/base.rs @@ -1,65 +1,65 @@ use crate::{ - graph::{GraphExecution, TensorOpsDescription}, - FusionBackend, FusionTensor, Handle, TensorDescription, TensorId, + graph::{GraphExecution, TensorOpsDescription}, + FusionBackend, FusionTensor, Handle, TensorDescription, TensorId, }; use burn_tensor::{ - ops::{FloatElem, IntElem}, - Data, Reader, + ops::{FloatElem, IntElem}, + Data, Reader, }; /// Define how to interact with the fusion server. pub trait FusionClient: Send + Sync + Clone { - /// The [fusion backend](FusionBackend) associated type. - type FusionBackend: FusionBackend; - /// The [graph execution](GraphExecution) associated type. - type GraphExecution: GraphExecution; + /// The [fusion backend](FusionBackend) associated type. + type FusionBackend: FusionBackend; + /// The [graph execution](GraphExecution) associated type. + type GraphExecution: GraphExecution; - /// Create a new client for the given [fusion device](FusionBackend::FusionDevice). - fn new(device: ::FusionDevice) -> Self; - /// Register a new [tensor operation description](TensorOpsDescription). - fn register(&self, ops: TensorOpsDescription); - /// Register all lazy computation. - fn drain_graph(&self); - /// Get the current device used by all operations handled by this client. - fn device(&self) -> &::FusionDevice; - /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; - /// Create a tensor with the given handle and shape. - fn register_tensor( - &self, - handle: Handle, - shape: Vec, - ) -> FusionTensor; - /// Read the values contained by a float tensor. - fn read_tensor_float( - &self, - tensor: TensorDescription, - ) -> Reader, D>>; - /// Read the values contained by an int tensor. - fn read_tensor_int( - &self, - tensor: TensorDescription, - ) -> Reader, D>>; - /// Read the values contained by a bool tensor. - fn read_tensor_bool(&self, tensor: TensorDescription) -> Reader>; - /// Change the client of the given float tensor. - fn change_client_float( - &self, - tensor: TensorDescription, - client: Self, - ) -> FusionTensor; - /// Change the client of the given int tensor. - fn change_client_int( - &self, - tensor: TensorDescription, - client: Self, - ) -> FusionTensor; - /// Change the client of the given bool tensor. - fn change_client_bool( - &self, - tensor: TensorDescription, - client: Self, - ) -> FusionTensor; - /// Drop the tensor with the given [tensor id](TensorId). - fn register_orphan(&self, id: &TensorId); + /// Create a new client for the given [fusion device](FusionBackend::FusionDevice). + fn new(device: ::FusionDevice) -> Self; + /// Register a new [tensor operation description](TensorOpsDescription). + fn register(&self, ops: TensorOpsDescription); + /// Register all lazy computation. + fn drain_graph(&self); + /// Get the current device used by all operations handled by this client. + fn device(&self) -> &::FusionDevice; + /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. + fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; + /// Create a tensor with the given handle and shape. + fn register_tensor( + &self, + handle: Handle, + shape: Vec, + ) -> FusionTensor; + /// Read the values contained by a float tensor. + fn read_tensor_float( + &self, + tensor: TensorDescription, + ) -> Reader, D>>; + /// Read the values contained by an int tensor. + fn read_tensor_int( + &self, + tensor: TensorDescription, + ) -> Reader, D>>; + /// Read the values contained by a bool tensor. + fn read_tensor_bool(&self, tensor: TensorDescription) -> Reader>; + /// Change the client of the given float tensor. + fn change_client_float( + &self, + tensor: TensorDescription, + client: Self, + ) -> FusionTensor; + /// Change the client of the given int tensor. + fn change_client_int( + &self, + tensor: TensorDescription, + client: Self, + ) -> FusionTensor; + /// Change the client of the given bool tensor. + fn change_client_bool( + &self, + tensor: TensorDescription, + client: Self, + ) -> FusionTensor; + /// Drop the tensor with the given [tensor id](TensorId). + fn register_orphan(&self, id: &TensorId); } diff --git a/burn-fusion/src/client/mutex.rs b/burn-fusion/src/client/mutex.rs index aa7408181a..db4bceb55a 100644 --- a/burn-fusion/src/client/mutex.rs +++ b/burn-fusion/src/client/mutex.rs @@ -1,7 +1,7 @@ use super::FusionClient; use crate::{ - graph::{GraphExecution, TensorOpsDescription}, - FusionBackend, FusionServer, FusionTensor, Handle, + graph::{GraphExecution, TensorOpsDescription}, + FusionBackend, FusionServer, FusionTensor, Handle, }; use burn_tensor::ops::FloatElem; use spin::Mutex; @@ -10,149 +10,150 @@ use std::sync::Arc; /// Use a mutex to communicate with the fusion server. pub struct MutexFusionClient where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - server: Arc>>, - device: B::FusionDevice, + server: Arc>>, + device: B::FusionDevice, } impl Clone for MutexFusionClient where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - device: self.device.clone(), + fn clone(&self) -> Self { + Self { + server: self.server.clone(), + device: self.device.clone(), + } } - } } impl FusionClient for MutexFusionClient where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - type FusionBackend = B; - type GraphExecution = G; + type FusionBackend = B; + type GraphExecution = G; + + fn new(device: B::FusionDevice) -> Self { + Self { + device: device.clone(), + server: Arc::new(Mutex::new(FusionServer::new(device))), + } + } + + fn register(&self, ops: TensorOpsDescription) { + self.server.lock().register(ops); + } + + fn drain_graph(&self) { + self.server.lock().drain_graph(); + } + + fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor { + let id = self.server.lock().create_empty_handle(); + + FusionTensor::new(id, shape, self.clone()) + } + + fn device(&self) -> &::FusionDevice { + &self.device + } + fn register_tensor( + &self, + handle: Handle, + shape: Vec, + ) -> FusionTensor { + let mut server = self.server.lock(); + let id = server.create_empty_handle(); + server.handles.register_handle(id.as_ref().clone(), handle); + core::mem::drop(server); + + FusionTensor::new(id, shape, self.clone()) + } + + fn read_tensor_float( + &self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> { + self.server.lock().read_float(tensor) + } + + fn read_tensor_int( + &self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> + { + self.server.lock().read_int(tensor) + } + + fn read_tensor_bool( + &self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader> { + self.server.lock().read_bool(tensor) + } + + fn change_client_float( + &self, + tensor: crate::TensorDescription, + client: Self, + ) -> FusionTensor { + let device = client.device.clone().into(); + + let mut other_server = client.server.lock(); + + let id = self + .server + .lock() + .change_server_float::(&tensor, &device, &mut other_server); + + core::mem::drop(other_server); + + FusionTensor::new(id, tensor.shape, client) + } + fn change_client_int( + &self, + tensor: crate::TensorDescription, + client: Self, + ) -> FusionTensor { + let device = client.device.clone().into(); + + let mut other_server = client.server.lock(); + + let id = self + .server + .lock() + .change_server_int::(&tensor, &device, &mut other_server); + + core::mem::drop(other_server); + + FusionTensor::new(id, tensor.shape, client) + } + + fn change_client_bool( + &self, + tensor: crate::TensorDescription, + client: Self, + ) -> FusionTensor { + let device = client.device.clone().into(); + + let mut other_server = client.server.lock(); + + let id = self + .server + .lock() + .change_server_bool::(&tensor, &device, &mut other_server); + + core::mem::drop(other_server); + + FusionTensor::new(id, tensor.shape, client) + } - fn new(device: B::FusionDevice) -> Self { - Self { - device: device.clone(), - server: Arc::new(Mutex::new(FusionServer::new(device))), + fn register_orphan(&self, id: &crate::TensorId) { + self.server.lock().drop_tensor_handle(id.clone()); } - } - - fn register(&self, ops: TensorOpsDescription) { - self.server.lock().register(ops); - } - - fn drain_graph(&self) { - self.server.lock().drain_graph(); - } - - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor { - let id = self.server.lock().create_empty_handle(); - - FusionTensor::new(id, shape, self.clone()) - } - - fn device(&self) -> &::FusionDevice { - &self.device - } - fn register_tensor( - &self, - handle: Handle, - shape: Vec, - ) -> FusionTensor { - let mut server = self.server.lock(); - let id = server.create_empty_handle(); - server.handles.register_handle(id.as_ref().clone(), handle); - core::mem::drop(server); - - FusionTensor::new(id, shape, self.clone()) - } - - fn read_tensor_float( - &self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> { - self.server.lock().read_float(tensor) - } - - fn read_tensor_int( - &self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> { - self.server.lock().read_int(tensor) - } - - fn read_tensor_bool( - &self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader> { - self.server.lock().read_bool(tensor) - } - - fn change_client_float( - &self, - tensor: crate::TensorDescription, - client: Self, - ) -> FusionTensor { - let device = client.device.clone().into(); - - let mut other_server = client.server.lock(); - - let id = self - .server - .lock() - .change_server_float::(&tensor, &device, &mut other_server); - - core::mem::drop(other_server); - - FusionTensor::new(id, tensor.shape, client) - } - fn change_client_int( - &self, - tensor: crate::TensorDescription, - client: Self, - ) -> FusionTensor { - let device = client.device.clone().into(); - - let mut other_server = client.server.lock(); - - let id = self - .server - .lock() - .change_server_int::(&tensor, &device, &mut other_server); - - core::mem::drop(other_server); - - FusionTensor::new(id, tensor.shape, client) - } - - fn change_client_bool( - &self, - tensor: crate::TensorDescription, - client: Self, - ) -> FusionTensor { - let device = client.device.clone().into(); - - let mut other_server = client.server.lock(); - - let id = self - .server - .lock() - .change_server_bool::(&tensor, &device, &mut other_server); - - core::mem::drop(other_server); - - FusionTensor::new(id, tensor.shape, client) - } - - fn register_orphan(&self, id: &crate::TensorId) { - self.server.lock().drop_tensor_handle(id.clone()); - } } diff --git a/burn-fusion/src/fusion.rs b/burn-fusion/src/fusion.rs index bc5ea0937e..f26630a561 100644 --- a/burn-fusion/src/fusion.rs +++ b/burn-fusion/src/fusion.rs @@ -6,65 +6,65 @@ pub type Handle = ::Handle; type Key = (core::any::TypeId, DeviceId); pub(crate) struct FusionClientLocator { - clients: spin::Mutex>>>, + clients: spin::Mutex>>>, } impl FusionClientLocator { - /// Create a new client locator. - pub const fn new() -> Self { - Self { - clients: spin::Mutex::new(None), + /// Create a new client locator. + pub const fn new() -> Self { + Self { + clients: spin::Mutex::new(None), + } } - } - - /// Get the fusion client for the given device. - /// - /// Provide the init function to create a new client if it isn't already initialized. - pub fn client( - &self, - device: &::FusionDevice, - ) -> C { - let device_id = device.id(); - let client_id = (core::any::TypeId::of::(), device_id); - let mut clients = self.clients.lock(); - if clients.is_none() { - let client = C::new(device.clone()); - Self::register_inner::(client_id, client, &mut clients); - } + /// Get the fusion client for the given device. + /// + /// Provide the init function to create a new client if it isn't already initialized. + pub fn client( + &self, + device: &::FusionDevice, + ) -> C { + let device_id = device.id(); + let client_id = (core::any::TypeId::of::(), device_id); + let mut clients = self.clients.lock(); - match clients.deref_mut() { - Some(clients) => match clients.get(&client_id) { - Some(client) => { - let client: &C = client.downcast_ref().unwrap(); - client.clone() + if clients.is_none() { + let client = C::new(device.clone()); + Self::register_inner::(client_id, client, &mut clients); } - None => { - let client = C::new(device.clone()); - let any = Box::new(client.clone()); - clients.insert(client_id, any); - client + + match clients.deref_mut() { + Some(clients) => match clients.get(&client_id) { + Some(client) => { + let client: &C = client.downcast_ref().unwrap(); + client.clone() + } + None => { + let client = C::new(device.clone()); + let any = Box::new(client.clone()); + clients.insert(client_id, any); + client + } + }, + _ => unreachable!(), } - }, - _ => unreachable!(), } - } - fn register_inner( - key: Key, - client: C, - clients: &mut Option>>, - ) { - if clients.is_none() { - *clients = Some(HashMap::new()); - } + fn register_inner( + key: Key, + client: C, + clients: &mut Option>>, + ) { + if clients.is_none() { + *clients = Some(HashMap::new()); + } - if let Some(clients) = clients { - if clients.contains_key(&key) { - panic!("Client already created for device {:?}", key); - } + if let Some(clients) = clients { + if clients.contains_key(&key) { + panic!("Client already created for device {:?}", key); + } - clients.insert(key, Box::new(client)); + clients.insert(key, Box::new(client)); + } } - } } diff --git a/burn-fusion/src/graph/base.rs b/burn-fusion/src/graph/base.rs index 43fe5677cb..28bcade60f 100644 --- a/burn-fusion/src/graph/base.rs +++ b/burn-fusion/src/graph/base.rs @@ -4,95 +4,95 @@ use std::{ops::RangeBounds, sync::Arc, vec::Drain}; /// The computational graph containing a list of [tensor operation descriptions](TensorOpsDescription). pub struct Graph { - operations: Vec>>, + operations: Vec>>, } impl Graph { - pub(crate) fn new() -> Self { - Self { - operations: Vec::new(), + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), + } } - } - pub(crate) fn add(&mut self, ops: Arc>) { - self.operations.push(ops); - } - - /// The size of the graph. - pub fn len(&self) -> usize { - self.operations.len() - } - - /// If the graph is empty. - pub fn is_empty(&self) -> bool { - self.operations.len() == 0 - } - - fn drain(&mut self, range: R) -> Drain<'_, Arc>> - where - R: RangeBounds, - { - self.operations.drain(range) - } - - fn remove>(&mut self, range: R, handles: &mut HandleContainer) { - for ops in self.operations.drain(range) { - ops.cleanup_tensor(handles) + pub(crate) fn add(&mut self, ops: Arc>) { + self.operations.push(ops); } - } - - fn nodes(&self) -> &[Arc>] { - &self.operations - } - - pub(crate) fn execute_optimization( - &mut self, - handles: &mut HandleContainer, - index: usize, - optimizations: &mut [Optimization], - ) { - let optimization = optimizations.get_mut(index).unwrap(); - let num_keep = optimization.ops.len(); - optimization.ops.execute(handles); - - self.remove(0..num_keep, handles); - - for optimization in optimizations.iter_mut() { - optimization.reset(); - - for node in self.nodes() { - optimization.register(node); - } + + /// The size of the graph. + pub fn len(&self) -> usize { + self.operations.len() + } + + /// If the graph is empty. + pub fn is_empty(&self) -> bool { + self.operations.len() == 0 + } + + fn drain(&mut self, range: R) -> Drain<'_, Arc>> + where + R: RangeBounds, + { + self.operations.drain(range) + } + + fn remove>(&mut self, range: R, handles: &mut HandleContainer) { + for ops in self.operations.drain(range) { + ops.cleanup_tensor(handles) + } } - } - pub(crate) fn execute(&mut self, handles: &mut HandleContainer) { - for ops in self.drain(..) { - ops.execute(handles); - ops.cleanup_tensor(handles); + fn nodes(&self) -> &[Arc>] { + &self.operations + } + + pub(crate) fn execute_optimization( + &mut self, + handles: &mut HandleContainer, + index: usize, + optimizations: &mut [Optimization], + ) { + let optimization = optimizations.get_mut(index).unwrap(); + let num_keep = optimization.ops.len(); + optimization.ops.execute(handles); + + self.remove(0..num_keep, handles); + + for optimization in optimizations.iter_mut() { + optimization.reset(); + + for node in self.nodes() { + optimization.register(node); + } + } + } + + pub(crate) fn execute(&mut self, handles: &mut HandleContainer) { + for ops in self.drain(..) { + ops.execute(handles); + ops.cleanup_tensor(handles); + } } - } } /// An optimization that can be executed. #[derive(new)] pub struct Optimization { - /// The [fusion operation](FusionOps) to potentially be executed. - pub ops: Box>, - /// The current status of the optimization. - pub status: FusionStatus, + /// The [fusion operation](FusionOps) to potentially be executed. + pub ops: Box>, + /// The current status of the optimization. + pub status: FusionStatus, } impl Optimization { - pub(crate) fn register(&mut self, ops: &Arc>) { - if let FusionStatus::Closed(_) = self.status { - return; - } + pub(crate) fn register(&mut self, ops: &Arc>) { + if let FusionStatus::Closed(_) = self.status { + return; + } - self.status = self.ops.register(ops.clone()); - } + self.status = self.ops.register(ops.clone()); + } - pub(crate) fn reset(&mut self) { - self.ops.reset(); - self.status = FusionStatus::Open(FusionProperties::default()); - } + pub(crate) fn reset(&mut self) { + self.ops.reset(); + self.status = FusionStatus::Open(FusionProperties::default()); + } } diff --git a/burn-fusion/src/graph/execution.rs b/burn-fusion/src/graph/execution.rs index 85c5159a3c..36cbf1a6d3 100644 --- a/burn-fusion/src/graph/execution.rs +++ b/burn-fusion/src/graph/execution.rs @@ -3,15 +3,15 @@ use crate::{FusionBackend, FusionStatus, HandleContainer}; /// The graph execution trait abstracts the way the graph is executing optimizations. pub trait GraphExecution: Default + Send { - /// Execute the given graph using the list of potential [optimizations](Optimization). - /// May do nothing if empty or not ready - fn maybe_execute( - &mut self, - graph: &mut Graph, - handles: &mut HandleContainer, - optimizations: &mut [Optimization], - force: bool, - ); + /// Execute the given graph using the list of potential [optimizations](Optimization). + /// May do nothing if empty or not ready + fn maybe_execute( + &mut self, + graph: &mut Graph, + handles: &mut HandleContainer, + optimizations: &mut [Optimization], + force: bool, + ); } /// Execute an optimization following a greedy algorithm. @@ -19,65 +19,65 @@ pub trait GraphExecution: Default + Send { pub struct GreedyGraphExecution; impl GraphExecution for GreedyGraphExecution { - fn maybe_execute( - &mut self, - graph: &mut Graph, - handles: &mut HandleContainer, - optimizations: &mut [Optimization], - force: bool, - ) { - loop { - if !force && still_optimizing(optimizations) { - break; - } + fn maybe_execute( + &mut self, + graph: &mut Graph, + handles: &mut HandleContainer, + optimizations: &mut [Optimization], + force: bool, + ) { + loop { + if !force && still_optimizing(optimizations) { + break; + } - match find_best_optimization_index(optimizations) { - Some(index) => { - graph.execute_optimization(handles, index, optimizations); - } - None => { - graph.execute(handles); - optimizations.iter_mut().for_each(|ops| ops.reset()); - } - } + match find_best_optimization_index(optimizations) { + Some(index) => { + graph.execute_optimization(handles, index, optimizations); + } + None => { + graph.execute(handles); + optimizations.iter_mut().for_each(|ops| ops.reset()); + } + } - if graph.is_empty() { - // No more ops to fuse. - break; - } + if graph.is_empty() { + // No more ops to fuse. + break; + } + } } - } } fn still_optimizing(optimizations: &[Optimization]) -> bool { - let mut num_stopped = 0; + let mut num_stopped = 0; - for optimization in optimizations.iter() { - if let FusionStatus::Closed(_) = optimization.status { - num_stopped += 1 + for optimization in optimizations.iter() { + if let FusionStatus::Closed(_) = optimization.status { + num_stopped += 1 + } } - } - num_stopped < optimizations.len() + num_stopped < optimizations.len() } fn find_best_optimization_index( - optimizations: &[Optimization], + optimizations: &[Optimization], ) -> Option { - let mut best_index = None; - let mut best_score = 0; + let mut best_index = None; + let mut best_score = 0; - for (i, optimization) in optimizations.iter().enumerate() { - let properties = match optimization.status { - FusionStatus::Closed(properties) => properties, - FusionStatus::Open(properties) => properties, - }; + for (i, optimization) in optimizations.iter().enumerate() { + let properties = match optimization.status { + FusionStatus::Closed(properties) => properties, + FusionStatus::Open(properties) => properties, + }; - if properties.ready && properties.score >= best_score { - best_index = Some(i); - best_score = properties.score; + if properties.ready && properties.score >= best_score { + best_index = Some(i); + best_score = properties.score; + } } - } - best_index + best_index } diff --git a/burn-fusion/src/graph/ops.rs b/burn-fusion/src/graph/ops.rs index ba1dd224bd..3f437bd1f5 100644 --- a/burn-fusion/src/graph/ops.rs +++ b/burn-fusion/src/graph/ops.rs @@ -2,1479 +2,1487 @@ use crate::FusionBackend; use crate::{HandleContainer, TensorDescription}; use burn_tensor::ops::FloatElem; use burn_tensor::{ - ops::{ConvOptions, ConvTransposeOptions}, - Distribution, Element, + ops::{ConvOptions, ConvTransposeOptions}, + Distribution, Element, }; use core::hash::Hash; use std::ops::Range; /// General trait to abstract how a single operation is executed. pub trait Ops: Send + Sync { - /// The argument necessary for the execution to happen. - type Args: Send + Sync; + /// The argument necessary for the execution to happen. + type Args: Send + Sync; - /// Execute the operation. - fn execute(&self, args: &Self::Args, handles: &mut HandleContainer); + /// Execute the operation. + fn execute(&self, args: &Self::Args, handles: &mut HandleContainer); } /// Describe all tensor operations possible. pub enum TensorOpsDescription { - /// Basic operation on a float tensor. - BaseOpsFloat(BaseOpsDescription), - /// Basic operation on an int tensor. - BaseOpsInt(BaseOpsDescription), - /// Basic operation on a bool tensor. - BaseOpsBool(BaseOpsDescription), - /// Numeric operation on a float tensor. - NumericOpsFloat(NumericOpsDescription), - /// Numeric operation on an int tensor. - NumericOpsInt(NumericOpsDescription), - /// Operation specific to a bool tensor. - BoolOps(BoolOpsDescription), - /// Operation specific to an int tensor. - IntOps(IntOpsDescription), - /// Operation specific to a float tensor. - FloatOps(FloatOpsDescription), - /// Module operation. - ModuleOps(ModuleOpsDescription), + /// Basic operation on a float tensor. + BaseOpsFloat(BaseOpsDescription), + /// Basic operation on an int tensor. + BaseOpsInt(BaseOpsDescription), + /// Basic operation on a bool tensor. + BaseOpsBool(BaseOpsDescription), + /// Numeric operation on a float tensor. + NumericOpsFloat(NumericOpsDescription), + /// Numeric operation on an int tensor. + NumericOpsInt(NumericOpsDescription), + /// Operation specific to a bool tensor. + BoolOps(BoolOpsDescription), + /// Operation specific to an int tensor. + IntOps(IntOpsDescription), + /// Operation specific to a float tensor. + FloatOps(FloatOpsDescription), + /// Module operation. + ModuleOps(ModuleOpsDescription), } /// Operation description specific to a float tensor. pub enum FloatOpsDescription { - /// Operation corresponding to [exp](burn_tensor::ops::TensorOps::exp). - Exp( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [log](burn_tensor::ops::TensorOps::log). - Log( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [log1p](burn_tensor::ops::TensorOps::log1p). - Log1p( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [erf](burn_tensor::ops::TensorOps::erf). - Erf( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [powf](burn_tensor::ops::TensorOps::powf). - Powf( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to [sqrt](burn_tensor::ops::TensorOps::sqrt). - Sqrt( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [cos](burn_tensor::ops::TensorOps::cos). - Cos( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [sin](burn_tensor::ops::TensorOps::sin). - Sin( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [tanh](burn_tensor::ops::TensorOps::tanh). - Tanh( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [into_int](burn_tensor::ops::TensorOps::into_int). - IntoInt( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [matmul](burn_tensor::ops::TensorOps::matmul). - Matmul( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to [random](burn_tensor::ops::TensorOps::random). - Random( - (TensorDescription, Distribution>), - Box>)>>, - ), - /// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip). - Recip( - UnaryOpsDescription, - Box>, - ), + /// Operation corresponding to [exp](burn_tensor::ops::TensorOps::exp). + Exp( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [log](burn_tensor::ops::TensorOps::log). + Log( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [log1p](burn_tensor::ops::TensorOps::log1p). + Log1p( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [erf](burn_tensor::ops::TensorOps::erf). + Erf( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [powf](burn_tensor::ops::TensorOps::powf). + Powf( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to [sqrt](burn_tensor::ops::TensorOps::sqrt). + Sqrt( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [cos](burn_tensor::ops::TensorOps::cos). + Cos( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [sin](burn_tensor::ops::TensorOps::sin). + Sin( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [tanh](burn_tensor::ops::TensorOps::tanh). + Tanh( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [into_int](burn_tensor::ops::TensorOps::into_int). + IntoInt( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [matmul](burn_tensor::ops::TensorOps::matmul). + Matmul( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to [random](burn_tensor::ops::TensorOps::random). + Random( + (TensorDescription, Distribution>), + Box>)>>, + ), + /// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip). + Recip( + UnaryOpsDescription, + Box>, + ), } /// Operation description specific to module. pub enum ModuleOpsDescription { - /// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding). - Embedding( - EmbeddingDescription, - Box>, - ), - /// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward). - EmbeddingBackward( - EmbeddingBackwardDescription, - Box>, - ), - /// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d). - Conv1d(Conv1dDescription, Box>), - /// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d). - Conv2d(Conv2dDescription, Box>), - /// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d). - ConvTranspose1d( - ConvTranspose1dDescription, - Box>, - ), - /// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d). - ConvTranspose2d( - ConvTranspose2dDescription, - Box>, - ), - /// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d). - AvgPool1d( - AvgPool1dDescription, - Box>, - ), - /// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d). - AvgPool2d( - AvgPool2dDescription, - Box>, - ), - /// Operation corresponding to - /// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward). - AvgPool1dBackward( - AvgPool1dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward). - AvgPool2dBackward( - AvgPool2dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d). - AdaptiveAvgPool1d( - AdaptiveAvgPool1dDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d). - AdaptiveAvgPool2d( - AdaptiveAvgPool2dDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward). - AdaptiveAvgPool1dBackward( - AdaptiveAvgPool1dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward). - AdaptiveAvgPool2dBackward( - AdaptiveAvgPool2dBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d). - MaxPool1d( - MaxPool1dDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices). - MaxPool1dWithIndices( - MaxPool1dWithIndicesDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward). - MaxPool1dWithIndicesBackward( - MaxPool1dWithIndicesBackwardDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d). - MaxPool2d( - MaxPool2dDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices). - MaxPool2dWithIndices( - MaxPool2dWithIndicesDescription, - Box>, - ), - /// Operation corresponding to - /// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward). - MaxPool2dWithIndicesBackward( - MaxPool2dWithIndicesBackwardDescription, - Box>, - ), + /// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding). + Embedding( + EmbeddingDescription, + Box>, + ), + /// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward). + EmbeddingBackward( + EmbeddingBackwardDescription, + Box>, + ), + /// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d). + Conv1d(Conv1dDescription, Box>), + /// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d). + Conv2d(Conv2dDescription, Box>), + /// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d). + ConvTranspose1d( + ConvTranspose1dDescription, + Box>, + ), + /// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d). + ConvTranspose2d( + ConvTranspose2dDescription, + Box>, + ), + /// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d). + AvgPool1d( + AvgPool1dDescription, + Box>, + ), + /// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d). + AvgPool2d( + AvgPool2dDescription, + Box>, + ), + /// Operation corresponding to + /// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward). + AvgPool1dBackward( + AvgPool1dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward). + AvgPool2dBackward( + AvgPool2dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d). + AdaptiveAvgPool1d( + AdaptiveAvgPool1dDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d). + AdaptiveAvgPool2d( + AdaptiveAvgPool2dDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward). + AdaptiveAvgPool1dBackward( + AdaptiveAvgPool1dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward). + AdaptiveAvgPool2dBackward( + AdaptiveAvgPool2dBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d). + MaxPool1d( + MaxPool1dDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices). + MaxPool1dWithIndices( + MaxPool1dWithIndicesDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward). + MaxPool1dWithIndicesBackward( + MaxPool1dWithIndicesBackwardDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d). + MaxPool2d( + MaxPool2dDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices). + MaxPool2dWithIndices( + MaxPool2dWithIndicesDescription, + Box>, + ), + /// Operation corresponding to + /// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward). + MaxPool2dWithIndicesBackward( + MaxPool2dWithIndicesBackwardDescription, + Box>, + ), } /// Basic operations that can be done on any tensor type. pub enum BaseOpsDescription { - /// Operation corresponding to: - /// - /// Float => [to device](burn_tensor::ops::TensorOps::to_device). - /// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device). - /// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device). - ToDevice( - (TensorDescription, B::Device), - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [reshape](burn_tensor::ops::TensorOps::reshape). - /// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape). - /// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape). - Reshape( - ReshapeDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [swap_dims](burn_tensor::ops::TensorOps::swap_dims). - /// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims). - /// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims). - SwapDims( - SwapDimsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [slice](burn_tensor::ops::TensorOps::slice). - /// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice). - /// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice). - Slice( - SliceOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [slice assign](burn_tensor::ops::TensorOps::slice_assign). - /// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign). - /// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign). - SliceAssign( - SliceAssignOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [equal](burn_tensor::ops::TensorOps::equal). - /// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal). - /// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal). - Equal( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [repeat](burn_tensor::ops::TensorOps::repeat). - /// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat). - /// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat). - Repeat( - RepeatOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [cat](burn_tensor::ops::TensorOps::cat). - /// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat). - /// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat). - Cat(CatOpsDescription, Box>), + /// Operation corresponding to: + /// + /// Float => [to device](burn_tensor::ops::TensorOps::to_device). + /// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device). + /// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device). + ToDevice( + (TensorDescription, B::Device), + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [reshape](burn_tensor::ops::TensorOps::reshape). + /// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape). + /// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape). + Reshape( + ReshapeDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [swap_dims](burn_tensor::ops::TensorOps::swap_dims). + /// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims). + /// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims). + SwapDims( + SwapDimsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [slice](burn_tensor::ops::TensorOps::slice). + /// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice). + /// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice). + Slice( + SliceOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [slice assign](burn_tensor::ops::TensorOps::slice_assign). + /// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign). + /// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign). + SliceAssign( + SliceAssignOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [equal](burn_tensor::ops::TensorOps::equal). + /// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal). + /// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal). + Equal( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [repeat](burn_tensor::ops::TensorOps::repeat). + /// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat). + /// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat). + Repeat( + RepeatOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [cat](burn_tensor::ops::TensorOps::cat). + /// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat). + /// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat). + Cat(CatOpsDescription, Box>), } /// Numeric operations on int and float tensors. pub enum NumericOpsDescription { - /// Operation corresponding to: - /// - /// Float => [add](burn_tensor::ops::TensorOps::add). - /// Int => [add](burn_tensor::ops::IntTensorOps::int_add). - Add( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [add scalar](burn_tensor::ops::TensorOps::add_scalar). - /// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar). - AddScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [sub](burn_tensor::ops::TensorOps::sub). - /// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub). - Sub( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [sub scalar](burn_tensor::ops::TensorOps::sub_scalar). - /// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar). - SubScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [div](burn_tensor::ops::TensorOps::div). - /// Int => [div](burn_tensor::ops::IntTensorOps::int_div). - Div( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [div scalar](burn_tensor::ops::TensorOps::div_scalar). - /// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar). - DivScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [mul](burn_tensor::ops::TensorOps::mul). - /// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul). - Mul( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [mul scalar](burn_tensor::ops::TensorOps::mul_scalar). - /// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar). - MulScalar( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [abs](burn_tensor::ops::TensorOps::abs). - /// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs). - Abs( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [ones](burn_tensor::ops::TensorOps::ones). - /// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones). - Ones(TensorDescription, Box>), - /// Operation corresponding to: - /// - /// Float => [zeros](burn_tensor::ops::TensorOps::zeros). - /// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros). - Zeros(TensorDescription, Box>), - /// Operation corresponding to: - /// - /// Float => [full](burn_tensor::ops::TensorOps::full). - /// Int => [full](burn_tensor::ops::IntTensorOps::int_full). - Full( - (TensorDescription, E), - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [gather](burn_tensor::ops::TensorOps::gather). - /// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather). - Gather( - GatherOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [scatter](burn_tensor::ops::TensorOps::scatter). - /// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter). - Scatter( - ScatterOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [select](burn_tensor::ops::TensorOps::select). - /// Int => [select](burn_tensor::ops::IntTensorOps::int_select). - Select( - SelectOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [select assign](burn_tensor::ops::TensorOps::select_assign). - /// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign). - SelectAssign( - SelectAssignOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [mask where](burn_tensor::ops::TensorOps::mask_where). - /// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where). - MaskWhere( - MaskWhereOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [mask fill](burn_tensor::ops::TensorOps::mask_fill). - /// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill). - MaskFill( - MaskFillOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [mean dim](burn_tensor::ops::TensorOps::mean_dim). - /// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim). - MeanDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [mean](burn_tensor::ops::TensorOps::mean). - /// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean). - Mean( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [sum](burn_tensor::ops::TensorOps::sum). - /// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum). - Sum( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [sum dim](burn_tensor::ops::TensorOps::sum_dim). - /// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim). - SumDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [equal elem](burn_tensor::ops::TensorOps::equal_elem). - /// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem). - EqualElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [greater](burn_tensor::ops::TensorOps::greater). - /// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater). - Greater( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [greater elem](burn_tensor::ops::TensorOps::greater_elem). - /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). - GreaterElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [greater equal](burn_tensor::ops::TensorOps::greater_elem). - /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). - GreaterEqual( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [greater equal elem](burn_tensor::ops::TensorOps::greater_equal_elem). - /// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem). - GreaterEqualElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [lower](burn_tensor::ops::TensorOps::lower). - /// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower). - Lower( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [lower elem](burn_tensor::ops::TensorOps::lower_elem). - /// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem). - LowerElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [lower equal](burn_tensor::ops::TensorOps::lower_equal). - /// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal). - LowerEqual( - BinaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [lower equal elem](burn_tensor::ops::TensorOps::lower_equal_elem). - /// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem). - LowerEqualElem( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [argmax](burn_tensor::ops::TensorOps::argmax). - /// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax). - ArgMax( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [argmin](burn_tensor::ops::TensorOps::argmin). - /// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin). - ArgMin( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [max](burn_tensor::ops::TensorOps::max). - /// Int => [max](burn_tensor::ops::IntTensorOps::int_max). - Max( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [max dim with indices](burn_tensor::ops::TensorOps::max_dim_with_indices). - /// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices). - MaxDimWithIndices( - ReduceDimWithIndicesDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [min dim with indices](burn_tensor::ops::TensorOps::min_dim_with_indices). - /// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices). - MinDimWithIndices( - ReduceDimWithIndicesDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [min](burn_tensor::ops::TensorOps::min). - /// Int => [min](burn_tensor::ops::IntTensorOps::int_min). - Min( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to: - /// - /// Float => [max dim](burn_tensor::ops::TensorOps::max_dim). - /// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim). - MaxDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [min dim](burn_tensor::ops::TensorOps::min_dim). - /// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim). - MinDim( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [clamp](burn_tensor::ops::TensorOps::clamp). - /// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp). - Clamp( - ClampOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max). - /// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max). - ClampMax( - ScalarOpsDescription, - Box>>, - ), - /// Operation corresponding to: - /// - /// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min). - /// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min). - ClampMin( - ScalarOpsDescription, - Box>>, - ), + /// Operation corresponding to: + /// + /// Float => [add](burn_tensor::ops::TensorOps::add). + /// Int => [add](burn_tensor::ops::IntTensorOps::int_add). + Add( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [add scalar](burn_tensor::ops::TensorOps::add_scalar). + /// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar). + AddScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [sub](burn_tensor::ops::TensorOps::sub). + /// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub). + Sub( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [sub scalar](burn_tensor::ops::TensorOps::sub_scalar). + /// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar). + SubScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [div](burn_tensor::ops::TensorOps::div). + /// Int => [div](burn_tensor::ops::IntTensorOps::int_div). + Div( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [div scalar](burn_tensor::ops::TensorOps::div_scalar). + /// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar). + DivScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [mul](burn_tensor::ops::TensorOps::mul). + /// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul). + Mul( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [mul scalar](burn_tensor::ops::TensorOps::mul_scalar). + /// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar). + MulScalar( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [abs](burn_tensor::ops::TensorOps::abs). + /// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs). + Abs( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [ones](burn_tensor::ops::TensorOps::ones). + /// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones). + Ones(TensorDescription, Box>), + /// Operation corresponding to: + /// + /// Float => [zeros](burn_tensor::ops::TensorOps::zeros). + /// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros). + Zeros(TensorDescription, Box>), + /// Operation corresponding to: + /// + /// Float => [full](burn_tensor::ops::TensorOps::full). + /// Int => [full](burn_tensor::ops::IntTensorOps::int_full). + Full( + (TensorDescription, E), + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [gather](burn_tensor::ops::TensorOps::gather). + /// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather). + Gather( + GatherOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [scatter](burn_tensor::ops::TensorOps::scatter). + /// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter). + Scatter( + ScatterOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [select](burn_tensor::ops::TensorOps::select). + /// Int => [select](burn_tensor::ops::IntTensorOps::int_select). + Select( + SelectOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [select assign](burn_tensor::ops::TensorOps::select_assign). + /// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign). + SelectAssign( + SelectAssignOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [mask where](burn_tensor::ops::TensorOps::mask_where). + /// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where). + MaskWhere( + MaskWhereOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [mask fill](burn_tensor::ops::TensorOps::mask_fill). + /// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill). + MaskFill( + MaskFillOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [mean dim](burn_tensor::ops::TensorOps::mean_dim). + /// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim). + MeanDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [mean](burn_tensor::ops::TensorOps::mean). + /// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean). + Mean( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [sum](burn_tensor::ops::TensorOps::sum). + /// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum). + Sum( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [sum dim](burn_tensor::ops::TensorOps::sum_dim). + /// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim). + SumDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [equal elem](burn_tensor::ops::TensorOps::equal_elem). + /// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem). + EqualElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [greater](burn_tensor::ops::TensorOps::greater). + /// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater). + Greater( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [greater elem](burn_tensor::ops::TensorOps::greater_elem). + /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). + GreaterElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [greater equal](burn_tensor::ops::TensorOps::greater_elem). + /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). + GreaterEqual( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [greater equal elem](burn_tensor::ops::TensorOps::greater_equal_elem). + /// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem). + GreaterEqualElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [lower](burn_tensor::ops::TensorOps::lower). + /// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower). + Lower( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [lower elem](burn_tensor::ops::TensorOps::lower_elem). + /// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem). + LowerElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [lower equal](burn_tensor::ops::TensorOps::lower_equal). + /// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal). + LowerEqual( + BinaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [lower equal elem](burn_tensor::ops::TensorOps::lower_equal_elem). + /// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem). + LowerEqualElem( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [argmax](burn_tensor::ops::TensorOps::argmax). + /// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax). + ArgMax( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [argmin](burn_tensor::ops::TensorOps::argmin). + /// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin). + ArgMin( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [max](burn_tensor::ops::TensorOps::max). + /// Int => [max](burn_tensor::ops::IntTensorOps::int_max). + Max( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [max dim with indices](burn_tensor::ops::TensorOps::max_dim_with_indices). + /// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices). + MaxDimWithIndices( + ReduceDimWithIndicesDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [min dim with indices](burn_tensor::ops::TensorOps::min_dim_with_indices). + /// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices). + MinDimWithIndices( + ReduceDimWithIndicesDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [min](burn_tensor::ops::TensorOps::min). + /// Int => [min](burn_tensor::ops::IntTensorOps::int_min). + Min( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to: + /// + /// Float => [max dim](burn_tensor::ops::TensorOps::max_dim). + /// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim). + MaxDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [min dim](burn_tensor::ops::TensorOps::min_dim). + /// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim). + MinDim( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [clamp](burn_tensor::ops::TensorOps::clamp). + /// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp). + Clamp( + ClampOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max). + /// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max). + ClampMax( + ScalarOpsDescription, + Box>>, + ), + /// Operation corresponding to: + /// + /// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min). + /// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min). + ClampMin( + ScalarOpsDescription, + Box>>, + ), } /// Operation description specific to an int tensor. pub enum IntOpsDescription { - /// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float). - IntoFloat( - UnaryOpsDescription, - Box>, - ), + /// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float). + IntoFloat( + UnaryOpsDescription, + Box>, + ), } /// Operation description specific to a bool tensor. pub enum BoolOpsDescription { - /// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float). - IntoFloat( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int). - IntoInt( - UnaryOpsDescription, - Box>, - ), - /// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not). - Not( - UnaryOpsDescription, - Box>, - ), + /// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float). + IntoFloat( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int). + IntoInt( + UnaryOpsDescription, + Box>, + ), + /// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not). + Not( + UnaryOpsDescription, + Box>, + ), } #[derive(Hash)] /// Swap dim operation description. pub struct SwapDimsDescription { - /// Input tensor description. - pub input: TensorDescription, - /// output tensor description. - pub out: TensorDescription, - /// The first dim to swap. - pub dim1: usize, - /// The second dim to swap. - pub dim2: usize, + /// Input tensor description. + pub input: TensorDescription, + /// output tensor description. + pub out: TensorDescription, + /// The first dim to swap. + pub dim1: usize, + /// The second dim to swap. + pub dim2: usize, } #[derive(Hash)] #[allow(missing_docs)] pub struct ReshapeDescription { - pub input: TensorDescription, - pub out: TensorDescription, - pub shape: Vec, + pub input: TensorDescription, + pub out: TensorDescription, + pub shape: Vec, } #[derive(Hash)] #[allow(missing_docs)] pub struct BinaryOpsDescription { - pub lhs: TensorDescription, - pub rhs: TensorDescription, - pub out: TensorDescription, + pub lhs: TensorDescription, + pub rhs: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct UnaryOpsDescription { - pub input: TensorDescription, - pub out: TensorDescription, + pub input: TensorDescription, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct ScalarOpsDescription { - pub lhs: TensorDescription, - pub rhs: E, - pub out: TensorDescription, + pub lhs: TensorDescription, + pub rhs: E, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct GatherOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ScatterOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub value: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SelectOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SelectAssignOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub indices: TensorDescription, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub indices: TensorDescription, + pub value: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SliceOpsDescription { - pub tensor: TensorDescription, - pub ranges: Vec>, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub ranges: Vec>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct SliceAssignOpsDescription { - pub tensor: TensorDescription, - pub ranges: Vec>, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub ranges: Vec>, + pub value: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaskWhereOpsDescription { - pub tensor: TensorDescription, - pub mask: TensorDescription, - pub value: TensorDescription, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub mask: TensorDescription, + pub value: TensorDescription, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct MaskFillOpsDescription { - pub tensor: TensorDescription, - pub mask: TensorDescription, - pub value: E, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub mask: TensorDescription, + pub value: E, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct ClampOpsDescription { - pub tensor: TensorDescription, - pub min: E, - pub max: E, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub min: E, + pub max: E, + pub out: TensorDescription, } #[allow(missing_docs)] pub struct RepeatOpsDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub times: usize, - pub shape: Vec, - pub out: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub times: usize, + pub shape: Vec, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct CatOpsDescription { - pub tensors: Vec, - pub dim: usize, - pub out: TensorDescription, + pub tensors: Vec, + pub dim: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ReduceDimWithIndicesDescription { - pub tensor: TensorDescription, - pub dim: usize, - pub out: TensorDescription, - pub out_indices: TensorDescription, + pub tensor: TensorDescription, + pub dim: usize, + pub out: TensorDescription, + pub out_indices: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct EmbeddingDescription { - pub weights: TensorDescription, - pub indices: TensorDescription, - pub out: TensorDescription, + pub weights: TensorDescription, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct EmbeddingBackwardDescription { - pub weights: TensorDescription, - pub out_grad: TensorDescription, - pub indices: TensorDescription, - pub out: TensorDescription, + pub weights: TensorDescription, + pub out_grad: TensorDescription, + pub indices: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct Conv1dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvOptions<1>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvOptions<1>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct Conv2dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvOptions<2>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvOptions<2>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ConvTranspose1dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvTransposeOptions<1>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvTransposeOptions<1>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct ConvTranspose2dDescription { - pub x: TensorDescription, - pub weight: TensorDescription, - pub bias: Option, - pub options: ConvTransposeOptions<2>, - pub out: TensorDescription, + pub x: TensorDescription, + pub weight: TensorDescription, + pub bias: Option, + pub options: ConvTransposeOptions<2>, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool1dDescription { - pub x: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool2dDescription { - pub x: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool1dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool2dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub count_include_pad: bool, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub count_include_pad: bool, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dDescription { - pub x: TensorDescription, - pub output_size: usize, - pub out: TensorDescription, + pub x: TensorDescription, + pub output_size: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dDescription { - pub x: TensorDescription, - pub output_size: [usize; 2], - pub out: TensorDescription, + pub x: TensorDescription, + pub output_size: [usize; 2], + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dDescription { - pub x: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub dilation: usize, - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub dilation: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesDescription { - pub x: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub dilation: usize, - pub out: TensorDescription, - pub out_indices: TensorDescription, + pub x: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub dilation: usize, + pub out: TensorDescription, + pub out_indices: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub indices: TensorDescription, - pub kernel_size: usize, - pub stride: usize, - pub padding: usize, - pub dilation: usize, - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub indices: TensorDescription, + pub kernel_size: usize, + pub stride: usize, + pub padding: usize, + pub dilation: usize, + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dDescription { - pub x: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub dilation: [usize; 2], - pub out: TensorDescription, + pub x: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub out: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dWithIndicesDescription { - pub x: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub dilation: [usize; 2], - pub out: TensorDescription, - pub out_indices: TensorDescription, + pub x: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub out: TensorDescription, + pub out_indices: TensorDescription, } #[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dWithIndicesBackwardDescription { - pub x: TensorDescription, - pub grad: TensorDescription, - pub indices: TensorDescription, - pub kernel_size: [usize; 2], - pub stride: [usize; 2], - pub padding: [usize; 2], - pub dilation: [usize; 2], - pub out: TensorDescription, + pub x: TensorDescription, + pub grad: TensorDescription, + pub indices: TensorDescription, + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub out: TensorDescription, } impl TensorOpsDescription { - /// Cleanup the remaining tensor handles that have not been used. - pub(crate) fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - TensorOpsDescription::BaseOpsFloat(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::BaseOpsInt(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::BaseOpsBool(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::NumericOpsFloat(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::NumericOpsInt(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::BoolOps(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::IntOps(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::FloatOps(ops) => ops.cleanup_tensor(handles), - TensorOpsDescription::ModuleOps(ops) => ops.cleanup_tensor(handles), - } + /// Cleanup the remaining tensor handles that have not been used. + pub(crate) fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + TensorOpsDescription::BaseOpsFloat(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::BaseOpsInt(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::BaseOpsBool(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::NumericOpsFloat(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::NumericOpsInt(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::BoolOps(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::IntOps(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::FloatOps(ops) => ops.cleanup_tensor(handles), + TensorOpsDescription::ModuleOps(ops) => ops.cleanup_tensor(handles), + } - // Cleanup tensor handles that were outputted, but ignored. - handles.cleanup_orphans(); - } - /// Execute the operation. - pub(crate) fn execute(&self, handles: &mut HandleContainer) { - match self { - TensorOpsDescription::BaseOpsFloat(ops) => ops.execute(handles), - TensorOpsDescription::BaseOpsInt(ops) => ops.execute(handles), - TensorOpsDescription::BaseOpsBool(ops) => ops.execute(handles), - TensorOpsDescription::NumericOpsFloat(ops) => ops.execute(handles), - TensorOpsDescription::NumericOpsInt(ops) => ops.execute(handles), - TensorOpsDescription::BoolOps(ops) => ops.execute(handles), - TensorOpsDescription::IntOps(ops) => ops.execute(handles), - TensorOpsDescription::FloatOps(ops) => ops.execute(handles), - TensorOpsDescription::ModuleOps(ops) => ops.execute(handles), + // Cleanup tensor handles that were outputted, but ignored. + handles.cleanup_orphans(); + } + /// Execute the operation. + pub(crate) fn execute(&self, handles: &mut HandleContainer) { + match self { + TensorOpsDescription::BaseOpsFloat(ops) => ops.execute(handles), + TensorOpsDescription::BaseOpsInt(ops) => ops.execute(handles), + TensorOpsDescription::BaseOpsBool(ops) => ops.execute(handles), + TensorOpsDescription::NumericOpsFloat(ops) => ops.execute(handles), + TensorOpsDescription::NumericOpsInt(ops) => ops.execute(handles), + TensorOpsDescription::BoolOps(ops) => ops.execute(handles), + TensorOpsDescription::IntOps(ops) => ops.execute(handles), + TensorOpsDescription::FloatOps(ops) => ops.execute(handles), + TensorOpsDescription::ModuleOps(ops) => ops.execute(handles), + } } - } } impl BaseOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - BaseOpsDescription::ToDevice(_, _) => (), - BaseOpsDescription::Reshape(desc, _) => { - handles.cleanup(&desc.input); - } - BaseOpsDescription::SwapDims(desc, _) => { - handles.cleanup(&desc.input); - } - BaseOpsDescription::Slice(desc, _) => { - handles.cleanup(&desc.tensor); - } - BaseOpsDescription::SliceAssign(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.value); - } - BaseOpsDescription::Equal(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - BaseOpsDescription::Repeat(desc, _) => { - handles.cleanup(&desc.tensor); - } - BaseOpsDescription::Cat(desc, _) => { - for t in desc.tensors.iter() { - handles.cleanup(t); + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + BaseOpsDescription::ToDevice(_, _) => (), + BaseOpsDescription::Reshape(desc, _) => { + handles.cleanup(&desc.input); + } + BaseOpsDescription::SwapDims(desc, _) => { + handles.cleanup(&desc.input); + } + BaseOpsDescription::Slice(desc, _) => { + handles.cleanup(&desc.tensor); + } + BaseOpsDescription::SliceAssign(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.value); + } + BaseOpsDescription::Equal(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + BaseOpsDescription::Repeat(desc, _) => { + handles.cleanup(&desc.tensor); + } + BaseOpsDescription::Cat(desc, _) => { + for t in desc.tensors.iter() { + handles.cleanup(t); + } + } } - } } - } - fn execute(&self, handles: &mut HandleContainer) { - match self { - BaseOpsDescription::ToDevice(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Reshape(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::SwapDims(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Slice(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::SliceAssign(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Equal(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Repeat(desc, ops) => ops.execute(desc, handles), - BaseOpsDescription::Cat(desc, ops) => ops.execute(desc, handles), + fn execute(&self, handles: &mut HandleContainer) { + match self { + BaseOpsDescription::ToDevice(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Reshape(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::SwapDims(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Slice(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::SliceAssign(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Equal(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Repeat(desc, ops) => ops.execute(desc, handles), + BaseOpsDescription::Cat(desc, ops) => ops.execute(desc, handles), + } } - } } impl NumericOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - NumericOpsDescription::Add(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::AddScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Sub(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::SubScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Mul(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::MulScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Div(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::DivScalar(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Ones(_, _) => {} - NumericOpsDescription::Gather(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - } - NumericOpsDescription::Scatter(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - handles.cleanup(&desc.value); - } - NumericOpsDescription::Select(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - } - NumericOpsDescription::SelectAssign(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.indices); - handles.cleanup(&desc.value); - } - NumericOpsDescription::MaskWhere(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.value); - handles.cleanup(&desc.mask); - } - NumericOpsDescription::MaskFill(desc, _) => { - handles.cleanup(&desc.tensor); - handles.cleanup(&desc.mask); - } - NumericOpsDescription::EqualElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::GreaterElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::GreaterEqualElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::LowerElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::LowerEqualElem(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Greater(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::GreaterEqual(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::Lower(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::LowerEqual(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - NumericOpsDescription::ArgMax(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::ArgMin(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Clamp(desc, _) => { - handles.cleanup(&desc.tensor); - } - NumericOpsDescription::ClampMin(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::ClampMax(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Abs(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::Zeros(_, _) => {} - NumericOpsDescription::Full(_, _) => {} - NumericOpsDescription::MeanDim(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Mean(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::Sum(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::SumDim(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::Max(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::MaxDimWithIndices(desc, _) => { - handles.cleanup(&desc.tensor); - } - NumericOpsDescription::MinDimWithIndices(desc, _) => { - handles.cleanup(&desc.tensor); - } - NumericOpsDescription::Min(desc, _) => { - handles.cleanup(&desc.input); - } - NumericOpsDescription::MaxDim(desc, _) => { - handles.cleanup(&desc.lhs); - } - NumericOpsDescription::MinDim(desc, _) => { - handles.cleanup(&desc.lhs); - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + NumericOpsDescription::Add(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::AddScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Sub(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::SubScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Mul(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::MulScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Div(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::DivScalar(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Ones(_, _) => {} + NumericOpsDescription::Gather(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + } + NumericOpsDescription::Scatter(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + handles.cleanup(&desc.value); + } + NumericOpsDescription::Select(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + } + NumericOpsDescription::SelectAssign(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.indices); + handles.cleanup(&desc.value); + } + NumericOpsDescription::MaskWhere(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.value); + handles.cleanup(&desc.mask); + } + NumericOpsDescription::MaskFill(desc, _) => { + handles.cleanup(&desc.tensor); + handles.cleanup(&desc.mask); + } + NumericOpsDescription::EqualElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::GreaterElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::GreaterEqualElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::LowerElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::LowerEqualElem(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Greater(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::GreaterEqual(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::Lower(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::LowerEqual(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + NumericOpsDescription::ArgMax(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::ArgMin(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Clamp(desc, _) => { + handles.cleanup(&desc.tensor); + } + NumericOpsDescription::ClampMin(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::ClampMax(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Abs(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::Zeros(_, _) => {} + NumericOpsDescription::Full(_, _) => {} + NumericOpsDescription::MeanDim(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Mean(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::Sum(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::SumDim(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::Max(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::MaxDimWithIndices(desc, _) => { + handles.cleanup(&desc.tensor); + } + NumericOpsDescription::MinDimWithIndices(desc, _) => { + handles.cleanup(&desc.tensor); + } + NumericOpsDescription::Min(desc, _) => { + handles.cleanup(&desc.input); + } + NumericOpsDescription::MaxDim(desc, _) => { + handles.cleanup(&desc.lhs); + } + NumericOpsDescription::MinDim(desc, _) => { + handles.cleanup(&desc.lhs); + } + } } - } - fn execute(&self, handles: &mut HandleContainer) { - match self { - NumericOpsDescription::Add(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::AddScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Sub(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::SubScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Div(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::DivScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Mul(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MulScalar(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Ones(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Gather(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Scatter(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Select(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::SelectAssign(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaskWhere(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaskFill(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::EqualElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Greater(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::GreaterElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::GreaterEqual(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::GreaterEqualElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Lower(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::LowerElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::LowerEqual(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::LowerEqualElem(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ArgMax(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ArgMin(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Clamp(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ClampMin(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::ClampMax(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Abs(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Zeros(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Full(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MeanDim(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Mean(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Sum(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::SumDim(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Max(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaxDimWithIndices(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MinDimWithIndices(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::Min(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MaxDim(desc, ops) => ops.execute(desc, handles), - NumericOpsDescription::MinDim(desc, ops) => ops.execute(desc, handles), + fn execute(&self, handles: &mut HandleContainer) { + match self { + NumericOpsDescription::Add(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::AddScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Sub(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::SubScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Div(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::DivScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Mul(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MulScalar(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Ones(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Gather(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Scatter(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Select(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::SelectAssign(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaskWhere(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaskFill(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::EqualElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Greater(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::GreaterElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::GreaterEqual(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::GreaterEqualElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Lower(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::LowerElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::LowerEqual(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::LowerEqualElem(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ArgMax(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ArgMin(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Clamp(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ClampMin(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::ClampMax(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Abs(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Zeros(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Full(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MeanDim(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Mean(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Sum(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::SumDim(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Max(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaxDimWithIndices(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MinDimWithIndices(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::Min(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MaxDim(desc, ops) => ops.execute(desc, handles), + NumericOpsDescription::MinDim(desc, ops) => ops.execute(desc, handles), + } } - } } impl FloatOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - FloatOpsDescription::Matmul(desc, _) => { - handles.cleanup(&desc.lhs); - handles.cleanup(&desc.rhs); - } - FloatOpsDescription::Random(_, _) => {} - FloatOpsDescription::Exp(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs), - FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Sin(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::Tanh(desc, _) => handles.cleanup(&desc.input), - FloatOpsDescription::IntoInt(desc, _) => handles.cleanup(&desc.input), + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + FloatOpsDescription::Matmul(desc, _) => { + handles.cleanup(&desc.lhs); + handles.cleanup(&desc.rhs); + } + FloatOpsDescription::Random(_, _) => {} + FloatOpsDescription::Exp(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs), + FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Sin(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Tanh(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::IntoInt(desc, _) => handles.cleanup(&desc.input), + } } - } - fn execute(&self, handles: &mut HandleContainer) { - match self { - FloatOpsDescription::Matmul(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Random(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Exp(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Sin(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::Tanh(desc, ops) => ops.execute(desc, handles), - FloatOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), + fn execute(&self, handles: &mut HandleContainer) { + match self { + FloatOpsDescription::Matmul(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Random(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Exp(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Sin(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Tanh(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), + } } - } } impl IntOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - IntOpsDescription::IntoFloat(desc, _) => { - handles.cleanup(&desc.input); - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + IntOpsDescription::IntoFloat(desc, _) => { + handles.cleanup(&desc.input); + } + } } - } - fn execute(&self, handles: &mut HandleContainer) { - match self { - IntOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), + fn execute(&self, handles: &mut HandleContainer) { + match self { + IntOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), + } } - } } impl BoolOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - BoolOpsDescription::IntoFloat(desc, _) => { - handles.cleanup(&desc.input); - } - BoolOpsDescription::IntoInt(desc, _) => { - handles.cleanup(&desc.input); - } - BoolOpsDescription::Not(desc, _) => { - handles.cleanup(&desc.input); - } + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + BoolOpsDescription::IntoFloat(desc, _) => { + handles.cleanup(&desc.input); + } + BoolOpsDescription::IntoInt(desc, _) => { + handles.cleanup(&desc.input); + } + BoolOpsDescription::Not(desc, _) => { + handles.cleanup(&desc.input); + } + } } - } - fn execute(&self, handles: &mut HandleContainer) { - match self { - BoolOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), - BoolOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), - BoolOpsDescription::Not(desc, ops) => ops.execute(desc, handles), + fn execute(&self, handles: &mut HandleContainer) { + match self { + BoolOpsDescription::IntoFloat(desc, ops) => ops.execute(desc, handles), + BoolOpsDescription::IntoInt(desc, ops) => ops.execute(desc, handles), + BoolOpsDescription::Not(desc, ops) => ops.execute(desc, handles), + } } - } } impl ModuleOpsDescription { - fn cleanup_tensor(&self, handles: &mut HandleContainer) { - match self { - ModuleOpsDescription::Embedding(desc, _) => { - handles.cleanup(&desc.weights); - handles.cleanup(&desc.indices); - } - ModuleOpsDescription::EmbeddingBackward(desc, _) => { - handles.cleanup(&desc.weights); - handles.cleanup(&desc.out_grad); - handles.cleanup(&desc.indices); - } - ModuleOpsDescription::Conv1d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + fn cleanup_tensor(&self, handles: &mut HandleContainer) { + match self { + ModuleOpsDescription::Embedding(desc, _) => { + handles.cleanup(&desc.weights); + handles.cleanup(&desc.indices); + } + ModuleOpsDescription::EmbeddingBackward(desc, _) => { + handles.cleanup(&desc.weights); + handles.cleanup(&desc.out_grad); + handles.cleanup(&desc.indices); + } + ModuleOpsDescription::Conv1d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); - } - } - ModuleOpsDescription::Conv2d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + if let Some(bias) = &desc.bias { + handles.cleanup(bias); + } + } + ModuleOpsDescription::Conv2d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); - } - } - ModuleOpsDescription::ConvTranspose1d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + if let Some(bias) = &desc.bias { + handles.cleanup(bias); + } + } + ModuleOpsDescription::ConvTranspose1d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); - } - } - ModuleOpsDescription::ConvTranspose2d(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.weight); + if let Some(bias) = &desc.bias { + handles.cleanup(bias); + } + } + ModuleOpsDescription::ConvTranspose2d(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.weight); - if let Some(bias) = &desc.bias { - handles.cleanup(bias); + if let Some(bias) = &desc.bias { + handles.cleanup(bias); + } + } + ModuleOpsDescription::AvgPool1d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AvgPool2d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AvgPool1dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::AvgPool2dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::AdaptiveAvgPool1d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AdaptiveAvgPool2d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + } + ModuleOpsDescription::MaxPool1d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool1dWithIndices(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + handles.cleanup(&desc.indices); + } + ModuleOpsDescription::MaxPool2d(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool2dWithIndices(desc, _) => { + handles.cleanup(&desc.x); + } + ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, _) => { + handles.cleanup(&desc.x); + handles.cleanup(&desc.grad); + handles.cleanup(&desc.indices); + } } - } - ModuleOpsDescription::AvgPool1d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AvgPool2d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AvgPool1dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::AvgPool2dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::AdaptiveAvgPool1d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AdaptiveAvgPool2d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - } - ModuleOpsDescription::MaxPool1d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool1dWithIndices(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - handles.cleanup(&desc.indices); - } - ModuleOpsDescription::MaxPool2d(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool2dWithIndices(desc, _) => { - handles.cleanup(&desc.x); - } - ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, _) => { - handles.cleanup(&desc.x); - handles.cleanup(&desc.grad); - handles.cleanup(&desc.indices); - } } - } - fn execute(&self, handles: &mut HandleContainer) { - match self { - ModuleOpsDescription::Embedding(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::EmbeddingBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::Conv1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::Conv2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::ConvTranspose1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::ConvTranspose2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool1dBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AvgPool2dBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AdaptiveAvgPool1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AdaptiveAvgPool2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool1d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool1dWithIndices(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool2d(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool2dWithIndices(desc, ops) => ops.execute(desc, handles), - ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, ops) => ops.execute(desc, handles), + fn execute(&self, handles: &mut HandleContainer) { + match self { + ModuleOpsDescription::Embedding(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::EmbeddingBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::Conv1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::Conv2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::ConvTranspose1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::ConvTranspose2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool1dBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AvgPool2dBackward(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AdaptiveAvgPool1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AdaptiveAvgPool2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc, ops) => { + ops.execute(desc, handles) + } + ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc, ops) => { + ops.execute(desc, handles) + } + ModuleOpsDescription::MaxPool1d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool1dWithIndices(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc, ops) => { + ops.execute(desc, handles) + } + ModuleOpsDescription::MaxPool2d(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool2dWithIndices(desc, ops) => ops.execute(desc, handles), + ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc, ops) => { + ops.execute(desc, handles) + } + } } - } } diff --git a/burn-fusion/src/handle.rs b/burn-fusion/src/handle.rs index 8b2d36eb85..10a6dbbef3 100644 --- a/burn-fusion/src/handle.rs +++ b/burn-fusion/src/handle.rs @@ -6,132 +6,132 @@ use std::{collections::HashMap, sync::Arc}; /// are used optimally. #[derive(Default)] pub struct HandleContainer { - handles: HashMap>, - counter: u64, - pub(crate) handles_orphan: Vec, - /// The device on which all tensors are held. - pub device: B::Device, + handles: HashMap>, + counter: u64, + pub(crate) handles_orphan: Vec, + /// The device on which all tensors are held. + pub device: B::Device, } enum Handle { - NotInit, - Existing(B::Handle), + NotInit, + Existing(B::Handle), } impl HandleContainer { - pub(crate) fn new(device_handle: B::FusionDevice) -> Self { - Self { - handles: HashMap::new(), - handles_orphan: Vec::new(), - counter: 0, - device: device_handle.clone().into(), + pub(crate) fn new(device_handle: B::FusionDevice) -> Self { + Self { + handles: HashMap::new(), + handles_orphan: Vec::new(), + counter: 0, + device: device_handle.clone().into(), + } + } + + /// Register a handle for the given [tensor id](TensorId). + pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) { + self.handles.insert(id, Handle::Existing(handle)); } - } - - /// Register a handle for the given [tensor id](TensorId). - pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) { - self.handles.insert(id, Handle::Existing(handle)); - } - - /// Get the handle for the given [tensor id](TensorId). - pub fn get_handle(&mut self, tensor: &TensorDescription) -> B::Handle { - let (id, handle) = self - .handles - .remove_entry(&tensor.id) - .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id)); - - match handle { - Handle::Existing(handle) => match tensor.status { - TensorStatus::ReadOnly => { - self.handles.insert(id, Handle::Existing(handle.clone())); - handle + + /// Get the handle for the given [tensor id](TensorId). + pub fn get_handle(&mut self, tensor: &TensorDescription) -> B::Handle { + let (id, handle) = self + .handles + .remove_entry(&tensor.id) + .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id)); + + match handle { + Handle::Existing(handle) => match tensor.status { + TensorStatus::ReadOnly => { + self.handles.insert(id, Handle::Existing(handle.clone())); + handle + } + TensorStatus::ReadWrite => handle, + TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."), + }, + Handle::NotInit => panic!("Cannot get uninitialized handle."), } - TensorStatus::ReadWrite => handle, - TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."), - }, - Handle::NotInit => panic!("Cannot get uninitialized handle."), } - } - - /// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the - /// given [tensor description](TensorDescription). - pub fn get_float_tensor( - &mut self, - tensor: &TensorDescription, - ) -> B::TensorPrimitive { - B::float_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) - } - - /// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the - /// given [tensor description](TensorDescription). - pub fn get_int_tensor( - &mut self, - tensor: &TensorDescription, - ) -> B::IntTensorPrimitive { - B::int_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) - } - - /// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the - /// given [tensor description](TensorDescription). - pub fn get_bool_tensor( - &mut self, - tensor: &TensorDescription, - ) -> B::BoolTensorPrimitive { - B::bool_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) - } - - /// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_float_tensor( - &mut self, - id: &TensorId, - tensor: B::TensorPrimitive, - ) { - let handle = B::float_tensor_handle(tensor); - self.handles.insert(id.clone(), Handle::Existing(handle)); - } - - /// Register a new [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_int_tensor( - &mut self, - id: &TensorId, - tensor: B::IntTensorPrimitive, - ) { - let handle = B::int_tensor_handle(tensor); - self.handles.insert(id.clone(), Handle::Existing(handle)); - } - - /// Register a new [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_bool_tensor( - &mut self, - id: &TensorId, - tensor: B::BoolTensorPrimitive, - ) { - let handle = B::bool_tensor_handle(tensor); - self.handles.insert(id.clone(), Handle::Existing(handle)); - } - - /// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId). - pub fn create_tensor_uninit(&mut self) -> Arc { - let id = TensorId::new(self.counter); - self.counter += 1; - self.handles.insert(id.clone(), Handle::NotInit); - - Arc::new(id) - } - - pub(crate) fn cleanup(&mut self, tensor: &TensorDescription) { - match tensor.status { - TensorStatus::ReadOnly => (), - TensorStatus::NotInit => (), - TensorStatus::ReadWrite => { - self.handles.remove(&tensor.id); - } + + /// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the + /// given [tensor description](TensorDescription). + pub fn get_float_tensor( + &mut self, + tensor: &TensorDescription, + ) -> B::TensorPrimitive { + B::float_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) + } + + /// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the + /// given [tensor description](TensorDescription). + pub fn get_int_tensor( + &mut self, + tensor: &TensorDescription, + ) -> B::IntTensorPrimitive { + B::int_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) + } + + /// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the + /// given [tensor description](TensorDescription). + pub fn get_bool_tensor( + &mut self, + tensor: &TensorDescription, + ) -> B::BoolTensorPrimitive { + B::bool_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) + } + + /// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_float_tensor( + &mut self, + id: &TensorId, + tensor: B::TensorPrimitive, + ) { + let handle = B::float_tensor_handle(tensor); + self.handles.insert(id.clone(), Handle::Existing(handle)); + } + + /// Register a new [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_int_tensor( + &mut self, + id: &TensorId, + tensor: B::IntTensorPrimitive, + ) { + let handle = B::int_tensor_handle(tensor); + self.handles.insert(id.clone(), Handle::Existing(handle)); + } + + /// Register a new [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_bool_tensor( + &mut self, + id: &TensorId, + tensor: B::BoolTensorPrimitive, + ) { + let handle = B::bool_tensor_handle(tensor); + self.handles.insert(id.clone(), Handle::Existing(handle)); } - } - pub(crate) fn cleanup_orphans(&mut self) { - for id in self.handles_orphan.drain(..) { - self.handles.remove(&id); + /// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId). + pub fn create_tensor_uninit(&mut self) -> Arc { + let id = TensorId::new(self.counter); + self.counter += 1; + self.handles.insert(id.clone(), Handle::NotInit); + + Arc::new(id) + } + + pub(crate) fn cleanup(&mut self, tensor: &TensorDescription) { + match tensor.status { + TensorStatus::ReadOnly => (), + TensorStatus::NotInit => (), + TensorStatus::ReadWrite => { + self.handles.remove(&tensor.id); + } + } + } + + pub(crate) fn cleanup_orphans(&mut self) { + for id in self.handles_orphan.drain(..) { + self.handles.remove(&id); + } } - } } diff --git a/burn-fusion/src/ops/binary.rs b/burn-fusion/src/ops/binary.rs index c3148eb3c4..05d859252a 100644 --- a/burn-fusion/src/ops/binary.rs +++ b/burn-fusion/src/ops/binary.rs @@ -1,101 +1,101 @@ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let rhs = handles.get_float_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let rhs = handles.get_float_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_float_tensor(&args.out.id, output); - } - } - }; + handles.register_float_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let rhs = handles.get_float_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let rhs = handles.get_float_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let rhs = handles.get_int_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let rhs = handles.get_int_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } pub(crate) fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec { - let mut shape_out = Vec::with_capacity(lhs.len()); + let mut shape_out = Vec::with_capacity(lhs.len()); - for (l, r) in lhs.iter().zip(rhs.iter()) { - shape_out.push(usize::max(*l, *r)); - } + for (l, r) in lhs.iter().zip(rhs.iter()) { + shape_out.push(usize::max(*l, *r)); + } - shape_out + shape_out } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = BinaryOpsDescription; + impl Ops for $name { + type Args = BinaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let rhs = handles.get_int_tensor(&args.rhs); - let output = $ops(lhs, rhs); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let rhs = handles.get_int_tensor(&args.rhs); + let output = $ops(lhs, rhs); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } diff --git a/burn-fusion/src/ops/boolean.rs b/burn-fusion/src/ops/boolean.rs index 43d1f63e72..179db25d4d 100644 --- a/burn-fusion/src/ops/boolean.rs +++ b/burn-fusion/src/ops/boolean.rs @@ -1,399 +1,402 @@ use crate::{ - client::FusionClient, - get_client, - graph::{ - BaseOpsDescription, BinaryOpsDescription, BoolOpsDescription, CatOpsDescription, Ops, - ReshapeDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, - TensorOpsDescription, UnaryOpsDescription, - }, - ops::binary::binary_ops_shape, - Fusion, FusionBackend, + client::FusionClient, + get_client, + graph::{ + BaseOpsDescription, BinaryOpsDescription, BoolOpsDescription, CatOpsDescription, Ops, + ReshapeDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, + TensorOpsDescription, UnaryOpsDescription, + }, + ops::binary::binary_ops_shape, + Fusion, FusionBackend, }; use burn_tensor::{ - ops::{BoolTensor, BoolTensorOps}, - Device, Shape, + ops::{BoolTensor, BoolTensorOps}, + Device, Shape, }; impl BoolTensorOps for Fusion { - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::bool_empty(shape.clone(), device); - - client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) - } - - fn bool_shape(tensor: &BoolTensor) -> Shape { - tensor.shape() - } - - fn bool_into_data( - tensor: BoolTensor, - ) -> burn_tensor::Reader> { - tensor.bool_into_data() - } - - fn bool_from_data( - data: burn_tensor::Data, - device: &Device, - ) -> BoolTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::bool_from_data(data, device); - let shape = B::bool_shape(&tensor); - - client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) - } - - fn bool_into_int( - tensor: BoolTensor, - ) -> burn_tensor::ops::IntTensor { - struct IntoIntOps; - - impl Ops for IntoIntOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_into_int(input); - handles.register_int_tensor(&args.out.id, output); - } - } + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::bool_empty(shape.clone(), device); - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoIntOps::), - ))); - - out - } - - fn bool_into_float( - tensor: BoolTensor, - ) -> burn_tensor::ops::FloatTensor { - struct IntoFloatOps; - - impl Ops for IntoFloatOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_into_float(input); - handles.register_float_tensor(&args.out.id, output); - } + client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) } - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::BoolOps( - BoolOpsDescription::IntoFloat( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoFloatOps::), - ), - )); - - out - } - - fn bool_device(tensor: &BoolTensor) -> Device { - tensor.client.device().clone().into() - } - - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); - - if device_original == &device_target { - return tensor; + fn bool_shape(tensor: &BoolTensor) -> Shape { + tensor.shape() } - let client_target = get_client::(&device_target); - let client_original = tensor.client.clone(); + fn bool_into_data( + tensor: BoolTensor, + ) -> burn_tensor::Reader> { + tensor.bool_into_data() + } - client_original - .clone() - .change_client_bool::(tensor.into_description(), client_target) - } + fn bool_from_data( + data: burn_tensor::Data, + device: &Device, + ) -> BoolTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::bool_from_data(data, device); + let shape = B::bool_shape(&tensor); - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - struct ReshapeDimsOps; + client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) + } - impl Ops for ReshapeDimsOps { - type Args = ReshapeDescription; + fn bool_into_int( + tensor: BoolTensor, + ) -> burn_tensor::ops::IntTensor { + struct IntoIntOps; + + impl Ops for IntoIntOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_into_int(input); + handles.register_int_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client + .register(TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoIntOps::), + ))); + + out + } - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_reshape::(input, Shape::from(&args.shape)); - handles.register_bool_tensor(&args.out.id, output); - } + fn bool_into_float( + tensor: BoolTensor, + ) -> burn_tensor::ops::FloatTensor { + struct IntoFloatOps; + + impl Ops for IntoFloatOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_into_float(input); + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::BoolOps( + BoolOpsDescription::IntoFloat( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoFloatOps::), + ), + )); + + out } - let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape.clone()); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::Reshape( - ReshapeDescription { - input: tensor.into_description(), - shape, - out: out.to_description_out(), - }, - Box::new(ReshapeDimsOps::), - ), - )); - - out - } - - fn bool_slice( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - ) -> BoolTensor { - struct SliceOps; - - impl Ops for SliceOps { - type Args = SliceOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_bool_tensor::(&args.tensor); - - let output = B::bool_slice::(tensor, args.ranges.clone().try_into().unwrap()); - - handles.register_bool_tensor(&args.out.id, output); - } + fn bool_device(tensor: &BoolTensor) -> Device { + tensor.client.device().clone().into() } - let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + let device_original: &B::FusionDevice = tensor.client.device(); + let device_target: B::FusionDevice = device.clone().into(); - for i in shape.len()..D1 { - shape.push(tensor.shape[i]); - } + if device_original == &device_target { + return tensor; + } + + let client_target = get_client::(&device_target); + let client_original = tensor.client.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::Slice( - SliceOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - out: out.to_description_out(), - }, - Box::new(SliceOps::), - ), - )); - - out - } - - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [std::ops::Range; D2], - value: BoolTensor, - ) -> BoolTensor { - struct SliceAssignOps; - - impl Ops for SliceAssignOps { - type Args = SliceAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_bool_tensor::(&args.tensor); - let value = handles.get_bool_tensor::(&args.value); - - let output = - B::bool_slice_assign::(tensor, args.ranges.clone().try_into().unwrap(), value); - - handles.register_bool_tensor(&args.out.id, output); - } + client_original + .clone() + .change_client_bool::(tensor.into_description(), client_target) } - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::SliceAssign( - SliceAssignOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SliceAssignOps::), - ), - )); - - out - } - - fn bool_cat( - tensors: Vec>, - dim: usize, - ) -> BoolTensor { - struct CatOps; - - impl Ops for CatOps { - type Args = CatOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensors = args - .tensors - .iter() - .map(|tensor| handles.get_bool_tensor(tensor)) - .collect(); - - let output = B::bool_cat::(tensors, args.dim); - - handles.register_bool_tensor(&args.out.id, output); - } + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + struct ReshapeDimsOps; + + impl Ops for ReshapeDimsOps { + type Args = ReshapeDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_reshape::(input, Shape::from(&args.shape)); + handles.register_bool_tensor(&args.out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let out = tensor.client.tensor_uninitialized(shape.clone()); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::Reshape( + ReshapeDescription { + input: tensor.into_description(), + shape, + out: out.to_description_out(), + }, + Box::new(ReshapeDimsOps::), + ), + )); + + out } - let tensor_first = tensors.get(0).unwrap(); - let client = tensor_first.client.clone(); + fn bool_slice( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + ) -> BoolTensor { + struct SliceOps; + + impl Ops for SliceOps { + type Args = SliceOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_bool_tensor::(&args.tensor); + + let output = + B::bool_slice::(tensor, args.ranges.clone().try_into().unwrap()); + + handles.register_bool_tensor(&args.out.id, output); + } + } + + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); - // Calculate the output shape - let mut shape: Vec = tensor_first.shape.clone(); - shape[dim] = 0; - for tensor in tensors.iter() { - shape[dim] += tensor.shape[dim]; + for i in shape.len()..D1 { + shape.push(tensor.shape[i]); + } + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::Slice( + SliceOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + out: out.to_description_out(), + }, + Box::new(SliceOps::), + ), + )); + + out } - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat( - CatOpsDescription { - tensors: tensors.into_iter().map(|t| t.into_description()).collect(), - dim, - out: out.to_description_out(), - }, - Box::new(CatOps::), - ))); - - out - } - - fn bool_equal( - lhs: BoolTensor, - rhs: BoolTensor, - ) -> BoolTensor { - struct EqualOps; - - impl Ops for EqualOps { - type Args = BinaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let lhs = handles.get_bool_tensor::(&args.lhs); - let rhs = handles.get_bool_tensor(&args.rhs); - let output = B::bool_equal(lhs, rhs); - handles.register_bool_tensor(&args.out.id, output); - } + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [std::ops::Range; D2], + value: BoolTensor, + ) -> BoolTensor { + struct SliceAssignOps; + + impl Ops for SliceAssignOps { + type Args = SliceAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_bool_tensor::(&args.tensor); + let value = handles.get_bool_tensor::(&args.value); + + let output = B::bool_slice_assign::( + tensor, + args.ranges.clone().try_into().unwrap(), + value, + ); + + handles.register_bool_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::SliceAssign( + SliceAssignOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SliceAssignOps::), + ), + )); + + out } - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::Equal( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(EqualOps::), - ), - )); - - out - } - - fn bool_not(tensor: BoolTensor) -> BoolTensor { - struct NotOps; - - impl Ops for NotOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_not(input); - handles.register_bool_tensor(&args.out.id, output); - } + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> BoolTensor { + struct CatOps; + + impl Ops for CatOps { + type Args = CatOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensors = args + .tensors + .iter() + .map(|tensor| handles.get_bool_tensor(tensor)) + .collect(); + + let output = B::bool_cat::(tensors, args.dim); + + handles.register_bool_tensor(&args.out.id, output); + } + } + + let tensor_first = tensors.get(0).unwrap(); + let client = tensor_first.client.clone(); + + // Calculate the output shape + let mut shape: Vec = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; + } + + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat( + CatOpsDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }, + Box::new(CatOps::), + ))); + + out } - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::BoolOps( - crate::graph::BoolOpsDescription::Not( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(NotOps::), - ), - )); - - out - } - - fn bool_swap_dims( - tensor: BoolTensor, - dim1: usize, - dim2: usize, - ) -> BoolTensor { - struct SwapDimsOps; - - impl Ops for SwapDimsOps { - type Args = SwapDimsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_bool_tensor::(&args.input); - let output = B::bool_swap_dims(input, args.dim1, args.dim2); - handles.register_bool_tensor(&args.out.id, output); - } + fn bool_equal( + lhs: BoolTensor, + rhs: BoolTensor, + ) -> BoolTensor { + struct EqualOps; + + impl Ops for EqualOps { + type Args = BinaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let lhs = handles.get_bool_tensor::(&args.lhs); + let rhs = handles.get_bool_tensor(&args.rhs); + let output = B::bool_equal(lhs, rhs); + handles.register_bool_tensor(&args.out.id, output); + } + } + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::Equal( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(EqualOps::), + ), + )); + + out } - let mut shape = tensor.shape.clone(); - shape[dim1] = tensor.shape[dim2]; - shape[dim2] = tensor.shape[dim1]; - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsBool( - BaseOpsDescription::SwapDims( - SwapDimsDescription { - input: tensor.into_description(), - dim1, - dim2, - out: out.to_description_out(), - }, - Box::new(SwapDimsOps::), - ), - )); - - out - } + fn bool_not(tensor: BoolTensor) -> BoolTensor { + struct NotOps; + + impl Ops for NotOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_not(input); + handles.register_bool_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::BoolOps( + crate::graph::BoolOpsDescription::Not( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(NotOps::), + ), + )); + + out + } + + fn bool_swap_dims( + tensor: BoolTensor, + dim1: usize, + dim2: usize, + ) -> BoolTensor { + struct SwapDimsOps; + + impl Ops for SwapDimsOps { + type Args = SwapDimsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&args.input); + let output = B::bool_swap_dims(input, args.dim1, args.dim2); + handles.register_bool_tensor(&args.out.id, output); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsBool( + BaseOpsDescription::SwapDims( + SwapDimsDescription { + input: tensor.into_description(), + dim1, + dim2, + out: out.to_description_out(), + }, + Box::new(SwapDimsOps::), + ), + )); + + out + } } diff --git a/burn-fusion/src/ops/float.rs b/burn-fusion/src/ops/float.rs index 156fcb259d..fc60d3ef4b 100644 --- a/burn-fusion/src/ops/float.rs +++ b/burn-fusion/src/ops/float.rs @@ -1,1672 +1,1669 @@ use crate::{ - binary_float_cmp_ops, binary_float_ops, - client::FusionClient, - get_client, - graph::{ - BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, - FloatOpsDescription, GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, - NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription, - ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription, SelectOpsDescription, - SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, TensorOpsDescription, - UnaryOpsDescription, - }, - ops::binary::binary_ops_shape, - scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, unary_float_ops, Fusion, - FusionBackend, TensorDescription, + binary_float_cmp_ops, binary_float_ops, + client::FusionClient, + get_client, + graph::{ + BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, + FloatOpsDescription, GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, + NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription, + ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription, + SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, + TensorOpsDescription, UnaryOpsDescription, + }, + ops::binary::binary_ops_shape, + scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, unary_float_ops, Fusion, + FusionBackend, TensorDescription, }; use burn_tensor::{ - ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, - Data, Device, Distribution, Reader, Shape, + ops::{BoolTensor, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, TensorOps}, + Data, Device, Distribution, Reader, Shape, }; use std::ops::Range; impl TensorOps for Fusion { - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::from_data(data, device); - let shape = B::shape(&tensor); + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::from_data(data, device); + let shape = B::shape(&tensor); + + client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) + } + + fn random( + shape: Shape, + distribution: Distribution>, + device: &Device, + ) -> FloatTensor { + struct RandomOps; + + impl Ops for RandomOps { + type Args = (TensorDescription, Distribution>); + + fn execute( + &self, + (out, distribution): &Self::Args, + handles: &mut crate::HandleContainer, + ) { + let shape = Shape::from(out.shape.clone()); + let output: B::TensorPrimitive = + B::random(shape, *distribution, &handles.device); + handles.register_float_tensor(&out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Random( + (out.to_description_out(), distribution), + Box::new(RandomOps::), + ))); + + out + } + + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + struct ZerosOps; + + impl Ops for ZerosOps { + type Args = TensorDescription; + + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::zeros::(shape, &handles.device); + handles.register_float_tensor(&out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), + )); + + out + } + + fn ones(shape: Shape, device: &Device) -> FloatTensor { + struct OnesOps; + + impl Ops for OnesOps { + type Args = TensorDescription; + + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::ones::(shape, &handles.device); + handles.register_float_tensor(&out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), + )); + + out + } + + fn full( + shape: Shape, + fill_value: FloatElem, + device: &Device, + ) -> FloatTensor { + struct FullOps; + + impl Ops for FullOps { + type Args = (TensorDescription, FloatElem); + + fn execute(&self, (out, value): &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output: B::TensorPrimitive = B::full(shape, *value, &handles.device); + handles.register_float_tensor(&out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Full( + (out.to_description_out(), fill_value), + Box::new(FullOps::), + ), + )); + + out + } + + fn shape(tensor: &FloatTensor) -> Shape { + tensor.shape() + } + + fn into_data(tensor: FloatTensor) -> Reader, D>> { + tensor.into_data() + } + + fn device(tensor: &FloatTensor) -> Device { + tensor.client.device().clone().into() + } + + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor { + let device_original: &B::FusionDevice = tensor.client.device(); + let device_target: B::FusionDevice = device.clone().into(); + + if device_original == &device_target { + return tensor; + } + + let client_target = get_client::(&device_target); + let client_original = tensor.client.clone(); + + client_original + .clone() + .change_client_float::(tensor.into_description(), client_target) + } + + fn into_int(tensor: FloatTensor) -> IntTensor { + struct IntoIntOps; + + impl Ops for IntoIntOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = B::into_int(input); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::FloatOps( + FloatOpsDescription::IntoInt( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoIntOps::), + ), + )); + + out + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::empty(shape.clone(), device); + + client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(AddOps, B::add); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Add( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(AddOps, B::add_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::AddScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out + } + + fn clamp_min( + tensor: FloatTensor, + min: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(ClampMinOps, B::clamp_min); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ClampMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: min, + out: out.to_description_out(), + }, + Box::new(ClampMinOps::), + ), + )); + + out + } + + fn clamp_max( + tensor: FloatTensor, + max: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(ClampMaxOps, B::clamp_max); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ClampMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: max, + out: out.to_description_out(), + }, + Box::new(ClampMaxOps::), + ), + )); + + out + } + + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + struct ClampOps; + + impl Ops for ClampOps { + type Args = ClampOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.tensor); + let output = B::clamp(input, args.min, args.max); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Clamp( + ClampOpsDescription { + tensor: tensor.into_description(), + min, + max, + out: out.to_description_out(), + }, + Box::new(ClampOps::), + ), + )); + + out + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(SubOps, B::sub); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Sub( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(SubOps, B::sub_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::SubScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(MulOps, B::mul); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Mul( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(MulOps, B::mul_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MulScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(DivOps, B::div); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Div( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + scalar_float_ops!(DivOps, B::div_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::DivScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + binary_float_ops!(MatmulOps, B::matmul); + + let mut shape = binary_ops_shape(&lhs.shape, &rhs.shape); + + shape[D - 2] = lhs.shape[D - 2]; + shape[D - 1] = rhs.shape[D - 1]; + + let out = lhs.client.tensor_uninitialized(shape); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Matmul( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(MatmulOps::), + ))); + + out + } + + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + struct SwapDimsOps; + + impl Ops for SwapDimsOps { + type Args = SwapDimsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = B::swap_dims(input, args.dim1, args.dim2); + handles.register_float_tensor(&args.out.id, output); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::SwapDims( + SwapDimsDescription { + input: tensor.into_description(), + dim1, + dim2, + out: out.to_description_out(), + }, + Box::new(SwapDimsOps::), + ), + )); + + out + } + + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + struct ReshapeDimsOps; + + impl Ops for ReshapeDimsOps { + type Args = ReshapeDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = B::reshape::(input, Shape::from(&args.shape)); + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let out = tensor.client.tensor_uninitialized(shape.clone()); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::Reshape( + ReshapeDescription { + input: tensor.into_description(), + shape, + out: out.to_description_out(), + }, + Box::new(ReshapeDimsOps::), + ), + )); + + out + } + + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + struct GatherOps; + + impl Ops for GatherOps { + type Args = GatherOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::gather(args.dim, tensor, indices); + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = indices.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Gather( + GatherOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(GatherOps::), + ), + )); + + out + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + struct ScatterOps; + + impl Ops for ScatterOps { + type Args = ScatterOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_float_tensor(&args.value); + + let output = B::scatter(args.dim, tensor, indices, value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Scatter( + ScatterOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(ScatterOps::), + ), + )); + + out + } + + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + struct SelectOps; + + impl Ops for SelectOps { + type Args = SelectOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::select(tensor, args.dim, indices); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let mut shape: Vec = tensor.shape.clone(); + shape[dim] = indices.shape[0]; + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Select( + SelectOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectOps::), + ), + )); + + out + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + struct SelectAssignOps; + + impl Ops for SelectAssignOps { + type Args = SelectAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_float_tensor(&args.value); + + let output = B::select_assign(tensor, args.dim, indices, value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::SelectAssign( + SelectAssignOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectAssignOps::), + ), + )); + + out + } + + fn slice( + tensor: FloatTensor, + ranges: [Range; D2], + ) -> FloatTensor { + struct SliceOps; + + impl Ops for SliceOps { + type Args = SliceOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + + let output = B::slice::(tensor, args.ranges.clone().try_into().unwrap()); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + + for i in shape.len()..D1 { + shape.push(tensor.shape[i]); + } + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::Slice( + SliceOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + out: out.to_description_out(), + }, + Box::new(SliceOps::), + ), + )); + + out + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [Range; D2], + value: FloatTensor, + ) -> FloatTensor { + struct SliceAssignOps; + + impl Ops for SliceAssignOps { + type Args = SliceAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let value = handles.get_float_tensor::(&args.value); + + let output = B::slice_assign::( + tensor, + args.ranges.clone().try_into().unwrap(), + value, + ); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::SliceAssign( + SliceAssignOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SliceAssignOps::), + ), + )); + + out + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor { + struct MaskWhereOps; + + impl Ops for MaskWhereOps { + type Args = MaskWhereOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let value = handles.get_float_tensor(&args.value); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::mask_where(tensor, mask, value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaskWhere( + MaskWhereOpsDescription { + tensor: tensor.into_description(), + value: value.into_description(), + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskWhereOps::), + ), + )); + + out + } + + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + struct MaskFillOps; + + impl Ops for MaskFillOps { + type Args = MaskFillOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::mask_fill(tensor, mask, args.value); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaskFill( + MaskFillOpsDescription { + tensor: tensor.into_description(), + value, + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskFillOps::), + ), + )); + + out + } + + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(EqualOps, B::equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::BaseOpsFloat( + BaseOpsDescription::Equal( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(EqualOps::), + ), + )); + + out + } + + fn equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(EqualElemOps, B::equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::EqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(EqualElemOps::), + ), + )); + + out + } + + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(GreaterOps, B::greater); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Greater( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterOps::), + ), + )); + + out + } + + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::GreaterElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterElemOps::), + ), + )); + + out + } + + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(GreaterEqualOps, B::greater_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::GreaterEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterEqualOps::), + ), + )); + + out + } + + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::GreaterEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterEqualElemOps::), + ), + )); + + out + } + + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(LowerOps, B::lower); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Lower( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerOps::), + ), + )); + + out + } + + fn lower_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(LowerElemOps, B::lower_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::LowerElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerElemOps::), + ), + )); + + out + } + + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + binary_float_cmp_ops!(LowerEqualOps, B::lower_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::LowerEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerEqualOps::), + ), + )); + + out + } + + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::LowerEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerEqualElemOps::), + ), + )); + + out + } + + fn sum(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(SumOps, B::sum); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Sum( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SumOps::), + ), + )); + + out + } + + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(SumDimOps, B::sum_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::SumDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(SumDimOps::), + ), + )); + + out + } + + fn mean(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(MeanOps, B::mean); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Mean( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MeanOps::), + ), + )); + + out + } + + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(MeanDimOps, B::mean_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MeanDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MeanDimOps::), + ), + )); + + out + } + + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + tensor.clone() + } + + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + tensor + } + + fn exp(lhs: FloatTensor) -> FloatTensor { + unary_float_ops!(ExpOps, B::exp); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Exp( + UnaryOpsDescription { + input: lhs.into_description(), + out: out.to_description_out(), + }, + Box::new(ExpOps::), + ))); + + out + } + + fn log(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(LogOps, B::log); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) - } + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(LogOps::), + ))); - fn random( - shape: Shape, - distribution: Distribution>, - device: &Device, - ) -> FloatTensor { - struct RandomOps; + out + } + + fn log1p(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(Log1pOps, B::log1p); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log1p( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(Log1pOps::), + ))); + + out + } + + fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { + scalar_float_ops!(PowfOps, B::powf, f32); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Powf( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(PowfOps::), + ))); + + out + } + + fn sqrt(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(SqrtOps, B::sqrt); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sqrt( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SqrtOps::), + ))); + + out + } + + fn abs(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(AbsOps, B::abs); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Abs( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(AbsOps::), + ), + )); + + out + } + + fn cos(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(CosOps, B::cos); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Cos( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(CosOps::), + ))); - impl Ops for RandomOps { - type Args = (TensorDescription, Distribution>); + out + } + + fn sin(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(SinOps, B::sin); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sin( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SinOps::), + ))); + + out + } - fn execute(&self, (out, distribution): &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output: B::TensorPrimitive = B::random(shape, *distribution, &handles.device); - handles.register_float_tensor(&out.id, output); - } + fn tanh(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(TanhOps, B::tanh); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Tanh( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(TanhOps::), + ))); + + out } - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); + fn recip(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(Recip, B::recip); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(Recip::), + ))); + out + } + + fn erf(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(TanhOps, B::erf); - client.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Random( - (out.to_description_out(), distribution), - Box::new(RandomOps::), - ))); - - out - } - - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - struct ZerosOps; + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - impl Ops for ZerosOps { - type Args = TensorDescription; + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Erf( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(TanhOps::), + ))); - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::zeros::(shape, &handles.device); - handles.register_float_tensor(&out.id, output); - } + out } - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + struct CatOps; + + impl Ops for CatOps { + type Args = CatOpsDescription; - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), - )); + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensors = args + .tensors + .iter() + .map(|tensor| handles.get_float_tensor(tensor)) + .collect(); - out - } + let output = B::cat::(tensors, args.dim); - fn ones(shape: Shape, device: &Device) -> FloatTensor { - struct OnesOps; + handles.register_float_tensor(&args.out.id, output); + } + } - impl Ops for OnesOps { - type Args = TensorDescription; + let tensor_first = tensors.get(0).unwrap(); + let client = tensor_first.client.clone(); - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::ones::(shape, &handles.device); - handles.register_float_tensor(&out.id, output); - } + // Calculate the output shape + let mut shape: Vec = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; + } + + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::BaseOpsFloat(BaseOpsDescription::Cat( + CatOpsDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }, + Box::new(CatOps::), + ))); + + out + } + + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + scalar_float2int_ops!(ArgMaxOps, B::argmax, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ArgMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMaxOps::), + ), + )); + + out + } + + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + scalar_float2int_ops!(ArgMinOps, B::argmin, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::ArgMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMinOps::), + ), + )); + + out } - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); + fn max(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(MaxOps, B::max); - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), - )); + let out = tensor.client.tensor_uninitialized(vec![1]); - out - } + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Max( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MaxOps::), + ), + )); - fn full( - shape: Shape, - fill_value: FloatElem, - device: &Device, - ) -> FloatTensor { - struct FullOps; + out + } - impl Ops for FullOps { - type Args = (TensorDescription, FloatElem); + fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(MaxDimOps, B::max_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaxDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MaxDimOps::), + ), + )); + + out + } - fn execute(&self, (out, value): &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output: B::TensorPrimitive = B::full(shape, *value, &handles.device); - handles.register_float_tensor(&out.id, output); - } + fn max_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + struct MaxDimWithIndicesOps; + + impl Ops for MaxDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let (output, indices) = B::max_dim_with_indices(tensor, args.dim); + + handles.register_float_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MaxDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxDimWithIndicesOps::), + ), + )); + + (out, out_indices) } - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); + fn min(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(MinOps, B::min); - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Full( - (out.to_description_out(), fill_value), - Box::new(FullOps::), - ), - )); + let out = tensor.client.tensor_uninitialized(vec![1]); - out - } + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::Min( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MinOps::), + ), + )); - fn shape(tensor: &FloatTensor) -> Shape { - tensor.shape() - } + out + } - fn into_data(tensor: FloatTensor) -> Reader, D>> { - tensor.into_data() - } + fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + scalar_float_ops!(MinDimOps, B::min_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MinDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MinDimOps::), + ), + )); + + out + } - fn device(tensor: &FloatTensor) -> Device { - tensor.client.device().clone().into() - } - - fn to_device( - tensor: FloatTensor, - device: &Device, - ) -> FloatTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); - - if device_original == &device_target { - return tensor; - } - - let client_target = get_client::(&device_target); - let client_original = tensor.client.clone(); - - client_original - .clone() - .change_client_float::(tensor.into_description(), client_target) - } - - fn into_int(tensor: FloatTensor) -> IntTensor { - struct IntoIntOps; - - impl Ops for IntoIntOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = B::into_int(input); - - handles.register_int_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::FloatOps( - FloatOpsDescription::IntoInt( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoIntOps::), - ), - )); - - out - } - - fn empty(shape: Shape, device: &Device) -> FloatTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::empty(shape.clone(), device); - - client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) - } - - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(AddOps, B::add); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Add( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(AddOps, B::add_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::AddScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out - } - - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(ClampMinOps, B::clamp_min); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ClampMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: min, - out: out.to_description_out(), - }, - Box::new(ClampMinOps::), - ), - )); - - out - } - - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(ClampMaxOps, B::clamp_max); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ClampMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: max, - out: out.to_description_out(), - }, - Box::new(ClampMaxOps::), - ), - )); - - out - } - - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - struct ClampOps; - - impl Ops for ClampOps { - type Args = ClampOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.tensor); - let output = B::clamp(input, args.min, args.max); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Clamp( - ClampOpsDescription { - tensor: tensor.into_description(), - min, - max, - out: out.to_description_out(), - }, - Box::new(ClampOps::), - ), - )); - - out - } - - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(SubOps, B::sub); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Sub( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(SubOps, B::sub_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::SubScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(MulOps, B::mul); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Mul( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(MulOps, B::mul_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MulScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(DivOps, B::div); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Div( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(DivOps, B::div_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::DivScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - binary_float_ops!(MatmulOps, B::matmul); - - let mut shape = binary_ops_shape(&lhs.shape, &rhs.shape); - - shape[D - 2] = lhs.shape[D - 2]; - shape[D - 1] = rhs.shape[D - 1]; - - let out = lhs.client.tensor_uninitialized(shape); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Matmul( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(MatmulOps::), - ))); - - out - } - - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - struct SwapDimsOps; - - impl Ops for SwapDimsOps { - type Args = SwapDimsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = B::swap_dims(input, args.dim1, args.dim2); - handles.register_float_tensor(&args.out.id, output); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim1] = tensor.shape[dim2]; - shape[dim2] = tensor.shape[dim1]; - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::SwapDims( - SwapDimsDescription { - input: tensor.into_description(), - dim1, - dim2, - out: out.to_description_out(), - }, - Box::new(SwapDimsOps::), - ), - )); - - out - } - - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - struct ReshapeDimsOps; - - impl Ops for ReshapeDimsOps { - type Args = ReshapeDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = B::reshape::(input, Shape::from(&args.shape)); - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape.clone()); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::Reshape( - ReshapeDescription { - input: tensor.into_description(), - shape, - out: out.to_description_out(), - }, - Box::new(ReshapeDimsOps::), - ), - )); - - out - } - - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - struct GatherOps; - - impl Ops for GatherOps { - type Args = GatherOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::gather(args.dim, tensor, indices); - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Gather( - GatherOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(GatherOps::), - ), - )); - - out - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - struct ScatterOps; - - impl Ops for ScatterOps { - type Args = ScatterOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_float_tensor(&args.value); - - let output = B::scatter(args.dim, tensor, indices, value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Scatter( - ScatterOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(ScatterOps::), - ), - )); - - out - } - - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - struct SelectOps; - - impl Ops for SelectOps { - type Args = SelectOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::select(tensor, args.dim, indices); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let mut shape: Vec = tensor.shape.clone(); - shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Select( - SelectOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectOps::), - ), - )); - - out - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - struct SelectAssignOps; - - impl Ops for SelectAssignOps { - type Args = SelectAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_float_tensor(&args.value); - - let output = B::select_assign(tensor, args.dim, indices, value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::SelectAssign( - SelectAssignOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectAssignOps::), - ), - )); - - out - } - - fn slice( - tensor: FloatTensor, - ranges: [Range; D2], - ) -> FloatTensor { - struct SliceOps; - - impl Ops for SliceOps { - type Args = SliceOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - - let output = B::slice::(tensor, args.ranges.clone().try_into().unwrap()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); - - for i in shape.len()..D1 { - shape.push(tensor.shape[i]); - } - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::Slice( - SliceOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - out: out.to_description_out(), - }, - Box::new(SliceOps::), - ), - )); - - out - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [Range; D2], - value: FloatTensor, - ) -> FloatTensor { - struct SliceAssignOps; - - impl Ops for SliceAssignOps { - type Args = SliceAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let value = handles.get_float_tensor::(&args.value); - - let output = - B::slice_assign::(tensor, args.ranges.clone().try_into().unwrap(), value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::SliceAssign( - SliceAssignOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SliceAssignOps::), - ), - )); - - out - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor { - struct MaskWhereOps; - - impl Ops for MaskWhereOps { - type Args = MaskWhereOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let value = handles.get_float_tensor(&args.value); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::mask_where(tensor, mask, value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaskWhere( - MaskWhereOpsDescription { - tensor: tensor.into_description(), - value: value.into_description(), - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskWhereOps::), - ), - )); - - out - } - - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - struct MaskFillOps; - - impl Ops for MaskFillOps { - type Args = MaskFillOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::mask_fill(tensor, mask, args.value); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaskFill( - MaskFillOpsDescription { - tensor: tensor.into_description(), - value, - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskFillOps::), - ), - )); - - out - } - - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(EqualOps, B::equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::BaseOpsFloat( - BaseOpsDescription::Equal( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(EqualOps::), - ), - )); - - out - } - - fn equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(EqualElemOps, B::equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::EqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(EqualElemOps::), - ), - )); - - out - } - - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(GreaterOps, B::greater); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Greater( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterOps::), - ), - )); - - out - } - - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::GreaterElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterElemOps::), - ), - )); - - out - } - - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(GreaterEqualOps, B::greater_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::GreaterEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterEqualOps::), - ), - )); - - out - } - - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::GreaterEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterEqualElemOps::), - ), - )); - - out - } - - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(LowerOps, B::lower); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Lower( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerOps::), - ), - )); - - out - } - - fn lower_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(LowerElemOps, B::lower_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::LowerElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerElemOps::), - ), - )); - - out - } - - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - binary_float_cmp_ops!(LowerEqualOps, B::lower_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::LowerEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerEqualOps::), - ), - )); - - out - } - - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::LowerEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerEqualElemOps::), - ), - )); - - out - } - - fn sum(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SumOps, B::sum); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Sum( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SumOps::), - ), - )); - - out - } - - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(SumDimOps, B::sum_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::SumDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(SumDimOps::), - ), - )); - - out - } - - fn mean(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MeanOps, B::mean); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Mean( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MeanOps::), - ), - )); - - out - } - - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(MeanDimOps, B::mean_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MeanDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MeanDimOps::), - ), - )); - - out - } - - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - tensor.clone() - } - - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - tensor - } - - fn exp(lhs: FloatTensor) -> FloatTensor { - unary_float_ops!(ExpOps, B::exp); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Exp( - UnaryOpsDescription { - input: lhs.into_description(), - out: out.to_description_out(), - }, - Box::new(ExpOps::), - ))); - - out - } - - fn log(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(LogOps, B::log); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(LogOps::), - ))); - - out - } - - fn log1p(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(Log1pOps, B::log1p); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log1p( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(Log1pOps::), - ))); - - out - } - - fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { - scalar_float_ops!(PowfOps, B::powf, f32); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Powf( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(PowfOps::), - ))); - - out - } - - fn sqrt(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SqrtOps, B::sqrt); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sqrt( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SqrtOps::), - ))); - - out - } - - fn abs(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(AbsOps, B::abs); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Abs( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(AbsOps::), - ), - )); - - out - } - - fn cos(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(CosOps, B::cos); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Cos( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(CosOps::), - ))); - - out - } - - fn sin(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SinOps, B::sin); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sin( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SinOps::), - ))); - - out - } - - fn tanh(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(TanhOps, B::tanh); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Tanh( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(TanhOps::), - ))); - - out - } - - fn recip(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(Recip, B::recip); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(Recip::), - ))); - out - } - - fn erf(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(TanhOps, B::erf); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out - .client - .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Erf( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(TanhOps::), - ))); - - out - } - - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - struct CatOps; - - impl Ops for CatOps { - type Args = CatOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensors = args - .tensors - .iter() - .map(|tensor| handles.get_float_tensor(tensor)) - .collect(); - - let output = B::cat::(tensors, args.dim); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let tensor_first = tensors.get(0).unwrap(); - let client = tensor_first.client.clone(); - - // Calculate the output shape - let mut shape: Vec = tensor_first.shape.clone(); - shape[dim] = 0; - for tensor in tensors.iter() { - shape[dim] += tensor.shape[dim]; - } - - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::BaseOpsFloat(BaseOpsDescription::Cat( - CatOpsDescription { - tensors: tensors.into_iter().map(|t| t.into_description()).collect(), - dim, - out: out.to_description_out(), - }, - Box::new(CatOps::), - ))); - - out - } - - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - scalar_float2int_ops!(ArgMaxOps, B::argmax, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ArgMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMaxOps::), - ), - )); - - out - } - - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - scalar_float2int_ops!(ArgMinOps, B::argmin, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::ArgMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMinOps::), - ), - )); - - out - } - - fn max(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MaxOps, B::max); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Max( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MaxOps::), - ), - )); - - out - } - - fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(MaxDimOps, B::max_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaxDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MaxDimOps::), - ), - )); - - out - } - - fn max_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - struct MaxDimWithIndicesOps; - - impl Ops for MaxDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let (output, indices) = B::max_dim_with_indices(tensor, args.dim); - - handles.register_float_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MaxDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxDimWithIndicesOps::), - ), - )); - - (out, out_indices) - } - - fn min(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MinOps, B::min); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::Min( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MinOps::), - ), - )); - - out - } - - fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - scalar_float_ops!(MinDimOps, B::min_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MinDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MinDimOps::), - ), - )); - - out - } - - fn min_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - struct MinDimWithIndicesOps; - - impl Ops for MinDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_float_tensor::(&args.tensor); - let (output, indices) = B::min_dim_with_indices(tensor, args.dim); - - handles.register_float_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } - } - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsFloat( - NumericOpsDescription::MinDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MinDimWithIndicesOps::), - ), - )); - - (out, out_indices) - } + fn min_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + struct MinDimWithIndicesOps; + + impl Ops for MinDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&args.tensor); + let (output, indices) = B::min_dim_with_indices(tensor, args.dim); + + handles.register_float_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsFloat( + NumericOpsDescription::MinDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MinDimWithIndicesOps::), + ), + )); + + (out, out_indices) + } } diff --git a/burn-fusion/src/ops/int.rs b/burn-fusion/src/ops/int.rs index 8e7f989d8f..32d2d6a547 100644 --- a/burn-fusion/src/ops/int.rs +++ b/burn-fusion/src/ops/int.rs @@ -1,1406 +1,1401 @@ use crate::{ - binary_int_cmp_ops, binary_int_ops, - client::FusionClient, - get_client, - graph::{ - self, BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, - GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, NumericOpsDescription, - Ops, ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOpsDescription, - ScatterOpsDescription, SelectAssignOpsDescription, SelectOpsDescription, - SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, TensorOpsDescription, - UnaryOpsDescription, - }, - ops::binary::binary_ops_shape, - scalar_int_cmp_ops, scalar_int_ops, unary_int_ops, Fusion, FusionBackend, TensorDescription, + binary_int_cmp_ops, binary_int_ops, + client::FusionClient, + get_client, + graph::{ + self, BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription, + GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription, + NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription, + ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription, + SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription, + TensorOpsDescription, UnaryOpsDescription, + }, + ops::binary::binary_ops_shape, + scalar_int_cmp_ops, scalar_int_ops, unary_int_ops, Fusion, FusionBackend, TensorDescription, }; use burn_tensor::{ - ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, - Data, Device, Reader, Shape, + ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, + Data, Device, Reader, Shape, }; use core::ops::Range; impl IntTensorOps for Fusion { - fn int_empty(shape: Shape, device: &Device) -> IntTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::int_empty(shape.clone(), device); - - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - tensor.shape() - } - - fn int_into_data(tensor: IntTensor) -> Reader, D>> { - tensor.int_into_data() - } - - fn int_from_data( - data: Data, D>, - device: &Device, - ) -> IntTensor { - let client = get_client::(&device.clone().into()); - let tensor = B::int_from_data(data, device); - let shape = B::int_shape(&tensor); - - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) - } - - fn int_device(tensor: &IntTensor) -> Device { - tensor.client.device().clone().into() - } - - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); - - if device_original == &device_target { - return tensor; + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::int_empty(shape.clone(), device); + + client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + tensor.shape() + } + + fn int_into_data(tensor: IntTensor) -> Reader, D>> { + tensor.int_into_data() + } + + fn int_from_data( + data: Data, D>, + device: &Device, + ) -> IntTensor { + let client = get_client::(&device.clone().into()); + let tensor = B::int_from_data(data, device); + let shape = B::int_shape(&tensor); + + client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) } - let client_target = get_client::(&device_target); - let client_original = tensor.client.clone(); + fn int_device(tensor: &IntTensor) -> Device { + tensor.client.device().clone().into() + } - client_original - .clone() - .change_client_int::(tensor.into_description(), client_target) - } + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + let device_original: &B::FusionDevice = tensor.client.device(); + let device_target: B::FusionDevice = device.clone().into(); - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - struct ReshapeDimsOps; + if device_original == &device_target { + return tensor; + } - impl Ops for ReshapeDimsOps { - type Args = ReshapeDescription; + let client_target = get_client::(&device_target); + let client_original = tensor.client.clone(); - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = B::int_reshape::(input, Shape::from(&args.shape)); - handles.register_int_tensor(&args.out.id, output); - } + client_original + .clone() + .change_client_int::(tensor.into_description(), client_target) } - let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape.clone()); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt( - BaseOpsDescription::Reshape( - ReshapeDescription { - input: tensor.into_description(), - shape, - out: out.to_description_out(), - }, - Box::new(ReshapeDimsOps::), - ), - )); - - out - } - - fn int_slice( - tensor: IntTensor, - ranges: [Range; D2], - ) -> IntTensor { - struct SliceOps; - - impl Ops for SliceOps { - type Args = SliceOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - - let output = B::int_slice::(tensor, args.ranges.clone().try_into().unwrap()); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + struct ReshapeDimsOps; + + impl Ops for ReshapeDimsOps { + type Args = ReshapeDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = B::int_reshape::(input, Shape::from(&args.shape)); + handles.register_int_tensor(&args.out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let out = tensor.client.tensor_uninitialized(shape.clone()); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt( + BaseOpsDescription::Reshape( + ReshapeDescription { + input: tensor.into_description(), + shape, + out: out.to_description_out(), + }, + Box::new(ReshapeDimsOps::), + ), + )); + + out } - let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + fn int_slice( + tensor: IntTensor, + ranges: [Range; D2], + ) -> IntTensor { + struct SliceOps; + + impl Ops for SliceOps { + type Args = SliceOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + + let output = + B::int_slice::(tensor, args.ranges.clone().try_into().unwrap()); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); - for i in shape.len()..D1 { - shape.push(tensor.shape[i]); + for i in shape.len()..D1 { + shape.push(tensor.shape[i]); + } + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Slice( + SliceOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + out: out.to_description_out(), + }, + Box::new(SliceOps::), + ))); + + out } - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Slice( - SliceOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - out: out.to_description_out(), - }, - Box::new(SliceOps::), - ))); - - out - } - - fn int_slice_assign( - tensor: IntTensor, - ranges: [Range; D2], - value: IntTensor, - ) -> IntTensor { - struct SliceAssignOps; - - impl Ops for SliceAssignOps { - type Args = SliceAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let value = handles.get_int_tensor::(&args.value); - - let output = - B::int_slice_assign::(tensor, args.ranges.clone().try_into().unwrap(), value); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_slice_assign( + tensor: IntTensor, + ranges: [Range; D2], + value: IntTensor, + ) -> IntTensor { + struct SliceAssignOps; + + impl Ops for SliceAssignOps { + type Args = SliceAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let value = handles.get_int_tensor::(&args.value); + + let output = B::int_slice_assign::( + tensor, + args.ranges.clone().try_into().unwrap(), + value, + ); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt( + BaseOpsDescription::SliceAssign( + SliceAssignOpsDescription { + tensor: tensor.into_description(), + ranges: ranges.into(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SliceAssignOps::), + ), + )); + + out } - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt( - BaseOpsDescription::SliceAssign( - SliceAssignOpsDescription { - tensor: tensor.into_description(), - ranges: ranges.into(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SliceAssignOps::), - ), - )); - - out - } - - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - value: IntTensor, - ) -> IntTensor { - struct MaskWhereOps; - - impl Ops for MaskWhereOps { - type Args = MaskWhereOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let value = handles.get_int_tensor(&args.value); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::int_mask_where(tensor, mask, value); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + value: IntTensor, + ) -> IntTensor { + struct MaskWhereOps; + + impl Ops for MaskWhereOps { + type Args = MaskWhereOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let value = handles.get_int_tensor(&args.value); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::int_mask_where(tensor, mask, value); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaskWhere( + MaskWhereOpsDescription { + tensor: tensor.into_description(), + value: value.into_description(), + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskWhereOps::), + ), + )); + + out } - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaskWhere( - MaskWhereOpsDescription { - tensor: tensor.into_description(), - value: value.into_description(), - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskWhereOps::), - ), - )); - - out - } - - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor { - struct MaskFillOps; - - impl Ops for MaskFillOps { - type Args = MaskFillOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let mask = handles.get_bool_tensor(&args.mask); - - let output = B::int_mask_fill(tensor, mask, args.value); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor { + struct MaskFillOps; + + impl Ops for MaskFillOps { + type Args = MaskFillOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let mask = handles.get_bool_tensor(&args.mask); + + let output = B::int_mask_fill(tensor, mask, args.value); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaskFill( + MaskFillOpsDescription { + tensor: tensor.into_description(), + value, + mask: mask.into_description(), + out: out.to_description_out(), + }, + Box::new(MaskFillOps::), + ), + )); + + out } - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaskFill( - MaskFillOpsDescription { - tensor: tensor.into_description(), - value, - mask: mask.into_description(), - out: out.to_description_out(), - }, - Box::new(MaskFillOps::), - ), - )); - - out - } - - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - struct GatherOps; - - impl Ops for GatherOps { - type Args = GatherOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::int_gather(args.dim, tensor, indices); - handles.register_int_tensor(&args.out.id, output); - } + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + struct GatherOps; + + impl Ops for GatherOps { + type Args = GatherOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::int_gather(args.dim, tensor, indices); + handles.register_int_tensor(&args.out.id, output); + } + } + + let shape: Vec = indices.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Gather( + GatherOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(GatherOps::), + ), + )); + + out } - let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Gather( - GatherOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(GatherOps::), - ), - )); - - out - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - struct ScatterOps; - - impl Ops for ScatterOps { - type Args = ScatterOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_int_tensor(&args.value); - - let output = B::int_scatter(args.dim, tensor, indices, value); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + struct ScatterOps; + + impl Ops for ScatterOps { + type Args = ScatterOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_int_tensor(&args.value); + + let output = B::int_scatter(args.dim, tensor, indices, value); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Scatter( + ScatterOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(ScatterOps::), + ), + )); + + out } - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Scatter( - ScatterOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(ScatterOps::), - ), - )); - - out - } - - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - struct SelectOps; - - impl Ops for SelectOps { - type Args = SelectOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - - let output = B::int_select(tensor, args.dim, indices); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + struct SelectOps; + + impl Ops for SelectOps { + type Args = SelectOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + + let output = B::int_select(tensor, args.dim, indices); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let mut shape: Vec = tensor.shape.clone(); + shape[dim] = indices.shape[0]; + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Select( + SelectOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectOps::), + ), + )); + + out } - let mut shape: Vec = tensor.shape.clone(); - shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Select( - SelectOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectOps::), - ), - )); - - out - } - - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - struct SelectAssignOps; - - impl Ops for SelectAssignOps { - type Args = SelectAssignOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let indices = handles.get_int_tensor(&args.indices); - let value = handles.get_int_tensor(&args.value); - - let output = B::int_select_assign(tensor, args.dim, indices, value); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + struct SelectAssignOps; + + impl Ops for SelectAssignOps { + type Args = SelectAssignOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let indices = handles.get_int_tensor(&args.indices); + let value = handles.get_int_tensor(&args.value); + + let output = B::int_select_assign(tensor, args.dim, indices, value); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::SelectAssign( + SelectAssignOpsDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }, + Box::new(SelectAssignOps::), + ), + )); + + out } - let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::SelectAssign( - SelectAssignOpsDescription { - tensor: tensor.into_description(), - dim, - indices: indices.into_description(), - value: value.into_description(), - out: out.to_description_out(), - }, - Box::new(SelectAssignOps::), - ), - )); - - out - } - - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - struct CatOps; - - impl Ops for CatOps { - type Args = CatOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensors = args - .tensors - .iter() - .map(|tensor| handles.get_int_tensor(tensor)) - .collect(); - - let output = B::int_cat::(tensors, args.dim); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + struct CatOps; + + impl Ops for CatOps { + type Args = CatOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensors = args + .tensors + .iter() + .map(|tensor| handles.get_int_tensor(tensor)) + .collect(); + + let output = B::int_cat::(tensors, args.dim); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let tensor_first = tensors.get(0).unwrap(); + let client = tensor_first.client.clone(); + + // Calculate the output shape + let mut shape: Vec = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; + } + + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat( + CatOpsDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }, + Box::new(CatOps::), + ))); + + out } - let tensor_first = tensors.get(0).unwrap(); - let client = tensor_first.client.clone(); + fn int_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(EqualOps, B::int_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client + .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(EqualOps::), + ))); + + out + } - // Calculate the output shape - let mut shape: Vec = tensor_first.shape.clone(); - shape[dim] = 0; - for tensor in tensors.iter() { - shape[dim] += tensor.shape[dim]; + fn int_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::EqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(EqualElemOps::), + ), + )); + + out } - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat( - CatOpsDescription { - tensors: tensors.into_iter().map(|t| t.into_description()).collect(), - dim, - out: out.to_description_out(), - }, - Box::new(CatOps::), - ))); - - out - } - - fn int_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(EqualOps, B::int_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out - .client - .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(EqualOps::), - ))); - - out - } - - fn int_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::EqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(EqualElemOps::), - ), - )); - - out - } - - fn int_greater( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(GreaterOps, B::int_greater); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Greater( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterOps::), - ), - )); - - out - } - - fn int_greater_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::GreaterElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterElemOps::), - ), - )); - - out - } - - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::GreaterEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(GreaterEqualOps::), - ), - )); - - out - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::GreaterEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(GreaterEqualElemOps::), - ), - )); - - out - } - - fn int_lower( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(LowerOps, B::int_lower); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Lower( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerOps::), - ), - )); - - out - } - - fn int_lower_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::LowerElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerElemOps::), - ), - )); - - out - } - - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::LowerEqual( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(LowerEqualOps::), - ), - )); - - out - } - - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::LowerEqualElem( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(LowerEqualElemOps::), - ), - )); - - out - } - - fn int_add( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(AddOps, B::int_add); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Add( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out - } - - fn int_add_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(AddOps, B::int_add_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::AddScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(AddOps::), - ), - )); - - out - } - - fn int_sub( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(SubOps, B::int_sub); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Sub( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out - } - - fn int_sub_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(SubOps, B::int_sub_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::SubScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(SubOps::), - ), - )); - - out - } - - fn int_mul( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(MulOps, B::int_mul); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Mul( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out - } - - fn int_mul_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(MulOps, B::int_mul_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MulScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(MulOps::), - ), - )); - - out - } - - fn int_div( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - binary_int_ops!(DivOps, B::int_div); - - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Div( - BinaryOpsDescription { - lhs: lhs.into_description(), - rhs: rhs.into_description(), - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out - } - - fn int_div_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - scalar_int_ops!(DivOps, B::int_div_scalar); - - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); - - out - .client - .register(graph::TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::DivScalar( - ScalarOpsDescription { - lhs: lhs.into_description(), - rhs, - out: out.to_description_out(), - }, - Box::new(DivOps::), - ), - )); - - out - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - struct ZerosOps; - - impl Ops for ZerosOps { - type Args = TensorDescription; - - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::int_zeros::(shape, &handles.device); - handles.register_int_tensor(&out.id, output); - } + fn int_greater( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(GreaterOps, B::int_greater); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Greater( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterOps::), + ), + )); + + out } - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); + fn int_greater_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::GreaterElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterElemOps::), + ), + )); + + out + } - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), - )); + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::GreaterEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(GreaterEqualOps::), + ), + )); + + out + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::GreaterEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(GreaterEqualElemOps::), + ), + )); + + out + } + + fn int_lower( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(LowerOps, B::int_lower); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Lower( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerOps::), + ), + )); + + out + } + + fn int_lower_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::LowerElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerElemOps::), + ), + )); + + out + } - out - } + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::LowerEqual( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(LowerEqualOps::), + ), + )); + + out + } - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - struct OnesOps; + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::LowerEqualElem( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(LowerEqualElemOps::), + ), + )); + + out + } - impl Ops for OnesOps { - type Args = TensorDescription; + fn int_add( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(AddOps, B::int_add); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Add( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out + } - fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { - let shape = Shape::from(out.shape.clone()); - let output = B::int_ones::(shape, &handles.device); - handles.register_int_tensor(&out.id, output); - } + fn int_add_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(AddOps, B::int_add_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::AddScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(AddOps::), + ), + )); + + out } - let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); - let out = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), - )); - - out - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - unary_int_ops!(SumOps, B::int_sum); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Sum( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(SumOps::), - ), - )); - - out - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(SumDimOps, B::int_sum_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::SumDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(SumDimOps::), - ), - )); - - out - } - - fn int_mean(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MeanOps, B::int_mean); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Mean( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MeanOps::), - ), - )); - - out - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(MeanDimOps, B::int_mean_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MeanDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MeanDimOps::), - ), - )); - - out - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(ArgMaxOps, B::int_argmax, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ArgMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMaxOps::), - ), - )); - - out - } - - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(ArgMinOps, B::int_argmin, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ArgMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(ArgMinOps::), - ), - )); - - out - } - - fn int_clamp_min( - tensor: IntTensor, - min: IntElem, - ) -> IntTensor { - scalar_int_ops!(ClampMinOps, B::int_clamp_min); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ClampMin( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: min, - out: out.to_description_out(), - }, - Box::new(ClampMinOps::), - ), - )); - - out - } - - fn int_clamp_max( - tensor: IntTensor, - max: IntElem, - ) -> IntTensor { - scalar_int_ops!(ClampMaxOps, B::int_clamp_max); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::ClampMax( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: max, - out: out.to_description_out(), - }, - Box::new(ClampMaxOps::), - ), - )); - - out - } - - fn int_clamp( - tensor: IntTensor, - min: IntElem, - max: IntElem, - ) -> IntTensor { - struct ClampOps; - - impl Ops for ClampOps { - type Args = ClampOpsDescription>; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.tensor); - let output = B::int_clamp(input, args.min, args.max); - - handles.register_int_tensor(&args.out.id, output); - } + fn int_sub( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(SubOps, B::int_sub); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Sub( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out } - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Clamp( - ClampOpsDescription { - tensor: tensor.into_description(), - min, - max, - out: out.to_description_out(), - }, - Box::new(ClampOps::), - ), - )); - - out - } - - fn int_abs(tensor: IntTensor) -> IntTensor { - unary_int_ops!(AbsOps, B::int_abs); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Abs( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(AbsOps::), - ), - )); - - out - } - - fn int_into_float(tensor: IntTensor) -> FloatTensor { - struct IntoFloatOps; - - impl Ops for IntoFloatOps { - type Args = UnaryOpsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = B::int_into_float(input); - handles.register_float_tensor(&args.out.id, output); - } + fn int_sub_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(SubOps, B::int_sub_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::SubScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(SubOps::), + ), + )); + + out } - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - out.client.register(TensorOpsDescription::IntOps( - graph::IntOpsDescription::IntoFloat( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(IntoFloatOps::), - ), - )); - - out - } - - fn int_swap_dims( - tensor: IntTensor, - dim1: usize, - dim2: usize, - ) -> IntTensor { - struct SwapDimsOps; - - impl Ops for SwapDimsOps { - type Args = SwapDimsDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = B::int_swap_dims(input, args.dim1, args.dim2); - handles.register_int_tensor(&args.out.id, output); - } + fn int_mul( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(MulOps, B::int_mul); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Mul( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out } - let mut shape = tensor.shape.clone(); - shape[dim1] = tensor.shape[dim2]; - shape[dim2] = tensor.shape[dim1]; - - let out = tensor.client.tensor_uninitialized(shape); - - tensor - .client - .clone() - .register(TensorOpsDescription::BaseOpsInt( - BaseOpsDescription::SwapDims( - SwapDimsDescription { - input: tensor.into_description(), - dim1, - dim2, - out: out.to_description_out(), - }, - Box::new(SwapDimsOps::), - ), - )); - - out - } - - fn int_max(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MaxOps, B::int_max); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Max( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MaxOps::), - ), - )); - - out - } - - fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(MaxDimOps, B::int_max_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaxDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MaxDimOps::), - ), - )); - - out - } - - fn int_max_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - struct MaxDimWithIndicesOps; - - impl Ops for MaxDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let (output, indices) = B::int_max_dim_with_indices(tensor, args.dim); - - handles.register_int_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } + fn int_mul_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(MulOps, B::int_mul_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MulScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(MulOps::), + ), + )); + + out } - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MaxDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxDimWithIndicesOps::), - ), - )); - - (out, out_indices) - } - - fn int_min(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MinOps, B::int_min); - - let out = tensor.client.tensor_uninitialized(vec![1]); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::Min( - UnaryOpsDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }, - Box::new(MinOps::), - ), - )); - - out - } - - fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { - scalar_int_ops!(MinDimOps, B::int_min_dim, usize); - - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); - - out.client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MinDim( - ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: dim, - out: out.to_description_out(), - }, - Box::new(MinDimOps::), - ), - )); - - out - } - - fn int_min_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - struct MinDimWithIndicesOps; - - impl Ops for MinDimWithIndicesOps { - type Args = ReduceDimWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let tensor = handles.get_int_tensor::(&args.tensor); - let (output, indices) = B::int_min_dim_with_indices(tensor, args.dim); - - handles.register_int_tensor(&args.out.id, output); - handles.register_int_tensor(&args.out_indices.id, indices); - } + fn int_div( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + binary_int_ops!(DivOps, B::int_div); + + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Div( + BinaryOpsDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out } - let mut shape = tensor.shape.clone(); - shape[dim] = 1; - let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); - - client.register(TensorOpsDescription::NumericOpsInt( - NumericOpsDescription::MinDimWithIndices( - ReduceDimWithIndicesDescription { - tensor: tensor.into_description(), - dim, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MinDimWithIndicesOps::), - ), - )); - - (out, out_indices) - } + fn int_div_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + scalar_int_ops!(DivOps, B::int_div_scalar); + + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + + out.client + .register(graph::TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::DivScalar( + ScalarOpsDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }, + Box::new(DivOps::), + ), + )); + + out + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + struct ZerosOps; + + impl Ops for ZerosOps { + type Args = TensorDescription; + + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::int_zeros::(shape, &handles.device); + handles.register_int_tensor(&out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), + )); + + out + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + struct OnesOps; + + impl Ops for OnesOps { + type Args = TensorDescription; + + fn execute(&self, out: &Self::Args, handles: &mut crate::HandleContainer) { + let shape = Shape::from(out.shape.clone()); + let output = B::int_ones::(shape, &handles.device); + handles.register_int_tensor(&out.id, output); + } + } + + let shape: Vec = shape.dims.into(); + let client = get_client::(&device.clone().into()); + let out = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), + )); + + out + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + unary_int_ops!(SumOps, B::int_sum); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Sum( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(SumOps::), + ), + )); + + out + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(SumDimOps, B::int_sum_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::SumDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(SumDimOps::), + ), + )); + + out + } + + fn int_mean(tensor: IntTensor) -> IntTensor { + unary_int_ops!(MeanOps, B::int_mean); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Mean( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MeanOps::), + ), + )); + + out + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(MeanDimOps, B::int_mean_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MeanDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MeanDimOps::), + ), + )); + + out + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(ArgMaxOps, B::int_argmax, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ArgMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMaxOps::), + ), + )); + + out + } + + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(ArgMinOps, B::int_argmin, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ArgMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(ArgMinOps::), + ), + )); + + out + } + + fn int_clamp_min( + tensor: IntTensor, + min: IntElem, + ) -> IntTensor { + scalar_int_ops!(ClampMinOps, B::int_clamp_min); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ClampMin( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: min, + out: out.to_description_out(), + }, + Box::new(ClampMinOps::), + ), + )); + + out + } + + fn int_clamp_max( + tensor: IntTensor, + max: IntElem, + ) -> IntTensor { + scalar_int_ops!(ClampMaxOps, B::int_clamp_max); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::ClampMax( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: max, + out: out.to_description_out(), + }, + Box::new(ClampMaxOps::), + ), + )); + + out + } + + fn int_clamp( + tensor: IntTensor, + min: IntElem, + max: IntElem, + ) -> IntTensor { + struct ClampOps; + + impl Ops for ClampOps { + type Args = ClampOpsDescription>; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.tensor); + let output = B::int_clamp(input, args.min, args.max); + + handles.register_int_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Clamp( + ClampOpsDescription { + tensor: tensor.into_description(), + min, + max, + out: out.to_description_out(), + }, + Box::new(ClampOps::), + ), + )); + + out + } + + fn int_abs(tensor: IntTensor) -> IntTensor { + unary_int_ops!(AbsOps, B::int_abs); + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Abs( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(AbsOps::), + ), + )); + + out + } + + fn int_into_float(tensor: IntTensor) -> FloatTensor { + struct IntoFloatOps; + + impl Ops for IntoFloatOps { + type Args = UnaryOpsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = B::int_into_float(input); + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + out.client.register(TensorOpsDescription::IntOps( + graph::IntOpsDescription::IntoFloat( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(IntoFloatOps::), + ), + )); + + out + } + + fn int_swap_dims( + tensor: IntTensor, + dim1: usize, + dim2: usize, + ) -> IntTensor { + struct SwapDimsOps; + + impl Ops for SwapDimsOps { + type Args = SwapDimsDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = B::int_swap_dims(input, args.dim1, args.dim2); + handles.register_int_tensor(&args.out.id, output); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + + let out = tensor.client.tensor_uninitialized(shape); + + tensor + .client + .clone() + .register(TensorOpsDescription::BaseOpsInt( + BaseOpsDescription::SwapDims( + SwapDimsDescription { + input: tensor.into_description(), + dim1, + dim2, + out: out.to_description_out(), + }, + Box::new(SwapDimsOps::), + ), + )); + + out + } + + fn int_max(tensor: IntTensor) -> IntTensor { + unary_int_ops!(MaxOps, B::int_max); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Max( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MaxOps::), + ), + )); + + out + } + + fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(MaxDimOps, B::int_max_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaxDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MaxDimOps::), + ), + )); + + out + } + + fn int_max_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + struct MaxDimWithIndicesOps; + + impl Ops for MaxDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let (output, indices) = B::int_max_dim_with_indices(tensor, args.dim); + + handles.register_int_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MaxDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxDimWithIndicesOps::), + ), + )); + + (out, out_indices) + } + + fn int_min(tensor: IntTensor) -> IntTensor { + unary_int_ops!(MinOps, B::int_min); + + let out = tensor.client.tensor_uninitialized(vec![1]); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::Min( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(MinOps::), + ), + )); + + out + } + + fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { + scalar_int_ops!(MinDimOps, B::int_min_dim, usize); + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = tensor.client.tensor_uninitialized(shape); + + out.client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MinDim( + ScalarOpsDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }, + Box::new(MinDimOps::), + ), + )); + + out + } + + fn int_min_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + struct MinDimWithIndicesOps; + + impl Ops for MinDimWithIndicesOps { + type Args = ReduceDimWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&args.tensor); + let (output, indices) = B::int_min_dim_with_indices(tensor, args.dim); + + handles.register_int_tensor(&args.out.id, output); + handles.register_int_tensor(&args.out_indices.id, indices); + } + } + + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let client = tensor.client.clone(); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); + + client.register(TensorOpsDescription::NumericOpsInt( + NumericOpsDescription::MinDimWithIndices( + ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MinDimWithIndicesOps::), + ), + )); + + (out, out_indices) + } } diff --git a/burn-fusion/src/ops/module.rs b/burn-fusion/src/ops/module.rs index d20d8ccd87..2eef4be4b3 100644 --- a/burn-fusion/src/ops/module.rs +++ b/burn-fusion/src/ops/module.rs @@ -1,900 +1,907 @@ use crate::{ - client::FusionClient, - graph::{ - AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, - AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, - AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, - AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, - ConvTranspose2dDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, - MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, - MaxPool2dWithIndicesDescription, Ops, TensorOpsDescription, - }, - Fusion, FusionBackend, + client::FusionClient, + graph::{ + AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, + AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, + AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, + AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, + ConvTranspose2dDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, + MaxPool1dWithIndicesDescription, MaxPool2dDescription, + MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Ops, + TensorOpsDescription, + }, + Fusion, FusionBackend, }; use burn_tensor::ops::{ - conv::{ - calculate_conv_output_size, calculate_conv_transpose_output_size, calculate_pool_output_size, - }, - ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool1dBackward, - MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + conv::{ + calculate_conv_output_size, calculate_conv_transpose_output_size, + calculate_pool_output_size, + }, + ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool1dBackward, + MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; impl ModuleOps> for Fusion { - fn conv1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> FloatTensor { - struct Conv1dOps; - - impl Ops for Conv1dOps { - type Args = Conv1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv1d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size = calculate_conv_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.dilation[0], - x.shape[2], - ); - - let shape = vec![x.shape[0], weight.shape[0], size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::Conv1d( - Conv1dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(Conv1dOps), - ), - )); - - out - } - - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor { - struct Conv2dOps; - - impl Ops for Conv2dOps { - type Args = Conv2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv2d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } - } - - let size_0 = calculate_conv_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.dilation[0], - x.shape[2], - ); - let size_1 = calculate_conv_output_size( - weight.shape[3], - options.stride[1], - options.padding[1], - options.dilation[1], - x.shape[3], - ); - - let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::Conv2d( - Conv2dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(Conv2dOps), - ), - )); - - out - } - - fn conv_transpose1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> FloatTensor { - struct ConvTranspose1dOps; - - impl Ops for ConvTranspose1dOps { - type Args = ConvTranspose1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv_transpose1d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } - } + fn conv1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> FloatTensor { + struct Conv1dOps; + + impl Ops for Conv1dOps { + type Args = Conv1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv1d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); - let size = calculate_conv_transpose_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.padding_out[0], - options.dilation[0], - x.shape[2], - ); - - let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::ConvTranspose1d( - ConvTranspose1dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(ConvTranspose1dOps), - ), - )); - - out - } - - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor { - struct ConvTranspose2dOps; - - impl Ops for ConvTranspose2dOps { - type Args = ConvTranspose2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); - let bias = args - .bias - .as_ref() - .map(|bias| handles.get_float_tensor(bias)); - - let output = B::conv_transpose2d(x, weight, bias, args.options.clone()); - - handles.register_float_tensor(&args.out.id, output); - } + let shape = vec![x.shape[0], weight.shape[0], size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::Conv1d( + Conv1dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(Conv1dOps), + ), + )); + + out } - let size_0 = calculate_conv_transpose_output_size( - weight.shape[2], - options.stride[0], - options.padding[0], - options.padding_out[0], - options.dilation[0], - x.shape[2], - ); - let size_1 = calculate_conv_transpose_output_size( - weight.shape[3], - options.stride[1], - options.padding[1], - options.padding_out[1], - options.dilation[1], - x.shape[3], - ); - - let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::ConvTranspose2d( - ConvTranspose2dDescription { - x: x.into_description(), - weight: weight.into_description(), - bias: bias.map(|bias| bias.into_description()), - options, - out: out.to_description_out(), - }, - Box::new(ConvTranspose2dOps), - ), - )); - - out - } - - fn avg_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool1dOps; - - impl Ops for AvgPool1dOps { - type Args = AvgPool1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::avg_pool1d( - x, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor { + struct Conv2dOps; + + impl Ops for Conv2dOps { + type Args = Conv2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv2d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size_0 = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.dilation[1], + x.shape[3], ); - handles.register_float_tensor(&args.out.id, output); - } + let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::Conv2d( + Conv2dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(Conv2dOps), + ), + )); + + out } - let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); - let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool1d( - AvgPool1dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool1dOps), - ), - )); - - out - } - - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool2dOps; - - impl Ops for AvgPool2dOps { - type Args = AvgPool2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::avg_pool2d( - x, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, + fn conv_transpose1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + struct ConvTranspose1dOps; + + impl Ops for ConvTranspose1dOps { + type Args = ConvTranspose1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv_transpose1d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size = calculate_conv_transpose_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.padding_out[0], + options.dilation[0], + x.shape[2], ); - handles.register_float_tensor(&args.out.id, output); - } + let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::ConvTranspose1d( + ConvTranspose1dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(ConvTranspose1dOps), + ), + )); + + out } - let size_0 = calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]); - let size_1 = calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]); - - let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool2d( - AvgPool2dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool2dOps), - ), - )); - - out - } - - fn avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool1dBackwardOps; - - impl Ops for AvgPool1dBackwardOps { - type Args = AvgPool1dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::avg_pool1d_backward( - x, - grad, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + struct ConvTranspose2dOps; + + impl Ops for ConvTranspose2dOps { + type Args = ConvTranspose2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let weight = handles.get_float_tensor(&args.weight); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor(bias)); + + let output = B::conv_transpose2d(x, weight, bias, args.options.clone()); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size_0 = calculate_conv_transpose_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.padding_out[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_transpose_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.padding_out[1], + options.dilation[1], + x.shape[3], ); - handles.register_float_tensor(&args.out.id, output); - } + let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::ConvTranspose2d( + ConvTranspose2dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options, + out: out.to_description_out(), + }, + Box::new(ConvTranspose2dOps), + ), + )); + + out } - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool1dBackward( - AvgPool1dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool1dBackwardOps), - ), - )); - - out - } - - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - struct AvgPool2dBackwardOps; - - impl Ops for AvgPool2dBackwardOps { - type Args = AvgPool2dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::avg_pool2d_backward( - x, - grad, - args.kernel_size, - args.stride, - args.padding, - args.count_include_pad, - ); - - handles.register_float_tensor(&args.out.id, output); - } + fn avg_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool1dOps; + + impl Ops for AvgPool1dOps { + type Args = AvgPool1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::avg_pool1d( + x, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, + ); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); + let shape = vec![x.shape[0], x.shape[1], size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool1d( + AvgPool1dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool1dOps), + ), + )); + + out } - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AvgPool2dBackward( - AvgPool2dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - kernel_size, - stride, - padding, - count_include_pad, - out: out.to_description_out(), - }, - Box::new(AvgPool2dBackwardOps), - ), - )); - - out - } - - fn max_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> FloatTensor { - struct MaxPool1dOps; - - impl Ops for MaxPool1dOps { - type Args = MaxPool1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool1d( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - ); + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool2dOps; + + impl Ops for AvgPool2dOps { + type Args = AvgPool2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::avg_pool2d( + x, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, + ); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size_0 = + calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]); + let size_1 = + calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]); + + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool2d( + AvgPool2dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool2dOps), + ), + )); + + out + } - handles.register_float_tensor(&args.out.id, output); - } + fn avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool1dBackwardOps; + + impl Ops for AvgPool1dBackwardOps { + type Args = AvgPool1dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::avg_pool1d_backward( + x, + grad, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, + ); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool1dBackward( + AvgPool1dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool1dBackwardOps), + ), + )); + + out } - let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); - - let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool1d( - MaxPool1dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool1dOps), - ), - )); - - out - } - - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor { - struct MaxPool2dOps; - - impl Ops for MaxPool2dOps { - type Args = MaxPool2dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool2d( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - ); + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + struct AvgPool2dBackwardOps; + + impl Ops for AvgPool2dBackwardOps { + type Args = AvgPool2dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::avg_pool2d_backward( + x, + grad, + args.kernel_size, + args.stride, + args.padding, + args.count_include_pad, + ); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AvgPool2dBackward( + AvgPool2dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }, + Box::new(AvgPool2dBackwardOps), + ), + )); + + out + } - handles.register_float_tensor(&args.out.id, output); - } + fn max_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> FloatTensor { + struct MaxPool1dOps; + + impl Ops for MaxPool1dOps { + type Args = MaxPool1dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool1d( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + ); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); + + let shape = vec![x.shape[0], x.shape[1], size]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool1d( + MaxPool1dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool1dOps), + ), + )); + + out } - let size_0 = calculate_pool_output_size( - kernel_size[0], - stride[0], - padding[0], - dilation[0], - x.shape[2], - ); - let size_1 = calculate_pool_output_size( - kernel_size[1], - stride[1], - padding[1], - dilation[1], - x.shape[3], - ); - - let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool2d( - MaxPool2dDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool2dOps), - ), - )); - - out - } - - fn max_pool1d_with_indices( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices { - struct MaxPool1dWithIndicesOps; - - impl Ops for MaxPool1dWithIndicesOps { - type Args = MaxPool1dWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool1d_with_indices( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor { + struct MaxPool2dOps; + + impl Ops for MaxPool2dOps { + type Args = MaxPool2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool2d( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + ); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], ); - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); - } + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool2d( + MaxPool2dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool2dOps), + ), + )); + + out } - let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); - let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool1dWithIndices( - MaxPool1dWithIndicesDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxPool1dWithIndicesOps), - ), - )); - - MaxPool1dWithIndices::new(out, out_indices) - } - - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices { - struct MaxPool2dWithIndicesOps; - - impl Ops for MaxPool2dWithIndicesOps { - type Args = MaxPool2dWithIndicesDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::max_pool2d_with_indices( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - ); - - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); - } + fn max_pool1d_with_indices( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices { + struct MaxPool1dWithIndicesOps; + + impl Ops for MaxPool1dWithIndicesOps { + type Args = MaxPool1dWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool1d_with_indices( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + ); + + handles.register_float_tensor(&args.out.id, output.output); + handles.register_int_tensor(&args.out_indices.id, output.indices); + } + } + + let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); + let shape = vec![x.shape[0], x.shape[1], size]; + let out = x.client.tensor_uninitialized(shape.clone()); + let out_indices = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool1dWithIndices( + MaxPool1dWithIndicesDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxPool1dWithIndicesOps), + ), + )); + + MaxPool1dWithIndices::new(out, out_indices) } - let size_0 = calculate_pool_output_size( - kernel_size[0], - stride[0], - padding[0], - dilation[0], - x.shape[2], - ); - let size_1 = calculate_pool_output_size( - kernel_size[1], - stride[1], - padding[1], - dilation[1], - x.shape[3], - ); - - let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool2dWithIndices( - MaxPool2dWithIndicesDescription { - x: x.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - out_indices: out_indices.to_description_out(), - }, - Box::new(MaxPool2dWithIndicesOps), - ), - )); - - MaxPool2dWithIndices::new(out, out_indices) - } - - fn max_pool1d_with_indices_backward( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool1dBackward { - struct MaxPool1dWithIndicesBackwardOps; - - impl Ops for MaxPool1dWithIndicesBackwardOps { - type Args = MaxPool1dWithIndicesBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); - let output = B::max_pool1d_with_indices_backward( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - grad, - indices, + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices { + struct MaxPool2dWithIndicesOps; + + impl Ops for MaxPool2dWithIndicesOps { + type Args = MaxPool2dWithIndicesDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::max_pool2d_with_indices( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + ); + + handles.register_float_tensor(&args.out.id, output.output); + handles.register_int_tensor(&args.out_indices.id, output.indices); + } + } + + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], ); - handles.register_float_tensor(&args.out.id, output.x_grad); - } + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape.clone()); + let out_indices = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool2dWithIndices( + MaxPool2dWithIndicesDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }, + Box::new(MaxPool2dWithIndicesOps), + ), + )); + + MaxPool2dWithIndices::new(out, out_indices) } - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool1dWithIndicesBackward( - MaxPool1dWithIndicesBackwardDescription { - x: x.into_description(), - grad: output_grad.into_description(), - indices: indices.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool1dWithIndicesBackwardOps), - ), - )); - - MaxPool1dBackward::new(out) - } - - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward { - struct MaxPool2dWithIndicesBackwardOps; - - impl Ops for MaxPool2dWithIndicesBackwardOps { - type Args = MaxPool2dWithIndicesBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); - let output = B::max_pool2d_with_indices_backward( - x, - args.kernel_size, - args.stride, - args.padding, - args.dilation, - grad, - indices, - ); - - handles.register_float_tensor(&args.out.id, output.x_grad); - } + fn max_pool1d_with_indices_backward( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool1dBackward { + struct MaxPool1dWithIndicesBackwardOps; + + impl Ops for MaxPool1dWithIndicesBackwardOps { + type Args = MaxPool1dWithIndicesBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let indices = handles.get_int_tensor(&args.indices); + let output = B::max_pool1d_with_indices_backward( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + grad, + indices, + ); + + handles.register_float_tensor(&args.out.id, output.x_grad); + } + } + + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool1dWithIndicesBackward( + MaxPool1dWithIndicesBackwardDescription { + x: x.into_description(), + grad: output_grad.into_description(), + indices: indices.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool1dWithIndicesBackwardOps), + ), + )); + + MaxPool1dBackward::new(out) } - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::MaxPool2dWithIndicesBackward( - MaxPool2dWithIndicesBackwardDescription { - x: x.into_description(), - grad: output_grad.into_description(), - indices: indices.into_description(), - kernel_size, - stride, - padding, - dilation, - out: out.to_description_out(), - }, - Box::new(MaxPool2dWithIndicesBackwardOps), - ), - )); - - MaxPool2dBackward::new(out) - } - - fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { - struct AdaptiveAvgPool1dOps; - - impl Ops for AdaptiveAvgPool1dOps { - type Args = AdaptiveAvgPool1dDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::adaptive_avg_pool1d(x, args.output_size); - - handles.register_float_tensor(&args.out.id, output); - } + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward { + struct MaxPool2dWithIndicesBackwardOps; + + impl Ops for MaxPool2dWithIndicesBackwardOps { + type Args = MaxPool2dWithIndicesBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let indices = handles.get_int_tensor(&args.indices); + let output = B::max_pool2d_with_indices_backward( + x, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + grad, + indices, + ); + + handles.register_float_tensor(&args.out.id, output.x_grad); + } + } + + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::MaxPool2dWithIndicesBackward( + MaxPool2dWithIndicesBackwardDescription { + x: x.into_description(), + grad: output_grad.into_description(), + indices: indices.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }, + Box::new(MaxPool2dWithIndicesBackwardOps), + ), + )); + + MaxPool2dBackward::new(out) } - let shape = vec![x.shape[0], x.shape[1], output_size]; - let out = x.client.tensor_uninitialized(shape); + fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { + struct AdaptiveAvgPool1dOps; - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool1d( - AdaptiveAvgPool1dDescription { - x: x.into_description(), - output_size, - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool1dOps), - ), - )); + impl Ops for AdaptiveAvgPool1dOps { + type Args = AdaptiveAvgPool1dDescription; - out - } + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::adaptive_avg_pool1d(x, args.output_size); - fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { - struct AdaptiveAvgPool2dOps; + handles.register_float_tensor(&args.out.id, output); + } + } - impl Ops for AdaptiveAvgPool2dOps { - type Args = AdaptiveAvgPool2dDescription; + let shape = vec![x.shape[0], x.shape[1], output_size]; + let out = x.client.tensor_uninitialized(shape); - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let output = B::adaptive_avg_pool2d(x, args.output_size); + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool1d( + AdaptiveAvgPool1dDescription { + x: x.into_description(), + output_size, + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool1dOps), + ), + )); - handles.register_float_tensor(&args.out.id, output); - } + out } - let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; - let out = x.client.tensor_uninitialized(shape); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool2d( - AdaptiveAvgPool2dDescription { - x: x.into_description(), - output_size, - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool2dOps), - ), - )); - - out - } - - fn adaptive_avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - struct AdaptiveAvgPool1dBackwardOps; - - impl Ops for AdaptiveAvgPool1dBackwardOps { - type Args = AdaptiveAvgPool1dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::adaptive_avg_pool1d_backward(x, grad); - - handles.register_float_tensor(&args.out.id, output); - } + fn adaptive_avg_pool2d( + x: FloatTensor, + output_size: [usize; 2], + ) -> FloatTensor { + struct AdaptiveAvgPool2dOps; + + impl Ops for AdaptiveAvgPool2dOps { + type Args = AdaptiveAvgPool2dDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let output = B::adaptive_avg_pool2d(x, args.output_size); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; + let out = x.client.tensor_uninitialized(shape); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool2d( + AdaptiveAvgPool2dDescription { + x: x.into_description(), + output_size, + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool2dOps), + ), + )); + + out } - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool1dBackward( - AdaptiveAvgPool1dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool1dBackwardOps), - ), - )); - - out - } - - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - struct AdaptiveAvgPool2dBackwardOps; - - impl Ops for AdaptiveAvgPool2dBackwardOps { - type Args = AdaptiveAvgPool2dBackwardDescription; - - fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let output = B::adaptive_avg_pool2d_backward(x, grad); - - handles.register_float_tensor(&args.out.id, output); - } + fn adaptive_avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + struct AdaptiveAvgPool1dBackwardOps; + + impl Ops for AdaptiveAvgPool1dBackwardOps { + type Args = AdaptiveAvgPool1dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::adaptive_avg_pool1d_backward(x, grad); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool1dBackward( + AdaptiveAvgPool1dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool1dBackwardOps), + ), + )); + + out } - let out = x.client.tensor_uninitialized(x.shape.clone()); - - x.client.clone().register(TensorOpsDescription::ModuleOps( - crate::graph::ModuleOpsDescription::AdaptiveAvgPool2dBackward( - AdaptiveAvgPool2dBackwardDescription { - x: x.into_description(), - grad: grad.into_description(), - out: out.to_description_out(), - }, - Box::new(AdaptiveAvgPool2dBackwardOps), - ), - )); - - out - } + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + struct AdaptiveAvgPool2dBackwardOps; + + impl Ops for AdaptiveAvgPool2dBackwardOps { + type Args = AdaptiveAvgPool2dBackwardDescription; + + fn execute(&self, args: &Self::Args, handles: &mut crate::HandleContainer) { + let x = handles.get_float_tensor(&args.x); + let grad = handles.get_float_tensor(&args.grad); + let output = B::adaptive_avg_pool2d_backward(x, grad); + + handles.register_float_tensor(&args.out.id, output); + } + } + + let out = x.client.tensor_uninitialized(x.shape.clone()); + + x.client.clone().register(TensorOpsDescription::ModuleOps( + crate::graph::ModuleOpsDescription::AdaptiveAvgPool2dBackward( + AdaptiveAvgPool2dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + out: out.to_description_out(), + }, + Box::new(AdaptiveAvgPool2dBackwardOps), + ), + )); + + out + } } diff --git a/burn-fusion/src/ops/unary.rs b/burn-fusion/src/ops/unary.rs index f35e00aef6..84f6900b51 100644 --- a/burn-fusion/src/ops/unary.rs +++ b/burn-fusion/src/ops/unary.rs @@ -1,168 +1,168 @@ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_ops { - ( + ( $name:ident, $ops:expr ) => { - scalar_float_ops!($name, $ops, FloatElem); - }; - ( + scalar_float_ops!($name, $ops, FloatElem); + }; + ( $name:ident, $ops:expr, $elem:ty ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription<$elem>; + impl Ops for $name { + type Args = ScalarOpsDescription<$elem>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_float_tensor(&args.out.id, output); - } - } - }; + handles.register_float_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float2int_ops { - ( + ( $name:ident, $ops:expr, $elem:ty ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription<$elem>; + impl Ops for $name { + type Args = ScalarOpsDescription<$elem>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_float_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = UnaryOpsDescription; + impl Ops for $name { + type Args = UnaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let input = handles.get_float_tensor::(&args.input); - let output = $ops(input); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let input = handles.get_float_tensor::(&args.input); + let output = $ops(input); - handles.register_float_tensor(&args.out.id, output); - } - } - }; + handles.register_float_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_int_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = UnaryOpsDescription; + impl Ops for $name { + type Args = UnaryOpsDescription; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let input = handles.get_int_tensor::(&args.input); - let output = $ops(input); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let input = handles.get_int_tensor::(&args.input); + let output = $ops(input); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription>; + impl Ops for $name { + type Args = ScalarOpsDescription>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_float_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_float_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_cmp_ops { - ( + ( $name:ident, $ops:expr ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription>; + impl Ops for $name { + type Args = ScalarOpsDescription>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_bool_tensor(&args.out.id, output); - } - } - }; + handles.register_bool_tensor(&args.out.id, output); + } + } + }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_ops { - ( + ( $name:ident, $ops:expr ) => { - scalar_int_ops!($name, $ops, IntElem); - }; - ( + scalar_int_ops!($name, $ops, IntElem); + }; + ( $name:ident, $ops:expr, $elem:ty ) => { - struct $name; + struct $name; - impl Ops for $name { - type Args = ScalarOpsDescription<$elem>; + impl Ops for $name { + type Args = ScalarOpsDescription<$elem>; - fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { - let lhs = handles.get_int_tensor::(&args.lhs); - let output = $ops(lhs, args.rhs.clone()); + fn execute(&self, args: &Self::Args, handles: &mut $crate::HandleContainer) { + let lhs = handles.get_int_tensor::(&args.lhs); + let output = $ops(lhs, args.rhs.clone()); - handles.register_int_tensor(&args.out.id, output); - } - } - }; + handles.register_int_tensor(&args.out.id, output); + } + } + }; } diff --git a/burn-fusion/src/server.rs b/burn-fusion/src/server.rs index f2dfb163d6..9b52d38295 100644 --- a/burn-fusion/src/server.rs +++ b/burn-fusion/src/server.rs @@ -1,162 +1,161 @@ use crate::{ - graph::{Graph, GraphExecution, Optimization, TensorOpsDescription}, - FusionBackend, FusionProperties, FusionStatus, HandleContainer, TensorId, + graph::{Graph, GraphExecution, Optimization, TensorOpsDescription}, + FusionBackend, FusionProperties, FusionStatus, HandleContainer, TensorId, }; use burn_tensor::ops::{FloatElem, IntElem}; use std::sync::Arc; pub struct FusionServer where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - optimizations: Vec>, - graph: Graph, - pub(crate) handles: HandleContainer, - execution: G, - pub device: B::FusionDevice, + optimizations: Vec>, + graph: Graph, + pub(crate) handles: HandleContainer, + execution: G, + pub device: B::FusionDevice, } impl FusionServer where - B: FusionBackend, - G: GraphExecution, + B: FusionBackend, + G: GraphExecution, { - pub fn new(device: B::FusionDevice) -> Self { - let optimizations = B::operations(&device.clone().into()) - .into_iter() - .map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default()))) - .collect(); - - Self { - optimizations, - graph: Graph::new(), - handles: HandleContainer::new(device.clone()), - execution: G::default(), - device, + pub fn new(device: B::FusionDevice) -> Self { + let optimizations = B::operations(&device.clone().into()) + .into_iter() + .map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default()))) + .collect(); + + Self { + optimizations, + graph: Graph::new(), + handles: HandleContainer::new(device.clone()), + execution: G::default(), + device, + } } - } - - pub fn register(&mut self, ops: TensorOpsDescription) { - let ops = Arc::new(ops); - self.graph.add(ops.clone()); - - self - .optimizations - .iter_mut() - .for_each(|optimization| optimization.register(&ops)); - - self.execution.maybe_execute( - &mut self.graph, - &mut self.handles, - &mut self.optimizations, - false, - ); - } - - pub fn drain_graph(&mut self) { - if self.graph.is_empty() { - return; + + pub fn register(&mut self, ops: TensorOpsDescription) { + let ops = Arc::new(ops); + self.graph.add(ops.clone()); + + self.optimizations + .iter_mut() + .for_each(|optimization| optimization.register(&ops)); + + self.execution.maybe_execute( + &mut self.graph, + &mut self.handles, + &mut self.optimizations, + false, + ); } - self.execution.maybe_execute( - &mut self.graph, - &mut self.handles, - &mut self.optimizations, - true, - ); - } - - pub fn create_empty_handle(&mut self) -> Arc { - self.handles.create_tensor_uninit() - } - - pub fn read_float( - &mut self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> { - // Make sure all registered operations are executed. - // The underlying backend can still be async. - self.drain_graph(); - - let tensor = self.handles.get_float_tensor(&tensor); - B::into_data(tensor) - } - - pub fn read_int( - &mut self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader, D>> { - // Make sure all registered operations are executed. - // The underlying backend can still be async. - self.drain_graph(); - - let tensor = self.handles.get_int_tensor(&tensor); - B::int_into_data(tensor) - } - - pub fn read_bool( - &mut self, - tensor: crate::TensorDescription, - ) -> burn_tensor::Reader> { - // Make sure all registered operations are executed. - // The underlying backend can still be async. - self.drain_graph(); - - let tensor = self.handles.get_bool_tensor(&tensor); - B::bool_into_data(tensor) - } - - pub fn change_server_float( - &mut self, - tensor: &crate::TensorDescription, - device: &B::Device, - server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_float_tensor::(tensor); - let tensor = B::to_device(tensor, device); - let id = server_device.create_empty_handle(); - - server_device - .handles - .register_float_tensor(&id, tensor.clone()); - - id - } - pub fn change_server_int( - &mut self, - tensor: &crate::TensorDescription, - device: &B::Device, - server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_int_tensor::(tensor); - let tensor = B::int_to_device(tensor, device); - let id = server_device.create_empty_handle(); - - server_device - .handles - .register_int_tensor(&id, tensor.clone()); - - id - } - pub fn change_server_bool( - &mut self, - tensor: &crate::TensorDescription, - device: &B::Device, - server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_bool_tensor::(tensor); - let tensor = B::bool_to_device(tensor, device); - let id = server_device.create_empty_handle(); - - server_device - .handles - .register_bool_tensor(&id, tensor.clone()); - - id - } - - pub fn drop_tensor_handle(&mut self, id: TensorId) { - self.handles.handles_orphan.push(id); - } + pub fn drain_graph(&mut self) { + if self.graph.is_empty() { + return; + } + + self.execution.maybe_execute( + &mut self.graph, + &mut self.handles, + &mut self.optimizations, + true, + ); + } + + pub fn create_empty_handle(&mut self) -> Arc { + self.handles.create_tensor_uninit() + } + + pub fn read_float( + &mut self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> { + // Make sure all registered operations are executed. + // The underlying backend can still be async. + self.drain_graph(); + + let tensor = self.handles.get_float_tensor(&tensor); + B::into_data(tensor) + } + + pub fn read_int( + &mut self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader, D>> { + // Make sure all registered operations are executed. + // The underlying backend can still be async. + self.drain_graph(); + + let tensor = self.handles.get_int_tensor(&tensor); + B::int_into_data(tensor) + } + + pub fn read_bool( + &mut self, + tensor: crate::TensorDescription, + ) -> burn_tensor::Reader> { + // Make sure all registered operations are executed. + // The underlying backend can still be async. + self.drain_graph(); + + let tensor = self.handles.get_bool_tensor(&tensor); + B::bool_into_data(tensor) + } + + pub fn change_server_float( + &mut self, + tensor: &crate::TensorDescription, + device: &B::Device, + server_device: &mut Self, + ) -> Arc { + let tensor = self.handles.get_float_tensor::(tensor); + let tensor = B::to_device(tensor, device); + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_float_tensor(&id, tensor.clone()); + + id + } + pub fn change_server_int( + &mut self, + tensor: &crate::TensorDescription, + device: &B::Device, + server_device: &mut Self, + ) -> Arc { + let tensor = self.handles.get_int_tensor::(tensor); + let tensor = B::int_to_device(tensor, device); + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_int_tensor(&id, tensor.clone()); + + id + } + pub fn change_server_bool( + &mut self, + tensor: &crate::TensorDescription, + device: &B::Device, + server_device: &mut Self, + ) -> Arc { + let tensor = self.handles.get_bool_tensor::(tensor); + let tensor = B::bool_to_device(tensor, device); + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_bool_tensor(&id, tensor.clone()); + + id + } + + pub fn drop_tensor_handle(&mut self, id: TensorId) { + self.handles.handles_orphan.push(id); + } } diff --git a/burn-fusion/src/tensor.rs b/burn-fusion/src/tensor.rs index 7cac08af54..70ffcf3937 100644 --- a/burn-fusion/src/tensor.rs +++ b/burn-fusion/src/tensor.rs @@ -1,140 +1,140 @@ use crate::client::FusionClient; use burn_tensor::{ - backend::Backend, - ops::{FloatElem, IntElem}, - Data, Reader, Shape, + backend::Backend, + ops::{FloatElem, IntElem}, + Data, Reader, Shape, }; use std::sync::Arc; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. #[derive(Clone)] pub struct FusionTensor { - /// Tensor id. - pub id: Arc, - /// The shape of the tensor. - pub shape: Vec, - /// The [fusion client](FusionClient). - pub client: C, - // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. - // - // When a tensor is dropped and is still an orphan, we need to register it as such to avoid - // memory leak. Otherwise, the cleanup is going to happen during a graph execution. - pub(crate) is_orphan: bool, + /// Tensor id. + pub id: Arc, + /// The shape of the tensor. + pub shape: Vec, + /// The [fusion client](FusionClient). + pub client: C, + // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. + // + // When a tensor is dropped and is still an orphan, we need to register it as such to avoid + // memory leak. Otherwise, the cleanup is going to happen during a graph execution. + pub(crate) is_orphan: bool, } impl core::fmt::Debug for FusionTensor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str( - format!( - "{{ id: {:?}, shape: {:?}, should_drop: {:?}, backend: {:?}, device: {:?} }}", - self.id, - self.shape, - self.is_orphan, - ::name(), - self.client.device().clone().into(), - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + format!( + "{{ id: {:?}, shape: {:?}, should_drop: {:?}, backend: {:?}, device: {:?} }}", + self.id, + self.shape, + self.is_orphan, + ::name(), + self.client.device().clone().into(), + ) + .as_str(), + ) + } } impl FusionTensor { - pub(crate) fn new(id: Arc, shape: Vec, client: C) -> Self { - Self { - id, - shape, - client, - is_orphan: true, + pub(crate) fn new(id: Arc, shape: Vec, client: C) -> Self { + Self { + id, + shape, + client, + is_orphan: true, + } } - } - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.shape.clone()) - } - - fn status(&self) -> TensorStatus { - if Arc::strong_count(&self.id) <= 1 { - TensorStatus::ReadWrite - } else { - TensorStatus::ReadOnly + pub(crate) fn shape(&self) -> Shape { + Shape::from(self.shape.clone()) } - } - - /// Description to be used when using an uninitialized tensor as output. - pub(crate) fn to_description_out(&self) -> TensorDescription { - TensorDescription { - status: TensorStatus::NotInit, - shape: self.shape.clone(), - id: self.id.as_ref().clone(), + + fn status(&self) -> TensorStatus { + if Arc::strong_count(&self.id) <= 1 { + TensorStatus::ReadWrite + } else { + TensorStatus::ReadOnly + } } - } - /// Description to be used when using an initialized tensor used as input. - pub(crate) fn into_description(mut self) -> TensorDescription { - let status = self.status(); - let mut shape_out = Vec::new(); - core::mem::swap(&mut self.shape, &mut shape_out); + /// Description to be used when using an uninitialized tensor as output. + pub(crate) fn to_description_out(&self) -> TensorDescription { + TensorDescription { + status: TensorStatus::NotInit, + shape: self.shape.clone(), + id: self.id.as_ref().clone(), + } + } - if let TensorStatus::ReadWrite = status { - self.is_orphan = false; + /// Description to be used when using an initialized tensor used as input. + pub(crate) fn into_description(mut self) -> TensorDescription { + let status = self.status(); + let mut shape_out = Vec::new(); + core::mem::swap(&mut self.shape, &mut shape_out); + + if let TensorStatus::ReadWrite = status { + self.is_orphan = false; + } + + TensorDescription { + status, + shape: shape_out, + id: self.id.as_ref().clone(), + } } - TensorDescription { - status, - shape: shape_out, - id: self.id.as_ref().clone(), + pub(crate) fn into_data(self) -> Reader, D>> { + self.client + .clone() + .read_tensor_float(self.into_description()) } - } - - pub(crate) fn into_data(self) -> Reader, D>> { - self - .client - .clone() - .read_tensor_float(self.into_description()) - } - - pub(crate) fn int_into_data(self) -> Reader, D>> { - self.client.clone().read_tensor_int(self.into_description()) - } - - pub(crate) fn bool_into_data(self) -> Reader> { - self - .client - .clone() - .read_tensor_bool(self.into_description()) - } -} -impl Drop for FusionTensor { - fn drop(&mut self) { - if !self.is_orphan { - return; + pub(crate) fn int_into_data( + self, + ) -> Reader, D>> { + self.client.clone().read_tensor_int(self.into_description()) } - match self.status() { - TensorStatus::ReadWrite => { - self.client.register_orphan(&self.id); - } - TensorStatus::ReadOnly => {} - TensorStatus::NotInit => {} + pub(crate) fn bool_into_data(self) -> Reader> { + self.client + .clone() + .read_tensor_bool(self.into_description()) + } +} + +impl Drop for FusionTensor { + fn drop(&mut self) { + if !self.is_orphan { + return; + } + + match self.status() { + TensorStatus::ReadWrite => { + self.client.register_orphan(&self.id); + } + TensorStatus::ReadOnly => {} + TensorStatus::NotInit => {} + } } - } } /// The tensor unique identifier. #[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct TensorId { - value: u64, + value: u64, } /// The status of the current tensor. #[derive(Hash, Clone, Debug, PartialEq, Eq)] pub enum TensorStatus { - /// The tensor can be read, but not written. - ReadOnly, - /// The tensor can be mutated inplace. - ReadWrite, - /// No handle exists for that tensor. - NotInit, + /// The tensor can be read, but not written. + ReadOnly, + /// The tensor can be mutated inplace. + ReadWrite, + /// No handle exists for that tensor. + NotInit, } /// A tensor definition represents a snapshot of a tensor when it was used. @@ -149,17 +149,17 @@ pub enum TensorStatus { /// 4. Status::ReadWrite #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct TensorDescription { - /// The [tensor id](TensorId). - pub id: TensorId, - /// The shape of the tensor. - pub shape: Vec, - /// The [status](TensorStatus) of the tensor when it was used. - pub status: TensorStatus, + /// The [tensor id](TensorId). + pub id: TensorId, + /// The shape of the tensor. + pub shape: Vec, + /// The [status](TensorStatus) of the tensor when it was used. + pub status: TensorStatus, } impl TensorId { - /// Create a new tensor id. - pub fn new(value: u64) -> Self { - Self { value } - } + /// Create a new tensor id. + pub fn new(value: u64) -> Self { + Self { value } + } } diff --git a/burn-import/build.rs b/burn-import/build.rs index da847f00dc..f6cd681658 100644 --- a/burn-import/build.rs +++ b/burn-import/build.rs @@ -1,11 +1,11 @@ fn main() { - if cfg!(feature = "onnx") { - // Generate the onnx protobuf files - protobuf_codegen::Codegen::new() - .pure() - .includes(["src"]) - .input("src/onnx/protos/onnx.proto") - .cargo_out_dir("onnx-protos") - .run_from_script(); - } + if cfg!(feature = "onnx") { + // Generate the onnx protobuf files + protobuf_codegen::Codegen::new() + .pure() + .includes(["src"]) + .input("src/onnx/protos/onnx.proto") + .cargo_out_dir("onnx-protos") + .run_from_script(); + } } diff --git a/burn-import/onnx-tests/build.rs b/burn-import/onnx-tests/build.rs index 62afd67f7d..dc1576099e 100644 --- a/burn-import/onnx-tests/build.rs +++ b/burn-import/onnx-tests/build.rs @@ -1,114 +1,114 @@ use burn_import::onnx::{ModelGen, RecordType}; fn main() { - // Re-run this build script if the onnx-tests directory changes. - println!("cargo:rerun-if-changed=tests"); + // Re-run this build script if the onnx-tests directory changes. + println!("cargo:rerun-if-changed=tests"); - // Add onnx models. - ModelGen::new() - .input("tests/add/add_int.onnx") - .input("tests/add/add.onnx") - .input("tests/avg_pool2d/avg_pool2d.onnx") - .input("tests/batch_norm/batch_norm.onnx") - .input("tests/clip/clip_opset16.onnx") - .input("tests/clip/clip_opset7.onnx") - .input("tests/concat/concat.onnx") - .input("tests/conv1d/conv1d.onnx") - .input("tests/conv2d/conv2d.onnx") - .input("tests/div/div.onnx") - .input("tests/dropout/dropout_opset16.onnx") - .input("tests/dropout/dropout_opset7.onnx") - .input("tests/equal/equal.onnx") - .input("tests/erf/erf.onnx") - .input("tests/flatten/flatten.onnx") - .input("tests/gather/gather.onnx") - .input("tests/global_avr_pool/global_avr_pool.onnx") - .input("tests/linear/linear.onnx") - .input("tests/log_softmax/log_softmax.onnx") - .input("tests/maxpool2d/maxpool2d.onnx") - .input("tests/mul/mul.onnx") - .input("tests/recip/recip.onnx") - .input("tests/relu/relu.onnx") - .input("tests/reshape/reshape.onnx") - .input("tests/sigmoid/sigmoid.onnx") - .input("tests/softmax/softmax.onnx") - .input("tests/sub/sub_int.onnx") - .input("tests/sub/sub.onnx") - .input("tests/tanh/tanh.onnx") - .input("tests/transpose/transpose.onnx") - .out_dir("model/") - .run_from_script(); + // Add onnx models. + ModelGen::new() + .input("tests/add/add_int.onnx") + .input("tests/add/add.onnx") + .input("tests/avg_pool2d/avg_pool2d.onnx") + .input("tests/batch_norm/batch_norm.onnx") + .input("tests/clip/clip_opset16.onnx") + .input("tests/clip/clip_opset7.onnx") + .input("tests/concat/concat.onnx") + .input("tests/conv1d/conv1d.onnx") + .input("tests/conv2d/conv2d.onnx") + .input("tests/div/div.onnx") + .input("tests/dropout/dropout_opset16.onnx") + .input("tests/dropout/dropout_opset7.onnx") + .input("tests/equal/equal.onnx") + .input("tests/erf/erf.onnx") + .input("tests/flatten/flatten.onnx") + .input("tests/gather/gather.onnx") + .input("tests/global_avr_pool/global_avr_pool.onnx") + .input("tests/linear/linear.onnx") + .input("tests/log_softmax/log_softmax.onnx") + .input("tests/maxpool2d/maxpool2d.onnx") + .input("tests/mul/mul.onnx") + .input("tests/recip/recip.onnx") + .input("tests/relu/relu.onnx") + .input("tests/reshape/reshape.onnx") + .input("tests/sigmoid/sigmoid.onnx") + .input("tests/softmax/softmax.onnx") + .input("tests/sub/sub_int.onnx") + .input("tests/sub/sub.onnx") + .input("tests/tanh/tanh.onnx") + .input("tests/transpose/transpose.onnx") + .out_dir("model/") + .run_from_script(); - // The following tests are used to generate the model with different record types. - // (e.g. bincode, pretty_json, etc.) Do not need to add new tests here, just use the default - // record type to the ModelGen::new() call above. + // The following tests are used to generate the model with different record types. + // (e.g. bincode, pretty_json, etc.) Do not need to add new tests here, just use the default + // record type to the ModelGen::new() call above. - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk/") - .record_type(RecordType::NamedMpk) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk/") + .record_type(RecordType::NamedMpk) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk_half/") - .record_type(RecordType::NamedMpk) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk_half/") + .record_type(RecordType::NamedMpk) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/pretty_json/") - .record_type(RecordType::PrettyJson) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/pretty_json/") + .record_type(RecordType::PrettyJson) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/pretty_json_half/") - .record_type(RecordType::PrettyJson) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/pretty_json_half/") + .record_type(RecordType::PrettyJson) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk_gz/") - .record_type(RecordType::NamedMpkGz) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk_gz/") + .record_type(RecordType::NamedMpkGz) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/named_mpk_gz_half/") - .record_type(RecordType::NamedMpkGz) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/named_mpk_gz_half/") + .record_type(RecordType::NamedMpkGz) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode/") - .record_type(RecordType::Bincode) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode/") + .record_type(RecordType::Bincode) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode_half/") - .record_type(RecordType::Bincode) - .half_precision(true) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode_half/") + .record_type(RecordType::Bincode) + .half_precision(true) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode_embedded/") - .embed_states(true) - .record_type(RecordType::Bincode) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode_embedded/") + .embed_states(true) + .record_type(RecordType::Bincode) + .run_from_script(); - ModelGen::new() - .input("tests/conv1d/conv1d.onnx") - .out_dir("model/bincode_embedded_half/") - .embed_states(true) - .half_precision(true) - .record_type(RecordType::Bincode) - .run_from_script(); + ModelGen::new() + .input("tests/conv1d/conv1d.onnx") + .out_dir("model/bincode_embedded_half/") + .embed_states(true) + .half_precision(true) + .record_type(RecordType::Bincode) + .run_from_script(); - // panic!("Purposefully failing build to output logs."); + // panic!("Purposefully failing build to output logs."); } diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs index 3863bc9204..b10c45849c 100644 --- a/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/burn-import/onnx-tests/tests/onnx_tests.rs @@ -13,599 +13,599 @@ macro_rules! include_models { // ATTENTION: Modify this macro to include all models in the `model` directory. include_models!( - add_int, - add, - avg_pool2d, - batch_norm, - clip_opset16, - clip_opset7, - concat, - conv1d, - conv2d, - div, - dropout_opset16, - dropout_opset7, - equal, - erf, - flatten, - gather, - global_avr_pool, - linear, - log_softmax, - maxpool2d, - mul, - recip, - relu, - reshape, - sigmoid, - softmax, - sub_int, - sub, - tanh, - transpose + add_int, + add, + avg_pool2d, + batch_norm, + clip_opset16, + clip_opset7, + concat, + conv1d, + conv2d, + div, + dropout_opset16, + dropout_opset7, + equal, + erf, + flatten, + gather, + global_avr_pool, + linear, + log_softmax, + maxpool2d, + mul, + recip, + relu, + reshape, + sigmoid, + softmax, + sub_int, + sub, + tanh, + transpose ); #[cfg(test)] mod tests { - use core::f64::consts; - - use super::*; - - use burn::tensor::{Data, Int, Shape, Tensor}; + use core::f64::consts; + + use super::*; + + use burn::tensor::{Data, Int, Shape, Tensor}; - use float_cmp::ApproxEq; - - type Backend = burn_ndarray::NdArray; + use float_cmp::ApproxEq; + + type Backend = burn_ndarray::NdArray; - #[test] - fn add_scalar_to_tensor_and_tensor_to_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: add::Model = add::Model::default(); + #[test] + fn add_scalar_to_tensor_and_tensor_to_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: add::Model = add::Model::default(); - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let scalar = 2f64; - let output = model.forward(input, scalar); - let expected = Data::from([[[[9., 10., 11., 12.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 2f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[9., 10., 11., 12.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn add_scalar_to_int_tensor_and_int_tensor_to_int_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: add_int::Model = add_int::Model::default(); - - // Run the model - let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); - let scalar = 2; - let output = model.forward(input, scalar); - let expected = Data::from([[[[9, 11, 13, 15]]]]); + #[test] + fn add_scalar_to_int_tensor_and_int_tensor_to_int_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: add_int::Model = add_int::Model::default(); + + // Run the model + let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); + let scalar = 2; + let output = model.forward(input, scalar); + let expected = Data::from([[[[9, 11, 13, 15]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn sub_scalar_from_tensor_and_tensor_from_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: sub::Model = sub::Model::default(); + #[test] + fn sub_scalar_from_tensor_and_tensor_from_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: sub::Model = sub::Model::default(); - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let scalar = 3.0f64; - let output = model.forward(input, scalar); - let expected = Data::from([[[[6., 7., 8., 9.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 3.0f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[6., 7., 8., 9.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn sub_scalar_from_int_tensor_and_int_tensor_from_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: sub_int::Model = sub_int::Model::default(); + #[test] + fn sub_scalar_from_int_tensor_and_int_tensor_from_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: sub_int::Model = sub_int::Model::default(); - // Run the model - let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); - let scalar = 3; - let output = model.forward(input, scalar); - let expected = Data::from([[[[6, 6, 6, 6]]]]); + // Run the model + let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]]); + let scalar = 3; + let output = model.forward(input, scalar); + let expected = Data::from([[[[6, 6, 6, 6]]]]); - assert_eq!(output.to_data(), expected); - } - #[test] - fn mul_scalar_with_tensor_and_tensor_with_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: mul::Model = mul::Model::default(); + assert_eq!(output.to_data(), expected); + } + #[test] + fn mul_scalar_with_tensor_and_tensor_with_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: mul::Model = mul::Model::default(); - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let scalar = 6.0f64; - let output = model.forward(input, scalar); - let expected = Data::from([[[[126., 252., 378., 504.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 6.0f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[126., 252., 378., 504.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn div_tensor_by_scalar_and_tensor_by_tensor() { - // Initialize the model without weights (because the exported file does not contain them) - let model: div::Model = div::Model::new(); + #[test] + fn div_tensor_by_scalar_and_tensor_by_tensor() { + // Initialize the model without weights (because the exported file does not contain them) + let model: div::Model = div::Model::new(); - // Run the model - let input = Tensor::::from_floats([[[[3., 6., 6., 9.]]]]); - let scalar1 = 9.0f64; - let scalar2 = 3.0f64; - let output = model.forward(input, scalar1, scalar2); - let expected = Data::from([[[[1., 2., 2., 3.]]]]); + // Run the model + let input = Tensor::::from_floats([[[[3., 6., 6., 9.]]]]); + let scalar1 = 9.0f64; + let scalar2 = 3.0f64; + let output = model.forward(input, scalar1, scalar2); + let expected = Data::from([[[[1., 2., 2., 3.]]]]); - assert_eq!(output.to_data(), expected); - } + assert_eq!(output.to_data(), expected); + } - #[test] - fn concat_tensors() { - // Initialize the model - let model: concat::Model = concat::Model::new(); + #[test] + fn concat_tensors() { + // Initialize the model + let model: concat::Model = concat::Model::new(); - // Run the model - let input = Tensor::::zeros([1, 2, 3, 5]); + // Run the model + let input = Tensor::::zeros([1, 2, 3, 5]); - let output = model.forward(input); + let output = model.forward(input); - let expected = Shape::from([1, 18, 3, 5]); + let expected = Shape::from([1, 18, 3, 5]); - assert_eq!(output.shape(), expected); - } + assert_eq!(output.shape(), expected); + } - #[test] - fn conv1d() { - // Initialize the model with weights (loaded from the exported file) - let model: conv1d::Model = conv1d::Model::default(); + #[test] + fn conv1d() { + // Initialize the model with weights (loaded from the exported file) + let model: conv1d::Model = conv1d::Model::default(); - // Run the model with pi as input for easier testing - let input = Tensor::::full([6, 4, 10], consts::PI); + // Run the model with pi as input for easier testing + let input = Tensor::::full([6, 4, 10], consts::PI); - let output = model.forward(input); + let output = model.forward(input); - // test the output shape - let expected_shape: Shape<3> = Shape::from([6, 2, 7]); - assert_eq!(output.shape(), expected_shape); + // test the output shape + let expected_shape: Shape<3> = Shape::from([6, 2, 7]); + assert_eq!(output.shape(), expected_shape); - // We are using the sum of the output tensor to test the correctness of the conv1d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum = output.sum().into_scalar(); - let expected_sum = -54.549_243; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + // We are using the sum of the output tensor to test the correctness of the conv1d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); + let expected_sum = -54.549_243; // from pytorch + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn conv2d() { - // Initialize the model with weights (loaded from the exported file) - let model: conv2d::Model = conv2d::Model::default(); + #[test] + fn conv2d() { + // Initialize the model with weights (loaded from the exported file) + let model: conv2d::Model = conv2d::Model::default(); - // Run the model with ones as input for easier testing - let input = Tensor::::ones([2, 4, 10, 15]); + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15]); - let output = model.forward(input); + let output = model.forward(input); - let expected_shape = Shape::from([2, 6, 6, 15]); - assert_eq!(output.shape(), expected_shape); + let expected_shape = Shape::from([2, 6, 6, 15]); + assert_eq!(output.shape(), expected_shape); - // We are using the sum of the output tensor to test the correctness of the conv2d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum = output.sum().into_scalar(); + // We are using the sum of the output tensor to test the correctness of the conv2d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); - let expected_sum = -113.869_99; // from pytorch + let expected_sum = -113.869_99; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn dropout_opset16() { - let model: dropout_opset16::Model = dropout_opset16::Model::default(); + #[test] + fn dropout_opset16() { + let model: dropout_opset16::Model = dropout_opset16::Model::default(); - // Run the model with ones as input for easier testing - let input = Tensor::::ones([2, 4, 10, 15]); + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15]); - let output = model.forward(input); + let output = model.forward(input); - let expected_shape = Shape::from([2, 4, 10, 15]); - assert_eq!(output.shape(), expected_shape); + let expected_shape = Shape::from([2, 4, 10, 15]); + assert_eq!(output.shape(), expected_shape); - let output_sum = output.sum().into_scalar(); + let output_sum = output.sum().into_scalar(); - let expected_sum = 1200.0; // from pytorch + let expected_sum = 1200.0; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn dropout_opset7() { - let model: dropout_opset7::Model = dropout_opset7::Model::default(); + #[test] + fn dropout_opset7() { + let model: dropout_opset7::Model = dropout_opset7::Model::default(); - // Run the model with ones as input for easier testing - let input = Tensor::::ones([2, 4, 10, 15]); + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15]); - let output = model.forward(input); + let output = model.forward(input); - let expected_shape = Shape::from([2, 4, 10, 15]); - assert_eq!(output.shape(), expected_shape); + let expected_shape = Shape::from([2, 4, 10, 15]); + assert_eq!(output.shape(), expected_shape); - let output_sum = output.sum().into_scalar(); + let output_sum = output.sum().into_scalar(); - let expected_sum = 1200.0; // from pytorch + let expected_sum = 1200.0; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } - #[test] - fn erf() { - let model: erf::Model = erf::Model::default(); + #[test] + fn erf() { + let model: erf::Model = erf::Model::default(); - let input = Tensor::::from_data([[[[1.0, 2.0, 3.0, 4.0]]]]); - let output = model.forward(input); - let expected = Tensor::::from_data([[[[0.8427, 0.9953, 1.0000, 1.0000]]]]); + let input = Tensor::::from_data([[[[1.0, 2.0, 3.0, 4.0]]]]); + let output = model.forward(input); + let expected = Tensor::::from_data([[[[0.8427, 0.9953, 1.0000, 1.0000]]]]); - output.to_data().assert_approx_eq(&expected.to_data(), 4); - } + output.to_data().assert_approx_eq(&expected.to_data(), 4); + } - #[test] - fn gather() { - // Initialize the model with weights (loaded from the exported file) - let model: gather::Model = gather::Model::default(); + #[test] + fn gather() { + // Initialize the model with weights (loaded from the exported file) + let model: gather::Model = gather::Model::default(); - // Run the model - let input = Tensor::::from_floats([[1., 2.], [3., 4.]]); - let index = Tensor::::from_ints([[0, 0], [1, 0]]); - let output = model.forward(input, index); - let expected = Data::from([[1., 1.], [4., 3.]]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn globalavrpool_1d_2d() { - // The model contains 1d and 2d global average pooling nodes - let model: global_avr_pool::Model = global_avr_pool::Model::default(); - - // Run the model with ones as input for easier testing - let input_1d = Tensor::::ones([2, 4, 10]); - let input_2d = Tensor::::ones([3, 10, 3, 15]); - - let (output_1d, output_2d) = model.forward(input_1d, input_2d); - - let expected_shape_1d = Shape::from([2, 4, 1]); - let expected_shape_2d = Shape::from([3, 10, 1, 1]); - assert_eq!(output_1d.shape(), expected_shape_1d); - assert_eq!(output_2d.shape(), expected_shape_2d); - - let output_sum_1d = output_1d.sum().into_scalar(); - let output_sum_2d = output_2d.sum().into_scalar(); - - let expected_sum_1d = 8.0; // from pytorch - let expected_sum_2d = 30.0; // from pytorch - - assert!(expected_sum_1d.approx_eq(output_sum_1d, (1.0e-4, 2))); - assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2))); - } - - #[test] - fn softmax() { - // Initialize the model without weights (because the exported file does not contain them) - let model: softmax::Model = softmax::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.36830685, 0.29917702, 0.33251613], - [0.521_469_2, 0.13475533, 0.343_775_5], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn log_softmax() { - // Initialize the model without weights (because the exported file does not contain them) - let model: log_softmax::Model = log_softmax::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [-0.998_838_9, -1.206_719_9, -1.101_067], - [-0.651_105_1, -2.004_294_6, -1.067_766_4], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn maxpool2d() { - // Initialize the model without weights (because the exported file does not contain them) - let model: maxpool2d::Model = maxpool2d::Model::new(); - - // Run the model - let input = Tensor::::from_floats([[[ - [1.927, 1.487, 0.901, -2.106, 0.678], - [-1.235, -0.043, -1.605, -0.752, -0.687], - [-0.493, 0.241, -1.111, 0.092, -2.317], - [-0.217, -1.385, -0.396, 0.803, -0.622], - [-0.592, -0.063, -0.829, 0.331, -1.558], - ]]]); - let output = model.forward(input); - let expected = Data::from([[[ - [0.901, 1.927, 1.487, 0.901], - [0.901, 1.927, 1.487, 0.901], - [-0.396, 0.803, 0.241, -0.396], - ]]]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn avg_pool2d() { - // Initialize the model without weights (because the exported file does not contain them) - let model: avg_pool2d::Model = avg_pool2d::Model::new(); - - // Run the model - let input = Tensor::::from_floats([[[ - [-0.077, 0.360, -0.782, 0.072, 0.665], - [-0.287, 1.621, -1.597, -0.052, 0.611], - [0.760, -0.034, -0.345, 0.494, -0.078], - [-1.805, -0.476, 0.205, 0.338, 1.353], - [0.374, 0.013, 0.774, -0.109, -0.271], - ]]]); - let output = model.forward(input); - let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]); - - output.to_data().assert_approx_eq(&expected, 3); - } - - #[test] - fn reshape() { - // Initialize the model without weights (because the exported file does not contain them) - let model: reshape::Model = reshape::Model::new(); - - // Run the model - let input = Tensor::::from_floats([0., 1., 2., 3.]); - let output = model.forward(input); - let expected = Data::from([[0., 1., 2., 3.]]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn flatten() { - // Initialize the model without weights (because the exported file does not contain them) - let model: flatten::Model = flatten::Model::new(); - - // Run the model - let input = Tensor::::ones([1, 5, 15]); - let output = model.forward(input); - - let expected_shape = Shape::from([1, 75]); - assert_eq!(expected_shape, output.shape()); - } - - #[test] - fn batch_norm() { - let model: batch_norm::Model = batch_norm::Model::default(); - - // Run the model with ones as input for easier testing - let input = Tensor::::ones([1, 20, 1]); - let output = model.forward(input); - - let expected_shape = Shape::from([1, 5, 2, 2]); - assert_eq!(output.shape(), expected_shape); - - let output_sum = output.sum().into_scalar(); - let expected_sum = 19.999_802; // from pytorch - assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2))); - } - - #[test] - fn relu() { - // Initialize the model without weights (because the exported file does not contain them) - let model: relu::Model = relu::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, 0.00000000, 0.00000000], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn sigmoid() { - // Initialize the model without weights (because the exported file does not contain them) - let model: sigmoid::Model = sigmoid::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.58338636, 0.532_157_9, 0.55834854], - [0.557_33, 0.24548186, 0.45355222], - ]); - - output.to_data().assert_approx_eq(&expected, 7); - } - - #[test] - fn transpose() { - // Initialize the model without weights (because the exported file does not contain them) - let model: transpose::Model = transpose::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], - ]); - let output = model.forward(input); - let expected = Data::from([ - [0.33669037, 0.23033303], - [0.128_809_4, -1.122_856_4], - [0.23446237, -0.18632829], - ]); - - assert_eq!(output.to_data(), expected); - } - - #[test] - fn equal_scalar_to_scalar_and_tensor_to_tensor() { - // Initialize the model with weights (loaded from the exported file) - let model: equal::Model = equal::Model::default(); - - // Run the model - let input = Tensor::::from_floats([[[[1., 1., 1., 1.]]]]); - - let scalar = 2f64; - let (tensor_out, scalar_out) = model.forward(input, scalar); - let expected_tensor = Data::from([[[[true, true, true, true]]]]); - let expected_scalar = false; - - assert_eq!(tensor_out.to_data(), expected_tensor); - assert_eq!(scalar_out, expected_scalar); - } - - #[test] - fn clip_opset16() { - // Initialize the model without weights (because the exported file does not contain them) - let model: clip_opset16::Model = clip_opset16::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let (output1, output2, output3) = model.forward(input); - let expected1 = Data::from([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); - let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); - - assert_eq!(output1.to_data(), expected1); - assert_eq!(output2.to_data(), expected2); - assert_eq!(output3.to_data(), expected3); - } - - #[test] - fn clip_opset7() { - // Initialize the model without weights (because the exported file does not contain them) - let model: clip_opset7::Model = clip_opset7::Model::new(); - - // Run the model - let input = Tensor::::from_floats([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let (output1, output2, output3) = model.forward(input); - let expected1 = Data::from([ - 0.88226926, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); - let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); - - assert_eq!(output1.to_data(), expected1); - assert_eq!(output2.to_data(), expected2); - assert_eq!(output3.to_data(), expected3); - } - - #[test] - fn linear() { - // Initialize the model with weights (loaded from the exported file) - let model: linear::Model = linear::Model::default(); - #[allow(clippy::approx_constant)] - let input1 = Tensor::::full([4, 3], 3.14); - #[allow(clippy::approx_constant)] - let input2 = Tensor::::full([2, 5], 3.14); - #[allow(clippy::approx_constant)] - let input3 = Tensor::::full([3, 2, 7], 3.14); - - let (output1, output2, output3) = model.forward(input1, input2, input3); - - // test the output shape - let expected_shape1: Shape<2> = Shape::from([4, 4]); - let expected_shape2: Shape<2> = Shape::from([2, 6]); - let expected_shape3: Shape<3> = Shape::from([3, 2, 8]); - assert_eq!(output1.shape(), expected_shape1); - assert_eq!(output2.shape(), expected_shape2); - assert_eq!(output3.shape(), expected_shape3); - - // We are using the sum of the output tensor to test the correctness of the conv1d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum1 = output1.sum().into_scalar(); - let output_sum2 = output2.sum().into_scalar(); - let output_sum3 = output3.sum().into_scalar(); - - let expected_sum1 = -9.655_477; // from pytorch - let expected_sum2 = -8.053_822; // from pytorch - let expected_sum3 = 27.575_281; // from pytorch - - assert!(expected_sum1.approx_eq(output_sum1, (1.0e-6, 2))); - assert!(expected_sum2.approx_eq(output_sum2, (1.0e-6, 2))); - assert!(expected_sum3.approx_eq(output_sum3, (1.0e-6, 2))); - } - - #[test] - fn tanh() { - // Initialize the model - let model = tanh::Model::::new(); - - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let output = model.forward(input); - // data from pyTorch - let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]); - output.to_data().assert_approx_eq(&expected, 4); - } - - #[test] - fn recip() { - // Initialize the model - let model = recip::Model::::new(); - - // Run the model - let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); - let output = model.forward(input); - // data from pyTorch - let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]); - output.to_data().assert_approx_eq(&expected, 4); - } + // Run the model + let input = Tensor::::from_floats([[1., 2.], [3., 4.]]); + let index = Tensor::::from_ints([[0, 0], [1, 0]]); + let output = model.forward(input, index); + let expected = Data::from([[1., 1.], [4., 3.]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn globalavrpool_1d_2d() { + // The model contains 1d and 2d global average pooling nodes + let model: global_avr_pool::Model = global_avr_pool::Model::default(); + + // Run the model with ones as input for easier testing + let input_1d = Tensor::::ones([2, 4, 10]); + let input_2d = Tensor::::ones([3, 10, 3, 15]); + + let (output_1d, output_2d) = model.forward(input_1d, input_2d); + + let expected_shape_1d = Shape::from([2, 4, 1]); + let expected_shape_2d = Shape::from([3, 10, 1, 1]); + assert_eq!(output_1d.shape(), expected_shape_1d); + assert_eq!(output_2d.shape(), expected_shape_2d); + + let output_sum_1d = output_1d.sum().into_scalar(); + let output_sum_2d = output_2d.sum().into_scalar(); + + let expected_sum_1d = 8.0; // from pytorch + let expected_sum_2d = 30.0; // from pytorch + + assert!(expected_sum_1d.approx_eq(output_sum_1d, (1.0e-4, 2))); + assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2))); + } + + #[test] + fn softmax() { + // Initialize the model without weights (because the exported file does not contain them) + let model: softmax::Model = softmax::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.36830685, 0.29917702, 0.33251613], + [0.521_469_2, 0.13475533, 0.343_775_5], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn log_softmax() { + // Initialize the model without weights (because the exported file does not contain them) + let model: log_softmax::Model = log_softmax::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [-0.998_838_9, -1.206_719_9, -1.101_067], + [-0.651_105_1, -2.004_294_6, -1.067_766_4], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn maxpool2d() { + // Initialize the model without weights (because the exported file does not contain them) + let model: maxpool2d::Model = maxpool2d::Model::new(); + + // Run the model + let input = Tensor::::from_floats([[[ + [1.927, 1.487, 0.901, -2.106, 0.678], + [-1.235, -0.043, -1.605, -0.752, -0.687], + [-0.493, 0.241, -1.111, 0.092, -2.317], + [-0.217, -1.385, -0.396, 0.803, -0.622], + [-0.592, -0.063, -0.829, 0.331, -1.558], + ]]]); + let output = model.forward(input); + let expected = Data::from([[[ + [0.901, 1.927, 1.487, 0.901], + [0.901, 1.927, 1.487, 0.901], + [-0.396, 0.803, 0.241, -0.396], + ]]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn avg_pool2d() { + // Initialize the model without weights (because the exported file does not contain them) + let model: avg_pool2d::Model = avg_pool2d::Model::new(); + + // Run the model + let input = Tensor::::from_floats([[[ + [-0.077, 0.360, -0.782, 0.072, 0.665], + [-0.287, 1.621, -1.597, -0.052, 0.611], + [0.760, -0.034, -0.345, 0.494, -0.078], + [-1.805, -0.476, 0.205, 0.338, 1.353], + [0.374, 0.013, 0.774, -0.109, -0.271], + ]]]); + let output = model.forward(input); + let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]); + + output.to_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn reshape() { + // Initialize the model without weights (because the exported file does not contain them) + let model: reshape::Model = reshape::Model::new(); + + // Run the model + let input = Tensor::::from_floats([0., 1., 2., 3.]); + let output = model.forward(input); + let expected = Data::from([[0., 1., 2., 3.]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn flatten() { + // Initialize the model without weights (because the exported file does not contain them) + let model: flatten::Model = flatten::Model::new(); + + // Run the model + let input = Tensor::::ones([1, 5, 15]); + let output = model.forward(input); + + let expected_shape = Shape::from([1, 75]); + assert_eq!(expected_shape, output.shape()); + } + + #[test] + fn batch_norm() { + let model: batch_norm::Model = batch_norm::Model::default(); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([1, 20, 1]); + let output = model.forward(input); + + let expected_shape = Shape::from([1, 5, 2, 2]); + assert_eq!(output.shape(), expected_shape); + + let output_sum = output.sum().into_scalar(); + let expected_sum = 19.999_802; // from pytorch + assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2))); + } + + #[test] + fn relu() { + // Initialize the model without weights (because the exported file does not contain them) + let model: relu::Model = relu::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, 0.00000000, 0.00000000], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn sigmoid() { + // Initialize the model without weights (because the exported file does not contain them) + let model: sigmoid::Model = sigmoid::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.58338636, 0.532_157_9, 0.55834854], + [0.557_33, 0.24548186, 0.45355222], + ]); + + output.to_data().assert_approx_eq(&expected, 7); + } + + #[test] + fn transpose() { + // Initialize the model without weights (because the exported file does not contain them) + let model: transpose::Model = transpose::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + [0.33669037, 0.128_809_4, 0.23446237], + [0.23033303, -1.122_856_4, -0.18632829], + ]); + let output = model.forward(input); + let expected = Data::from([ + [0.33669037, 0.23033303], + [0.128_809_4, -1.122_856_4], + [0.23446237, -0.18632829], + ]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn equal_scalar_to_scalar_and_tensor_to_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: equal::Model = equal::Model::default(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 1., 1., 1.]]]]); + + let scalar = 2f64; + let (tensor_out, scalar_out) = model.forward(input, scalar); + let expected_tensor = Data::from([[[[true, true, true, true]]]]); + let expected_scalar = false; + + assert_eq!(tensor_out.to_data(), expected_tensor); + assert_eq!(scalar_out, expected_scalar); + } + + #[test] + fn clip_opset16() { + // Initialize the model without weights (because the exported file does not contain them) + let model: clip_opset16::Model = clip_opset16::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let (output1, output2, output3) = model.forward(input); + let expected1 = Data::from([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); + let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); + + assert_eq!(output1.to_data(), expected1); + assert_eq!(output2.to_data(), expected2); + assert_eq!(output3.to_data(), expected3); + } + + #[test] + fn clip_opset7() { + // Initialize the model without weights (because the exported file does not contain them) + let model: clip_opset7::Model = clip_opset7::Model::new(); + + // Run the model + let input = Tensor::::from_floats([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let (output1, output2, output3) = model.forward(input); + let expected1 = Data::from([ + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, + ]); + let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); + let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); + + assert_eq!(output1.to_data(), expected1); + assert_eq!(output2.to_data(), expected2); + assert_eq!(output3.to_data(), expected3); + } + + #[test] + fn linear() { + // Initialize the model with weights (loaded from the exported file) + let model: linear::Model = linear::Model::default(); + #[allow(clippy::approx_constant)] + let input1 = Tensor::::full([4, 3], 3.14); + #[allow(clippy::approx_constant)] + let input2 = Tensor::::full([2, 5], 3.14); + #[allow(clippy::approx_constant)] + let input3 = Tensor::::full([3, 2, 7], 3.14); + + let (output1, output2, output3) = model.forward(input1, input2, input3); + + // test the output shape + let expected_shape1: Shape<2> = Shape::from([4, 4]); + let expected_shape2: Shape<2> = Shape::from([2, 6]); + let expected_shape3: Shape<3> = Shape::from([3, 2, 8]); + assert_eq!(output1.shape(), expected_shape1); + assert_eq!(output2.shape(), expected_shape2); + assert_eq!(output3.shape(), expected_shape3); + + // We are using the sum of the output tensor to test the correctness of the conv1d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum1 = output1.sum().into_scalar(); + let output_sum2 = output2.sum().into_scalar(); + let output_sum3 = output3.sum().into_scalar(); + + let expected_sum1 = -9.655_477; // from pytorch + let expected_sum2 = -8.053_822; // from pytorch + let expected_sum3 = 27.575_281; // from pytorch + + assert!(expected_sum1.approx_eq(output_sum1, (1.0e-6, 2))); + assert!(expected_sum2.approx_eq(output_sum2, (1.0e-6, 2))); + assert!(expected_sum3.approx_eq(output_sum3, (1.0e-6, 2))); + } + + #[test] + fn tanh() { + // Initialize the model + let model = tanh::Model::::new(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let output = model.forward(input); + // data from pyTorch + let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]); + output.to_data().assert_approx_eq(&expected, 4); + } + + #[test] + fn recip() { + // Initialize the model + let model = recip::Model::::new(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let output = model.forward(input); + // data from pyTorch + let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]); + output.to_data().assert_approx_eq(&expected, 4); + } } diff --git a/burn-import/onnx-tests/tests/record_type_tests.rs b/burn-import/onnx-tests/tests/record_type_tests.rs index d532672125..98558b5019 100644 --- a/burn-import/onnx-tests/tests/record_type_tests.rs +++ b/burn-import/onnx-tests/tests/record_type_tests.rs @@ -6,58 +6,58 @@ // different. macro_rules! test_model { - ($mod_name:ident) => { - test_model!($mod_name, 1.0e-4); // Default tolerance - }; - ($mod_name:ident, $tolerance:expr) => { - pub mod $mod_name { - include!(concat!( - env!("OUT_DIR"), - "/model/", - stringify!($mod_name), - "/conv1d.rs" - )); - } - - #[test] - fn $mod_name() { - // Initialize the model with weights (loaded from the exported file) - let model: $mod_name::Model = $mod_name::Model::default(); - - // Run the model with pi as input for easier testing - let input = Tensor::::full([6, 4, 10], consts::PI); - - let output = model.forward(input); - - // test the output shape - let expected_shape: Shape<3> = Shape::from([6, 2, 7]); - assert_eq!(output.shape(), expected_shape); - - // We are using the sum of the output tensor to test the correctness of the conv1d node - // because the output tensor is too large to compare with the expected tensor. - let output_sum = output.sum().into_scalar(); - let expected_sum = -54.549_243; // from pytorch - assert!(expected_sum.approx_eq(output_sum, ($tolerance, 2))); - } - }; + ($mod_name:ident) => { + test_model!($mod_name, 1.0e-4); // Default tolerance + }; + ($mod_name:ident, $tolerance:expr) => { + pub mod $mod_name { + include!(concat!( + env!("OUT_DIR"), + "/model/", + stringify!($mod_name), + "/conv1d.rs" + )); + } + + #[test] + fn $mod_name() { + // Initialize the model with weights (loaded from the exported file) + let model: $mod_name::Model = $mod_name::Model::default(); + + // Run the model with pi as input for easier testing + let input = Tensor::::full([6, 4, 10], consts::PI); + + let output = model.forward(input); + + // test the output shape + let expected_shape: Shape<3> = Shape::from([6, 2, 7]); + assert_eq!(output.shape(), expected_shape); + + // We are using the sum of the output tensor to test the correctness of the conv1d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); + let expected_sum = -54.549_243; // from pytorch + assert!(expected_sum.approx_eq(output_sum, ($tolerance, 2))); + } + }; } #[cfg(test)] mod tests { - use burn::tensor::{Shape, Tensor}; - use float_cmp::ApproxEq; - use std::f64::consts; - - type Backend = burn_ndarray::NdArray; - - test_model!(named_mpk); - test_model!(named_mpk_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(pretty_json); - test_model!(pretty_json_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(named_mpk_gz); - test_model!(named_mpk_gz_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(bincode); - test_model!(bincode_half, 1.0e-2); // Reduce tolerance for half precision - test_model!(bincode_embedded); - test_model!(bincode_embedded_half, 1.0e-2); // Reduce tolerance for half precision + use burn::tensor::{Shape, Tensor}; + use float_cmp::ApproxEq; + use std::f64::consts; + + type Backend = burn_ndarray::NdArray; + + test_model!(named_mpk); + test_model!(named_mpk_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(pretty_json); + test_model!(pretty_json_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(named_mpk_gz); + test_model!(named_mpk_gz_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(bincode); + test_model!(bincode_half, 1.0e-2); // Reduce tolerance for half precision + test_model!(bincode_embedded); + test_model!(bincode_embedded_half, 1.0e-2); // Reduce tolerance for half precision } diff --git a/burn-import/src/burn/codegen.rs b/burn-import/src/burn/codegen.rs index 2f61292b2b..ed617e8086 100644 --- a/burn-import/src/burn/codegen.rs +++ b/burn-import/src/burn/codegen.rs @@ -5,89 +5,89 @@ use burn::nn::PaddingConfig1d; use burn::nn::PaddingConfig2d; fn convert_primitive(primitive: T) -> TokenStream { - let value = primitive.to_string(); - value.parse().unwrap() + let value = primitive.to_string(); + value.parse().unwrap() } fn convert_to_array<'a, I, T: ToTokens>(list: I) -> TokenStream where - I: Iterator, - T: 'a, + I: Iterator, + T: 'a, { - let mut body = quote! {}; + let mut body = quote! {}; - list.for_each(|item| { - let elem = item.to_tokens(); - body.extend(quote! {#elem,}); - }); + list.for_each(|item| { + let elem = item.to_tokens(); + body.extend(quote! {#elem,}); + }); - quote! { - [#body] - } + quote! { + [#body] + } } pub trait ToTokens { - fn to_tokens(&self) -> TokenStream; + fn to_tokens(&self) -> TokenStream; } impl ToTokens for [T; N] { - fn to_tokens(&self) -> TokenStream { - convert_to_array(self.iter()) - } + fn to_tokens(&self) -> TokenStream { + convert_to_array(self.iter()) + } } impl ToTokens for Vec { - fn to_tokens(&self) -> TokenStream { - convert_to_array(self.iter()) - } + fn to_tokens(&self) -> TokenStream { + convert_to_array(self.iter()) + } } /// Prettier output for `usize` impl ToTokens for usize { - fn to_tokens(&self) -> TokenStream { - convert_primitive(self) - } + fn to_tokens(&self) -> TokenStream { + convert_primitive(self) + } } /// Prettier output for `i64` impl ToTokens for i64 { - fn to_tokens(&self) -> TokenStream { - convert_primitive(self) - } + fn to_tokens(&self) -> TokenStream { + convert_primitive(self) + } } /// Prettier output for `f64` impl ToTokens for f64 { - fn to_tokens(&self) -> TokenStream { - convert_primitive(self) - } + fn to_tokens(&self) -> TokenStream { + convert_primitive(self) + } } /// Padding configuration impl ToTokens for PaddingConfig1d { - fn to_tokens(&self) -> TokenStream { - match self { - Self::Same => quote! { PaddingConfig1d::Same }, - Self::Valid => quote! { PaddingConfig1d::Valid }, - Self::Explicit(padding) => { - let padding = padding.to_tokens(); - quote! { PaddingConfig1d::Explicit(#padding) } - } + fn to_tokens(&self) -> TokenStream { + match self { + Self::Same => quote! { PaddingConfig1d::Same }, + Self::Valid => quote! { PaddingConfig1d::Valid }, + Self::Explicit(padding) => { + let padding = padding.to_tokens(); + quote! { PaddingConfig1d::Explicit(#padding) } + } + } } - } } /// Padding configuration impl ToTokens for PaddingConfig2d { - fn to_tokens(&self) -> TokenStream { - match self { - Self::Same => quote! { PaddingConfig2d::Same }, - Self::Valid => quote! { PaddingConfig2d::Valid }, - Self::Explicit(padding1, padding2) => { - let padding1 = padding1.to_tokens(); - let padding2 = padding2.to_tokens(); - quote! { PaddingConfig2d::Explicit(#padding1, #padding2) } - } + fn to_tokens(&self) -> TokenStream { + match self { + Self::Same => quote! { PaddingConfig2d::Same }, + Self::Valid => quote! { PaddingConfig2d::Valid }, + Self::Explicit(padding1, padding2) => { + let padding1 = padding1.to_tokens(); + let padding2 = padding2.to_tokens(); + quote! { PaddingConfig2d::Explicit(#padding1, #padding2) } + } + } } - } } diff --git a/burn-import/src/burn/graph.rs b/burn-import/src/burn/graph.rs index ea0c2dfc9c..98437c3877 100644 --- a/burn-import/src/burn/graph.rs +++ b/burn-import/src/burn/graph.rs @@ -1,618 +1,606 @@ use super::{BurnImports, Scope, Type}; use crate::burn::{ - node::{Node, NodeCodegen}, - TensorKind, TensorType, + node::{Node, NodeCodegen}, + TensorKind, TensorType, }; use burn::record::{ - BinFileRecorder, BurnRecord, FileRecorder, NamedMpkFileRecorder, NamedMpkGzFileRecorder, - PrecisionSettings, PrettyJsonFileRecorder, Recorder, + BinFileRecorder, BurnRecord, FileRecorder, NamedMpkFileRecorder, NamedMpkGzFileRecorder, + PrecisionSettings, PrettyJsonFileRecorder, Recorder, }; use proc_macro2::TokenStream; use quote::quote; use serde::{ - ser::{SerializeMap, SerializeTuple}, - Serialize, + ser::{SerializeMap, SerializeTuple}, + Serialize, }; use std::{any::type_name, collections::HashMap, path::PathBuf}; /// Type of the record to be saved. #[derive(Debug, Clone, Default, Copy)] pub enum RecordType { - /// Pretty JSON format (useful for debugging). - PrettyJson, + /// Pretty JSON format (useful for debugging). + PrettyJson, - #[default] - /// Compressed Named MessagePack. - NamedMpkGz, + #[default] + /// Compressed Named MessagePack. + NamedMpkGz, - /// Uncompressed Named MessagePack. - NamedMpk, + /// Uncompressed Named MessagePack. + NamedMpk, - /// Bincode format (useful for embedding and for no-std support). - Bincode, + /// Bincode format (useful for embedding and for no-std support). + Bincode, } /// Burn graph intermediate representation of modules and tensor operations. #[derive(Default, Debug)] pub struct BurnGraph { - nodes: Vec>, - scope: Scope, - imports: BurnImports, - top_comment: Option, - default: Option, - blank_spaces: bool, - gen_new_fn: bool, - graph_input_types: Vec, - graph_output_types: Vec, + nodes: Vec>, + scope: Scope, + imports: BurnImports, + top_comment: Option, + default: Option, + blank_spaces: bool, + gen_new_fn: bool, + graph_input_types: Vec, + graph_output_types: Vec, } impl BurnGraph { - /// Register a new operation node into the graph. - /// - /// # Notes - /// - /// The node must be registered in the same order they will be executed in the forward pass. - pub fn register + 'static>(&mut self, node: N) { - let node = node.into_node(); - log::debug!("Registering node => '{}'", node.name()); - self.nodes.push(node); - } - - /// Generate a function `Model::new()` without any argument when `gen_new_fn` is `true`. - /// - /// This is useful if you intend to train the model generated. - pub fn with_new_fn(mut self, gen_new_fn: bool) -> Self { - self.gen_new_fn = gen_new_fn; - self - } - - /// Save the state of each node in a record file. - /// - /// The `Default` trait will be implemented for the generated model, which will load the record - /// saved at the provided path. In case of `embed_states` is true, the record will be embedded - /// in the generated code (useful for no-std support). - /// - /// # Arguments - /// - /// * `out_file` - The path to the record file. - /// * `record_type` - The type of the record to be saved. - /// * `embed_states` - Embed the record in the generated code. - /// - /// # Panics - /// - /// Panics if the record type is not `RecordType::Bincode` and `embed_states` is `true`. - pub fn with_record( - mut self, - out_file: PathBuf, - record_type: RecordType, - embed_states: bool, - ) -> Self { - let precision_ty_str = extract_type_name_by_type::(); - self - .imports - .register(format!("burn::record::{precision_ty_str}")); - - match record_type { - RecordType::PrettyJson => { - PrettyJsonFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructMap(BurnGraphState::new( - &self.nodes, - ))), - out_file.clone(), - ) - .unwrap(); - - assert!( - !embed_states, - "Embedding states is not supported for PrettyJsonFileRecorder." - ); + /// Register a new operation node into the graph. + /// + /// # Notes + /// + /// The node must be registered in the same order they will be executed in the forward pass. + pub fn register + 'static>(&mut self, node: N) { + let node = node.into_node(); + log::debug!("Registering node => '{}'", node.name()); + self.nodes.push(node); + } - self.register_record_file( - out_file, - &format!("burn::record::PrettyJsonFileRecorder::<{precision_ty_str}>"), - ); - } - RecordType::NamedMpkGz => { - NamedMpkGzFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructMap(BurnGraphState::new( - &self.nodes, - ))), - out_file.clone(), - ) - .unwrap(); + /// Generate a function `Model::new()` without any argument when `gen_new_fn` is `true`. + /// + /// This is useful if you intend to train the model generated. + pub fn with_new_fn(mut self, gen_new_fn: bool) -> Self { + self.gen_new_fn = gen_new_fn; + self + } - assert!( - !embed_states, - "Embedding states is not supported for NamedMpkGzFileRecorder." - ); - self.register_record_file( - out_file, - &format!("burn::record::NamedMpkGzFileRecorder::<{precision_ty_str}>"), - ); - } - - RecordType::NamedMpk => { - NamedMpkFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructMap(BurnGraphState::new( - &self.nodes, - ))), - out_file.clone(), - ) - .unwrap(); + /// Save the state of each node in a record file. + /// + /// The `Default` trait will be implemented for the generated model, which will load the record + /// saved at the provided path. In case of `embed_states` is true, the record will be embedded + /// in the generated code (useful for no-std support). + /// + /// # Arguments + /// + /// * `out_file` - The path to the record file. + /// * `record_type` - The type of the record to be saved. + /// * `embed_states` - Embed the record in the generated code. + /// + /// # Panics + /// + /// Panics if the record type is not `RecordType::Bincode` and `embed_states` is `true`. + pub fn with_record( + mut self, + out_file: PathBuf, + record_type: RecordType, + embed_states: bool, + ) -> Self { + let precision_ty_str = extract_type_name_by_type::(); + self.imports + .register(format!("burn::record::{precision_ty_str}")); + + match record_type { + RecordType::PrettyJson => { + PrettyJsonFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructMap( + BurnGraphState::new(&self.nodes), + )), + out_file.clone(), + ) + .unwrap(); + + assert!( + !embed_states, + "Embedding states is not supported for PrettyJsonFileRecorder." + ); + + self.register_record_file( + out_file, + &format!("burn::record::PrettyJsonFileRecorder::<{precision_ty_str}>"), + ); + } + RecordType::NamedMpkGz => { + NamedMpkGzFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructMap( + BurnGraphState::new(&self.nodes), + )), + out_file.clone(), + ) + .unwrap(); + + assert!( + !embed_states, + "Embedding states is not supported for NamedMpkGzFileRecorder." + ); + self.register_record_file( + out_file, + &format!("burn::record::NamedMpkGzFileRecorder::<{precision_ty_str}>"), + ); + } - assert!( - !embed_states, - "Embedding states is not supported for NamedMpkFileRecorder." - ); + RecordType::NamedMpk => { + NamedMpkFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructMap( + BurnGraphState::new(&self.nodes), + )), + out_file.clone(), + ) + .unwrap(); + + assert!( + !embed_states, + "Embedding states is not supported for NamedMpkFileRecorder." + ); + + self.register_record_file( + out_file, + &format!("burn::record::NamedMpkFileRecorder::<{precision_ty_str}>"), + ); + } - self.register_record_file( - out_file, - &format!("burn::record::NamedMpkFileRecorder::<{precision_ty_str}>"), - ); - } - - RecordType::Bincode => { - BinFileRecorder::::new() - .save_item( - BurnRecord::new::>(StructTuple(BurnGraphState::new(&self.nodes))), - out_file.clone(), - ) - .unwrap(); - - if embed_states { - self.register_record_embed(out_file); - } else { - self.register_record_file( - out_file, - &format!("burn::record::BinFileRecorder::<{precision_ty_str}>"), - ); + RecordType::Bincode => { + BinFileRecorder::::new() + .save_item( + BurnRecord::new::>(StructTuple(BurnGraphState::new( + &self.nodes, + ))), + out_file.clone(), + ) + .unwrap(); + + if embed_states { + self.register_record_embed(out_file); + } else { + self.register_record_file( + out_file, + &format!("burn::record::BinFileRecorder::<{precision_ty_str}>"), + ); + } + } } - } + + self } - self - } - - /// Add blank spaces in some places - /// - /// # Notes - /// - /// It can be problematic when testing. - pub fn with_blank_space(mut self, blank_spaces: bool) -> Self { - self.blank_spaces = blank_spaces; - self - } - - /// Add a comment at the top of the generated file. - pub fn with_top_comment(mut self, top_comment: Option) -> Self { - self.top_comment = top_comment; - self - } - - /// Generate tokens reprensenting the graph with Burn modules and tensor operations. - pub fn codegen(mut self) -> TokenStream { - self.build_scope(); - - self.register_imports(); - - let codegen_imports = self.imports.codegen(); - let codegen_struct = self.codegen_struct(); - let codegen_new_record = self.codegen_new_record(); - let codegen_forward = self.codegen_forward(); - - let maybe_blank = match self.blank_spaces { - true => quote! { - _blank_!(); - }, - false => quote! {}, - }; - let codegen_new = match self.gen_new_fn { - true => { - let new_fn = self.codegen_new(); + /// Add blank spaces in some places + /// + /// # Notes + /// + /// It can be problematic when testing. + pub fn with_blank_space(mut self, blank_spaces: bool) -> Self { + self.blank_spaces = blank_spaces; + self + } + + /// Add a comment at the top of the generated file. + pub fn with_top_comment(mut self, top_comment: Option) -> Self { + self.top_comment = top_comment; + self + } + + /// Generate tokens reprensenting the graph with Burn modules and tensor operations. + pub fn codegen(mut self) -> TokenStream { + self.build_scope(); + + self.register_imports(); + + let codegen_imports = self.imports.codegen(); + let codegen_struct = self.codegen_struct(); + let codegen_new_record = self.codegen_new_record(); + let codegen_forward = self.codegen_forward(); + + let maybe_blank = match self.blank_spaces { + true => quote! { + _blank_!(); + }, + false => quote! {}, + }; + let codegen_new = match self.gen_new_fn { + true => { + let new_fn = self.codegen_new(); + quote! { + #new_fn + #maybe_blank + } + } + false => quote! {}, + }; + let codegen_default = match self.default { + Some(default) => quote! { + #default + #maybe_blank + }, + None => quote! {}, + }; + + let maybe_top_file_comment = match self.top_comment { + Some(comment) => quote! { + _comment_!(#comment); + }, + None => quote! {}, + }; + quote! { - #new_fn + #maybe_top_file_comment + #codegen_imports #maybe_blank - } - } - false => quote! {}, - }; - let codegen_default = match self.default { - Some(default) => quote! { - #default - #maybe_blank - }, - None => quote! {}, - }; - - let maybe_top_file_comment = match self.top_comment { - Some(comment) => quote! { - _comment_!(#comment); - }, - None => quote! {}, - }; - - quote! { - #maybe_top_file_comment - #codegen_imports - #maybe_blank - #maybe_blank - - #codegen_struct - #maybe_blank - - #codegen_default - - impl Model { - #codegen_new_record #maybe_blank - #codegen_new - #codegen_forward + #codegen_struct + #maybe_blank + + #codegen_default + + impl Model { + #codegen_new_record + #maybe_blank + + #codegen_new + #codegen_forward + } } } - } - - fn register_imports(&mut self) { - // Register imports from nodes - self - .nodes - .iter() - .for_each(|node| node.register_imports(&mut self.imports)); - - // Combine input and output types into a single vector - let all_types = self - .graph_input_types - .iter() - .chain(&self.graph_output_types); - - // Register imports for bool and int tensors - for ty in all_types { - match ty { - Type::Tensor(TensorType { - kind: TensorKind::Bool, - .. - }) => { - self.imports.register("burn::tensor::Bool"); - } - Type::Tensor(TensorType { - kind: TensorKind::Int, - .. - }) => { - self.imports.register("burn::tensor::Int"); + + fn register_imports(&mut self) { + // Register imports from nodes + self.nodes + .iter() + .for_each(|node| node.register_imports(&mut self.imports)); + + // Combine input and output types into a single vector + let all_types = self + .graph_input_types + .iter() + .chain(&self.graph_output_types); + + // Register imports for bool and int tensors + for ty in all_types { + match ty { + Type::Tensor(TensorType { + kind: TensorKind::Bool, + .. + }) => { + self.imports.register("burn::tensor::Bool"); + } + Type::Tensor(TensorType { + kind: TensorKind::Int, + .. + }) => { + self.imports.register("burn::tensor::Int"); + } + _ => {} + } } - _ => {} - } } - } - /// Build the scope state to make sure tensor clones are added where needed. - fn build_scope(&mut self) { - log::debug!("Building the scope nodes len => '{}'", self.nodes.len()); - - fn to_tensor(ty: Type) -> Option { - match ty { - Type::Tensor(tensor) => Some(tensor), - Type::Scalar(_) => None, - Type::Other(_) => None, - } + /// Build the scope state to make sure tensor clones are added where needed. + fn build_scope(&mut self) { + log::debug!("Building the scope nodes len => '{}'", self.nodes.len()); + + fn to_tensor(ty: Type) -> Option { + match ty { + Type::Tensor(tensor) => Some(tensor), + Type::Scalar(_) => None, + Type::Other(_) => None, + } + } + + // Register graph tensor input with 0 as node position + self.graph_input_types + .clone() + .into_iter() + .flat_map(to_tensor) + .for_each(|tensor| { + self.scope.tensor_register_variable(&tensor, 0); + }); + + self.nodes + .iter() + .enumerate() + .for_each(|(node_position, node)| { + node.output_types() + .into_iter() + .flat_map(to_tensor) + .for_each(|tensor| { + self.scope + .tensor_register_variable(&tensor, node_position + 1) + }) + }); + + self.nodes + .iter() + .enumerate() + .for_each(|(node_position, node)| { + node.input_types() + .into_iter() + .flat_map(to_tensor) + .for_each(|tensor| { + self.scope + .tensor_register_future_use(&tensor, node_position) + }) + }); } - // Register graph tensor input with 0 as node position - self - .graph_input_types - .clone() - .into_iter() - .flat_map(to_tensor) - .for_each(|tensor| { - self.scope.tensor_register_variable(&tensor, 0); - }); - - self - .nodes - .iter() - .enumerate() - .for_each(|(node_position, node)| { - node - .output_types() - .into_iter() - .flat_map(to_tensor) - .for_each(|tensor| { - self - .scope - .tensor_register_variable(&tensor, node_position + 1) - }) - }); - - self - .nodes - .iter() - .enumerate() - .for_each(|(node_position, node)| { - node - .input_types() - .into_iter() - .flat_map(to_tensor) - .for_each(|tensor| { - self - .scope - .tensor_register_future_use(&tensor, node_position) - }) - }); - } - - fn register_record_file(&mut self, file: PathBuf, recorder_str: &str) { - self.imports.register("burn::record::Recorder"); - - let recorder_ty = syn::parse_str::(recorder_str).unwrap(); - - // Add default implementation - let file = file.to_str().unwrap(); - self.default = Some(quote! { - _blank_!(); - impl Default for Model { - fn default() -> Self { - Self::from_file(#file) + fn register_record_file(&mut self, file: PathBuf, recorder_str: &str) { + self.imports.register("burn::record::Recorder"); + + let recorder_ty = syn::parse_str::(recorder_str).unwrap(); + + // Add default implementation + let file = file.to_str().unwrap(); + self.default = Some(quote! { + _blank_!(); + impl Default for Model { + fn default() -> Self { + Self::from_file(#file) + } } - } - _blank_!(); - impl Model { - pub fn from_file(file: &str) -> Self { - let record = #recorder_ty::new() - .load(file.into()) - .expect("Record file to exist."); - Self::new_with(record) + _blank_!(); + impl Model { + pub fn from_file(file: &str) -> Self { + let record = #recorder_ty::new() + .load(file.into()) + .expect("Record file to exist."); + Self::new_with(record) + } } - } - }); - } - - fn register_record_embed(&mut self, file: PathBuf) { - self.imports.register("burn::record::Recorder"); - - // NOTE: Bincode format is used for embedding states for now. - let precision = extract_type_name_by_type::(); - let precision_ty = syn::parse_str::(&precision).unwrap(); - self.imports.register("burn::record::BinBytesRecorder"); - - let mut file = file; - file.set_extension(BinFileRecorder::::file_extension()); - let file = file.to_str().unwrap(); - self.default = Some(quote! { - _blank_!(); - static EMBEDDED_STATES: &[u8] = include_bytes!(#file); - _blank_!(); - impl Default for Model { - fn default() -> Self { - Self::from_embedded() + }); + } + + fn register_record_embed(&mut self, file: PathBuf) { + self.imports.register("burn::record::Recorder"); + + // NOTE: Bincode format is used for embedding states for now. + let precision = extract_type_name_by_type::(); + let precision_ty = syn::parse_str::(&precision).unwrap(); + self.imports.register("burn::record::BinBytesRecorder"); + + let mut file = file; + file.set_extension(BinFileRecorder::::file_extension()); + let file = file.to_str().unwrap(); + self.default = Some(quote! { + _blank_!(); + static EMBEDDED_STATES: &[u8] = include_bytes!(#file); + _blank_!(); + impl Default for Model { + fn default() -> Self { + Self::from_embedded() + } } - } - _blank_!(); - impl Model { - pub fn from_embedded() -> Self { - let record = BinBytesRecorder::<#precision_ty>::default() - .load(EMBEDDED_STATES.to_vec()) - .expect("Failed to decode state"); - - Self::new_with(record) + _blank_!(); + impl Model { + pub fn from_embedded() -> Self { + let record = BinBytesRecorder::<#precision_ty>::default() + .load(EMBEDDED_STATES.to_vec()) + .expect("Failed to decode state"); + + Self::new_with(record) + } } - } - }); - } - - fn codegen_struct(&self) -> TokenStream { - let mut body = quote! {}; - self - .nodes - .iter() - .filter_map(|node| node.field_type()) - .map(|field| { - let name = field.name(); - let ty = field.ty(); - - if matches!(&field, Type::Tensor(_)) { - quote! { - #name: burn::module::Param<#ty>, - } - } else { - quote! { - #name: #ty, - } - } - }) - .for_each(|code| body.extend(code)); - - // Extend with phantom data to avoid unused generic type. - body.extend(quote! { - phantom: core::marker::PhantomData, - }); - - quote! { - #[derive(Module, Debug)] - pub struct Model { - #body - } + }); } - } - - fn codegen_new(&self) -> TokenStream { - let mut body = quote! {}; - - self - .nodes - .iter() - .map(|node| node.field_init(false)) - .for_each(|code| body.extend(code)); - - let fields = self - .nodes - .iter() - .flat_map(|node| node.field_type()) - .map(|field| field.name().clone()) - .collect::>(); - - quote! { - #[allow(dead_code)] - pub fn new() -> Self { - #body - - Self { - #(#fields,)* - phantom: core::marker::PhantomData, + + fn codegen_struct(&self) -> TokenStream { + let mut body = quote! {}; + self.nodes + .iter() + .filter_map(|node| node.field_type()) + .map(|field| { + let name = field.name(); + let ty = field.ty(); + + if matches!(&field, Type::Tensor(_)) { + quote! { + #name: burn::module::Param<#ty>, + } + } else { + quote! { + #name: #ty, + } + } + }) + .for_each(|code| body.extend(code)); + + // Extend with phantom data to avoid unused generic type. + body.extend(quote! { + phantom: core::marker::PhantomData, + }); + + quote! { + #[derive(Module, Debug)] + pub struct Model { + #body } } } - } - fn codegen_new_record(&self) -> TokenStream { - let mut body = quote! {}; - - self - .nodes - .iter() - .map(|node| node.field_init(true)) - .for_each(|code| body.extend(code)); - - let fields = self - .nodes - .iter() - .flat_map(|node| node.field_type()) - .map(|field| field.name().clone()) - .collect::>(); - - quote! { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - #body - - Self { - #(#fields,)* - phantom: core::marker::PhantomData, + + fn codegen_new(&self) -> TokenStream { + let mut body = quote! {}; + + self.nodes + .iter() + .map(|node| node.field_init(false)) + .for_each(|code| body.extend(code)); + + let fields = self + .nodes + .iter() + .flat_map(|node| node.field_type()) + .map(|field| field.name().clone()) + .collect::>(); + + quote! { + #[allow(dead_code)] + pub fn new() -> Self { + #body + + Self { + #(#fields,)* + phantom: core::marker::PhantomData, + } } } } - } + fn codegen_new_record(&self) -> TokenStream { + let mut body = quote! {}; - fn codegen_forward(&mut self) -> TokenStream { - let mut input_def = quote! {}; - let mut output_type_def = quote! {}; - let mut output_return_def = quote! {}; + self.nodes + .iter() + .map(|node| node.field_init(true)) + .for_each(|code| body.extend(code)); - self.graph_input_types.iter().for_each(|input| { - let name = input.name().clone(); - let ty = input.ty(); + let fields = self + .nodes + .iter() + .flat_map(|node| node.field_type()) + .map(|field| field.name().clone()) + .collect::>(); - input_def.extend(quote! { - #name: #ty, + quote! { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + #body + + Self { + #(#fields,)* + phantom: core::marker::PhantomData, + } + } + } + } - }) - }); + fn codegen_forward(&mut self) -> TokenStream { + let mut input_def = quote! {}; + let mut output_type_def = quote! {}; + let mut output_return_def = quote! {}; - let multiple_output = self.graph_output_types.len() > 1; + self.graph_input_types.iter().for_each(|input| { + let name = input.name().clone(); + let ty = input.ty(); - self.graph_output_types.iter().for_each(|output| { - let name = output.name(); - let ty = output.ty(); + input_def.extend(quote! { + #name: #ty, - if multiple_output { - output_type_def.extend(quote! { - #ty, - }); - output_return_def.extend(quote! { - #name, - }); - } else { - output_type_def.extend(quote! { - #ty + }) }); - output_return_def.extend(quote! { - #name + + let multiple_output = self.graph_output_types.len() > 1; + + self.graph_output_types.iter().for_each(|output| { + let name = output.name(); + let ty = output.ty(); + + if multiple_output { + output_type_def.extend(quote! { + #ty, + }); + output_return_def.extend(quote! { + #name, + }); + } else { + output_type_def.extend(quote! { + #ty + }); + output_return_def.extend(quote! { + #name + }); + } }); - } - }); - - if multiple_output { - output_return_def = quote! { - (#output_return_def) - }; - output_type_def = quote! { - (#output_type_def) - }; - } - let mut body = quote! {}; - self - .nodes - .iter() - .enumerate() - .map(|(index, node)| node.forward(&mut self.scope, index)) - .for_each(|code| body.extend(code)); - - // TODO Return the result without a `let` binding from a block, - // otherwise let_and_return error will be triggered by clippy. - // For now, we just disable the warning. - quote! { - #[allow(clippy::let_and_return)] - pub fn forward(&self, #input_def) -> #output_type_def { - #body - - #output_return_def + if multiple_output { + output_return_def = quote! { + (#output_return_def) + }; + output_type_def = quote! { + (#output_type_def) + }; + } + + let mut body = quote! {}; + self.nodes + .iter() + .enumerate() + .map(|(index, node)| node.forward(&mut self.scope, index)) + .for_each(|code| body.extend(code)); + + // TODO Return the result without a `let` binding from a block, + // otherwise let_and_return error will be triggered by clippy. + // For now, we just disable the warning. + quote! { + #[allow(clippy::let_and_return)] + pub fn forward(&self, #input_def) -> #output_type_def { + #body + + #output_return_def + } } } - } - - /// Register the input and output types of the graph using the passed in names. - /// The names must be unique and match the names of the inputs and outputs of the nodes. - /// The order will be preserved. - /// - /// # Arguments - /// - /// * `input_names` - The names of the inputs of the graph. - /// * `output_names` - The names of the outputs of the graph. - /// - /// # Panics - /// - /// Panics if the graph is empty. - pub fn register_input_output(&mut self, input_names: Vec, output_names: Vec) { - assert!( - !self.nodes.is_empty(), - "Cannot register input and output types for an empty graph." - ); - - // Get the unique names of each input of the nodes - let mut inputs = HashMap::new(); - let mut outputs = HashMap::new(); - for node in self.nodes.iter() { - for input in node.input_types() { - inputs.insert(input.name().to_string(), input); - } - for output in node.output_types() { - outputs.insert(output.name().to_string(), output); - } - } - // Get the input and output types of the graph using passed in names - input_names.iter().for_each(|input| { - self - .graph_input_types - .push(inputs.get(input).unwrap().clone()); - }); - - output_names.iter().for_each(|output| { - self.graph_output_types.push( - outputs - .get(output) - .unwrap_or_else(|| panic!("Output type is not found for {output}")) - .clone(), - ); - }); - } + /// Register the input and output types of the graph using the passed in names. + /// The names must be unique and match the names of the inputs and outputs of the nodes. + /// The order will be preserved. + /// + /// # Arguments + /// + /// * `input_names` - The names of the inputs of the graph. + /// * `output_names` - The names of the outputs of the graph. + /// + /// # Panics + /// + /// Panics if the graph is empty. + pub fn register_input_output(&mut self, input_names: Vec, output_names: Vec) { + assert!( + !self.nodes.is_empty(), + "Cannot register input and output types for an empty graph." + ); + + // Get the unique names of each input of the nodes + let mut inputs = HashMap::new(); + let mut outputs = HashMap::new(); + for node in self.nodes.iter() { + for input in node.input_types() { + inputs.insert(input.name().to_string(), input); + } + for output in node.output_types() { + outputs.insert(output.name().to_string(), output); + } + } + + // Get the input and output types of the graph using passed in names + input_names.iter().for_each(|input| { + self.graph_input_types + .push(inputs.get(input).unwrap().clone()); + }); + + output_names.iter().for_each(|output| { + self.graph_output_types.push( + outputs + .get(output) + .unwrap_or_else(|| panic!("Output type is not found for {output}")) + .clone(), + ); + }); + } } #[derive(new, Debug)] struct BurnGraphState<'a, PS: PrecisionSettings> { - nodes: &'a Vec>, + nodes: &'a Vec>, } /// Represents a custom serialization strategy for the graph state in the module struct. @@ -630,24 +618,24 @@ struct BurnGraphState<'a, PS: PrecisionSettings> { struct StructMap<'a, PS: PrecisionSettings>(BurnGraphState<'a, PS>); impl<'a, PS: PrecisionSettings> Serialize for StructMap<'a, PS> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let nodes_with_names = self - .0 - .nodes - .iter() - .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) - .collect::>(); - let mut map = serializer.serialize_map(Some(nodes_with_names.len()))?; - - for (node, name) in nodes_with_names.iter() { - map.serialize_entry(&name.to_string(), &node)?; - } + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let nodes_with_names = self + .0 + .nodes + .iter() + .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) + .collect::>(); + let mut map = serializer.serialize_map(Some(nodes_with_names.len()))?; + + for (node, name) in nodes_with_names.iter() { + map.serialize_entry(&name.to_string(), &node)?; + } - map.end() - } + map.end() + } } /// Represents a custom serialization strategy for the graph state in the module struct. @@ -664,31 +652,31 @@ impl<'a, PS: PrecisionSettings> Serialize for StructMap<'a, PS> { struct StructTuple<'a, PS: PrecisionSettings>(BurnGraphState<'a, PS>); impl<'a, PS: PrecisionSettings> Serialize for StructTuple<'a, PS> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let nodes_with_names = self - .0 - .nodes - .iter() - .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) - .collect::>(); - let mut map = serializer.serialize_tuple(nodes_with_names.len())?; - - for (node, _name) in nodes_with_names.iter() { - map.serialize_element(&node)?; - } + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let nodes_with_names = self + .0 + .nodes + .iter() + .filter_map(|node| node.field_type().map(|ty| (node, ty.name().clone()))) + .collect::>(); + let mut map = serializer.serialize_tuple(nodes_with_names.len())?; + + for (node, _name) in nodes_with_names.iter() { + map.serialize_element(&node)?; + } - map.end() - } + map.end() + } } fn extract_type_name_by_type() -> String { - let full_type_name = type_name::(); - full_type_name - .rsplit("::") - .next() - .unwrap_or(full_type_name) - .to_string() + let full_type_name = type_name::(); + full_type_name + .rsplit("::") + .next() + .unwrap_or(full_type_name) + .to_string() } diff --git a/burn-import/src/burn/imports.rs b/burn-import/src/burn/imports.rs index c091eb0d24..540a45529b 100644 --- a/burn-import/src/burn/imports.rs +++ b/burn-import/src/burn/imports.rs @@ -5,37 +5,38 @@ use std::collections::HashSet; /// Keep track of imported modules. #[derive(Debug, Default)] pub struct BurnImports { - imports: HashSet, + imports: HashSet, } impl BurnImports { - /// Register an import type. - /// - /// # Notes - /// - /// Each import statement will be generated just once no matter how many times it was - /// registered. - pub fn register>(&mut self, import: S) { - self.imports.insert(import.into()); - } + /// Register an import type. + /// + /// # Notes + /// + /// Each import statement will be generated just once no matter how many times it was + /// registered. + pub fn register>(&mut self, import: S) { + self.imports.insert(import.into()); + } - /// Generate the import tokens. - pub fn codegen(&self) -> TokenStream { - let mut import_tokens = vec![]; + /// Generate the import tokens. + pub fn codegen(&self) -> TokenStream { + let mut import_tokens = vec![]; - for import in self.imports.iter() { - let path: syn::Path = syn::parse_str(import).expect("Unable to parse input string as a path"); + for import in self.imports.iter() { + let path: syn::Path = + syn::parse_str(import).expect("Unable to parse input string as a path"); - import_tokens.push(quote! { #path }); - } + import_tokens.push(quote! { #path }); + } - quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #(use #import_tokens;)* + #(use #import_tokens;)* + } } - } } diff --git a/burn-import/src/burn/node/avg_pool2d.rs b/burn-import/src/burn/node/avg_pool2d.rs index b35cfdcdc0..3f12457f20 100644 --- a/burn-import/src/burn/node/avg_pool2d.rs +++ b/burn-import/src/burn/node/avg_pool2d.rs @@ -8,151 +8,151 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; #[derive(Debug, Clone)] pub struct AvgPool2dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub config: AvgPool2dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub config: AvgPool2dConfig, } impl AvgPool2dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - config: AvgPool2dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - AvgPool2d - }, - ), - input, - output, - config, + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + config: AvgPool2dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + AvgPool2d + }, + ), + input, + output, + config, + } } - } } impl NodeCodegen for AvgPool2dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - let kernel_size = self.config.kernel_size.to_tokens(); - let strides = self.config.strides.to_tokens(); - let padding = self.config.padding.to_tokens(); - - let init_line = quote! { - init(); - }; + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } - let tokens = quote! { - let #name = AvgPool2dConfig::new(#kernel_size) - .with_strides(#strides) - .with_padding(#padding) - .#init_line - }; + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; + let kernel_size = self.config.kernel_size.to_tokens(); + let strides = self.config.strides.to_tokens(); + let padding = self.config.padding.to_tokens(); - Some(tokens) - } + let init_line = quote! { + init(); + }; - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + let tokens = quote! { + let #name = AvgPool2dConfig::new(#kernel_size) + .with_strides(#strides) + .with_padding(#padding) + .#init_line + }; - quote! { - let #output = self.#field.forward(#input); + Some(tokens) } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig2d"); - imports.register("burn::nn::pool::AvgPool2d"); - imports.register("burn::nn::pool::AvgPool2dConfig"); - } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - fn into_node(self) -> Node { - Node::AvgPool2d(self) - } + quote! { + let #output = self.#field.forward(#input); + } + } + + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig2d"); + imports.register("burn::nn::pool::AvgPool2d"); + imports.register("burn::nn::pool::AvgPool2dConfig"); + } - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) - } + fn into_node(self) -> Node { + Node::AvgPool2d(self) + } + + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{avg_pool2d::AvgPool2dNode, test::assert_tokens}, - TensorType, - }; - use burn::{nn::pool::AvgPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(AvgPool2dNode::new( - "avg_pool2d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - AvgPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig2d; - use burn::nn::pool::AvgPool2d; - use burn::nn::pool::AvgPool2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - avg_pool2d: AvgPool2d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{avg_pool2d::AvgPool2dNode, test::assert_tokens}, + TensorType, + }; + use burn::{nn::pool::AvgPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(AvgPool2dNode::new( + "avg_pool2d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + AvgPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::pool::AvgPool2d; + use burn::nn::pool::AvgPool2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + avg_pool2d: AvgPool2d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let avg_pool2d = AvgPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .init(); - - Self { - avg_pool2d, - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let avg_pool2d = AvgPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .init(); + + Self { + avg_pool2d, + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.avg_pool2d.forward(input); + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.avg_pool2d.forward(input); - output + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/base.rs b/burn-import/src/burn/node/base.rs index f4262da460..944156306b 100644 --- a/burn-import/src/burn/node/base.rs +++ b/burn-import/src/burn/node/base.rs @@ -1,8 +1,9 @@ use super::{ - avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, - concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, - dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, linear::LinearNode, - matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode, + avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, + concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, + dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, + linear::LinearNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode, + unary::UnaryNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::record::PrecisionSettings; @@ -15,371 +16,371 @@ pub type SerializationBackend = NdArray; /// Codegen trait that should be implemented by all [node](Node) entries. pub trait NodeCodegen: std::fmt::Debug { - /// All types that are used as inputs during the forward pass. - /// - /// # Notes - /// The vec should not include types that are accessible with `self`. - /// See [field type](NodeCodegen::field_type). - fn input_types(&self) -> Vec; - - /// All types that are produced during the forward pass. - fn output_types(&self) -> Vec; - - /// The forward pass implementation of the node. - /// - /// # Notes - /// - /// The [Scope](Scope) struct should be used for [input tensor type](Type::Tensor) access. - /// The method [use_owned_tensor](Scope::use_owned_tensor) keeps track of tensor reference - /// count and insert `clone` with necessary. - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream; - - /// Convert the node implementation into a [node entry](Node). - fn into_node(self) -> Node; - - /// Register the necessary imports. - fn register_imports(&self, _imports: &mut BurnImports) {} - - /// (Optional) Declare the type of the field - /// - /// # Notes - /// - /// This should be implemented when the node has some parameters. - /// Just one field per type is possible, if the node has multiple types for its parameters, a - /// tuple can be used. - /// - /// Other field functions should be implemented when this one returns something other than None. - /// * [field_init](NodeCodegen::field_init) to initialize parameters. - /// * [field_serialize](NodeCodegen::field_serialize) to create the model record. - fn field_type(&self) -> Option { - None - } - - /// (Optional) Declare how the parameters are initialized with and without a record. - /// - /// The function should be implemented along [field_type](NodeCodegen::field_type). - fn field_init(&self, _with_record: bool) -> Option { - None - } - - /// (Optional) Declare how the parameters are serialized in a record. - /// - /// The function should be implemented along [field_type](NodeCodegen::field_type). - fn field_serialize(&self, _serializer: S) -> Result { - panic!("Serialization should be implemented when field_type is not None."); - } + /// All types that are used as inputs during the forward pass. + /// + /// # Notes + /// The vec should not include types that are accessible with `self`. + /// See [field type](NodeCodegen::field_type). + fn input_types(&self) -> Vec; + + /// All types that are produced during the forward pass. + fn output_types(&self) -> Vec; + + /// The forward pass implementation of the node. + /// + /// # Notes + /// + /// The [Scope](Scope) struct should be used for [input tensor type](Type::Tensor) access. + /// The method [use_owned_tensor](Scope::use_owned_tensor) keeps track of tensor reference + /// count and insert `clone` with necessary. + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream; + + /// Convert the node implementation into a [node entry](Node). + fn into_node(self) -> Node; + + /// Register the necessary imports. + fn register_imports(&self, _imports: &mut BurnImports) {} + + /// (Optional) Declare the type of the field + /// + /// # Notes + /// + /// This should be implemented when the node has some parameters. + /// Just one field per type is possible, if the node has multiple types for its parameters, a + /// tuple can be used. + /// + /// Other field functions should be implemented when this one returns something other than None. + /// * [field_init](NodeCodegen::field_init) to initialize parameters. + /// * [field_serialize](NodeCodegen::field_serialize) to create the model record. + fn field_type(&self) -> Option { + None + } + + /// (Optional) Declare how the parameters are initialized with and without a record. + /// + /// The function should be implemented along [field_type](NodeCodegen::field_type). + fn field_init(&self, _with_record: bool) -> Option { + None + } + + /// (Optional) Declare how the parameters are serialized in a record. + /// + /// The function should be implemented along [field_type](NodeCodegen::field_type). + fn field_serialize(&self, _serializer: S) -> Result { + panic!("Serialization should be implemented when field_type is not None."); + } } #[derive(Debug, Clone)] pub enum Node { - AvgPool2d(AvgPool2dNode), - BatchNorm(BatchNormNode), - Binary(BinaryNode), - Clip(ClipNode), - Concat(ConcatNode), - Constant(ConstantNode), - Conv1d(Conv1dNode), - Conv2d(Conv2dNode), - Dropout(DropoutNode), - Gather(GatherNode), - GlobalAvgPool(GlobalAvgPoolNode), - Linear(LinearNode), - Matmul(MatmulNode), - MaxPool2d(MaxPool2dNode), - Reshape(ReshapeNode), - Unary(UnaryNode), + AvgPool2d(AvgPool2dNode), + BatchNorm(BatchNormNode), + Binary(BinaryNode), + Clip(ClipNode), + Concat(ConcatNode), + Constant(ConstantNode), + Conv1d(Conv1dNode), + Conv2d(Conv2dNode), + Dropout(DropoutNode), + Gather(GatherNode), + GlobalAvgPool(GlobalAvgPoolNode), + Linear(LinearNode), + Matmul(MatmulNode), + MaxPool2d(MaxPool2dNode), + Reshape(ReshapeNode), + Unary(UnaryNode), } macro_rules! match_all { - ($self:expr, $func:expr) => {{ - #[allow(clippy::redundant_closure_call)] - match $self { - Node::AvgPool2d(node) => $func(node), - Node::BatchNorm(node) => $func(node), - Node::Binary(node) => $func(node), - Node::Clip(node) => $func(node), - Node::Concat(node) => $func(node), - Node::Constant(node) => $func(node), - Node::Conv1d(node) => $func(node), - Node::Conv2d(node) => $func(node), - Node::Dropout(node) => $func(node), - Node::Gather(node) => $func(node), - Node::GlobalAvgPool(node) => $func(node), - Node::Linear(node) => $func(node), - Node::Matmul(node) => $func(node), - Node::MaxPool2d(node) => $func(node), - Node::Reshape(node) => $func(node), - Node::Unary(node) => $func(node), - } - }}; + ($self:expr, $func:expr) => {{ + #[allow(clippy::redundant_closure_call)] + match $self { + Node::AvgPool2d(node) => $func(node), + Node::BatchNorm(node) => $func(node), + Node::Binary(node) => $func(node), + Node::Clip(node) => $func(node), + Node::Concat(node) => $func(node), + Node::Constant(node) => $func(node), + Node::Conv1d(node) => $func(node), + Node::Conv2d(node) => $func(node), + Node::Dropout(node) => $func(node), + Node::Gather(node) => $func(node), + Node::GlobalAvgPool(node) => $func(node), + Node::Linear(node) => $func(node), + Node::Matmul(node) => $func(node), + Node::MaxPool2d(node) => $func(node), + Node::Reshape(node) => $func(node), + Node::Unary(node) => $func(node), + } + }}; } impl Serialize for Node { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.field_serialize(serializer) - } + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.field_serialize(serializer) + } } impl Node { - pub fn name(&self) -> &str { - match self { - Node::AvgPool2d(_) => "avg_pool2d", - Node::BatchNorm(_) => "batch_norm", - Node::Binary(binary) => binary.binary_type.as_str(), - Node::Concat(_) => "concat", - Node::Clip(_) => "clip", - Node::Constant(_) => "constant", - Node::Conv1d(_) => "conv1d", - Node::Conv2d(_) => "conv2d", - Node::Dropout(_) => "dropout", - Node::Gather(_) => "gather", - Node::GlobalAvgPool(_) => "global_avg_pool", - Node::Linear(_) => "linear", - Node::Matmul(_) => "matmul", - Node::MaxPool2d(_) => "max_pool2d", - Node::Reshape(_) => "reshape", - Node::Unary(unary) => unary.kind.as_str(), + pub fn name(&self) -> &str { + match self { + Node::AvgPool2d(_) => "avg_pool2d", + Node::BatchNorm(_) => "batch_norm", + Node::Binary(binary) => binary.binary_type.as_str(), + Node::Concat(_) => "concat", + Node::Clip(_) => "clip", + Node::Constant(_) => "constant", + Node::Conv1d(_) => "conv1d", + Node::Conv2d(_) => "conv2d", + Node::Dropout(_) => "dropout", + Node::Gather(_) => "gather", + Node::GlobalAvgPool(_) => "global_avg_pool", + Node::Linear(_) => "linear", + Node::Matmul(_) => "matmul", + Node::MaxPool2d(_) => "max_pool2d", + Node::Reshape(_) => "reshape", + Node::Unary(unary) => unary.kind.as_str(), + } } - } } impl NodeCodegen for Node { - fn output_types(&self) -> Vec { - match_all!(self, NodeCodegen::::output_types) - } - - fn input_types(&self) -> Vec { - match_all!(self, NodeCodegen::::input_types) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - match_all!(self, |node| NodeCodegen::::forward( - node, - scope, - node_position - )) - } - - fn field_type(&self) -> Option { - match_all!(self, NodeCodegen::::field_type) - } - - fn field_init(&self, with_record: bool) -> Option { - match_all!(self, |node| NodeCodegen::::field_init( - node, - with_record - )) - } - - fn register_imports(&self, imports: &mut BurnImports) { - match_all!(self, |node| NodeCodegen::::register_imports( - node, imports - )) - } - - fn into_node(self) -> Node { - self - } - - fn field_serialize(&self, serializer: S) -> Result { - match_all!(self, |node| NodeCodegen::::field_serialize( - node, serializer - )) - } + fn output_types(&self) -> Vec { + match_all!(self, NodeCodegen::::output_types) + } + + fn input_types(&self) -> Vec { + match_all!(self, NodeCodegen::::input_types) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + match_all!(self, |node| NodeCodegen::::forward( + node, + scope, + node_position + )) + } + + fn field_type(&self) -> Option { + match_all!(self, NodeCodegen::::field_type) + } + + fn field_init(&self, with_record: bool) -> Option { + match_all!(self, |node| NodeCodegen::::field_init( + node, + with_record + )) + } + + fn register_imports(&self, imports: &mut BurnImports) { + match_all!(self, |node| NodeCodegen::::register_imports( + node, imports + )) + } + + fn into_node(self) -> Node { + self + } + + fn field_serialize(&self, serializer: S) -> Result { + match_all!(self, |node| NodeCodegen::::field_serialize( + node, serializer + )) + } } #[cfg(test)] pub(crate) mod tests { - use crate::burn::{ - graph::BurnGraph, - node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen}, - TensorType, - }; - use burn::{ - nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, - }; - use proc_macro2::TokenStream; - use quote::quote; - - pub(crate) fn one_node_graph + 'static>( - node_gen: T, - forward: TokenStream, - input_names: Vec, - output_names: Vec, - ) { - let mut graph = BurnGraph::::default(); - - graph.register(node_gen); - - graph.register_input_output(input_names, output_names); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + use crate::burn::{ + graph::BurnGraph, + node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen}, + TensorType, + }; + use burn::{ + nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, + }; + use proc_macro2::TokenStream; + use quote::quote; + + pub(crate) fn one_node_graph + 'static>( + node_gen: T, + forward: TokenStream, + input_names: Vec, + output_names: Vec, + ) { + let mut graph = BurnGraph::::default(); + + graph.register(node_gen); + + graph.register_input_output(input_names, output_names); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } } + + #[allow(clippy::let_and_return)] + #forward } + }; - #[allow(clippy::let_and_return)] - #forward - } - }; + assert_tokens(graph.codegen(), expected); + } - assert_tokens(graph.codegen(), expected); - } - - #[test] - fn test_codegen_two_nodes() { - let mut graph = BurnGraph::::default(); - - graph.register(MatmulNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor3", 4), - )); - graph.register(Conv2dNode::new( - "conv2d", - TensorType::new_float("tensor3", 4), - TensorType::new_float("tensor4", 4), - Data::from([2.]).serialize(), - None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor4".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::conv::Conv2dConfig; - use burn::nn::conv::Conv2d; - use burn::nn::PaddingConfig2d; - - #[derive(Module, Debug)] - pub struct Model { - conv2d: Conv2d, - phantom: core::marker::PhantomData, - } + #[test] + fn test_codegen_two_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor3", 4), + )); + graph.register(Conv2dNode::new( + "conv2d", + TensorType::new_float("tensor3", 4), + TensorType::new_float("tensor4", 4), + Data::from([2.]).serialize(), + None, + Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor4".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::conv::Conv2dConfig; + use burn::nn::conv::Conv2d; + use burn::nn::PaddingConfig2d; + + #[derive(Module, Debug)] + pub struct Model { + conv2d: Conv2d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv2d = Conv2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .with_groups(1) - .with_bias(true) - .init_with(record.conv2d); - - Self { - conv2d, - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv2d = Conv2dConfig::new([3, 3], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .with_groups(1) + .with_bias(true) + .init_with(record.conv2d); + + Self { + conv2d, + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.matmul(tensor2); - let tensor4 = self.conv2d.forward(tensor3); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.matmul(tensor2); + let tensor4 = self.conv2d.forward(tensor3); - tensor4 + tensor4 + } } - } - }; - - assert_tokens(graph.codegen(), expected); - } - - #[test] - fn test_codegen_clone_tensor() { - let mut graph = BurnGraph::::default(); - - graph.register(MatmulNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor3", 4), - )); - graph.register(Conv2dNode::new( - "conv2d", - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor4", 4), - Data::from([2.]).serialize(), - None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), - )); - graph.register(MatmulNode::new( - TensorType::new_float("tensor3", 4), - TensorType::new_float("tensor4", 4), - TensorType::new_float("output", 4), - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["output".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, }; - use burn::nn::PaddingConfig2d; - use burn::nn::conv::Conv2d; - use burn::nn::conv::Conv2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - conv2d: Conv2d, - phantom: core::marker::PhantomData, - } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv2d = Conv2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .with_groups(1) - .with_bias(true) - .init_with(record.conv2d); - - Self { - conv2d, - phantom: core::marker::PhantomData, - } + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_clone_tensor() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor3", 4), + )); + graph.register(Conv2dNode::new( + "conv2d", + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor4", 4), + Data::from([2.]).serialize(), + None, + Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + )); + graph.register(MatmulNode::new( + TensorType::new_float("tensor3", 4), + TensorType::new_float("tensor4", 4), + TensorType::new_float("output", 4), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["output".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::conv::Conv2d; + use burn::nn::conv::Conv2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + conv2d: Conv2d, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.matmul(tensor2.clone()); - let tensor4 = self.conv2d.forward(tensor2); - let output = tensor3.matmul(tensor4); - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv2d = Conv2dConfig::new([3, 3], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .with_groups(1) + .with_bias(true) + .init_with(record.conv2d); + + Self { + conv2d, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.matmul(tensor2.clone()); + let tensor4 = self.conv2d.forward(tensor2); + let output = tensor3.matmul(tensor4); + + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/batch_norm.rs b/burn-import/src/burn/node/batch_norm.rs index 6f2174bcc4..b706c47cc2 100644 --- a/burn-import/src/burn/node/batch_norm.rs +++ b/burn-import/src/burn/node/batch_norm.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{ConstantRecord, Param, ParamId}, - nn::{BatchNormConfig, BatchNormRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{ConstantRecord, Param, ParamId}, + nn::{BatchNormConfig, BatchNormRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,49 +12,49 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct BatchNormNode { - pub dim: usize, - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub gamma: DataSerialize, - pub beta: DataSerialize, - pub running_mean: DataSerialize, - pub running_var: DataSerialize, - pub config: BatchNormConfig, + pub dim: usize, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub gamma: DataSerialize, + pub beta: DataSerialize, + pub running_mean: DataSerialize, + pub running_var: DataSerialize, + pub config: BatchNormConfig, } impl BatchNormNode { - #[allow(clippy::too_many_arguments)] - pub fn new>( - dim: usize, - name: S, - input: TensorType, - output: TensorType, - gamma: DataSerialize, - beta: DataSerialize, - running_mean: DataSerialize, - running_var: DataSerialize, - config: BatchNormConfig, - ) -> Self { - let dim_tokens = dim.to_tokens(); - - Self { - dim, - field: OtherType::new( - name, - quote! { - BatchNorm - }, - ), - input, - output, - gamma, - beta, - running_mean, - running_var, - config, + #[allow(clippy::too_many_arguments)] + pub fn new>( + dim: usize, + name: S, + input: TensorType, + output: TensorType, + gamma: DataSerialize, + beta: DataSerialize, + running_mean: DataSerialize, + running_var: DataSerialize, + config: BatchNormConfig, + ) -> Self { + let dim_tokens = dim.to_tokens(); + + Self { + dim, + field: OtherType::new( + name, + quote! { + BatchNorm + }, + ), + input, + output, + gamma, + beta, + running_mean, + running_var, + config, + } } - } } macro_rules! batch_norm_serialize { @@ -101,124 +101,124 @@ macro_rules! batch_norm_serialize { } impl NodeCodegen for BatchNormNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let num_features = self.config.num_features.to_tokens(); - let epsilon = self.config.epsilon; - let momentum = self.config.momentum; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; - - let tokens = quote! { - let #name = BatchNormConfig::new(#num_features) - .with_epsilon(#epsilon) - .with_momentum(#momentum) - .#init_line - }; - - Some(tokens) - } - - fn field_serialize(&self, serializer: S) -> Result { - batch_norm_serialize!(self, serializer) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; - - quote! { - let #output = self.#field.forward(#input); + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::BatchNorm"); - imports.register("burn::nn::BatchNormConfig"); - } - - fn into_node(self) -> Node { - Node::BatchNorm(self) - } -} -#[cfg(test)] -mod tests { - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - use burn::{record::FullPrecisionSettings, tensor::Data}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(BatchNormNode::new( - 2, // Batch norm 2d - "norm", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - Data::from([2.]).serialize(), - Data::from([2.]).serialize(), - Data::from([2.]).serialize(), - BatchNormConfig::new(128), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let num_features = self.config.num_features.to_tokens(); + let epsilon = self.config.epsilon; + let momentum = self.config.momentum; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, }; - use burn::nn::BatchNorm; - use burn::nn::BatchNormConfig; - #[derive(Module, Debug)] - pub struct Model { - norm: BatchNorm, - phantom: core::marker::PhantomData, + let tokens = quote! { + let #name = BatchNormConfig::new(#num_features) + .with_epsilon(#epsilon) + .with_momentum(#momentum) + .#init_line + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + batch_norm_serialize!(self, serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::BatchNorm"); + imports.register("burn::nn::BatchNormConfig"); + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let norm = BatchNormConfig::new(128) - .with_epsilon(0.00001f64) - .with_momentum(0.1f64) - .init_with(record.norm); - - Self { - norm, - phantom: core::marker::PhantomData, - } + fn into_node(self) -> Node { + Node::BatchNorm(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{record::FullPrecisionSettings, tensor::Data}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(BatchNormNode::new( + 2, // Batch norm 2d + "norm", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + Data::from([2.]).serialize(), + Data::from([2.]).serialize(), + Data::from([2.]).serialize(), + BatchNormConfig::new(128), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::BatchNorm; + use burn::nn::BatchNormConfig; + + #[derive(Module, Debug)] + pub struct Model { + norm: BatchNorm, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.norm.forward(input); - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let norm = BatchNormConfig::new(128) + .with_epsilon(0.00001f64) + .with_momentum(0.1f64) + .init_with(record.norm); + + Self { + norm, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.norm.forward(input); + + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/binary.rs b/burn-import/src/burn/node/binary.rs index e01c0d302f..f874fc8ab5 100644 --- a/burn-import/src/burn/node/binary.rs +++ b/burn-import/src/burn/node/binary.rs @@ -7,23 +7,23 @@ use std::sync::Arc; #[derive(Clone)] pub enum BinaryType { - Add, - Sub, - Mul, - Div, - Equal, + Add, + Sub, + Mul, + Div, + Equal, } impl BinaryType { - pub(crate) fn as_str(&self) -> &str { - match self { - BinaryType::Add => "add", - BinaryType::Sub => "sub", - BinaryType::Mul => "mul", - BinaryType::Div => "div", - BinaryType::Equal => "equal", + pub(crate) fn as_str(&self) -> &str { + match self { + BinaryType::Add => "add", + BinaryType::Sub => "sub", + BinaryType::Mul => "mul", + BinaryType::Div => "div", + BinaryType::Equal => "equal", + } } - } } // Simple fn pointer that receive input as a token stream and return function call. @@ -32,141 +32,141 @@ type FnPointer = Arc TokenStream>; /// Node for all binary operators. #[derive(Clone, new)] pub struct BinaryNode { - pub lhs: Type, - pub rhs: Type, - pub output: Type, - pub binary_type: BinaryType, - function: FnPointer, + pub lhs: Type, + pub rhs: Type, + pub output: Type, + pub binary_type: BinaryType, + function: FnPointer, } impl std::fmt::Debug for BinaryNode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str( - format!( - "BinaryNode {{ lhs: {:?}, rhs: {:?}, output: {:?}, name: {:?} }}", - self.lhs, - self.rhs, - self.output, - self.binary_type.as_str() - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + format!( + "BinaryNode {{ lhs: {:?}, rhs: {:?}, output: {:?}, name: {:?} }}", + self.lhs, + self.rhs, + self.output, + self.binary_type.as_str() + ) + .as_str(), + ) + } } impl NodeCodegen for BinaryNode { - fn output_types(&self) -> Vec { - vec![self.output.clone()] - } + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } - fn input_types(&self) -> Vec { - vec![self.lhs.clone(), self.rhs.clone()] - } + fn input_types(&self) -> Vec { + vec![self.lhs.clone(), self.rhs.clone()] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + // Get the lhs name in the form of token stream. + let lhs = match &self.lhs { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("lhs must be a tensor or scalar"), + }; - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - // Get the lhs name in the form of token stream. - let lhs = match &self.lhs { - Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), - Type::Scalar(scalar) => { - let name = scalar.name.clone(); - quote! { #name } - } - _ => panic!("lhs must be a tensor or scalar"), - }; - - // Get the rhs name in the form of token stream - let rhs = match &self.rhs { - Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), - Type::Scalar(scalar) => { - let name = scalar.name.clone(); - quote! { #name } - } - _ => panic!("rhs must be a tensor or scalar"), - }; - - let output = &self.output.name(); - let function = (self.function)(lhs, rhs); - - quote! { - let #output = #function; + // Get the rhs name in the form of token stream + let rhs = match &self.rhs { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("rhs must be a tensor or scalar"), + }; + + let output = &self.output.name(); + let function = (self.function)(lhs, rhs); + + quote! { + let #output = #function; + } } - } - fn into_node(self) -> Node { - Node::Binary(self) - } + fn into_node(self) -> Node { + Node::Binary(self) + } } impl BinaryNode { - pub(crate) fn add(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.add(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.add_scalar(#rhs) }, - (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.add_scalar(#lhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs + #rhs }, - _ => panic!("Addition is supported for tensor and scalar only"), - }; - - Self::new(lhs, rhs, output, BinaryType::Add, Arc::new(function)) - } + pub(crate) fn add(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.add(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.add_scalar(#rhs) }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.add_scalar(#lhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs + #rhs }, + _ => panic!("Addition is supported for tensor and scalar only"), + }; - pub(crate) fn sub(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, - _ => panic!("Subtraction is supported for tensor and scalar only"), - }; + Self::new(lhs, rhs, output, BinaryType::Add, Arc::new(function)) + } - Self::new(lhs, rhs, output, BinaryType::Sub, Arc::new(function)) - } + pub(crate) fn sub(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, + _ => panic!("Subtraction is supported for tensor and scalar only"), + }; - pub(crate) fn mul(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.mul(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.mul_scalar(#rhs) }, - (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.mul_scalar(#lhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs * #rhs }, - _ => panic!("Multiplication is supported for tensor and scalar only"), - }; + Self::new(lhs, rhs, output, BinaryType::Sub, Arc::new(function)) + } - Self::new(lhs, rhs, output, BinaryType::Mul, Arc::new(function)) - } + pub(crate) fn mul(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.mul(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.mul_scalar(#rhs) }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.mul_scalar(#lhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs * #rhs }, + _ => panic!("Multiplication is supported for tensor and scalar only"), + }; - pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) }, - (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs }, - _ => panic!("Division is supported for tensor and scalar only"), - }; + Self::new(lhs, rhs, output, BinaryType::Mul, Arc::new(function)) + } - Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function)) - } + pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs }, + _ => panic!("Division is supported for tensor and scalar only"), + }; - pub(crate) fn equal(lhs: Type, rhs: Type, output: Type) -> Self { - let function = match (&lhs, &rhs) { - (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.equal(#rhs) }, - (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs == #rhs }, - _ => panic!("Comparison is supported for tensor to tensor and scalar to scalar only"), - }; + Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function)) + } - Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function)) - } + pub(crate) fn equal(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.equal(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs == #rhs }, + _ => panic!("Comparison is supported for tensor to tensor and scalar to scalar only"), + }; + + Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function)) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; + use burn::record::FullPrecisionSettings; - use super::*; - use crate::burn::graph::BurnGraph; - use crate::burn::node::test::assert_tokens; - use crate::burn::node::tests::one_node_graph; - use crate::burn::{ScalarKind, ScalarType, TensorType}; + use super::*; + use crate::burn::graph::BurnGraph; + use crate::burn::node::test::assert_tokens; + use crate::burn::node::tests::one_node_graph; + use crate::burn::{ScalarKind, ScalarType, TensorType}; - macro_rules! test_binary_operator_on_tensors { + macro_rules! test_binary_operator_on_tensors { ($operator:ident) => {{ one_node_graph( BinaryNode::$operator( @@ -187,158 +187,158 @@ mod tests { }}; } - macro_rules! test_binary_operator_on_tensor_and_scalar { - ($operator:ident, $burn_operator:ident) => {{ - one_node_graph( - BinaryNode::$operator( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), - Type::Tensor(TensorType::new_float("tensor3", 4)), - ), - quote! { - pub fn forward(&self, scalar1: f32, tensor1: Tensor) -> Tensor { - let tensor3 = tensor1.$burn_operator(scalar1); - - tensor3 - } - }, - vec!["scalar1".to_string(), "tensor1".to_string()], - vec!["tensor3".to_string()], - ); - }}; - } - - macro_rules! test_binary_operator_on_scalar_and_scalar { - ($operator:ident, $scalar_operator:tt) => {{ - one_node_graph( - BinaryNode::$operator( - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), - Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), - Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float32)), - ), - quote! { - pub fn forward(&self, scalar1: f32, scalar2: f32) -> f32 { - let scalar3 = scalar1 $scalar_operator scalar2; - - scalar3 - } - }, - vec!["scalar1".to_string(), "scalar2".to_string()], - vec!["scalar3".to_string()], - ); - }}; - } + macro_rules! test_binary_operator_on_tensor_and_scalar { + ($operator:ident, $burn_operator:ident) => {{ + one_node_graph( + BinaryNode::$operator( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Tensor(TensorType::new_float("tensor3", 4)), + ), + quote! { + pub fn forward(&self, scalar1: f32, tensor1: Tensor) -> Tensor { + let tensor3 = tensor1.$burn_operator(scalar1); + + tensor3 + } + }, + vec!["scalar1".to_string(), "tensor1".to_string()], + vec!["tensor3".to_string()], + ); + }}; + } - #[test] - fn test_binary_codegen_add() { - test_binary_operator_on_tensors!(add); - } + macro_rules! test_binary_operator_on_scalar_and_scalar { + ($operator:ident, $scalar_operator:tt) => {{ + one_node_graph( + BinaryNode::$operator( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), + Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float32)), + ), + quote! { + pub fn forward(&self, scalar1: f32, scalar2: f32) -> f32 { + let scalar3 = scalar1 $scalar_operator scalar2; + + scalar3 + } + }, + vec!["scalar1".to_string(), "scalar2".to_string()], + vec!["scalar3".to_string()], + ); + }}; + } - #[test] - fn test_binary_codegen_add_scalar() { - test_binary_operator_on_tensor_and_scalar!(add, add_scalar); - } + #[test] + fn test_binary_codegen_add() { + test_binary_operator_on_tensors!(add); + } - #[test] - fn test_binary_codegen_add_scalars() { - test_binary_operator_on_scalar_and_scalar!(add, +); - } + #[test] + fn test_binary_codegen_add_scalar() { + test_binary_operator_on_tensor_and_scalar!(add, add_scalar); + } - #[test] - fn test_binary_codegen_sub() { - test_binary_operator_on_tensors!(sub); - } + #[test] + fn test_binary_codegen_add_scalars() { + test_binary_operator_on_scalar_and_scalar!(add, +); + } - #[test] - fn test_binary_codegen_sub_scalar() { - test_binary_operator_on_tensor_and_scalar!(sub, sub_scalar); - } + #[test] + fn test_binary_codegen_sub() { + test_binary_operator_on_tensors!(sub); + } - #[test] - fn test_binary_codegen_sub_scalars() { - test_binary_operator_on_scalar_and_scalar!(sub, -); - } + #[test] + fn test_binary_codegen_sub_scalar() { + test_binary_operator_on_tensor_and_scalar!(sub, sub_scalar); + } - #[test] - fn test_binary_codegen_mul() { - test_binary_operator_on_tensors!(mul); - } + #[test] + fn test_binary_codegen_sub_scalars() { + test_binary_operator_on_scalar_and_scalar!(sub, -); + } - #[test] - fn test_binary_codegen_mul_scalar() { - test_binary_operator_on_tensor_and_scalar!(mul, mul_scalar); - } + #[test] + fn test_binary_codegen_mul() { + test_binary_operator_on_tensors!(mul); + } - #[test] - fn test_binary_codegen_mul_scalars() { - test_binary_operator_on_scalar_and_scalar!(mul, *); - } + #[test] + fn test_binary_codegen_mul_scalar() { + test_binary_operator_on_tensor_and_scalar!(mul, mul_scalar); + } - #[test] - fn test_binary_codegen_div() { - test_binary_operator_on_tensors!(div); - } + #[test] + fn test_binary_codegen_mul_scalars() { + test_binary_operator_on_scalar_and_scalar!(mul, *); + } - #[test] - fn test_binary_codegen_div_scalar() { - test_binary_operator_on_tensor_and_scalar!(div, div_scalar); - } + #[test] + fn test_binary_codegen_div() { + test_binary_operator_on_tensors!(div); + } - #[test] - fn test_binary_codegen_div_scalars() { - test_binary_operator_on_scalar_and_scalar!(div, /); - } + #[test] + fn test_binary_codegen_div_scalar() { + test_binary_operator_on_tensor_and_scalar!(div, div_scalar); + } - #[test] - fn test_binary_codegen_equal_tensors() { - let mut graph = BurnGraph::::default(); - let node_gen = BinaryNode::equal( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - Type::Tensor(TensorType::new_bool("tensor3", 4)), - ); - - graph.register(node_gen); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::tensor::Bool; - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + #[test] + fn test_binary_codegen_div_scalars() { + test_binary_operator_on_scalar_and_scalar!(div, /); + } - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[test] + fn test_binary_codegen_equal_tensors() { + let mut graph = BurnGraph::::default(); + let node_gen = BinaryNode::equal( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Type::Tensor(TensorType::new_bool("tensor3", 4)), + ); + + graph.register(node_gen); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Bool; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.equal(tensor2); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.equal(tensor2); - tensor3 + tensor3 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } - #[test] - fn test_binary_codegen_equal_scalars() { - test_binary_operator_on_scalar_and_scalar!(equal, ==); - } + #[test] + fn test_binary_codegen_equal_scalars() { + test_binary_operator_on_scalar_and_scalar!(equal, ==); + } } diff --git a/burn-import/src/burn/node/clip.rs b/burn-import/src/burn/node/clip.rs index 69254a3d3a..3156ab6d73 100644 --- a/burn-import/src/burn/node/clip.rs +++ b/burn-import/src/burn/node/clip.rs @@ -6,182 +6,182 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct ClipNode { - pub input: TensorType, - pub output: TensorType, - pub min: Option, // Should be elem Type - pub max: Option, + pub input: TensorType, + pub output: TensorType, + pub min: Option, // Should be elem Type + pub max: Option, } impl NodeCodegen for ClipNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - - if let Some(min) = self.min { - if let Some(max) = self.max { - quote! { - let #output = #input.clamp(#min, #max); - } - } else { - quote! { - let #output = #input.clamp_min(#min); + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + if let Some(min) = self.min { + if let Some(max) = self.max { + quote! { + let #output = #input.clamp(#min, #max); + } + } else { + quote! { + let #output = #input.clamp_min(#min); + } + } + } else if let Some(max) = self.max { + return quote! { + let #output = #input.clamp_max(#max); + }; + } else { + panic!("Clip node must have at least one min or max value"); } - } - } else if let Some(max) = self.max { - return quote! { - let #output = #input.clamp_max(#max); - }; - } else { - panic!("Clip node must have at least one min or max value"); } - } - fn into_node(self) -> Node { - Node::Clip(self) - } + fn into_node(self) -> Node { + Node::Clip(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; + use burn::record::FullPrecisionSettings; - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - #[test] - fn codegen_nodes_min_max() { - let mut graph = BurnGraph::::default(); + #[test] + fn codegen_nodes_min_max() { + let mut graph = BurnGraph::::default(); - graph.register(ClipNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - Some(0.0), - Some(1.0), - )); + graph.register(ClipNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + Some(0.0), + Some(1.0), + )); - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.clamp(0f64, 1f64); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.clamp(0f64, 1f64); - tensor2 + tensor2 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } - #[test] - fn codegen_nodes_min() { - let mut graph = BurnGraph::::default(); + #[test] + fn codegen_nodes_min() { + let mut graph = BurnGraph::::default(); - graph.register(ClipNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - Some(0.0), - None, - )); + graph.register(ClipNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + Some(0.0), + None, + )); - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.clamp_min(0f64); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.clamp_min(0f64); - tensor2 + tensor2 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } - #[test] - fn codegen_nodes_max() { - let mut graph = BurnGraph::::default(); + #[test] + fn codegen_nodes_max() { + let mut graph = BurnGraph::::default(); - graph.register(ClipNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - None, - Some(1.0), - )); + graph.register(ClipNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + None, + Some(1.0), + )); - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.clamp_max(1f64); + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.clamp_max(1f64); - tensor2 + tensor2 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/concat.rs b/burn-import/src/burn/node/concat.rs index 20ecc2b0fe..a0cb6e1893 100644 --- a/burn-import/src/burn/node/concat.rs +++ b/burn-import/src/burn/node/concat.rs @@ -7,101 +7,100 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct ConcatNode { - pub inputs: Vec, - pub output: TensorType, - pub dim: usize, + pub inputs: Vec, + pub output: TensorType, + pub dim: usize, } impl NodeCodegen for ConcatNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - - fn input_types(&self) -> Vec { - self - .inputs - .iter() - .map(|t| Type::Tensor(t.clone())) - .collect() - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let dim = self.dim.to_tokens(); - let inputs = self - .inputs - .iter() - .map(|t| scope.tensor_use_owned(t, node_position)); - - let output = &self.output.name; - - quote! { - let #output = burn::tensor::Tensor::cat([#(#inputs),*].into(), #dim); + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] } - } - fn into_node(self) -> Node { - Node::Concat(self) - } + fn input_types(&self) -> Vec { + self.inputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let dim = self.dim.to_tokens(); + let inputs = self + .inputs + .iter() + .map(|t| scope.tensor_use_owned(t, node_position)); + + let output = &self.output.name; + + quote! { + let #output = burn::tensor::Tensor::cat([#(#inputs),*].into(), #dim); + } + } + + fn into_node(self) -> Node { + Node::Concat(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{concat::ConcatNode, test::assert_tokens}, - TensorType, - }; - - #[test] - fn test_codegen_concat() { - let mut graph = BurnGraph::::default(); - - graph.register(ConcatNode::new( - vec![ - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - ], - TensorType::new_float("tensor3", 4), - 1, - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + use burn::record::FullPrecisionSettings; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{concat::ConcatNode, test::assert_tokens}, + TensorType, + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + #[test] + fn test_codegen_concat() { + let mut graph = BurnGraph::::default(); + + graph.register(ConcatNode::new( + vec![ + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + ], + TensorType::new_float("tensor3", 4), + 1, + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1); + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } + } - tensor3 + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1); + + tensor3 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/constant.rs b/burn-import/src/burn/node/constant.rs index 886b6a85f9..d091529e1e 100644 --- a/burn-import/src/burn/node/constant.rs +++ b/burn-import/src/burn/node/constant.rs @@ -1,9 +1,9 @@ use super::{Node, NodeCodegen}; use crate::burn::{ScalarKind, ScalarType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::ParamId, - record::{ParamSerde, PrecisionSettings}, - tensor::DataSerialize, + module::ParamId, + record::{ParamSerde, PrecisionSettings}, + tensor::DataSerialize, }; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -11,178 +11,178 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct ConstantNode { - pub name: String, - pub value: ConstantValue, - pub output: Type, + pub name: String, + pub value: ConstantValue, + pub output: Type, } #[derive(Debug, Clone)] pub enum TensorValue { - Float(DataSerialize), - Int(DataSerialize), - // TODO Support bool serialization (@antimora 8/26/2023) + Float(DataSerialize), + Int(DataSerialize), + // TODO Support bool serialization (@antimora 8/26/2023) } #[derive(Debug, Clone, new)] pub enum ConstantValue { - /// Float constant. - Float32(f32), - Float64(f64), + /// Float constant. + Float32(f32), + Float64(f64), - /// Integer constant. - Int32(i32), - Int64(i64), + /// Integer constant. + Int32(i32), + Int64(i64), - // Boolean constant. - Bool(bool), + // Boolean constant. + Bool(bool), - /// Tensor constant. - Tensor(TensorType, TensorValue), + /// Tensor constant. + Tensor(TensorType, TensorValue), } impl ConstantValue { - pub fn ty_tokens(&self) -> TokenStream { - match self { - ConstantValue::Float32(_) => quote! { f32 }, - ConstantValue::Float64(_) => quote! { f64 }, - ConstantValue::Int32(_) => quote! { i32 }, - ConstantValue::Int64(_) => quote! { i64 }, - ConstantValue::Bool(_) => quote! { bool }, - ConstantValue::Tensor(tensor_type, _) => { - let ty = tensor_type.ty(); - quote! { burn::module::Param<#ty>} - } + pub fn ty_tokens(&self) -> TokenStream { + match self { + ConstantValue::Float32(_) => quote! { f32 }, + ConstantValue::Float64(_) => quote! { f64 }, + ConstantValue::Int32(_) => quote! { i32 }, + ConstantValue::Int64(_) => quote! { i64 }, + ConstantValue::Bool(_) => quote! { bool }, + ConstantValue::Tensor(tensor_type, _) => { + let ty = tensor_type.ty(); + quote! { burn::module::Param<#ty>} + } + } } - } - pub fn val_tokens(&self) -> TokenStream { - match self { - ConstantValue::Float32(val) => quote! { #val }, - ConstantValue::Float64(val) => quote! { #val }, - ConstantValue::Int32(val) => quote! { #val }, - ConstantValue::Int64(val) => quote! { #val }, - ConstantValue::Bool(val) => quote! { #val }, - ConstantValue::Tensor(_, _) => { - panic!("Tensor constant is not assignable.") - } + pub fn val_tokens(&self) -> TokenStream { + match self { + ConstantValue::Float32(val) => quote! { #val }, + ConstantValue::Float64(val) => quote! { #val }, + ConstantValue::Int32(val) => quote! { #val }, + ConstantValue::Int64(val) => quote! { #val }, + ConstantValue::Bool(val) => quote! { #val }, + ConstantValue::Tensor(_, _) => { + panic!("Tensor constant is not assignable.") + } + } } - } } impl ConstantNode { - pub fn new(name: String, value: ConstantValue, output: Type) -> Self { - Self { - name, - value, - output, + pub fn new(name: String, value: ConstantValue, output: Type) -> Self { + Self { + name, + value, + output, + } } - } - pub fn constant_value_into_type(&self) -> Type { - let name = Ident::new(self.name.as_str(), Span::call_site()); - match &self.value { - ConstantValue::Float32(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Float32, - }), - ConstantValue::Float64(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Float64, - }), - ConstantValue::Int32(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Int32, - }), - ConstantValue::Int64(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Int64, - }), - ConstantValue::Bool(_) => Type::Scalar(ScalarType { - name, - kind: ScalarKind::Bool, - }), - - ConstantValue::Tensor(tensor_type, _) => Type::Tensor(tensor_type.clone()), + pub fn constant_value_into_type(&self) -> Type { + let name = Ident::new(self.name.as_str(), Span::call_site()); + match &self.value { + ConstantValue::Float32(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Float32, + }), + ConstantValue::Float64(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Float64, + }), + ConstantValue::Int32(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Int32, + }), + ConstantValue::Int64(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Int64, + }), + ConstantValue::Bool(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Bool, + }), + + ConstantValue::Tensor(tensor_type, _) => Type::Tensor(tensor_type.clone()), + } } - } } impl NodeCodegen for ConstantNode { - fn output_types(&self) -> Vec { - vec![self.output.clone()] - } - - fn input_types(&self) -> Vec { - vec![] - } - - fn field_type(&self) -> Option { - match &self.value { - ConstantValue::Tensor(tensor_type, _) => Some(Type::Tensor(tensor_type.clone())), - _ => None, + fn output_types(&self) -> Vec { + vec![self.output.clone()] } - } - fn field_init(&self, with_record: bool) -> Option { - match &self.value { - ConstantValue::Tensor(tensor_type, _) => { - let ty = tensor_type.ty(); - let name = Ident::new(self.name.as_ref(), Span::call_site()); - let shape = tensor_type.clone().shape.unwrap().to_tokens(); - let dim = tensor_type.clone().dim.to_tokens(); - - if with_record { - Some(quote! { - let #name = record.#name.map(|tensor| tensor.set_require_grad(false)); - }) - } else { - Some(quote! { - let #name: burn::module::Param<#ty> = burn::module::Param::new( - burn::module::ParamId::new(), - Tensor::::zeros(#shape).set_require_grad(false), - ); - }) - } - } - _ => None, + fn input_types(&self) -> Vec { + vec![] } - } - fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { - let name = Ident::new(self.name.as_ref(), Span::call_site()); - let output = self.output.name(); + fn field_type(&self) -> Option { + match &self.value { + ConstantValue::Tensor(tensor_type, _) => Some(Type::Tensor(tensor_type.clone())), + _ => None, + } + } - match &self.value { - ConstantValue::Tensor(_, _) => { - quote! { - let #output = self.#name.val(); + fn field_init(&self, with_record: bool) -> Option { + match &self.value { + ConstantValue::Tensor(tensor_type, _) => { + let ty = tensor_type.ty(); + let name = Ident::new(self.name.as_ref(), Span::call_site()); + let shape = tensor_type.clone().shape.unwrap().to_tokens(); + let dim = tensor_type.clone().dim.to_tokens(); + + if with_record { + Some(quote! { + let #name = record.#name.map(|tensor| tensor.set_require_grad(false)); + }) + } else { + Some(quote! { + let #name: burn::module::Param<#ty> = burn::module::Param::new( + burn::module::ParamId::new(), + Tensor::::zeros(#shape).set_require_grad(false), + ); + }) + } + } + _ => None, } - } - _ => { - let val = self.value.val_tokens(); - let ty = self.value.ty_tokens(); + } - quote! { - let #output: #ty = #val; + fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { + let name = Ident::new(self.name.as_ref(), Span::call_site()); + let output = self.output.name(); + + match &self.value { + ConstantValue::Tensor(_, _) => { + quote! { + let #output = self.#name.val(); + } + } + _ => { + let val = self.value.val_tokens(); + let ty = self.value.ty_tokens(); + + quote! { + let #output: #ty = #val; + } + } } - } } - } - - fn into_node(self) -> Node { - Node::Constant(self) - } - - fn field_serialize(&self, serializer: S) -> Result { - if let ConstantValue::Tensor(_, ds) = &self.value { - let data: DataSerialize = match ds { - TensorValue::Float(data) => data.clone().convert(), - TensorValue::Int(data) => data.clone().convert(), - }; - let data = ParamSerde::new(ParamId::new().into_string(), data); - return data.serialize(serializer); + + fn into_node(self) -> Node { + Node::Constant(self) } - S::serialize_none(serializer) - } + fn field_serialize(&self, serializer: S) -> Result { + if let ConstantValue::Tensor(_, ds) = &self.value { + let data: DataSerialize = match ds { + TensorValue::Float(data) => data.clone().convert(), + TensorValue::Int(data) => data.clone().convert(), + }; + let data = ParamSerde::new(ParamId::new().into_string(), data); + return data.serialize(serializer); + } + + S::serialize_none(serializer) + } } // TODO add test missing for constant node (@antimora 8/2/2023) diff --git a/burn-import/src/burn/node/conv1d.rs b/burn-import/src/burn/node/conv1d.rs index e3d5b38bc2..0c3e568880 100644 --- a/burn-import/src/burn/node/conv1d.rs +++ b/burn-import/src/burn/node/conv1d.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv1dConfig, Conv1dRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{ConstantRecord, Param, ParamId}, + nn::conv::{Conv1dConfig, Conv1dRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,191 +12,191 @@ use serde::Serialize; #[derive(Clone, Debug)] pub struct Conv1dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub data_weights: DataSerialize, - pub data_bias: Option>, - pub config: Conv1dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub data_weights: DataSerialize, + pub data_bias: Option>, + pub config: Conv1dConfig, } impl Conv1dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - data_weights: DataSerialize, - data_bias: Option>, - config: Conv1dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Conv1d - }, - ), - input, - output, - data_weights, - data_bias, - config, + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + data_weights: DataSerialize, + data_bias: Option>, + config: Conv1dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Conv1d + }, + ), + input, + output, + data_weights, + data_bias, + config, + } } - } } impl NodeCodegen for Conv1dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let channels_in = self.config.channels_in.to_tokens(); - let channels_out = self.config.channels_out.to_tokens(); - let kernel_size = self.config.kernel_size.to_tokens(); - let stride = self.config.stride.to_tokens(); - let dilation = self.config.dilation.to_tokens(); - let groups = self.config.groups.to_tokens(); - let padding = self.config.padding.to_tokens(); - let bias = self.config.bias; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } - let tokens = quote! { - let #name = Conv1dConfig::new(#channels_in, #channels_out, #kernel_size) - .with_stride(#stride) - .with_padding(#padding) - .with_dilation(#dilation) - .with_groups(#groups) - .with_bias(#bias) - .#init_line - }; + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let channels_in = self.config.channels_in.to_tokens(); + let channels_out = self.config.channels_out.to_tokens(); + let kernel_size = self.config.kernel_size.to_tokens(); + let stride = self.config.stride.to_tokens(); + let dilation = self.config.dilation.to_tokens(); + let groups = self.config.groups.to_tokens(); + let padding = self.config.padding.to_tokens(); + let bias = self.config.bias; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, + }; - Some(tokens) - } - - fn field_serialize(&self, serializer: S) -> Result { - let record = Conv1dRecord:: { - weight: Param::new( - ParamId::new(), - Tensor::from_data(self.data_weights.clone().convert()), - ), - bias: self - .data_bias - .as_ref() - .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), - stride: ConstantRecord::new(), - kernel_size: ConstantRecord::new(), - dilation: ConstantRecord::new(), - groups: ConstantRecord::new(), - padding: ConstantRecord::new(), - }; + let tokens = quote! { + let #name = Conv1dConfig::new(#channels_in, #channels_out, #kernel_size) + .with_stride(#stride) + .with_padding(#padding) + .with_dilation(#dilation) + .with_groups(#groups) + .with_bias(#bias) + .#init_line + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let record = Conv1dRecord:: { + weight: Param::new( + ParamId::new(), + Tensor::from_data(self.data_weights.clone().convert()), + ), + bias: self + .data_bias + .as_ref() + .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), + stride: ConstantRecord::new(), + kernel_size: ConstantRecord::new(), + dilation: ConstantRecord::new(), + groups: ConstantRecord::new(), + padding: ConstantRecord::new(), + }; + + let item = Record::into_item::(record); + item.serialize(serializer) + } - let item = Record::into_item::(record); - item.serialize(serializer) - } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + quote! { + let #output = self.#field.forward(#input); + } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig1d"); + imports.register("burn::nn::conv::Conv1d"); + imports.register("burn::nn::conv::Conv1dConfig"); + } - quote! { - let #output = self.#field.forward(#input); + fn into_node(self) -> Node { + Node::Conv1d(self) } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig1d"); - imports.register("burn::nn::conv::Conv1d"); - imports.register("burn::nn::conv::Conv1dConfig"); - } - - fn into_node(self) -> Node { - Node::Conv1d(self) - } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{conv1d::Conv1dNode, test::assert_tokens}, - TensorType, - }; - use burn::{ - nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, - }; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(Conv1dNode::new( - "conv1d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - None, - Conv1dConfig::new(3, 3, 3).with_padding(PaddingConfig1d::Valid), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig1d; - use burn::nn::conv::Conv1d; - use burn::nn::conv::Conv1dConfig; - - #[derive(Module, Debug)] - pub struct Model { - conv1d: Conv1d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{conv1d::Conv1dNode, test::assert_tokens}, + TensorType, + }; + use burn::{ + nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv1d = Conv1dConfig::new(3, 3, 3) - .with_stride(1) - .with_padding(PaddingConfig1d::Valid) - .with_dilation(1) - .with_groups(1) - .with_bias(true) - .init_with(record.conv1d); - - Self { - conv1d, - phantom: core::marker::PhantomData, - } + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(Conv1dNode::new( + "conv1d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + None, + Conv1dConfig::new(3, 3, 3).with_padding(PaddingConfig1d::Valid), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig1d; + use burn::nn::conv::Conv1d; + use burn::nn::conv::Conv1dConfig; + + #[derive(Module, Debug)] + pub struct Model { + conv1d: Conv1d, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.conv1d.forward(input); - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv1d = Conv1dConfig::new(3, 3, 3) + .with_stride(1) + .with_padding(PaddingConfig1d::Valid) + .with_dilation(1) + .with_groups(1) + .with_bias(true) + .init_with(record.conv1d); + + Self { + conv1d, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.conv1d.forward(input); + + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/conv2d.rs b/burn-import/src/burn/node/conv2d.rs index d6b0059794..9b3c9f4408 100644 --- a/burn-import/src/burn/node/conv2d.rs +++ b/burn-import/src/burn/node/conv2d.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv2dConfig, Conv2dRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{ConstantRecord, Param, ParamId}, + nn::conv::{Conv2dConfig, Conv2dRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,190 +12,190 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct Conv2dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub data_weights: DataSerialize, - pub data_bias: Option>, - pub config: Conv2dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub data_weights: DataSerialize, + pub data_bias: Option>, + pub config: Conv2dConfig, } impl Conv2dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - data_weights: DataSerialize, - data_bias: Option>, - config: Conv2dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Conv2d - }, - ), - input, - output, - data_weights, - data_bias, - config, + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + data_weights: DataSerialize, + data_bias: Option>, + config: Conv2dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Conv2d + }, + ), + input, + output, + data_weights, + data_bias, + config, + } } - } } impl NodeCodegen for Conv2dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let channels = self.config.channels.to_tokens(); - let kernel_size = self.config.kernel_size.to_tokens(); - let stride = self.config.stride.to_tokens(); - let dilation = self.config.dilation.to_tokens(); - let groups = self.config.groups.to_tokens(); - let padding = self.config.padding.to_tokens(); - let bias = self.config.bias; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } - let tokens = quote! { - let #name = Conv2dConfig::new(#channels, #kernel_size) - .with_stride(#stride) - .with_padding(#padding) - .with_dilation(#dilation) - .with_groups(#groups) - .with_bias(#bias) - .#init_line - }; + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let channels = self.config.channels.to_tokens(); + let kernel_size = self.config.kernel_size.to_tokens(); + let stride = self.config.stride.to_tokens(); + let dilation = self.config.dilation.to_tokens(); + let groups = self.config.groups.to_tokens(); + let padding = self.config.padding.to_tokens(); + let bias = self.config.bias; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, + }; - Some(tokens) - } - - fn field_serialize(&self, serializer: S) -> Result { - let record = Conv2dRecord:: { - weight: Param::new( - ParamId::new(), - Tensor::from_data(self.data_weights.clone().convert()), - ), - bias: self - .data_bias - .as_ref() - .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), - stride: [ConstantRecord::new(); 2], - kernel_size: [ConstantRecord::new(); 2], - dilation: [ConstantRecord::new(); 2], - groups: ConstantRecord::new(), - padding: ConstantRecord::new(), - }; + let tokens = quote! { + let #name = Conv2dConfig::new(#channels, #kernel_size) + .with_stride(#stride) + .with_padding(#padding) + .with_dilation(#dilation) + .with_groups(#groups) + .with_bias(#bias) + .#init_line + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let record = Conv2dRecord:: { + weight: Param::new( + ParamId::new(), + Tensor::from_data(self.data_weights.clone().convert()), + ), + bias: self + .data_bias + .as_ref() + .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), + stride: [ConstantRecord::new(); 2], + kernel_size: [ConstantRecord::new(); 2], + dilation: [ConstantRecord::new(); 2], + groups: ConstantRecord::new(), + padding: ConstantRecord::new(), + }; + + let item = Record::into_item::(record); + item.serialize(serializer) + } - let item = Record::into_item::(record); - item.serialize(serializer) - } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + quote! { + let #output = self.#field.forward(#input); + } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig2d"); + imports.register("burn::nn::conv::Conv2d"); + imports.register("burn::nn::conv::Conv2dConfig"); + } - quote! { - let #output = self.#field.forward(#input); + fn into_node(self) -> Node { + Node::Conv2d(self) } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig2d"); - imports.register("burn::nn::conv::Conv2d"); - imports.register("burn::nn::conv::Conv2dConfig"); - } - - fn into_node(self) -> Node { - Node::Conv2d(self) - } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{conv2d::Conv2dNode, test::assert_tokens}, - TensorType, - }; - use burn::{ - nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, - }; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(Conv2dNode::new( - "conv2d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig2d; - use burn::nn::conv::Conv2d; - use burn::nn::conv::Conv2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - conv2d: Conv2d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{conv2d::Conv2dNode, test::assert_tokens}, + TensorType, + }; + use burn::{ + nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let conv2d = Conv2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .with_groups(1) - .with_bias(true) - .init_with(record.conv2d); - - Self { - conv2d, - phantom: core::marker::PhantomData, - } + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(Conv2dNode::new( + "conv2d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + None, + Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::conv::Conv2d; + use burn::nn::conv::Conv2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + conv2d: Conv2d, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.conv2d.forward(input); - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let conv2d = Conv2dConfig::new([3, 3], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .with_groups(1) + .with_bias(true) + .init_with(record.conv2d); + + Self { + conv2d, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.conv2d.forward(input); + + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/dropout.rs b/burn-import/src/burn/node/dropout.rs index 5c8efa3045..de61653e13 100644 --- a/burn-import/src/burn/node/dropout.rs +++ b/burn-import/src/burn/node/dropout.rs @@ -8,138 +8,138 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; #[derive(Debug, Clone)] pub struct DropoutNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub config: DropoutConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub config: DropoutConfig, } impl DropoutNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - config: DropoutConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Dropout - }, - ), - input, - output, - config, + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + config: DropoutConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Dropout + }, + ), + input, + output, + config, + } } - } } impl NodeCodegen for DropoutNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - - let prob = self.config.prob.to_tokens(); - - let init_line = quote! { - init(); - }; - - let tokens = quote! { - let #name = DropoutConfig::new(#prob) - .#init_line - }; - - Some(tokens) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; - - quote! { - let #output = self.#field.forward(#input); + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::Dropout"); - imports.register("burn::nn::DropoutConfig"); - } - - fn into_node(self) -> Node { - Node::Dropout(self) - } - - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) - } -} -#[cfg(test)] -mod tests { - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - use burn::{nn::DropoutConfig, record::FullPrecisionSettings}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(DropoutNode::new( - "dropout", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - DropoutConfig::new(0.5), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; + + let prob = self.config.prob.to_tokens(); + + let init_line = quote! { + init(); + }; + + let tokens = quote! { + let #name = DropoutConfig::new(#prob) + .#init_line }; - use burn::nn::Dropout; - use burn::nn::DropoutConfig; - #[derive(Module, Debug)] - pub struct Model { - dropout: Dropout, - phantom: core::marker::PhantomData, + Some(tokens) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + quote! { + let #output = self.#field.forward(#input); } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::Dropout"); + imports.register("burn::nn::DropoutConfig"); + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let dropout = DropoutConfig::new(0.5) - .init(); + fn into_node(self) -> Node { + Node::Dropout(self) + } + + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{nn::DropoutConfig, record::FullPrecisionSettings}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(DropoutNode::new( + "dropout", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + DropoutConfig::new(0.5), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::Dropout; + use burn::nn::DropoutConfig; + + #[derive(Module, Debug)] + pub struct Model { + dropout: Dropout, + phantom: core::marker::PhantomData, - Self { - dropout, - phantom: core::marker::PhantomData, - } } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.dropout.forward(input); - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let dropout = DropoutConfig::new(0.5) + .init(); + + Self { + dropout, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.dropout.forward(input); + + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/gather.rs b/burn-import/src/burn/node/gather.rs index 933b76f94d..2c0e6bc9ea 100644 --- a/burn-import/src/burn/node/gather.rs +++ b/burn-import/src/burn/node/gather.rs @@ -6,101 +6,101 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct GatherNode { - pub input: TensorType, - pub index: TensorType, - pub output: TensorType, - pub dim: usize, + pub input: TensorType, + pub index: TensorType, + pub output: TensorType, + pub dim: usize, } impl NodeCodegen for GatherNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - - fn input_types(&self) -> Vec { - vec![ - Type::Tensor(self.input.clone()), - Type::Tensor(self.index.clone()), - ] - } - - fn forward( - &self, - scope: &mut crate::burn::Scope, - node_position: usize, - ) -> proc_macro2::TokenStream { - let dim = self.dim.to_tokens(); - let input = scope.tensor_use_owned(&self.input, node_position); - let index = scope.tensor_use_owned(&self.index, node_position); - let output = &self.output.name; - - quote! { - let #output = #input.gather(#dim, #index); + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] } - } - fn into_node(self) -> super::Node { - Node::Gather(self) - } + fn input_types(&self) -> Vec { + vec![ + Type::Tensor(self.input.clone()), + Type::Tensor(self.index.clone()), + ] + } + + fn forward( + &self, + scope: &mut crate::burn::Scope, + node_position: usize, + ) -> proc_macro2::TokenStream { + let dim = self.dim.to_tokens(); + let input = scope.tensor_use_owned(&self.input, node_position); + let index = scope.tensor_use_owned(&self.index, node_position); + let output = &self.output.name; + + quote! { + let #output = #input.gather(#dim, #index); + } + } + + fn into_node(self) -> super::Node { + Node::Gather(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{gather::GatherNode, test::assert_tokens}, - TensorType, - }; - - #[test] - fn test_codegen_gather() { - let mut graph = BurnGraph::::default(); - - graph.register(GatherNode::new( - TensorType::new_float("tensor1", 2), - TensorType::new_int("tensor2", 2), - TensorType::new_float("tensor3", 2), - 1, - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::tensor::Int; - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + use burn::record::FullPrecisionSettings; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{gather::GatherNode, test::assert_tokens}, + TensorType, + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + #[test] + fn test_codegen_gather() { + let mut graph = BurnGraph::::default(); + + graph.register(GatherNode::new( + TensorType::new_float("tensor1", 2), + TensorType::new_int("tensor2", 2), + TensorType::new_float("tensor3", 2), + 1, + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.gather(1, tensor2); + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } + } + + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.gather(1, tensor2); - tensor3 + tensor3 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/global_avg_pool.rs b/burn-import/src/burn/node/global_avg_pool.rs index 76c18992ec..80d6f3cec2 100644 --- a/burn-import/src/burn/node/global_avg_pool.rs +++ b/burn-import/src/burn/node/global_avg_pool.rs @@ -13,203 +13,203 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type}; /// is equivalent to global average pooling. #[derive(Debug, Clone)] pub struct GlobalAvgPoolNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, } impl GlobalAvgPoolNode { - pub fn new>(name: S, input: TensorType, output: TensorType) -> Self { - // Depending on the input dimension, we need to use a different type nn module - let field_type = match input.dim { - 3 => quote! { - AdaptiveAvgPool1d - }, - 4 => quote! { - AdaptiveAvgPool2d - }, - dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), - }; + pub fn new>(name: S, input: TensorType, output: TensorType) -> Self { + // Depending on the input dimension, we need to use a different type nn module + let field_type = match input.dim { + 3 => quote! { + AdaptiveAvgPool1d + }, + 4 => quote! { + AdaptiveAvgPool2d + }, + dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), + }; - Self { - field: OtherType::new(name, field_type), - input, - output, + Self { + field: OtherType::new(name, field_type), + input, + output, + } } - } } impl NodeCodegen for GlobalAvgPoolNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - - let tokens = match self.input.dim { - 3 => { - quote! { - let #name = AdaptiveAvgPool1dConfig::new(1) - .init(); - } - } - 4 => { - quote! { - let #name = AdaptiveAvgPool2dConfig::new([1,1]) - .init(); - } - } - dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), - }; - - Some(tokens) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; - - quote! { - let #output = self.#field.forward(#input); + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] } - } - - fn register_imports(&self, imports: &mut BurnImports) { - match self.input.dim { - 3 => { - imports.register("burn::nn::pool::AdaptiveAvgPool1d"); - imports.register("burn::nn::pool::AdaptiveAvgPool1dConfig"); - } - 4 => { - imports.register("burn::nn::pool::AdaptiveAvgPool2d"); - imports.register("burn::nn::pool::AdaptiveAvgPool2dConfig"); - } - dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) } - } - - fn into_node(self) -> Node { - Node::GlobalAvgPool(self) - } - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) - } -} + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; -#[cfg(test)] -mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{global_avg_pool::GlobalAvgPoolNode, test::assert_tokens}, - TensorType, - }; - use burn::record::FullPrecisionSettings; - - #[test] - fn test_codegen_2d() { - let mut graph = BurnGraph::::default(); - - graph.register(GlobalAvgPoolNode::new( - "global_avg_pool1", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, + let tokens = match self.input.dim { + 3 => { + quote! { + let #name = AdaptiveAvgPool1dConfig::new(1) + .init(); + } + } + 4 => { + quote! { + let #name = AdaptiveAvgPool2dConfig::new([1,1]) + .init(); + } + } + dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), }; - use burn::nn::pool::AdaptiveAvgPool2d; - use burn::nn::pool::AdaptiveAvgPool2dConfig; - #[derive(Module, Debug)] - pub struct Model { - global_avg_pool1: AdaptiveAvgPool2d, - phantom: core::marker::PhantomData, - } + Some(tokens) + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let global_avg_pool1 = AdaptiveAvgPool2dConfig::new([1, 1]) - .init(); + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - Self { - global_avg_pool1, - phantom: core::marker::PhantomData, - } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.global_avg_pool1.forward(input); + quote! { + let #output = self.#field.forward(#input); + } + } - output + fn register_imports(&self, imports: &mut BurnImports) { + match self.input.dim { + 3 => { + imports.register("burn::nn::pool::AdaptiveAvgPool1d"); + imports.register("burn::nn::pool::AdaptiveAvgPool1dConfig"); + } + 4 => { + imports.register("burn::nn::pool::AdaptiveAvgPool2d"); + imports.register("burn::nn::pool::AdaptiveAvgPool2dConfig"); } + dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"), } - }; + } - assert_tokens(graph.codegen(), expected); - } + fn into_node(self) -> Node { + Node::GlobalAvgPool(self) + } - #[test] - fn test_codegen_1d() { - let mut graph = BurnGraph::::default(); + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } +} - graph.register(GlobalAvgPoolNode::new( - "global_avg_pool1", - TensorType::new_float("input", 3), - TensorType::new_float("output", 3), - )); +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{global_avg_pool::GlobalAvgPoolNode, test::assert_tokens}, + TensorType, + }; + use burn::record::FullPrecisionSettings; + + #[test] + fn test_codegen_2d() { + let mut graph = BurnGraph::::default(); + + graph.register(GlobalAvgPoolNode::new( + "global_avg_pool1", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::pool::AdaptiveAvgPool2d; + use burn::nn::pool::AdaptiveAvgPool2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + global_avg_pool1: AdaptiveAvgPool2d, + phantom: core::marker::PhantomData, + } + + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let global_avg_pool1 = AdaptiveAvgPool2dConfig::new([1, 1]) + .init(); - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + Self { + global_avg_pool1, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.global_avg_pool1.forward(input); - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, + output + } + } }; - use burn::nn::pool::AdaptiveAvgPool1d; - use burn::nn::pool::AdaptiveAvgPool1dConfig; - #[derive(Module, Debug)] - pub struct Model { - global_avg_pool1: AdaptiveAvgPool1d, - phantom: core::marker::PhantomData, - } + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_1d() { + let mut graph = BurnGraph::::default(); + + graph.register(GlobalAvgPoolNode::new( + "global_avg_pool1", + TensorType::new_float("input", 3), + TensorType::new_float("output", 3), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::pool::AdaptiveAvgPool1d; + use burn::nn::pool::AdaptiveAvgPool1dConfig; + + #[derive(Module, Debug)] + pub struct Model { + global_avg_pool1: AdaptiveAvgPool1d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let global_avg_pool1 = AdaptiveAvgPool1dConfig::new(1) - .init(); + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let global_avg_pool1 = AdaptiveAvgPool1dConfig::new(1) + .init(); - Self { - global_avg_pool1, - phantom: core::marker::PhantomData, + Self { + global_avg_pool1, + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.global_avg_pool1.forward(input); + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.global_avg_pool1.forward(input); - output + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/linear.rs b/burn-import/src/burn/node/linear.rs index 1e9bd5a1c2..b413c2c4a0 100644 --- a/burn-import/src/burn/node/linear.rs +++ b/burn-import/src/burn/node/linear.rs @@ -1,10 +1,10 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ - module::{Param, ParamId}, - nn::{LinearConfig, LinearRecord}, - record::{PrecisionSettings, Record}, - tensor::{DataSerialize, Tensor}, + module::{Param, ParamId}, + nn::{LinearConfig, LinearRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,167 +12,167 @@ use serde::Serialize; #[derive(Debug, Clone)] pub struct LinearNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub data_weights: DataSerialize, - pub data_bias: Option>, - pub config: LinearConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub data_weights: DataSerialize, + pub data_bias: Option>, + pub config: LinearConfig, } impl LinearNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - data_weights: DataSerialize, - data_bias: Option>, - config: LinearConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - Linear - }, - ), - input, - output, - data_weights, - data_bias, - config, + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + data_weights: DataSerialize, + data_bias: Option>, + config: LinearConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + Linear + }, + ), + input, + output, + data_weights, + data_bias, + config, + } } - } } impl NodeCodegen for LinearNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, with_record: bool) -> Option { - let name = &self.field.name; - let d_input = self.config.d_input.to_tokens(); - let d_output = self.config.d_output.to_tokens(); - let bias = self.config.bias; - - let init_line = match with_record { - true => quote! { - init_with(record.#name); - }, - false => quote! { - init(); - }, - }; - - let tokens = quote! { - let #name = LinearConfig::new(#d_input, #d_output) - .with_bias(#bias) - .#init_line - }; - - Some(tokens) - } - - fn field_serialize(&self, serializer: S) -> Result { - let record = LinearRecord:: { - weight: Param::new( - ParamId::new(), - Tensor::from_data(self.data_weights.clone().convert()), - ), - bias: self - .data_bias - .as_ref() - .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), - }; - - let item = Record::into_item::(record); - item.serialize(serializer) - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; - - quote! { - let #output = self.#field.forward(#input); + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::Linear"); - imports.register("burn::nn::LinearConfig"); - } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } - fn into_node(self) -> Node { - Node::Linear(self) - } -} + fn field_init(&self, with_record: bool) -> Option { + let name = &self.field.name; + let d_input = self.config.d_input.to_tokens(); + let d_output = self.config.d_output.to_tokens(); + let bias = self.config.bias; + + let init_line = match with_record { + true => quote! { + init_with(record.#name); + }, + false => quote! { + init(); + }, + }; -#[cfg(test)] -mod tests { - use super::*; - use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; - use burn::{record::FullPrecisionSettings, tensor::Data}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(LinearNode::new( - "linear", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - Data::from([2.]).serialize(), - None, - LinearConfig::new(128, 128), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, + let tokens = quote! { + let #name = LinearConfig::new(#d_input, #d_output) + .with_bias(#bias) + .#init_line }; - use burn::nn::Linear; - use burn::nn::LinearConfig; - #[derive(Module, Debug)] - pub struct Model { - linear: Linear, - phantom: core::marker::PhantomData, + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let record = LinearRecord:: { + weight: Param::new( + ParamId::new(), + Tensor::from_data(self.data_weights.clone().convert()), + ), + bias: self + .data_bias + .as_ref() + .map(|bias| Param::new(ParamId::new(), Tensor::from_data(bias.clone().convert()))), + }; + + let item = Record::into_item::(record); + item.serialize(serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); } + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let linear = LinearConfig::new(128, 128) - .with_bias(true) - .init_with(record.linear); + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::Linear"); + imports.register("burn::nn::LinearConfig"); + } - Self { - linear, - phantom: core::marker::PhantomData, - } + fn into_node(self) -> Node { + Node::Linear(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{record::FullPrecisionSettings, tensor::Data}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(LinearNode::new( + "linear", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + None, + LinearConfig::new(128, 128), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::Linear; + use burn::nn::LinearConfig; + + #[derive(Module, Debug)] + pub struct Model { + linear: Linear, + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.linear.forward(input); - output + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let linear = LinearConfig::new(128, 128) + .with_bias(true) + .init_with(record.linear); + + Self { + linear, + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.linear.forward(input); + + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/matmul.rs b/burn-import/src/burn/node/matmul.rs index 077eb272ef..b7b1eea97b 100644 --- a/burn-import/src/burn/node/matmul.rs +++ b/burn-import/src/burn/node/matmul.rs @@ -6,93 +6,93 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct MatmulNode { - pub lhs: TensorType, - pub rhs: TensorType, - pub output: TensorType, + pub lhs: TensorType, + pub rhs: TensorType, + pub output: TensorType, } impl NodeCodegen for MatmulNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - - fn input_types(&self) -> Vec { - vec![ - Type::Tensor(self.lhs.clone()), - Type::Tensor(self.rhs.clone()), - ] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let lhs = scope.tensor_use_owned(&self.lhs, node_position); - let rhs = scope.tensor_use_owned(&self.rhs, node_position); - let output = &self.output.name; - - quote! { - let #output = #lhs.matmul(#rhs); + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] } - } - fn into_node(self) -> Node { - Node::Matmul(self) - } + fn input_types(&self) -> Vec { + vec![ + Type::Tensor(self.lhs.clone()), + Type::Tensor(self.rhs.clone()), + ] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let lhs = scope.tensor_use_owned(&self.lhs, node_position); + let rhs = scope.tensor_use_owned(&self.rhs, node_position); + let output = &self.output.name; + + quote! { + let #output = #lhs.matmul(#rhs); + } + } + + fn into_node(self) -> Node { + Node::Matmul(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{matmul::MatmulNode, test::assert_tokens}, - TensorType, - }; - - #[test] - fn test_codegen_two_nodes() { - let mut graph = BurnGraph::::default(); - - graph.register(MatmulNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor3", 4), - )); - - graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + use burn::record::FullPrecisionSettings; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{matmul::MatmulNode, test::assert_tokens}, + TensorType, + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + #[test] + fn test_codegen_two_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor3", 4), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.matmul(tensor2); + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } + } - tensor3 + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.matmul(tensor2); + + tensor3 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/max_pool2d.rs b/burn-import/src/burn/node/max_pool2d.rs index 61ba4845db..2bbf859bb9 100644 --- a/burn-import/src/burn/node/max_pool2d.rs +++ b/burn-import/src/burn/node/max_pool2d.rs @@ -8,155 +8,155 @@ use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; #[derive(Debug, Clone)] pub struct MaxPool2dNode { - pub field: OtherType, - pub input: TensorType, - pub output: TensorType, - pub config: MaxPool2dConfig, + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub config: MaxPool2dConfig, } impl MaxPool2dNode { - pub fn new>( - name: S, - input: TensorType, - output: TensorType, - config: MaxPool2dConfig, - ) -> Self { - Self { - field: OtherType::new( - name, - quote! { - MaxPool2d - }, - ), - input, - output, - config, + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + config: MaxPool2dConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + MaxPool2d + }, + ), + input, + output, + config, + } } - } } impl NodeCodegen for MaxPool2dNode { - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } - fn field_type(&self) -> Option { - Some(Type::Other(self.field.clone())) - } - - fn field_init(&self, _with_record: bool) -> Option { - let name = &self.field.name; - let kernel_size = self.config.kernel_size.to_tokens(); - let strides = self.config.strides.to_tokens(); - let padding = self.config.padding.to_tokens(); - let dilation = self.config.dilation.to_tokens(); - - let init_line = quote! { - init(); - }; + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } - let tokens = quote! { - let #name = MaxPool2dConfig::new(#kernel_size) - .with_strides(#strides) - .with_padding(#padding) - .with_dilation(#dilation) - .#init_line - }; + fn field_init(&self, _with_record: bool) -> Option { + let name = &self.field.name; + let kernel_size = self.config.kernel_size.to_tokens(); + let strides = self.config.strides.to_tokens(); + let padding = self.config.padding.to_tokens(); + let dilation = self.config.dilation.to_tokens(); - Some(tokens) - } + let init_line = quote! { + init(); + }; - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let field = &self.field.name; + let tokens = quote! { + let #name = MaxPool2dConfig::new(#kernel_size) + .with_strides(#strides) + .with_padding(#padding) + .with_dilation(#dilation) + .#init_line + }; - quote! { - let #output = self.#field.forward(#input); + Some(tokens) } - } - fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PaddingConfig2d"); - imports.register("burn::nn::pool::MaxPool2d"); - imports.register("burn::nn::pool::MaxPool2dConfig"); - } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; - fn into_node(self) -> Node { - Node::MaxPool2d(self) - } + quote! { + let #output = self.#field.forward(#input); + } + } + + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PaddingConfig2d"); + imports.register("burn::nn::pool::MaxPool2d"); + imports.register("burn::nn::pool::MaxPool2dConfig"); + } - fn field_serialize(&self, serializer: S) -> Result { - S::serialize_none(serializer) - } + fn into_node(self) -> Node { + Node::MaxPool2d(self) + } + + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{max_pool2d::MaxPool2dNode, test::assert_tokens}, - TensorType, - }; - use burn::{nn::pool::MaxPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; - - #[test] - fn test_codegen() { - let mut graph = BurnGraph::::default(); - - graph.register(MaxPool2dNode::new( - "max_pool2d", - TensorType::new_float("input", 4), - TensorType::new_float("output", 4), - MaxPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]), - )); - - graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; - use burn::nn::PaddingConfig2d; - use burn::nn::pool::MaxPool2d; - use burn::nn::pool::MaxPool2dConfig; - - #[derive(Module, Debug)] - pub struct Model { - max_pool2d: MaxPool2d, - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{max_pool2d::MaxPool2dNode, test::assert_tokens}, + TensorType, + }; + use burn::{nn::pool::MaxPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(MaxPool2dNode::new( + "max_pool2d", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + MaxPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + use burn::nn::PaddingConfig2d; + use burn::nn::pool::MaxPool2d; + use burn::nn::pool::MaxPool2dConfig; + + #[derive(Module, Debug)] + pub struct Model { + max_pool2d: MaxPool2d, + phantom: core::marker::PhantomData, + } - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - let max_pool2d = MaxPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .init(); - - Self { - max_pool2d, - phantom: core::marker::PhantomData, + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + let max_pool2d = MaxPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Valid) + .with_dilation([1, 1]) + .init(); + + Self { + max_pool2d, + phantom: core::marker::PhantomData, + } } - } - #[allow(clippy::let_and_return)] - pub fn forward(&self, input: Tensor) -> Tensor { - let output = self.max_pool2d.forward(input); + #[allow(clippy::let_and_return)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.max_pool2d.forward(input); - output + output + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/reshape.rs b/burn-import/src/burn/node/reshape.rs index 032f1a3c33..df8959e90b 100644 --- a/burn-import/src/burn/node/reshape.rs +++ b/burn-import/src/burn/node/reshape.rs @@ -6,85 +6,85 @@ use quote::quote; #[derive(Debug, Clone, new)] pub struct ReshapeNode { - pub input: TensorType, - pub output: TensorType, - pub shape: Vec, + pub input: TensorType, + pub output: TensorType, + pub shape: Vec, } impl NodeCodegen for ReshapeNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] - } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } - fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] - } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let shape_values = &self.shape.to_tokens(); + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let shape_values = &self.shape.to_tokens(); - quote! { - let #output = #input.reshape(#shape_values); + quote! { + let #output = #input.reshape(#shape_values); + } } - } - fn into_node(self) -> Node { - Node::Reshape(self) - } + fn into_node(self) -> Node { + Node::Reshape(self) + } } #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; - - use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{reshape::ReshapeNode, test::assert_tokens}, - TensorType, - }; - - #[test] - fn test_codegen_nodes() { - let mut graph = BurnGraph::::default(); - - graph.register(ReshapeNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - [4, 4, 4, 4].into(), - )); - - graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); - - let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + use burn::record::FullPrecisionSettings; - #[derive(Module, Debug)] - pub struct Model { - phantom: core::marker::PhantomData, - } + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{reshape::ReshapeNode, test::assert_tokens}, + TensorType, + }; - impl Model { - #[allow(unused_variables)] - pub fn new_with(record: ModelRecord) -> Self { - Self { - phantom: core::marker::PhantomData, - } + #[test] + fn test_codegen_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(ReshapeNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + [4, 4, 4, 4].into(), + )); + + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, } - #[allow(clippy::let_and_return)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.reshape([4, 4, 4, 4]); - tensor2 + impl Model { + #[allow(unused_variables)] + pub fn new_with(record: ModelRecord) -> Self { + Self { + phantom: core::marker::PhantomData, + } + } + #[allow(clippy::let_and_return)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.reshape([4, 4, 4, 4]); + + tensor2 + } } - } - }; + }; - assert_tokens(graph.codegen(), expected); - } + assert_tokens(graph.codegen(), expected); + } } diff --git a/burn-import/src/burn/node/test.rs b/burn-import/src/burn/node/test.rs index 33fb5dd530..903248ec98 100644 --- a/burn-import/src/burn/node/test.rs +++ b/burn-import/src/burn/node/test.rs @@ -3,8 +3,8 @@ use proc_macro2::TokenStream; #[track_caller] pub fn assert_tokens(tokens1: TokenStream, tokens2: TokenStream) { - let tokens1 = format_tokens(tokens1); - let tokens2 = format_tokens(tokens2); + let tokens1 = format_tokens(tokens1); + let tokens2 = format_tokens(tokens2); - pretty_assertions::assert_eq!(tokens1, tokens2); + pretty_assertions::assert_eq!(tokens1, tokens2); } diff --git a/burn-import/src/burn/node/unary.rs b/burn-import/src/burn/node/unary.rs index d7fc3da925..7f557b785c 100644 --- a/burn-import/src/burn/node/unary.rs +++ b/burn-import/src/burn/node/unary.rs @@ -11,386 +11,386 @@ type FnPointer = Rc TokenStream>; /// Node for all unary operators. #[derive(Clone, new)] pub struct UnaryNode { - pub input: Type, - pub output: Type, - pub kind: UnaryNodeKind, - function: FnPointer, + pub input: Type, + pub output: Type, + pub kind: UnaryNodeKind, + function: FnPointer, } /// Type of unary node. #[derive(Clone)] pub enum UnaryNodeKind { - Cast, - Erf, - Flatten, - LogSoftmax, - Softmax, - Relu, - Reciprocal, - Sigmoid, - Tanh, - Transpose, + Cast, + Erf, + Flatten, + LogSoftmax, + Softmax, + Relu, + Reciprocal, + Sigmoid, + Tanh, + Transpose, } impl UnaryNodeKind { - pub fn as_str(&self) -> &str { - match self { - Self::Cast => "cast", - Self::Erf => "erf", - Self::Flatten => "flatten", - Self::LogSoftmax => "log_softmax", - Self::Softmax => "softmax", - Self::Relu => "relu", - Self::Reciprocal => "reciprocal", - Self::Sigmoid => "sigmoid", - Self::Tanh => "tanh", - Self::Transpose => "transpose", + pub fn as_str(&self) -> &str { + match self { + Self::Cast => "cast", + Self::Erf => "erf", + Self::Flatten => "flatten", + Self::LogSoftmax => "log_softmax", + Self::Softmax => "softmax", + Self::Relu => "relu", + Self::Reciprocal => "reciprocal", + Self::Sigmoid => "sigmoid", + Self::Tanh => "tanh", + Self::Transpose => "transpose", + } } - } } impl std::fmt::Debug for UnaryNode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str( - format!( - "UnaryNode {{ input: {:?}, output: {:?}, name: {} }}", - self.input, - self.output, - self.kind.as_str() - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + format!( + "UnaryNode {{ input: {:?}, output: {:?}, name: {} }}", + self.input, + self.output, + self.kind.as_str() + ) + .as_str(), + ) + } } impl NodeCodegen for UnaryNode { - fn output_types(&self) -> Vec { - vec![self.output.clone()] - } - - fn input_types(&self) -> Vec { - vec![self.input.clone()] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - // Get the lhs name in the form of token stream. - let input = match &self.input { - Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), - Type::Scalar(scalar) => { - let name = scalar.name.clone(); - quote! { #name } - } - _ => panic!("lhs must be a tensor or scalar"), - }; - - // let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name(); - let function = (self.function)(input); - - quote! { - let #output = #function; + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + vec![self.input.clone()] } - } - fn into_node(self) -> Node { - Node::Unary(self) - } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + // Get the lhs name in the form of token stream. + let input = match &self.input { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("lhs must be a tensor or scalar"), + }; + + // let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name(); + let function = (self.function)(input); + + quote! { + let #output = #function; + } + } + + fn into_node(self) -> Node { + Node::Unary(self) + } } impl UnaryNode { - pub(crate) fn erf(input: Type, output: Type) -> Self { - let function = move |input| quote! { #input.erf() }; - Self::new(input, output, UnaryNodeKind::Erf, Rc::new(function)) - } - - pub(crate) fn flatten(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self { - let start_dim = start_dim.to_tokens(); - let end_dim = end_dim.to_tokens(); - let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) }; - - Self::new(input, output, UnaryNodeKind::Flatten, Rc::new(function)) - } - - pub(crate) fn relu(input: Type, output: Type) -> Self { - let function = move |input| quote! { burn::tensor::activation::relu(#input) }; - Self::new(input, output, UnaryNodeKind::Relu, Rc::new(function)) - } - - pub(crate) fn sigmoid(input: Type, output: Type) -> Self { - let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) }; - Self::new(input, output, UnaryNodeKind::Sigmoid, Rc::new(function)) - } - - pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self { - let dim = dim.to_tokens(); - let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) }; - Self::new(input, output, UnaryNodeKind::LogSoftmax, Rc::new(function)) - } - - pub(crate) fn softmax(input: Type, output: Type, dim: usize) -> Self { - let dim = dim.to_tokens(); - let function = move |input| quote! { burn::tensor::activation::softmax(#input, #dim) }; - Self::new(input, output, UnaryNodeKind::Softmax, Rc::new(function)) - } - - pub(crate) fn tanh(input: Type, output: Type) -> Self { - let function = move |input| quote! { burn::tensor::activation::tanh(#input)}; - Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function)) - } - - pub(crate) fn transpose(input: Type, output: Type) -> Self { - let function = move |input| quote! { #input.transpose() }; - Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function)) - } - - pub(crate) fn reciprocal(input: Type, output: Type) -> Self { - let function = move |input| quote! { #input.recip() }; - Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function)) - } - - /// Casts the input to the output type. - /// - /// Currently this function only supports the following conversions: - /// 1) scalar -> scalar - /// - /// TODO: Implement the following conversions: - /// 2) tensor int -> tensor float - /// 3) tensor float -> tensor int - /// 4) tensor -> scalar - /// 5) scalar -> tensor - pub(crate) fn cast(input: Type, output: Type) -> Self { - let function = match output.clone() { - Type::Scalar(scalar) => { - let ty = scalar.ty(); - move |input| quote! { #input as #ty } - } - Type::Tensor(_tensor) => { - // TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023) - // TODO: If the input is scalar and the output type is a tensor, - // we should generate another code block. (@antimora 8/4/2023) - // Tensor::from_data(Data::from([#input]).convert()).unsqueeze(); - todo!() - } - - _ => panic!("output must be a tensor"), - }; - - Self::new(input, output, UnaryNodeKind::Cast, Rc::new(function)) - } + pub(crate) fn erf(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.erf() }; + Self::new(input, output, UnaryNodeKind::Erf, Rc::new(function)) + } + + pub(crate) fn flatten(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self { + let start_dim = start_dim.to_tokens(); + let end_dim = end_dim.to_tokens(); + let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) }; + + Self::new(input, output, UnaryNodeKind::Flatten, Rc::new(function)) + } + + pub(crate) fn relu(input: Type, output: Type) -> Self { + let function = move |input| quote! { burn::tensor::activation::relu(#input) }; + Self::new(input, output, UnaryNodeKind::Relu, Rc::new(function)) + } + + pub(crate) fn sigmoid(input: Type, output: Type) -> Self { + let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) }; + Self::new(input, output, UnaryNodeKind::Sigmoid, Rc::new(function)) + } + + pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self { + let dim = dim.to_tokens(); + let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) }; + Self::new(input, output, UnaryNodeKind::LogSoftmax, Rc::new(function)) + } + + pub(crate) fn softmax(input: Type, output: Type, dim: usize) -> Self { + let dim = dim.to_tokens(); + let function = move |input| quote! { burn::tensor::activation::softmax(#input, #dim) }; + Self::new(input, output, UnaryNodeKind::Softmax, Rc::new(function)) + } + + pub(crate) fn tanh(input: Type, output: Type) -> Self { + let function = move |input| quote! { burn::tensor::activation::tanh(#input)}; + Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function)) + } + + pub(crate) fn transpose(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.transpose() }; + Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function)) + } + + pub(crate) fn reciprocal(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.recip() }; + Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function)) + } + + /// Casts the input to the output type. + /// + /// Currently this function only supports the following conversions: + /// 1) scalar -> scalar + /// + /// TODO: Implement the following conversions: + /// 2) tensor int -> tensor float + /// 3) tensor float -> tensor int + /// 4) tensor -> scalar + /// 5) scalar -> tensor + pub(crate) fn cast(input: Type, output: Type) -> Self { + let function = match output.clone() { + Type::Scalar(scalar) => { + let ty = scalar.ty(); + move |input| quote! { #input as #ty } + } + Type::Tensor(_tensor) => { + // TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023) + // TODO: If the input is scalar and the output type is a tensor, + // we should generate another code block. (@antimora 8/4/2023) + // Tensor::from_data(Data::from([#input]).convert()).unsqueeze(); + todo!() + } + + _ => panic!("output must be a tensor"), + }; + + Self::new(input, output, UnaryNodeKind::Cast, Rc::new(function)) + } } #[cfg(test)] mod tests { - use super::*; - use crate::burn::node::tests::one_node_graph; - use crate::burn::{ScalarKind, ScalarType, TensorType}; - - #[test] - fn test_unary_codegen_flatten() { - one_node_graph( - UnaryNode::flatten( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - 1, - 2, - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.flatten(1, 2); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_erf() { - one_node_graph( - UnaryNode::erf( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.erf(); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_relu() { - one_node_graph( - UnaryNode::relu( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::relu(tensor1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_sigmoid() { - one_node_graph( - UnaryNode::sigmoid( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::sigmoid(tensor1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_log_softmax() { - one_node_graph( - UnaryNode::log_softmax( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - 1, - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_softmax() { - one_node_graph( - UnaryNode::softmax( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - 1, - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::softmax(tensor1, 1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_tanh() { - one_node_graph( - UnaryNode::tanh( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = burn::tensor::activation::tanh(tensor1); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_transpose() { - one_node_graph( - UnaryNode::transpose( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.transpose(); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_reciprocal() { - one_node_graph( - UnaryNode::reciprocal( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.recip(); - - tensor2 - } - }, - vec!["tensor1".to_string()], - vec!["tensor2".to_string()], - ); - } - - #[test] - fn test_unary_codegen_cast() { - one_node_graph( - UnaryNode::cast( - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)), - Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), - ), - quote! { - pub fn forward(&self, scalar1: f64) -> f32 { - let scalar2 = scalar1 as f32; - - scalar2 - } - }, - vec!["scalar1".to_string()], - vec!["scalar2".to_string()], - ); - one_node_graph( - UnaryNode::cast( - Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), - Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), - ), - quote! { - pub fn forward(&self, scalar1: f32) -> f64 { - let scalar2 = scalar1 as f64; - - scalar2 - } - }, - vec!["scalar1".to_string()], - vec!["scalar2".to_string()], - ); - } + use super::*; + use crate::burn::node::tests::one_node_graph; + use crate::burn::{ScalarKind, ScalarType, TensorType}; + + #[test] + fn test_unary_codegen_flatten() { + one_node_graph( + UnaryNode::flatten( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + 1, + 2, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.flatten(1, 2); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_erf() { + one_node_graph( + UnaryNode::erf( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.erf(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_relu() { + one_node_graph( + UnaryNode::relu( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::relu(tensor1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_sigmoid() { + one_node_graph( + UnaryNode::sigmoid( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::sigmoid(tensor1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_log_softmax() { + one_node_graph( + UnaryNode::log_softmax( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + 1, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_softmax() { + one_node_graph( + UnaryNode::softmax( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + 1, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::softmax(tensor1, 1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_tanh() { + one_node_graph( + UnaryNode::tanh( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::tanh(tensor1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_transpose() { + one_node_graph( + UnaryNode::transpose( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.transpose(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_reciprocal() { + one_node_graph( + UnaryNode::reciprocal( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.recip(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_cast() { + one_node_graph( + UnaryNode::cast( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), + ), + quote! { + pub fn forward(&self, scalar1: f64) -> f32 { + let scalar2 = scalar1 as f32; + + scalar2 + } + }, + vec!["scalar1".to_string()], + vec!["scalar2".to_string()], + ); + one_node_graph( + UnaryNode::cast( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), + ), + quote! { + pub fn forward(&self, scalar1: f32) -> f64 { + let scalar2 = scalar1 as f64; + + scalar2 + } + }, + vec!["scalar1".to_string()], + vec!["scalar2".to_string()], + ); + } } diff --git a/burn-import/src/burn/scope.rs b/burn-import/src/burn/scope.rs index 1ceaa0b972..9e497c09aa 100644 --- a/burn-import/src/burn/scope.rs +++ b/burn-import/src/burn/scope.rs @@ -7,78 +7,78 @@ use std::collections::HashMap; /// The scope struct ensures that ownership rules are respected during the forward pass. #[derive(Clone, Debug, Default)] pub struct Scope { - variables: HashMap>, + variables: HashMap>, } #[derive(Clone, Debug, new)] struct TensorVariable { - references: usize, - node_position: usize, + references: usize, + node_position: usize, } impl Scope { - /// Declare a new tensor variable. - pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - for variable in variables.iter_mut() { - if variable.node_position == node_position { - variable.references += 1; - return; - } - } + /// Declare a new tensor variable. + pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) { + if let Some(variables) = self.variables.get_mut(&tensor.name) { + for variable in variables.iter_mut() { + if variable.node_position == node_position { + variable.references += 1; + return; + } + } - variables.push(TensorVariable::new(0, node_position)); - } else { - self.variables.insert( - tensor.name.clone(), - vec![TensorVariable::new(0, node_position)], - ); + variables.push(TensorVariable::new(0, node_position)); + } else { + self.variables.insert( + tensor.name.clone(), + vec![TensorVariable::new(0, node_position)], + ); + } } - } - /// Register a future use of a tensor variable. - /// - /// # Notes - /// - /// We need to know all futures use of a variable in advance. - pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - for variable in variables.iter_mut().rev() { - if node_position >= variable.node_position { - variable.references += 1; - break; + /// Register a future use of a tensor variable. + /// + /// # Notes + /// + /// We need to know all futures use of a variable in advance. + pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) { + if let Some(variables) = self.variables.get_mut(&tensor.name) { + for variable in variables.iter_mut().rev() { + if node_position >= variable.node_position { + variable.references += 1; + break; + } + } + } else { + panic!("No variable with name {}", tensor.name); } - } - } else { - panic!("No variable with name {}", tensor.name); } - } - /// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward. - pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - let mut count = 0; - let name = &tensor.name; + /// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward. + pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream { + if let Some(variables) = self.variables.get_mut(&tensor.name) { + let mut count = 0; + let name = &tensor.name; - for variable in variables.iter_mut().rev() { - if node_position >= variable.node_position { - variable.references -= 1; - count = variable.references; - break; - } - } + for variable in variables.iter_mut().rev() { + if node_position >= variable.node_position { + variable.references -= 1; + count = variable.references; + break; + } + } - if count > 0 { - quote! { - #name.clone() - } - } else { - quote! { - #name + if count > 0 { + quote! { + #name.clone() + } + } else { + quote! { + #name + } + } + } else { + panic!("No variable with name {}", &tensor.name); } - } - } else { - panic!("No variable with name {}", &tensor.name); } - } } diff --git a/burn-import/src/burn/ty.rs b/burn-import/src/burn/ty.rs index fb710f872f..292523ecc7 100644 --- a/burn-import/src/burn/ty.rs +++ b/burn-import/src/burn/ty.rs @@ -7,146 +7,146 @@ use crate::burn::ToTokens; #[derive(Debug, Clone)] pub struct TensorType { - pub name: Ident, - pub dim: usize, - pub kind: TensorKind, - pub shape: Option>, + pub name: Ident, + pub dim: usize, + pub kind: TensorKind, + pub shape: Option>, } #[derive(Debug, Clone, Copy)] pub enum TensorKind { - Int, - Float, - Bool, + Int, + Float, + Bool, } #[derive(Debug, Clone)] pub enum ScalarKind { - Int32, - Int64, - Float32, - Float64, - Bool, + Int32, + Int64, + Float32, + Float64, + Bool, } #[derive(Debug, Clone)] pub struct ScalarType { - pub name: Ident, - pub kind: ScalarKind, + pub name: Ident, + pub kind: ScalarKind, } #[derive(Debug, Clone)] pub struct OtherType { - pub name: Ident, - pub ty: TokenStream, + pub name: Ident, + pub ty: TokenStream, } #[derive(Debug, Clone)] pub enum Type { - /// Tensor type. - Tensor(TensorType), + /// Tensor type. + Tensor(TensorType), - /// Scalar type. - Scalar(ScalarType), + /// Scalar type. + Scalar(ScalarType), - // Other type (more flexible type). - Other(OtherType), + // Other type (more flexible type). + Other(OtherType), } impl Type { - pub fn name(&self) -> &Ident { - match self { - Type::Tensor(tensor) => &tensor.name, - Type::Scalar(scalar) => &scalar.name, - Type::Other(other) => &other.name, + pub fn name(&self) -> &Ident { + match self { + Type::Tensor(tensor) => &tensor.name, + Type::Scalar(scalar) => &scalar.name, + Type::Other(other) => &other.name, + } } - } - pub fn ty(&self) -> TokenStream { - match self { - Type::Tensor(tensor) => tensor.ty(), - Type::Scalar(scalar) => scalar.ty(), - Type::Other(other) => other.ty(), + pub fn ty(&self) -> TokenStream { + match self { + Type::Tensor(tensor) => tensor.ty(), + Type::Scalar(scalar) => scalar.ty(), + Type::Other(other) => other.ty(), + } } - } } impl ScalarType { - pub fn new>(name: S, kind: ScalarKind) -> Self { - Self { - name: Ident::new(name.as_ref(), Span::call_site()), - kind, + pub fn new>(name: S, kind: ScalarKind) -> Self { + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + kind, + } } - } - pub fn ty(&self) -> TokenStream { - match self.kind { - ScalarKind::Int32 => quote! { i32 }, - ScalarKind::Int64 => quote! { i64 }, - ScalarKind::Float32 => quote! { f32 }, - ScalarKind::Float64 => quote! { f64 }, - ScalarKind::Bool => quote! { bool }, + pub fn ty(&self) -> TokenStream { + match self.kind { + ScalarKind::Int32 => quote! { i32 }, + ScalarKind::Int64 => quote! { i64 }, + ScalarKind::Float32 => quote! { f32 }, + ScalarKind::Float64 => quote! { f64 }, + ScalarKind::Bool => quote! { bool }, + } } - } } impl TensorType { - pub fn new>( - name: S, - dim: usize, - kind: TensorKind, - shape: Option>, - ) -> Self { - Self { - name: Ident::new(name.as_ref(), Span::call_site()), - dim, - kind, - shape, + pub fn new>( + name: S, + dim: usize, + kind: TensorKind, + shape: Option>, + ) -> Self { + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + dim, + kind, + shape, + } } - } - pub fn new_float>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Float, None) - } - - pub fn new_int>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Int, None) - } - - pub fn new_bool>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Bool, None) - } - - pub fn ty(&self) -> TokenStream { - let dim = self.dim.to_tokens(); - match self { - TensorType { - kind: TensorKind::Float, - .. - } => quote! { - Tensor - }, - TensorType { - kind: TensorKind::Int, - .. - } => quote! { - Tensor - }, - TensorType { - kind: TensorKind::Bool, - .. - } => quote! { - Tensor - }, + pub fn new_float>(name: S, dim: usize) -> Self { + Self::new(name, dim, TensorKind::Float, None) + } + + pub fn new_int>(name: S, dim: usize) -> Self { + Self::new(name, dim, TensorKind::Int, None) + } + + pub fn new_bool>(name: S, dim: usize) -> Self { + Self::new(name, dim, TensorKind::Bool, None) + } + + pub fn ty(&self) -> TokenStream { + let dim = self.dim.to_tokens(); + match self { + TensorType { + kind: TensorKind::Float, + .. + } => quote! { + Tensor + }, + TensorType { + kind: TensorKind::Int, + .. + } => quote! { + Tensor + }, + TensorType { + kind: TensorKind::Bool, + .. + } => quote! { + Tensor + }, + } } - } } impl OtherType { - pub fn new>(name: S, tokens: TokenStream) -> Self { - Self { - name: Ident::new(name.as_ref(), Span::call_site()), - ty: tokens, + pub fn new>(name: S, tokens: TokenStream) -> Self { + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + ty: tokens, + } + } + pub fn ty(&self) -> TokenStream { + self.ty.clone() } - } - pub fn ty(&self) -> TokenStream { - self.ty.clone() - } } diff --git a/burn-import/src/formatter.rs b/burn-import/src/formatter.rs index bba4cdaeff..adc8545c38 100644 --- a/burn-import/src/formatter.rs +++ b/burn-import/src/formatter.rs @@ -3,15 +3,15 @@ use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt}; /// Formats a token stream into a string. pub fn format_tokens(tokens: TokenStream) -> String { - let fmt = code_formatter(); + let fmt = code_formatter(); - fmt.format_tokens(tokens).expect("Valid token tree") + fmt.format_tokens(tokens).expect("Valid token tree") } fn code_formatter() -> RustFmt { - let config = Config::new_str() - .post_proc(PostProcess::ReplaceMarkersAndDocBlocks) - .edition(Edition::Rust2021); + let config = Config::new_str() + .post_proc(PostProcess::ReplaceMarkersAndDocBlocks) + .edition(Edition::Rust2021); - RustFmt::from_config(config) + RustFmt::from_config(config) } diff --git a/burn-import/src/logger.rs b/burn-import/src/logger.rs index c5a279ef99..3378f17401 100644 --- a/burn-import/src/logger.rs +++ b/burn-import/src/logger.rs @@ -2,22 +2,22 @@ use std::error::Error; use tracing_core::LevelFilter; pub fn init_log() -> Result<(), Box> { - let result = tracing_subscriber::fmt() - .with_max_level(LevelFilter::DEBUG) - .without_time() - .try_init(); + let result = tracing_subscriber::fmt() + .with_max_level(LevelFilter::DEBUG) + .without_time() + .try_init(); - if result.is_ok() { - update_panic_hook(); - } - result + if result.is_ok() { + update_panic_hook(); + } + result } fn update_panic_hook() { - let hook = std::panic::take_hook(); + let hook = std::panic::take_hook(); - std::panic::set_hook(Box::new(move |info| { - log::error!("PANIC => {}", info.to_string()); - hook(info); - })); + std::panic::set_hook(Box::new(move |info| { + log::error!("PANIC => {}", info.to_string()); + hook(info); + })); } diff --git a/burn-import/src/main.rs b/burn-import/src/main.rs index 4590e41b78..2601568250 100644 --- a/burn-import/src/main.rs +++ b/burn-import/src/main.rs @@ -2,16 +2,16 @@ use burn_import::onnx::{ModelGen, RecordType}; /// Takes an ONNX file and generates a model from it fn main() { - let onnx_file = std::env::args().nth(1).expect("No input file provided"); - let output_dir = std::env::args() - .nth(2) - .expect("No output directory provided"); + let onnx_file = std::env::args().nth(1).expect("No input file provided"); + let output_dir = std::env::args() + .nth(2) + .expect("No output directory provided"); - // Generate the model code from the ONNX file. - ModelGen::new() - .input(onnx_file.as_str()) - .development(true) - .record_type(RecordType::PrettyJson) - .out_dir(output_dir.as_str()) - .run_from_cli(); + // Generate the model code from the ONNX file. + ModelGen::new() + .input(onnx_file.as_str()) + .development(true) + .record_type(RecordType::PrettyJson) + .out_dir(output_dir.as_str()) + .run_from_cli(); } diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index fe08b05d79..623d7584f2 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -5,110 +5,110 @@ use crate::onnx::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. pub fn coalesce(nodes: &mut Vec) { - let mut iter_mut = nodes.iter_mut().peekable(); - let mut nodes_to_remove: Vec = vec![]; - while let Some(node) = iter_mut.next() { - match node.node_type { - NodeType::Gemm => convert_gemm_to_linear(node), - NodeType::MatMul => { - convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove); - } - _ => {} + let mut iter_mut = nodes.iter_mut().peekable(); + let mut nodes_to_remove: Vec = vec![]; + while let Some(node) = iter_mut.next() { + match node.node_type { + NodeType::Gemm => convert_gemm_to_linear(node), + NodeType::MatMul => { + convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove); + } + _ => {} + } } - } - // Remove nodes instructed by conversation functions - for node_to_remove in nodes_to_remove { - nodes.retain(|n| n.name != node_to_remove); - } + // Remove nodes instructed by conversation functions + for node_to_remove in nodes_to_remove { + nodes.retain(|n| n.name != node_to_remove); + } } /// This function converts a Gemm node into a Linear node /// /// PyTorch and other frameworks use Gemm node to represent Linear layer. fn convert_gemm_to_linear(node: &mut Node) { - if node.outputs.len() != 1 { - panic!("Gemm node must have 1 output"); - } - let straight_linear = match ( - node.attrs.get("alpha"), - node.attrs.get("beta"), - node.attrs.get("transB"), - ) { - ( - Some(AttributeValue::Float32(alpha)), - Some(AttributeValue::Float32(beta)), - Some(AttributeValue::Int64(trans_b)), - ) => *alpha == 1.0 && *beta == 1.0 && *trans_b == 1, - _ => false, - }; - - if straight_linear { - node.node_type = NodeType::Linear; - node.attrs.remove("alpha"); - node.attrs.remove("beta"); - node.attrs.remove("transB"); - - // Transpose the weights - transpose_linear_node_weights(node); - } else { - panic!("Full Gemm node not supported yet."); - } + if node.outputs.len() != 1 { + panic!("Gemm node must have 1 output"); + } + let straight_linear = match ( + node.attrs.get("alpha"), + node.attrs.get("beta"), + node.attrs.get("transB"), + ) { + ( + Some(AttributeValue::Float32(alpha)), + Some(AttributeValue::Float32(beta)), + Some(AttributeValue::Int64(trans_b)), + ) => *alpha == 1.0 && *beta == 1.0 && *trans_b == 1, + _ => false, + }; + + if straight_linear { + node.node_type = NodeType::Linear; + node.attrs.remove("alpha"); + node.attrs.remove("beta"); + node.attrs.remove("transB"); + + // Transpose the weights + transpose_linear_node_weights(node); + } else { + panic!("Full Gemm node not supported yet."); + } } // Transpose linear weights (required for Gemm -> Linear conversion) fn transpose_linear_node_weights(node: &mut Node) { - assert!( - node.inputs.len() > 1, - "Linear node must have at least 2 input" - ); - - assert!(node.inputs[1].value.is_some(), "Input must have a value"); - - let weight = node.inputs[1] - .clone() - .into_tensor() - .expect("Tensor input is expected"); - - assert_eq!(weight.dim, 2, "Weight must be a 2D tensor"); - - let shape = weight.shape.unwrap(); - - match weight.data.expect("Tensor must have data") { - Data::Float32s(data) => { - let data_t = transpose_flattened(data, shape[0], shape[1]); - node.inputs[1].value = Some(Data::Float32s(data_t)); - } - Data::Float64s(data) => { - let data_t = transpose_flattened(data, shape[0], shape[1]); - node.inputs[1].value = Some(Data::Float64s(data_t)); + assert!( + node.inputs.len() > 1, + "Linear node must have at least 2 input" + ); + + assert!(node.inputs[1].value.is_some(), "Input must have a value"); + + let weight = node.inputs[1] + .clone() + .into_tensor() + .expect("Tensor input is expected"); + + assert_eq!(weight.dim, 2, "Weight must be a 2D tensor"); + + let shape = weight.shape.unwrap(); + + match weight.data.expect("Tensor must have data") { + Data::Float32s(data) => { + let data_t = transpose_flattened(data, shape[0], shape[1]); + node.inputs[1].value = Some(Data::Float32s(data_t)); + } + Data::Float64s(data) => { + let data_t = transpose_flattened(data, shape[0], shape[1]); + node.inputs[1].value = Some(Data::Float64s(data_t)); + } + Data::Float16s(data) => { + let data_t = transpose_flattened(data, shape[0], shape[1]); + node.inputs[1].value = Some(Data::Float16s(data_t)); + } + _ => panic!("Only float types are supported for Linear node"), } - Data::Float16s(data) => { - let data_t = transpose_flattened(data, shape[0], shape[1]); - node.inputs[1].value = Some(Data::Float16s(data_t)); - } - _ => panic!("Only float types are supported for Linear node"), - } - let shape = Some(vec![shape[1], shape[0]]); // Transpose the shape - node.inputs[1].ty = ArgType::Tensor(TensorType { - shape, - elem_type: weight.elem_type, - dim: 2, - }); + let shape = Some(vec![shape[1], shape[0]]); // Transpose the shape + node.inputs[1].ty = ArgType::Tensor(TensorType { + shape, + elem_type: weight.elem_type, + dim: 2, + }); } fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec { - assert_eq!(matrix.len(), rows * cols, "Matrix must be flattened"); + assert_eq!(matrix.len(), rows * cols, "Matrix must be flattened"); - let mut transposed: Vec = vec![matrix[0]; matrix.len()]; + let mut transposed: Vec = vec![matrix[0]; matrix.len()]; - for i in 0..rows { - for j in 0..cols { - transposed[j * rows + i] = matrix[i * cols + j]; + for i in 0..rows { + for j in 0..cols { + transposed[j * rows + i] = matrix[i * cols + j]; + } } - } - transposed + transposed } /// This function converts a MatMul node into a Linear node if possible. @@ -118,65 +118,65 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec /// This function also converts the following Add node into a Linear node if possible. /// Add node is used to represent bias in PyTorch. fn convert_matmul_to_linear( - node: &mut Node, - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, + node: &mut Node, + iter_mut: &mut Peekable>, + nodes_to_remove: &mut Vec, ) { - if node.inputs.len() != 2 { - panic!("MatMul node must have 2 inputs"); - } - - // if the second input does not have a value, it is not a weight, then proceed to the next node - if node.inputs[1].value.is_none() { - return; - } - - // Check if the second input is a 2D tensor - if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { - assert_eq!(tensor_type.dim, 2, "Weight must be a 2D tensor"); - } else { - panic!("Tensor input is expected"); - } - - // Convert the node to Linear - node.node_type = NodeType::Linear; - - // Check the next node for potential conversion - if let Some(peek_node) = iter_mut.peek() { - if is_add_node_with_bias(peek_node, node) { - convert_and_remove_add_node(iter_mut, nodes_to_remove, node); + if node.inputs.len() != 2 { + panic!("MatMul node must have 2 inputs"); + } + + // if the second input does not have a value, it is not a weight, then proceed to the next node + if node.inputs[1].value.is_none() { + return; + } + + // Check if the second input is a 2D tensor + if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { + assert_eq!(tensor_type.dim, 2, "Weight must be a 2D tensor"); + } else { + panic!("Tensor input is expected"); + } + + // Convert the node to Linear + node.node_type = NodeType::Linear; + + // Check the next node for potential conversion + if let Some(peek_node) = iter_mut.peek() { + if is_add_node_with_bias(peek_node, node) { + convert_and_remove_add_node(iter_mut, nodes_to_remove, node); + } } - } } /// Helper function to check if the peeked node is an Add node with bias fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { - peek_node.node_type == NodeType::Add - && peek_node.inputs.len() == 2 - && ((peek_node.inputs[0].name == current_node.outputs[0].name - && peek_node.inputs[1].value.is_some()) - || (peek_node.inputs[1].name == current_node.outputs[0].name - && peek_node.inputs[0].value.is_some())) + peek_node.node_type == NodeType::Add + && peek_node.inputs.len() == 2 + && ((peek_node.inputs[0].name == current_node.outputs[0].name + && peek_node.inputs[1].value.is_some()) + || (peek_node.inputs[1].name == current_node.outputs[0].name + && peek_node.inputs[0].value.is_some())) } /// Helper function to convert and remove the Add node fn convert_and_remove_add_node( - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, - current_node: &mut Node, + iter_mut: &mut Peekable>, + nodes_to_remove: &mut Vec, + current_node: &mut Node, ) { - let bias_node = iter_mut.next().unwrap(); + let bias_node = iter_mut.next().unwrap(); - let bias_input = if bias_node.inputs[0].value.is_some() { - bias_node.inputs[0].clone() - } else { - bias_node.inputs[1].clone() - }; + let bias_input = if bias_node.inputs[0].value.is_some() { + bias_node.inputs[0].clone() + } else { + bias_node.inputs[1].clone() + }; - // Push the bias input and update the output name - current_node.inputs.push(bias_input); - current_node.outputs[0].name = bias_node.outputs[0].name.clone(); + // Push the bias input and update the output name + current_node.inputs.push(bias_input); + current_node.outputs[0].name = bias_node.outputs[0].name.clone(); - // Remove the Add node - nodes_to_remove.push(bias_node.name.clone()); + // Remove the Add node + nodes_to_remove.push(bias_node.name.clone()); } diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index a168a9a28f..ce2a9ce4f3 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -3,280 +3,279 @@ use std::collections::HashMap; use protobuf::Enum; use super::{ - ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, - op_configuration::flatten_config, - protos::tensor_proto::DataType, + ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, + op_configuration::flatten_config, + protos::tensor_proto::DataType, }; struct TensorDimUpdater { - arguments: HashMap, + arguments: HashMap, } impl TensorDimUpdater { - fn new(inputs: &[Argument]) -> Self { - let mut arguments: HashMap = HashMap::with_capacity(inputs.len()); + fn new(inputs: &[Argument]) -> Self { + let mut arguments: HashMap = HashMap::with_capacity(inputs.len()); - inputs.iter().for_each(|input| { - arguments.insert(input.name.clone(), input.clone()); - }); + inputs.iter().for_each(|input| { + arguments.insert(input.name.clone(), input.clone()); + }); - Self { arguments } - } - /// Update tensor inputs from the registered arguments and returns the number of input - /// updated. - fn update_tensor_inputs(&self, node: &mut Node) -> usize { - self.update_arguments(&mut node.inputs) - } - - /// Update the arguments struct from the node output tensors and return the number of output - /// updated. - fn update_tensor_outputs(&mut self, node: &Node) -> usize { - node - .outputs - .iter() - .map(|arg| { - self.arguments.insert(arg.name.clone(), arg.clone()); - }) - .count() - } - - fn update_arguments(&self, arguments: &mut [Argument]) -> usize { - arguments - .iter_mut() - .filter_map(|input| self.arguments.get(&input.name).map(|arg| (arg, input))) - .map(|(arg, input)| { - input.ty = arg.ty.clone(); - }) - .count() - } + Self { arguments } + } + /// Update tensor inputs from the registered arguments and returns the number of input + /// updated. + fn update_tensor_inputs(&self, node: &mut Node) -> usize { + self.update_arguments(&mut node.inputs) + } + + /// Update the arguments struct from the node output tensors and return the number of output + /// updated. + fn update_tensor_outputs(&mut self, node: &Node) -> usize { + node.outputs + .iter() + .map(|arg| { + self.arguments.insert(arg.name.clone(), arg.clone()); + }) + .count() + } + + fn update_arguments(&self, arguments: &mut [Argument]) -> usize { + arguments + .iter_mut() + .filter_map(|input| self.arguments.get(&input.name).map(|arg| (arg, input))) + .map(|(arg, input)| { + input.ty = arg.ty.clone(); + }) + .count() + } } /// Infer the dimension of each output tensor and update them. pub fn dim_inference( - nodes: &mut Vec, - graph_inputs: &Vec, - graph_outputs: &mut Vec, + nodes: &mut Vec, + graph_inputs: &Vec, + graph_outputs: &mut Vec, ) { - let mut updater = TensorDimUpdater::new(graph_inputs); - - for node in nodes.iter_mut() { - updater.update_tensor_inputs(node); - - match node.node_type { - NodeType::Conv1d => conv1d_update_outputs(node), - NodeType::Conv2d => conv2d_update_outputs(node), - NodeType::MaxPool2d => same_as_input(node), - NodeType::Linear => linear_update_outputs(node), - NodeType::Flatten => flatten_update_outputs(node), - NodeType::GatherElements => same_as_input(node), - NodeType::Relu => same_as_input(node), - NodeType::LogSoftmax => same_as_input(node), - NodeType::BatchNormalization => same_as_input(node), - NodeType::Add => same_as_input(node), - NodeType::Sub => same_as_input(node), - NodeType::Mul => same_as_input(node), - NodeType::Cast => cast_update_outputs(node), - NodeType::Div => same_as_input(node), - NodeType::Erf => same_as_input(node), - NodeType::Sqrt => same_as_input(node), - NodeType::Tanh => same_as_input(node), - NodeType::Reciprocal => same_as_input(node), - NodeType::Softmax => same_as_input(node), - NodeType::ReduceMean => mean_update_outputs(node), - NodeType::Constant => constant_update_outputs(node), - NodeType::Equal => equal_update_outputs(node), - NodeType::Shape => shape_update_outputs(node), - NodeType::Unsqueeze => unsqueeze_update_outputs(node), - NodeType::Sigmoid => same_as_input(node), - NodeType::Transpose => same_as_input(node), - NodeType::Concat => concat_update_outputs(node), - NodeType::Reshape => reshape_update_outputs(node), - NodeType::Dropout => same_as_input(node), - NodeType::GlobalAveragePool => same_as_input(node), - NodeType::AveragePool2d => same_as_input(node), - NodeType::Clip => same_as_input(node), - // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. - _ => temporary_pass_through_stub(node), + let mut updater = TensorDimUpdater::new(graph_inputs); + + for node in nodes.iter_mut() { + updater.update_tensor_inputs(node); + + match node.node_type { + NodeType::Conv1d => conv1d_update_outputs(node), + NodeType::Conv2d => conv2d_update_outputs(node), + NodeType::MaxPool2d => same_as_input(node), + NodeType::Linear => linear_update_outputs(node), + NodeType::Flatten => flatten_update_outputs(node), + NodeType::GatherElements => same_as_input(node), + NodeType::Relu => same_as_input(node), + NodeType::LogSoftmax => same_as_input(node), + NodeType::BatchNormalization => same_as_input(node), + NodeType::Add => same_as_input(node), + NodeType::Sub => same_as_input(node), + NodeType::Mul => same_as_input(node), + NodeType::Cast => cast_update_outputs(node), + NodeType::Div => same_as_input(node), + NodeType::Erf => same_as_input(node), + NodeType::Sqrt => same_as_input(node), + NodeType::Tanh => same_as_input(node), + NodeType::Reciprocal => same_as_input(node), + NodeType::Softmax => same_as_input(node), + NodeType::ReduceMean => mean_update_outputs(node), + NodeType::Constant => constant_update_outputs(node), + NodeType::Equal => equal_update_outputs(node), + NodeType::Shape => shape_update_outputs(node), + NodeType::Unsqueeze => unsqueeze_update_outputs(node), + NodeType::Sigmoid => same_as_input(node), + NodeType::Transpose => same_as_input(node), + NodeType::Concat => concat_update_outputs(node), + NodeType::Reshape => reshape_update_outputs(node), + NodeType::Dropout => same_as_input(node), + NodeType::GlobalAveragePool => same_as_input(node), + NodeType::AveragePool2d => same_as_input(node), + NodeType::Clip => same_as_input(node), + // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. + _ => temporary_pass_through_stub(node), + } + + updater.update_tensor_outputs(node); } - updater.update_tensor_outputs(node); - } - - updater.update_arguments(graph_outputs); + updater.update_arguments(graph_outputs); } fn constant_update_outputs(node: &mut Node) { - // Fix the tensor dimension of the output when the value is tensor - - let keys = [ - "value", - "value_float", - "value_floats", - "value_int", - "value_ints", - "value_string", - "value_strings", - "sparse_value", - ]; - - let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); - - node.outputs[0].ty = match matched_value { - Some(value) => match &value { - // The value is stored in an attribute - AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - dim: tensor.dim, - shape: tensor.shape.clone(), - }), - AttributeValue::Float32(_) => ArgType::Scalar(ElementType::Float32), - AttributeValue::Float32s(value) => ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - dim: 1, - shape: Some(vec![value.len()]), - }), - AttributeValue::Int64(_) => ArgType::Scalar(ElementType::Int64), - AttributeValue::Int64s(value) => ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - dim: 1, - shape: Some(vec![value.len()]), - }), - ty => panic!("Constant value of {:?} is not supported", ty), - }, - None => panic!("Constant node must have a value attribute"), - }; + // Fix the tensor dimension of the output when the value is tensor + + let keys = [ + "value", + "value_float", + "value_floats", + "value_int", + "value_ints", + "value_string", + "value_strings", + "sparse_value", + ]; + + let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); + + node.outputs[0].ty = match matched_value { + Some(value) => match &value { + // The value is stored in an attribute + AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + dim: tensor.dim, + shape: tensor.shape.clone(), + }), + AttributeValue::Float32(_) => ArgType::Scalar(ElementType::Float32), + AttributeValue::Float32s(value) => ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + dim: 1, + shape: Some(vec![value.len()]), + }), + AttributeValue::Int64(_) => ArgType::Scalar(ElementType::Int64), + AttributeValue::Int64s(value) => ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + dim: 1, + shape: Some(vec![value.len()]), + }), + ty => panic!("Constant value of {:?} is not supported", ty), + }, + None => panic!("Constant node must have a value attribute"), + }; } /// Infer the shape of the output tensor of a Conv2d node fn linear_update_outputs(node: &mut Node) { - // Extract the configuration of the linear layer (inputs are known) - let node_input = &node.inputs[0]; - let weight = &node.inputs[1]; - - // Calculate the output shape. Usually we do not use shapes, but since the input shape is - // known, we can calculate the output shape. - if let ArgType::Tensor(tensor) = node_input.clone().ty { - let mut tensor = tensor.clone(); - let mut shape = tensor.shape.clone().unwrap(); - - if let ArgType::Tensor(weight_tensor) = weight.clone().ty { - let last = shape.last_mut().unwrap(); - *last = *weight_tensor.shape.unwrap().first().unwrap(); + // Extract the configuration of the linear layer (inputs are known) + let node_input = &node.inputs[0]; + let weight = &node.inputs[1]; + + // Calculate the output shape. Usually we do not use shapes, but since the input shape is + // known, we can calculate the output shape. + if let ArgType::Tensor(tensor) = node_input.clone().ty { + let mut tensor = tensor.clone(); + let mut shape = tensor.shape.clone().unwrap(); + + if let ArgType::Tensor(weight_tensor) = weight.clone().ty { + let last = shape.last_mut().unwrap(); + *last = *weight_tensor.shape.unwrap().first().unwrap(); + } else { + panic!("Weight must be a tensor"); + } + + tensor.shape = Some(shape); + + // Update the output tensor + node.outputs[0].ty = ArgType::Tensor(tensor); } else { - panic!("Weight must be a tensor"); + panic!("Only tensor input is valid"); } - - tensor.shape = Some(shape); - - // Update the output tensor - node.outputs[0].ty = ArgType::Tensor(tensor); - } else { - panic!("Only tensor input is valid"); - } } /// Update the output type using "to" attribute fn cast_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Cast: multiple inputs are not supported"); - } - let output = &mut node.outputs[0]; - - // Extract cast type and update the output tensor - let elem_type = match node.attrs.get("to") { - Some(value) => match &value { - AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - _ => panic!("Cast: unsupported type"), - }, - _ => panic!("'to' attribute must be an Int64"), - }, - None => panic!("Constant node must have a value attribute"), - }; - - match output.ty.clone() { - ArgType::Tensor(tensor) => { - if tensor.dim == 0 { - // treat 0-dim tensor as scalar - output.ty = ArgType::Scalar(elem_type); - } else { - todo!("Cast: support casting from different tensor types"); - } + if node.inputs.len() != 1 { + panic!("Cast: multiple inputs are not supported"); } - ArgType::Scalar(_scalar) => { - output.ty = ArgType::Scalar(elem_type); + let output = &mut node.outputs[0]; + + // Extract cast type and update the output tensor + let elem_type = match node.attrs.get("to") { + Some(value) => match &value { + AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("Cast: unsupported type"), + }, + _ => panic!("'to' attribute must be an Int64"), + }, + None => panic!("Constant node must have a value attribute"), + }; + + match output.ty.clone() { + ArgType::Tensor(tensor) => { + if tensor.dim == 0 { + // treat 0-dim tensor as scalar + output.ty = ArgType::Scalar(elem_type); + } else { + todo!("Cast: support casting from different tensor types"); + } + } + ArgType::Scalar(_scalar) => { + output.ty = ArgType::Scalar(elem_type); + } + _ => panic!("Cast: only scalar input is valid"), } - _ => panic!("Cast: only scalar input is valid"), - } } fn concat_update_outputs(node: &mut Node) { - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor), - _ => None, - }) - .unwrap(); - - node.outputs[0].ty = ArgType::Tensor(tensor.clone()); + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor), + _ => None, + }) + .unwrap(); + + node.outputs[0].ty = ArgType::Tensor(tensor.clone()); } fn reshape_update_outputs(node: &mut Node) { - assert_eq!(node.inputs.len(), 2); - - let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value { - shape - } else { - panic!("Reshape: int64s shape is expected per ONNX spec"); - }; - - // The output dimension is the same as the shape length - let dim = shape.len(); - let elem_type = match node.inputs[0].ty.clone() { - ArgType::Tensor(tensor) => tensor.elem_type, - _ => panic!("Reshape: invalid input type"), - }; - - let shape = shape.iter().map(|&dim| dim as usize).collect(); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - dim, - shape: Some(shape), - }); + assert_eq!(node.inputs.len(), 2); + + let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value { + shape + } else { + panic!("Reshape: int64s shape is expected per ONNX spec"); + }; + + // The output dimension is the same as the shape length + let dim = shape.len(); + let elem_type = match node.inputs[0].ty.clone() { + ArgType::Tensor(tensor) => tensor.elem_type, + _ => panic!("Reshape: invalid input type"), + }; + + let shape = shape.iter().map(|&dim| dim as usize).collect(); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + dim, + shape: Some(shape), + }); } fn mean_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Mean: multiple inputs are not supported"); - } - - // Extract the configuration of the linear layer (inputs are known) - let node_input = &mut node.inputs[0]; - let tensor = match node_input.clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - if dim_only { - node.outputs[0].ty = ArgType::Tensor(tensor); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor }); - } + if node.inputs.len() != 1 { + panic!("Mean: multiple inputs are not supported"); + } + + // Extract the configuration of the linear layer (inputs are known) + let node_input = &mut node.inputs[0]; + let tensor = match node_input.clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + if dim_only { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor }); + } } /// Infers the shape of a Unsqueeze node and replaces the shape of the output tensor. /// @@ -284,116 +283,116 @@ fn mean_update_outputs(node: &mut Node) { /// /// Unsqueeze is not implemented fully. This is left WIP from the past. fn unsqueeze_update_outputs(node: &mut Node) { - let node_input = node - .inputs - .first_mut() - .expect("Unsqueeze: an input is required"); - - if let ArgType::Tensor(tensor) = &mut node_input.ty { - tensor.dim += 1; - - // add a new dimension to the input tensor by extending the shape - // TODO: support unsqueezing configurations - if let Some(shape) = &mut tensor.shape { - shape.insert(0, 1); + let node_input = node + .inputs + .first_mut() + .expect("Unsqueeze: an input is required"); + + if let ArgType::Tensor(tensor) = &mut node_input.ty { + tensor.dim += 1; + + // add a new dimension to the input tensor by extending the shape + // TODO: support unsqueezing configurations + if let Some(shape) = &mut tensor.shape { + shape.insert(0, 1); + } else { + todo!("Unsqueeze: support unsqueezing a tensor without shape"); + } + + node.outputs[0].ty = ArgType::Tensor(tensor.clone()); } else { - todo!("Unsqueeze: support unsqueezing a tensor without shape"); + panic!("Only tensor input is valid"); } - - node.outputs[0].ty = ArgType::Tensor(tensor.clone()); - } else { - panic!("Only tensor input is valid"); - } } fn same_as_input(node: &mut Node) { - node.outputs[0].ty = node.inputs[0].ty.clone(); + node.outputs[0].ty = node.inputs[0].ty.clone(); } /// Temporary pass-through stub for dimension inference so that we can export the IR model. fn temporary_pass_through_stub(node: &mut Node) { - log::warn!( - "Must implement dimension inference for {:?}", - node.node_type - ); + log::warn!( + "Must implement dimension inference for {:?}", + node.node_type + ); } fn equal_update_outputs(node: &mut Node) { - let input1_type = node.inputs[0].ty.clone(); - - match input1_type { - ArgType::Tensor(tensor) => { - // if the input is a tensor, the output is a tensor of bool - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - ..tensor - }); + let input1_type = node.inputs[0].ty.clone(); + + match input1_type { + ArgType::Tensor(tensor) => { + // if the input is a tensor, the output is a tensor of bool + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + ..tensor + }); + } + ArgType::Scalar(_) => { + node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); + } + _ => panic!("Only tensor input is valid"), } - ArgType::Scalar(_) => { - node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); - } - _ => panic!("Only tensor input is valid"), - } } fn shape_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Gather: multiple inputs are not supported: {:?}", node); - } - - // Extract the configuration of the linear layer (inputs are known) - let node_input = &mut node.inputs[0]; - if let ArgType::Tensor(tensor) = node_input.clone().ty { - // Update the output tensor - node.outputs[0].ty = ArgType::Shape(tensor.dim); - } else { - panic!("Only tensor input is valid"); - } + if node.inputs.len() != 1 { + panic!("Gather: multiple inputs are not supported: {:?}", node); + } + + // Extract the configuration of the linear layer (inputs are known) + let node_input = &mut node.inputs[0]; + if let ArgType::Tensor(tensor) = node_input.clone().ty { + // Update the output tensor + node.outputs[0].ty = ArgType::Shape(tensor.dim); + } else { + panic!("Only tensor input is valid"); + } } /// Infers the shape of a Flatten node and replaces the shape of the output tensor. fn flatten_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Flatten: multiple inputs are not supported"); - } - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor), - _ => None, - }) - .unwrap(); - - let input_dim = tensor.dim; - - let (start_dim, end_dim) = flatten_config(node); - - let collapsed_dims = end_dim - start_dim; - let output_dim = input_dim - collapsed_dims; - - node.outputs[0].ty = ArgType::Tensor(TensorType { - dim: output_dim, - ..tensor.clone() - }); + if node.inputs.len() != 1 { + panic!("Flatten: multiple inputs are not supported"); + } + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor), + _ => None, + }) + .unwrap(); + + let input_dim = tensor.dim; + + let (start_dim, end_dim) = flatten_config(node); + + let collapsed_dims = end_dim - start_dim; + let output_dim = input_dim - collapsed_dims; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim: output_dim, + ..tensor.clone() + }); } /// Infers the shape of a Conv1d node and replaces the shape of the output tensor. fn conv1d_update_outputs(node: &mut Node) { - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { - node.outputs[0].ty = ArgType::Tensor(tensor); - } else { - panic!("Only tensor input is valid"); - } + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + panic!("Only tensor input is valid"); + } } /// Infers the shape of a Conv2d node and replaces the shape of the output tensor. fn conv2d_update_outputs(node: &mut Node) { - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { - node.outputs[0].ty = ArgType::Tensor(tensor); - } else { - panic!("Only tensor input is valid"); - } + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + panic!("Only tensor input is valid"); + } } diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index d8a2354839..6b9db5b01b 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -1,12 +1,12 @@ use std::{ - collections::{HashMap, HashSet}, - fs::File, - path::Path, + collections::{HashMap, HashSet}, + fs::File, + path::Path, }; use crate::onnx::{ - coalesce::coalesce, ir::TensorType, node_remap::remap_node_type, - proto_conversion::convert_node_proto, + coalesce::coalesce, ir::TensorType, node_remap::remap_node_type, + proto_conversion::convert_node_proto, }; use super::dim_inference::dim_inference; @@ -16,12 +16,12 @@ use super::protos::{ModelProto, TensorProto}; use protobuf::Message; const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 6] = [ - NodeType::BatchNormalization, - NodeType::Clip, - NodeType::Conv1d, - NodeType::Conv2d, - NodeType::Dropout, - NodeType::Reshape, + NodeType::BatchNormalization, + NodeType::Clip, + NodeType::Conv1d, + NodeType::Conv2d, + NodeType::Dropout, + NodeType::Reshape, ]; /// Open an onnx file and convert it to a Graph (intermediate representation) @@ -40,83 +40,83 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 6] = [ /// * If the file cannot be parsed /// * If the nodes are not topologically sorted pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { - log::info!("Parsing ONNX file: {}", onnx_path.display()); - - // Open the file - let mut file = File::open(onnx_path).expect("Unable to open file"); - let onnx_model: ModelProto = - Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); - - log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); - log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); - - log::debug!( - "Number of initializers: {:?}", - onnx_model.graph.initializer.len() - ); - - log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); - - // Convert the nodes - let mut nodes: Vec = vec![]; - for onnx_node in onnx_model.graph.node.iter() { - let mut node = convert_node_proto(onnx_node); - remap_node_type(&mut node); - nodes.push(node); - } - - // ONNX nodes must be topologically sorted per spec: - // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs - assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - - // Move inputs with initializers to states - move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); - - // Handle Identity nodes (expects inputs to be moved to states) - handle_identity(&mut nodes); - - // Lift constants to initializers (expects inputs to be moved to states) - lift_constants(&mut nodes); - - // Coalesce and transform nodes - coalesce(&mut nodes); - - // Rename nodes and inputs, save the mapping for later - let old_node_names = rename_nodes(&mut nodes); - - // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. - let mut inputs = onnx_model - .graph - .input - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); - - // Map each output in the model's graph to an Argument and collect them into a vector. - let mut outputs = onnx_model - .graph - .output - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); - - let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); - - // Infer shapes and update the inputs and outputs - dim_inference(&mut nodes, &inputs, &mut outputs); - - // Remove the graph inputs/output that are not used by any node - remove_unused_graph_inputs(&mut inputs, &mut outputs, &nodes); - - log::info!("Finished parsing ONNX file: {}", onnx_path.display()); - - ONNXGraph { - nodes, - inputs, - outputs, - old_node_names, - old_input_names, - } + log::info!("Parsing ONNX file: {}", onnx_path.display()); + + // Open the file + let mut file = File::open(onnx_path).expect("Unable to open file"); + let onnx_model: ModelProto = + Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); + + log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); + log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); + + log::debug!( + "Number of initializers: {:?}", + onnx_model.graph.initializer.len() + ); + + log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); + + // Convert the nodes + let mut nodes: Vec = vec![]; + for onnx_node in onnx_model.graph.node.iter() { + let mut node = convert_node_proto(onnx_node); + remap_node_type(&mut node); + nodes.push(node); + } + + // ONNX nodes must be topologically sorted per spec: + // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs + assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); + + // Move inputs with initializers to states + move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); + + // Handle Identity nodes (expects inputs to be moved to states) + handle_identity(&mut nodes); + + // Lift constants to initializers (expects inputs to be moved to states) + lift_constants(&mut nodes); + + // Coalesce and transform nodes + coalesce(&mut nodes); + + // Rename nodes and inputs, save the mapping for later + let old_node_names = rename_nodes(&mut nodes); + + // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. + let mut inputs = onnx_model + .graph + .input + .iter() + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect(); + + // Map each output in the model's graph to an Argument and collect them into a vector. + let mut outputs = onnx_model + .graph + .output + .iter() + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect(); + + let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); + + // Infer shapes and update the inputs and outputs + dim_inference(&mut nodes, &inputs, &mut outputs); + + // Remove the graph inputs/output that are not used by any node + remove_unused_graph_inputs(&mut inputs, &mut outputs, &nodes); + + log::info!("Finished parsing ONNX file: {}", onnx_path.display()); + + ONNXGraph { + nodes, + inputs, + outputs, + old_node_names, + old_input_names, + } } /// This function moves inputs that are also present @@ -128,49 +128,49 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { /// * `nodes` - A mutable reference to a vector of nodes /// * `initializers` - A vector of TensorProto fn move_inputs_to_state(nodes: &mut Vec, initializers: &[TensorProto]) { - // Convert initializers to hashmap for faster lookup - let initializers = initializers - .iter() - .map(|x| (x.name.clone(), x.clone())) - .collect::>(); - - // Iterate over each node in the graph - nodes.iter_mut().for_each(|node| { - for input in node.inputs.iter_mut() { - // If there is a corresponding initializer for the input, then move the data to the input value - if let Some(initializer) = initializers.get(&input.name) { - move_initializer_data(initializer, input); - } - } - }); + // Convert initializers to hashmap for faster lookup + let initializers = initializers + .iter() + .map(|x| (x.name.clone(), x.clone())) + .collect::>(); + + // Iterate over each node in the graph + nodes.iter_mut().for_each(|node| { + for input in node.inputs.iter_mut() { + // If there is a corresponding initializer for the input, then move the data to the input value + if let Some(initializer) = initializers.get(&input.name) { + move_initializer_data(initializer, input); + } + } + }); } fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { - // If the input name matches the tensor name in the initializer - // Convert the initializer to a tensor - let tensor = Tensor::try_from(initializer.clone()).expect("Invalid tensor"); - - if tensor.dim == 0 { - // Convert zero dim tensor to scalar - if let Some(data) = tensor.data { - input.value = Some(data.into_scalar()); + // If the input name matches the tensor name in the initializer + // Convert the initializer to a tensor + let tensor = Tensor::try_from(initializer.clone()).expect("Invalid tensor"); + + if tensor.dim == 0 { + // Convert zero dim tensor to scalar + if let Some(data) = tensor.data { + input.value = Some(data.into_scalar()); + } else { + input.value = None; + } + + // Update the input type + input.ty = ArgType::Scalar(tensor.elem_type); } else { - input.value = None; + // Move the tensor data to the input value + input.value = tensor.data.clone(); + + // Update the input type + input.ty = ArgType::Tensor(TensorType { + dim: tensor.dim, + elem_type: tensor.elem_type, + shape: tensor.shape, + }); } - - // Update the input type - input.ty = ArgType::Scalar(tensor.elem_type); - } else { - // Move the tensor data to the input value - input.value = tensor.data.clone(); - - // Update the input type - input.ty = ArgType::Tensor(TensorType { - dim: tensor.dim, - elem_type: tensor.elem_type, - shape: tensor.shape, - }); - } } /// Lift constants from the graph into the states vector for known node types. @@ -196,116 +196,117 @@ fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { /// /// Panics if the node's output is not a constant. fn lift_constants(nodes: &mut Vec) { - log::info!("Lifting constants into the states"); - - // create a set to hold the node types to process - let node_types_to_process: HashSet = - LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - - // create a new vector to hold the graph's constants (index by the node's name) - let constants = nodes - .iter() - .filter(|node| node.node_type == NodeType::Constant || node.node_type == NodeType::Identity) - .map(|node| (node.outputs[0].name.clone(), node.clone())) - .collect::>(); - - // create a set to hold the IDs of constants to be removed - let mut constant_to_removed = HashSet::::new(); - - for node in nodes.iter_mut() { - // Skip the node if it is not in the set of node types to process - if !node_types_to_process.contains(&node.node_type) { - continue; - } + log::info!("Lifting constants into the states"); - // Skip the first input because it is the node's true input and not a constant/state - node - .inputs - .iter_mut() - .skip(1) // TODO make configurable - .for_each(|input| { - if let Some(constant) = constants.get(&input.name) { - if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { - // The value comes from Identity inputs - if let Some(constant_input) = constant.inputs.first() { - input.ty = constant_input.ty.clone(); - input.value = constant_input.value.clone(); - } - } else { - // The value comes from an attribute - let arg = convert_constant_value(constant); // get the value of the constant - - input.value = arg.value; // set the input's value to the constant's value - input.ty = arg.ty; // set the input's type to the constant's type - // remove the constant from the graph - } - constant_to_removed.insert(constant.name.clone()); + // create a set to hold the node types to process + let node_types_to_process: HashSet = + LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); + + // create a new vector to hold the graph's constants (index by the node's name) + let constants = nodes + .iter() + .filter(|node| node.node_type == NodeType::Constant || node.node_type == NodeType::Identity) + .map(|node| (node.outputs[0].name.clone(), node.clone())) + .collect::>(); + + // create a set to hold the IDs of constants to be removed + let mut constant_to_removed = HashSet::::new(); + + for node in nodes.iter_mut() { + // Skip the node if it is not in the set of node types to process + if !node_types_to_process.contains(&node.node_type) { + continue; } - }); - } - // remove the constants that were moved to the states vector - nodes.retain(|node| !constant_to_removed.contains(&node.name)); + // Skip the first input because it is the node's true input and not a constant/state + node.inputs + .iter_mut() + .skip(1) // TODO make configurable + .for_each(|input| { + if let Some(constant) = constants.get(&input.name) { + if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { + // The value comes from Identity inputs + if let Some(constant_input) = constant.inputs.first() { + input.ty = constant_input.ty.clone(); + input.value = constant_input.value.clone(); + } + } else { + // The value comes from an attribute + let arg = convert_constant_value(constant); // get the value of the constant + + input.value = arg.value; // set the input's value to the constant's value + input.ty = arg.ty; // set the input's type to the constant's type + // remove the constant from the graph + } + constant_to_removed.insert(constant.name.clone()); + } + }); + } - log::debug!( - "The number of constants lifted: {}", - constant_to_removed.len() - ); + // remove the constants that were moved to the states vector + nodes.retain(|node| !constant_to_removed.contains(&node.name)); + + log::debug!( + "The number of constants lifted: {}", + constant_to_removed.len() + ); } fn handle_identity(nodes: &mut Vec) { - log::info!("Handling identity nodes"); - - let mut nodes_to_remove = HashSet::new(); - - let identity_nodes = nodes - .iter() - .filter(|node| node.node_type == NodeType::Identity) - .cloned() - .collect::>(); - - // Handle pass-through nodes. - for identity_node in identity_nodes { - if identity_node.node_type == NodeType::Identity && identity_node.inputs[0].value.is_none() { - let input_name = &identity_node.inputs[0].name; - let output_name = &identity_node.outputs[0].name; - - // Replace the identity node's output with its input in the connected nodes. - for node in nodes.iter_mut() { - if let Some(matched_input) = node.inputs.iter_mut().find(|x| x.name == *output_name) { - matched_input.name = input_name.clone(); - } - } + log::info!("Handling identity nodes"); + + let mut nodes_to_remove = HashSet::new(); + + let identity_nodes = nodes + .iter() + .filter(|node| node.node_type == NodeType::Identity) + .cloned() + .collect::>(); + + // Handle pass-through nodes. + for identity_node in identity_nodes { + if identity_node.node_type == NodeType::Identity && identity_node.inputs[0].value.is_none() + { + let input_name = &identity_node.inputs[0].name; + let output_name = &identity_node.outputs[0].name; + + // Replace the identity node's output with its input in the connected nodes. + for node in nodes.iter_mut() { + if let Some(matched_input) = node.inputs.iter_mut().find(|x| x.name == *output_name) + { + matched_input.name = input_name.clone(); + } + } - nodes_to_remove.insert(identity_node); + nodes_to_remove.insert(identity_node); + } } - } - // Remove the identity nodes. - nodes.retain(|node| !nodes_to_remove.contains(node)); + // Remove the identity nodes. + nodes.retain(|node| !nodes_to_remove.contains(node)); } /// Rename the nodes in the graph to be unique and return a map of the old names to the new names. fn rename_nodes(nodes: &mut Vec) -> HashMap { - let mut old_names = HashMap::new(); - let mut counter: HashMap = HashMap::new(); + let mut old_names = HashMap::new(); + let mut counter: HashMap = HashMap::new(); - for node in nodes.iter_mut() { - // keep track of the number of nodes of each type - counter - .entry(node.node_type.clone()) - .and_modify(|e| *e += 1) - .or_insert(1); + for node in nodes.iter_mut() { + // keep track of the number of nodes of each type + counter + .entry(node.node_type.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); - let old_name = node.name.clone(); - let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase(); + let old_name = node.name.clone(); + let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase(); - node.name = new_name.clone(); + node.name = new_name.clone(); - old_names.insert(old_name, new_name); - } + old_names.insert(old_name, new_name); + } - old_names + old_names } /// Rename the inputs and output in the graph and return a map of @@ -315,60 +316,60 @@ fn rename_nodes(nodes: &mut Vec) -> HashMap { /// conv2_in1, conv2_in2, etc. This is done to be consistent with /// the naming convention of the nodes and allow to be used as rust identifiers. fn rename_inputs( - nodes: &mut Vec, - inputs: &mut Vec, - outputs: &mut Vec, + nodes: &mut Vec, + inputs: &mut Vec, + outputs: &mut Vec, ) -> HashMap { - let mut old_names = HashMap::new(); - - // rename all graph input names to follow input1, input2, input3, etc. - // (assumes the input names are already unique) - let mut counter = 1; - for input in inputs.iter_mut() { - let old_name = input.name.clone(); - let new_name = format!("input{}", counter); - input.name = new_name.clone(); - old_names.insert(old_name, new_name); - counter += 1; - } - - for node in nodes.iter_mut() { + let mut old_names = HashMap::new(); + + // rename all graph input names to follow input1, input2, input3, etc. + // (assumes the input names are already unique) let mut counter = 1; + for input in inputs.iter_mut() { + let old_name = input.name.clone(); + let new_name = format!("input{}", counter); + input.name = new_name.clone(); + old_names.insert(old_name, new_name); + counter += 1; + } - // loop through node outputs and rename them and store the new name <-> old name mapping - for output in node.outputs.iter_mut() { - let old_name = output.name.clone(); - let new_name = format!("{}_out{}", node.name, counter); - output.name = new_name.clone(); - old_names.insert(old_name, new_name); - counter += 1; + for node in nodes.iter_mut() { + let mut counter = 1; + + // loop through node outputs and rename them and store the new name <-> old name mapping + for output in node.outputs.iter_mut() { + let old_name = output.name.clone(); + let new_name = format!("{}_out{}", node.name, counter); + output.name = new_name.clone(); + old_names.insert(old_name, new_name); + counter += 1; + } } - } - for node in nodes.iter_mut() { - // loop through node inputs and rename them with previously replaced names - // and mark them as passed if they are in the old_names map (i.e. they are node outputs) - for input in node.inputs.iter_mut() { - if let Some(new_name) = old_names.get(&input.name) { - input.name = new_name.clone(); - input.passed = true; - } else { - input.name = "".to_string(); // Rename to a placeholder - input.passed = false; - } + for node in nodes.iter_mut() { + // loop through node inputs and rename them with previously replaced names + // and mark them as passed if they are in the old_names map (i.e. they are node outputs) + for input in node.inputs.iter_mut() { + if let Some(new_name) = old_names.get(&input.name) { + input.name = new_name.clone(); + input.passed = true; + } else { + input.name = "".to_string(); // Rename to a placeholder + input.passed = false; + } + } } - } - // Rename the graph outputs - for output in outputs.iter_mut() { - if let Some(new_name) = old_names.get(&output.name) { - output.name = new_name.clone(); - } else { - log::warn!("Output {:?} not found in old_names", output.name); + // Rename the graph outputs + for output in outputs.iter_mut() { + if let Some(new_name) = old_names.get(&output.name) { + output.name = new_name.clone(); + } else { + log::warn!("Output {:?} not found in old_names", output.name); + } } - } - old_names + old_names } /// Removes the graph inputs/output that are not used by any node. @@ -381,90 +382,90 @@ fn rename_inputs( /// Generally, it's a good idea to remove unused inputs/outputs because it makes the /// generated code cleaner and easier to read. fn remove_unused_graph_inputs( - inputs: &mut Vec, - outputs: &mut Vec, - nodes: &Vec, + inputs: &mut Vec, + outputs: &mut Vec, + nodes: &Vec, ) { - // Remove inputs that are not used by any node - inputs.retain(|input| { - for node in nodes.iter() { - if node - .inputs - .iter() - .any(|x| x.name == input.name && x.value.is_none()) - { - return true; - } - } - false - }); - - // Remove outputs that are not used by any node - outputs.retain(|output| { - for node in nodes.iter() { - if node.outputs.iter().any(|x| x.name == output.name) { - return true; - } - } - false - }); + // Remove inputs that are not used by any node + inputs.retain(|input| { + for node in nodes.iter() { + if node + .inputs + .iter() + .any(|x| x.name == input.name && x.value.is_none()) + { + return true; + } + } + false + }); + + // Remove outputs that are not used by any node + outputs.retain(|output| { + for node in nodes.iter() { + if node.outputs.iter().any(|x| x.name == output.name) { + return true; + } + } + false + }); } // Define a trait for topological sorting trait TopologicalSortable { - fn is_top_sorted(&self) -> bool; + fn is_top_sorted(&self) -> bool; } impl TopologicalSortable for Vec { - fn is_top_sorted(&self) -> bool { - // Create a hashmap to store the position of each node in the vector - let position: HashMap = self - .iter() - .enumerate() - .map(|(idx, node)| (node.name.clone(), idx)) - .collect(); - - // Iterate over each node in the vector - for node in self { - // Iterate over each output of the node - for output in &node.outputs { - // Iterate over each other node in the vector - for other_node in self { - // If the other node has an input that matches the current output - if other_node.inputs.contains(output) { - // If the position of the current node is greater than the position of the other node - if position[&node.name] > position[&other_node.name] { - // The vector is not topologically sorted - return false; + fn is_top_sorted(&self) -> bool { + // Create a hashmap to store the position of each node in the vector + let position: HashMap = self + .iter() + .enumerate() + .map(|(idx, node)| (node.name.clone(), idx)) + .collect(); + + // Iterate over each node in the vector + for node in self { + // Iterate over each output of the node + for output in &node.outputs { + // Iterate over each other node in the vector + for other_node in self { + // If the other node has an input that matches the current output + if other_node.inputs.contains(output) { + // If the position of the current node is greater than the position of the other node + if position[&node.name] > position[&other_node.name] { + // The vector is not topologically sorted + return false; + } + } + } } - } } - } - } - // The vector is topologically sorted - true - } + // The vector is topologically sorted + true + } } /// Get the value of a constant node from its attributes pub(crate) fn convert_constant_value(node: &Node) -> Argument { - // A value can be stored in any of these attributes - let keys = [ - "value", - "value_float", - "value_floats", - "value_int", - "value_ints", - "value_string", - "value_strings", - "sparse_value", - ]; - - let value = keys - .iter() - .find_map(|&key| node.attrs.get(key).cloned()) - .expect("Constant should have a value"); - - Argument::from(value) + // A value can be stored in any of these attributes + let keys = [ + "value", + "value_float", + "value_floats", + "value_int", + "value_ints", + "value_string", + "value_strings", + "sparse_value", + ]; + + let value = keys + .iter() + .find_map(|&key| node.attrs.get(key).cloned()) + .expect("Constant should have a value"); + + Argument::from(value) } diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index 29f9f857da..d4a2086ad5 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -9,39 +9,39 @@ pub type Shape = Vec; /// A node input or output. #[derive(Debug, Clone)] pub struct Argument { - /// The name of the node input. - pub name: String, + /// The name of the node input. + pub name: String, - /// The type of the argument. - pub ty: ArgType, + /// The type of the argument. + pub ty: ArgType, - /// The data of the argument. - pub value: Option, + /// The data of the argument. + pub value: Option, - /// True if the argument is passed to node, false otherwise. We use it mainly for informational purposes. - /// The argument should contain a value if passed is false. - pub passed: bool, + /// True if the argument is passed to node, false otherwise. We use it mainly for informational purposes. + /// The argument should contain a value if passed is false. + pub passed: bool, } /// The type of an argument. #[derive(Debug, Clone)] pub enum ArgType { - Scalar(ElementType), - Shape(Dim), - Tensor(TensorType), + Scalar(ElementType), + Shape(Dim), + Tensor(TensorType), } /// The type of an attribute. #[derive(Debug, Clone)] pub enum AttributeValue { - Float32(f32), - Float32s(Vec), - Int64(i64), - Int64s(Vec), - String(String), - Strings(Vec), - Tensor(Tensor), - Tensors(Vec), + Float32(f32), + Float32s(Vec), + Int64(i64), + Int64s(Vec), + String(String), + Strings(Vec), + Tensor(Tensor), + Tensors(Vec), } pub type Attributes = HashMap; @@ -49,126 +49,126 @@ pub type Attributes = HashMap; /// The type of an element. #[derive(Debug, Clone)] pub enum ElementType { - Float32, - Float64, - Int32, - Int64, - String, - Float16, - Bool, + Float32, + Float64, + Int32, + Int64, + String, + Float16, + Bool, } #[derive(Debug, Clone, Default)] pub struct TensorType { - /// The type of the tensor. - pub elem_type: ElementType, + /// The type of the tensor. + pub elem_type: ElementType, - /// The dimension of the tensor. - pub dim: Dim, + /// The dimension of the tensor. + pub dim: Dim, - /// The shape of the tensor. - pub shape: Option, + /// The shape of the tensor. + pub shape: Option, } impl Default for ElementType { - fn default() -> Self { - Self::Float32 - } + fn default() -> Self { + Self::Float32 + } } impl Default for ArgType { - fn default() -> Self { - Self::Tensor(TensorType::default()) - } + fn default() -> Self { + Self::Tensor(TensorType::default()) + } } impl Argument { - pub fn new(name: String) -> Self { - Self { - name, - ty: ArgType::default(), - value: None, - passed: false, - } - } + pub fn new(name: String) -> Self { + Self { + name, + ty: ArgType::default(), + value: None, + passed: false, + } + } } #[derive(Debug, Clone, Default)] pub struct Tensor { - /// The type of the tensor. - pub elem_type: ElementType, + /// The type of the tensor. + pub elem_type: ElementType, - /// The dimension of the tensor. - pub dim: Dim, + /// The dimension of the tensor. + pub dim: Dim, - /// The data of the tensor. - pub data: Option, + /// The data of the tensor. + pub data: Option, - /// The shape of the tensor. - pub shape: Option, + /// The shape of the tensor. + pub shape: Option, } /// Container to hold data for tensors and arguments #[derive(Clone)] pub enum Data { - Bool(bool), - Bools(Vec), - Float16(f16), - Float16s(Vec), - Float32(f32), - Float32s(Vec), - Float64(f64), - Float64s(Vec), - Int32(i32), - Int32s(Vec), - Int64(i64), - Int64s(Vec), - String(String), - Strings(Vec), + Bool(bool), + Bools(Vec), + Float16(f16), + Float16s(Vec), + Float32(f32), + Float32s(Vec), + Float64(f64), + Float64s(Vec), + Int32(i32), + Int32s(Vec), + Int64(i64), + Int64s(Vec), + String(String), + Strings(Vec), } /// ONNX graph representation #[derive(Debug, Clone)] pub struct ONNXGraph { - /// The nodes of the graph. - pub nodes: Vec, + /// The nodes of the graph. + pub nodes: Vec, - /// The inputs of the graph. - pub inputs: Vec, + /// The inputs of the graph. + pub inputs: Vec, - /// The outputs of the graph. - pub outputs: Vec, + /// The outputs of the graph. + pub outputs: Vec, - /// The original node names. - pub old_node_names: HashMap, + /// The original node names. + pub old_node_names: HashMap, - /// The original input names. - pub old_input_names: HashMap, + /// The original input names. + pub old_input_names: HashMap, } #[derive(Debug, Clone)] pub struct Node { - /// The type of the node. - pub node_type: NodeType, + /// The type of the node. + pub node_type: NodeType, - /// The name of the node. - pub name: String, + /// The name of the node. + pub name: String, - /// The inputs of the node. - pub inputs: Vec, + /// The inputs of the node. + pub inputs: Vec, - /// The outputs of the node. - pub outputs: Vec, + /// The outputs of the node. + pub outputs: Vec, - /// The attributes of the node. - pub attrs: Attributes, + /// The attributes of the node. + pub attrs: Attributes, } // Required by topological sort impl PartialEq for Node { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.node_type == other.node_type - } + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.node_type == other.node_type + } } // Required by topological sort @@ -176,596 +176,596 @@ impl Eq for Node {} // Required by topological sort impl core::hash::Hash for Node { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.node_type.hash(state); - self.inputs.hash(state); - self.outputs.hash(state); - } + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.node_type.hash(state); + self.inputs.hash(state); + self.outputs.hash(state); + } } // Required by topological sort impl core::hash::Hash for Argument { - fn hash(&self, state: &mut H) { - self.name.hash(state); - } + fn hash(&self, state: &mut H) { + self.name.hash(state); + } } impl Eq for Argument {} // Required by HashSet impl PartialEq for Argument { - fn eq(&self, other: &Self) -> bool { - self.name == other.name - } + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } } /// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops) #[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)] pub enum NodeType { - Abs, - Acos, - Acosh, - Add, - And, - ArgMax, - ArgMin, - Asin, - Asinh, - Atan, - Atanh, - AveragePool, - AveragePool1d, - AveragePool2d, - BatchNormalization, - Bernoulli, - BitShift, - BitwiseAnd, - BitwiseNot, - BitwiseOr, - BitwiseXor, - BlackmanWindow, - Cast, - CastLike, - Ceil, - Celu, - CenterCropPad, - Clip, - Col, - Compress, - Concat, - ConcatFromSequence, - Constant, - ConstantOfShape, - Conv, - Conv1d, - Conv2d, - ConvInteger, - ConvTranspose, - Cos, - Cosh, - CumSum, - DepthToSpace, - DequantizeLinear, - Det, - DFT, - Div, - Dropout, - DynamicQuantizeLinear, - Einsum, - Elu, - Equal, - Erf, - Exp, - Expand, - EyeLike, - Flatten, - Floor, - Gather, - GatherElements, - GatherND, - Gelu, - Gemm, - GlobalAveragePool, - GlobalLpPool, - GlobalMaxPool, - Greater, - GreaterOrEqual, - GridSample, - GroupNormalization, - GRU, - HammingWindow, - HannWindow, - Hardmax, - HardSigmoid, - HardSwish, - Identity, - If, - Im, - InstanceNormalization, - IsInf, - IsNaN, - LayerNormalization, - LeakyRelu, - Less, - LessOrEqual, - Linear, - Log, - LogSoftmax, - Loop, - LpNormalization, - LpPool, - LRN, - LSTM, - MatMul, - MatMulInteger, - Max, - MaxPool, - MaxPool1d, - MaxPool2d, - MaxRoiPool, - MaxUnpool, - Mean, - MeanVarianceNormalization, - MelWeightMatrix, - Min, - Mish, - Mod, - Mul, - Multinomial, - Neg, - NegativeLogLikelihoodLoss, - NonMaxSuppression, - NonZero, - Not, - OneHot, - Optional, - OptionalGetElement, - OptionalHasElement, - Or, - Pad, - Pow, - PRelu, - QLinearConv, - QLinearMatMul, - QuantizeLinear, - RandomNormal, - RandomNormalLike, - RandomUniform, - RandomUniformLike, - Range, - Reciprocal, - ReduceL, - ReduceLogSum, - ReduceLogSumExp, - ReduceMax, - ReduceMean, - ReduceMin, - ReduceProd, - ReduceSum, - ReduceSumSquare, - Relu, - Reshape, - Resize, - ReverseSequence, - RNN, - RoiAlign, - Round, - Scan, - Scatter, - ScatterElements, - ScatterND, - Selu, - SequenceAt, - SequenceConstruct, - SequenceEmpty, - SequenceErase, - SequenceInsert, - SequenceLength, - SequenceMap, - Shape, - Shrink, - Sigmoid, - Sign, - Sin, - Sinh, - Size, - Slice, - Softmax, - SoftmaxCrossEntropyLoss, - Softplus, - Softsign, - SpaceToDepth, - Split, - SplitToSequence, - Sqrt, - Squeeze, - STFT, - StringNormalizer, - Sub, - Sum, - Tan, - Tanh, - TfIdfVectorizer, - ThresholdedRelu, - Tile, - TopK, - Transpose, - Trilu, - Unique, - Unsqueeze, - Upsample, - Where, - Xor, + Abs, + Acos, + Acosh, + Add, + And, + ArgMax, + ArgMin, + Asin, + Asinh, + Atan, + Atanh, + AveragePool, + AveragePool1d, + AveragePool2d, + BatchNormalization, + Bernoulli, + BitShift, + BitwiseAnd, + BitwiseNot, + BitwiseOr, + BitwiseXor, + BlackmanWindow, + Cast, + CastLike, + Ceil, + Celu, + CenterCropPad, + Clip, + Col, + Compress, + Concat, + ConcatFromSequence, + Constant, + ConstantOfShape, + Conv, + Conv1d, + Conv2d, + ConvInteger, + ConvTranspose, + Cos, + Cosh, + CumSum, + DepthToSpace, + DequantizeLinear, + Det, + DFT, + Div, + Dropout, + DynamicQuantizeLinear, + Einsum, + Elu, + Equal, + Erf, + Exp, + Expand, + EyeLike, + Flatten, + Floor, + Gather, + GatherElements, + GatherND, + Gelu, + Gemm, + GlobalAveragePool, + GlobalLpPool, + GlobalMaxPool, + Greater, + GreaterOrEqual, + GridSample, + GroupNormalization, + GRU, + HammingWindow, + HannWindow, + Hardmax, + HardSigmoid, + HardSwish, + Identity, + If, + Im, + InstanceNormalization, + IsInf, + IsNaN, + LayerNormalization, + LeakyRelu, + Less, + LessOrEqual, + Linear, + Log, + LogSoftmax, + Loop, + LpNormalization, + LpPool, + LRN, + LSTM, + MatMul, + MatMulInteger, + Max, + MaxPool, + MaxPool1d, + MaxPool2d, + MaxRoiPool, + MaxUnpool, + Mean, + MeanVarianceNormalization, + MelWeightMatrix, + Min, + Mish, + Mod, + Mul, + Multinomial, + Neg, + NegativeLogLikelihoodLoss, + NonMaxSuppression, + NonZero, + Not, + OneHot, + Optional, + OptionalGetElement, + OptionalHasElement, + Or, + Pad, + Pow, + PRelu, + QLinearConv, + QLinearMatMul, + QuantizeLinear, + RandomNormal, + RandomNormalLike, + RandomUniform, + RandomUniformLike, + Range, + Reciprocal, + ReduceL, + ReduceLogSum, + ReduceLogSumExp, + ReduceMax, + ReduceMean, + ReduceMin, + ReduceProd, + ReduceSum, + ReduceSumSquare, + Relu, + Reshape, + Resize, + ReverseSequence, + RNN, + RoiAlign, + Round, + Scan, + Scatter, + ScatterElements, + ScatterND, + Selu, + SequenceAt, + SequenceConstruct, + SequenceEmpty, + SequenceErase, + SequenceInsert, + SequenceLength, + SequenceMap, + Shape, + Shrink, + Sigmoid, + Sign, + Sin, + Sinh, + Size, + Slice, + Softmax, + SoftmaxCrossEntropyLoss, + Softplus, + Softsign, + SpaceToDepth, + Split, + SplitToSequence, + Sqrt, + Squeeze, + STFT, + StringNormalizer, + Sub, + Sum, + Tan, + Tanh, + TfIdfVectorizer, + ThresholdedRelu, + Tile, + TopK, + Transpose, + Trilu, + Unique, + Unsqueeze, + Upsample, + Where, + Xor, } /// Truncate the vector display for debug display fn trunc(v: &Vec) -> String { - const BEGIN_INDEX: usize = 0; - const MAX_LEN: usize = 5; - let mut s = String::new(); - s.push('['); - for (i, item) in v.iter().enumerate() { - if i > BEGIN_INDEX { - s.push_str(", "); - } - s.push_str(&format!("{}", item)); - if i > MAX_LEN { - s.push_str(", ..."); - break; - } - } - s.push(']'); - s + const BEGIN_INDEX: usize = 0; + const MAX_LEN: usize = 5; + let mut s = String::new(); + s.push('['); + for (i, item) in v.iter().enumerate() { + if i > BEGIN_INDEX { + s.push_str(", "); + } + s.push_str(&format!("{}", item)); + if i > MAX_LEN { + s.push_str(", ..."); + break; + } + } + s.push(']'); + s } /// Shorten the tensor data for debug display impl fmt::Debug for Data { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Data::Float16s(v) => write!(f, "Float16s({})", trunc(v)), - Data::Float32s(v) => write!(f, "Float32s({})", trunc(v)), - Data::Float64s(v) => write!(f, "Float64s({})", trunc(v)), - Data::Int32s(v) => write!(f, "Int32s({})", trunc(v)), - Data::Int64s(v) => write!(f, "Int64s({})", trunc(v)), - Data::Strings(v) => write!(f, "Strings({})", trunc(v)), - Data::Bools(v) => write!(f, "Bools({})", trunc(v)), - Data::Float16(v) => write!(f, "Float16({})", v), - Data::Float32(v) => write!(f, "Float32({})", v), - Data::Float64(v) => write!(f, "Float64({})", v), - Data::Int32(v) => write!(f, "Int32({})", v), - Data::Int64(v) => write!(f, "Int64({})", v), - Data::String(v) => write!(f, "String({})", v), - Data::Bool(v) => write!(f, "Bool({})", v), - } - } + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Data::Float16s(v) => write!(f, "Float16s({})", trunc(v)), + Data::Float32s(v) => write!(f, "Float32s({})", trunc(v)), + Data::Float64s(v) => write!(f, "Float64s({})", trunc(v)), + Data::Int32s(v) => write!(f, "Int32s({})", trunc(v)), + Data::Int64s(v) => write!(f, "Int64s({})", trunc(v)), + Data::Strings(v) => write!(f, "Strings({})", trunc(v)), + Data::Bools(v) => write!(f, "Bools({})", trunc(v)), + Data::Float16(v) => write!(f, "Float16({})", v), + Data::Float32(v) => write!(f, "Float32({})", v), + Data::Float64(v) => write!(f, "Float64({})", v), + Data::Int32(v) => write!(f, "Int32({})", v), + Data::Int64(v) => write!(f, "Int64({})", v), + Data::String(v) => write!(f, "String({})", v), + Data::Bool(v) => write!(f, "Bool({})", v), + } + } } impl Data { - pub fn into_scalar(self) -> Self { - match self { - Data::Float16s(data) => { - assert_eq!(data.len(), 1); - Data::Float16(data[0]) - } - Data::Float32s(data) => { - assert_eq!(data.len(), 1); - Data::Float32(data[0]) - } - Data::Float64s(data) => { - assert_eq!(data.len(), 1); - Data::Float64(data[0]) - } - Data::Int32s(data) => { - assert_eq!(data.len(), 1); - Data::Int32(data[0]) - } - Data::Int64s(data) => { - assert_eq!(data.len(), 1); - Data::Int64(data[0]) - } - Data::Bools(data) => { - assert_eq!(data.len(), 1); - Data::Bool(data[0]) - } - Data::Strings(data) => { - assert_eq!(data.len(), 1); - Data::String(data[0].clone()) - } - _ => self, - } - } - - pub fn into_f16(self) -> f16 { - if let Data::Float16(elem) = self { - elem - } else { - panic!("Expected Float16, got {:?}", self); - } - } - - pub fn into_f32(self) -> f32 { - if let Data::Float32(elem) = self { - elem - } else { - panic!("Expected Float32, got {:?}", self); - } - } - - pub fn into_f64(self) -> f64 { - if let Data::Float64(elem) = self { - elem - } else { - panic!("Expected Float64, got {:?}", self); - } - } - - pub fn into_i32(self) -> i32 { - if let Data::Int32(elem) = self { - elem - } else { - panic!("Expected Int32, got {:?}", self); - } - } - - pub fn into_i64(self) -> i64 { - if let Data::Int64(elem) = self { - elem - } else { - panic!("Expected Int64, got {:?}", self); - } - } - - pub fn into_bool(self) -> bool { - if let Data::Bool(elem) = self { - elem - } else { - panic!("Expected Bool, got {:?}", self); - } - } - - pub fn into_string(self) -> String { - if let Data::String(elem) = self { - elem - } else { - panic!("Expected String, got {:?}", self); - } - } - - pub fn into_f16s(self) -> Vec { - if let Data::Float16s(elem) = self { - elem - } else { - panic!("Expected Float16s, got {:?}", self); - } - } - - pub fn into_f32s(self) -> Vec { - if let Data::Float32s(elem) = self { - elem - } else { - panic!("Expected Float32s, got {:?}", self); - } - } - - pub fn into_f64s(self) -> Vec { - if let Data::Float64s(elem) = self { - elem - } else { - panic!("Expected Float64s, got {:?}", self); - } - } - - pub fn into_i32s(self) -> Vec { - if let Data::Int32s(elem) = self { - elem - } else { - panic!("Expected Int32s, got {:?}", self); - } - } - - pub fn into_i64s(self) -> Vec { - if let Data::Int64s(elem) = self { - elem - } else { - panic!("Expected Int64s, got {:?}", self); + pub fn into_scalar(self) -> Self { + match self { + Data::Float16s(data) => { + assert_eq!(data.len(), 1); + Data::Float16(data[0]) + } + Data::Float32s(data) => { + assert_eq!(data.len(), 1); + Data::Float32(data[0]) + } + Data::Float64s(data) => { + assert_eq!(data.len(), 1); + Data::Float64(data[0]) + } + Data::Int32s(data) => { + assert_eq!(data.len(), 1); + Data::Int32(data[0]) + } + Data::Int64s(data) => { + assert_eq!(data.len(), 1); + Data::Int64(data[0]) + } + Data::Bools(data) => { + assert_eq!(data.len(), 1); + Data::Bool(data[0]) + } + Data::Strings(data) => { + assert_eq!(data.len(), 1); + Data::String(data[0].clone()) + } + _ => self, + } + } + + pub fn into_f16(self) -> f16 { + if let Data::Float16(elem) = self { + elem + } else { + panic!("Expected Float16, got {:?}", self); + } + } + + pub fn into_f32(self) -> f32 { + if let Data::Float32(elem) = self { + elem + } else { + panic!("Expected Float32, got {:?}", self); + } + } + + pub fn into_f64(self) -> f64 { + if let Data::Float64(elem) = self { + elem + } else { + panic!("Expected Float64, got {:?}", self); + } + } + + pub fn into_i32(self) -> i32 { + if let Data::Int32(elem) = self { + elem + } else { + panic!("Expected Int32, got {:?}", self); + } + } + + pub fn into_i64(self) -> i64 { + if let Data::Int64(elem) = self { + elem + } else { + panic!("Expected Int64, got {:?}", self); + } + } + + pub fn into_bool(self) -> bool { + if let Data::Bool(elem) = self { + elem + } else { + panic!("Expected Bool, got {:?}", self); + } + } + + pub fn into_string(self) -> String { + if let Data::String(elem) = self { + elem + } else { + panic!("Expected String, got {:?}", self); + } } - } - pub fn into_bools(self) -> Vec { - if let Data::Bools(elem) = self { - elem - } else { - panic!("Expected Bools, got {:?}", self); + pub fn into_f16s(self) -> Vec { + if let Data::Float16s(elem) = self { + elem + } else { + panic!("Expected Float16s, got {:?}", self); + } + } + + pub fn into_f32s(self) -> Vec { + if let Data::Float32s(elem) = self { + elem + } else { + panic!("Expected Float32s, got {:?}", self); + } } - } - pub fn into_strings(self) -> Vec { - if let Data::Strings(elem) = self { - elem - } else { - panic!("Expected Strings, got {:?}", self); + pub fn into_f64s(self) -> Vec { + if let Data::Float64s(elem) = self { + elem + } else { + panic!("Expected Float64s, got {:?}", self); + } + } + + pub fn into_i32s(self) -> Vec { + if let Data::Int32s(elem) = self { + elem + } else { + panic!("Expected Int32s, got {:?}", self); + } + } + + pub fn into_i64s(self) -> Vec { + if let Data::Int64s(elem) = self { + elem + } else { + panic!("Expected Int64s, got {:?}", self); + } + } + + pub fn into_bools(self) -> Vec { + if let Data::Bools(elem) = self { + elem + } else { + panic!("Expected Bools, got {:?}", self); + } + } + + pub fn into_strings(self) -> Vec { + if let Data::Strings(elem) = self { + elem + } else { + panic!("Expected Strings, got {:?}", self); + } } - } } impl AttributeValue { - pub fn into_f32(self) -> f32 { - if let AttributeValue::Float32(elem) = self { - elem - } else { - panic!("Expected Float32, got {:?}", self); + pub fn into_f32(self) -> f32 { + if let AttributeValue::Float32(elem) = self { + elem + } else { + panic!("Expected Float32, got {:?}", self); + } } - } - pub fn into_i32(self) -> i32 { - if let AttributeValue::Int64(elem) = self { - elem as i32 - } else { - panic!("Expected Int32, got {:?}", self); + pub fn into_i32(self) -> i32 { + if let AttributeValue::Int64(elem) = self { + elem as i32 + } else { + panic!("Expected Int32, got {:?}", self); + } } - } - pub fn into_i64(self) -> i64 { - if let AttributeValue::Int64(elem) = self { - elem - } else { - panic!("Expected Int64, got {:?}", self); + pub fn into_i64(self) -> i64 { + if let AttributeValue::Int64(elem) = self { + elem + } else { + panic!("Expected Int64, got {:?}", self); + } } - } - pub fn into_string(self) -> String { - if let AttributeValue::String(elem) = self { - elem - } else { - panic!("Expected String, got {:?}", self); + pub fn into_string(self) -> String { + if let AttributeValue::String(elem) = self { + elem + } else { + panic!("Expected String, got {:?}", self); + } } - } - pub fn into_tensor(self) -> Tensor { - if let AttributeValue::Tensor(elem) = self { - elem - } else { - panic!("Expected Tensor, got {:?}", self); + pub fn into_tensor(self) -> Tensor { + if let AttributeValue::Tensor(elem) = self { + elem + } else { + panic!("Expected Tensor, got {:?}", self); + } } - } - pub fn into_f32s(self) -> Vec { - if let AttributeValue::Float32s(elem) = self { - elem - } else { - panic!("Expected Float32s, got {:?}", self); + pub fn into_f32s(self) -> Vec { + if let AttributeValue::Float32s(elem) = self { + elem + } else { + panic!("Expected Float32s, got {:?}", self); + } } - } - pub fn into_i64s(self) -> Vec { - if let AttributeValue::Int64s(elem) = self { - elem - } else { - panic!("Expected Int64s, got {:?}", self); + pub fn into_i64s(self) -> Vec { + if let AttributeValue::Int64s(elem) = self { + elem + } else { + panic!("Expected Int64s, got {:?}", self); + } } - } - pub fn into_strings(self) -> Vec { - if let AttributeValue::Strings(elem) = self { - elem - } else { - panic!("Expected Strings, got {:?}", self); + pub fn into_strings(self) -> Vec { + if let AttributeValue::Strings(elem) = self { + elem + } else { + panic!("Expected Strings, got {:?}", self); + } } - } - pub fn into_tensors(self) -> Vec { - if let AttributeValue::Tensors(elem) = self { - elem - } else { - panic!("Expected Tensors, got {:?}", self); + pub fn into_tensors(self) -> Vec { + if let AttributeValue::Tensors(elem) = self { + elem + } else { + panic!("Expected Tensors, got {:?}", self); + } } - } } /// Convert AttributeValue to an Argument impl From for Argument { - fn from(attr: AttributeValue) -> Argument { - // "" is used as a placeholder for the name - let name = "".to_string(); - - match attr { - AttributeValue::Float32(value) => Argument { - ty: ArgType::Scalar(ElementType::Float32), - name, - value: Some(Data::Float32(value)), - passed: false, - }, - AttributeValue::Float32s(values) => Argument { - ty: ArgType::Tensor(TensorType { - dim: 1, - elem_type: ElementType::Float32, - shape: Some(vec![values.len()]), - }), - name, - value: Some(Data::Float32s(values)), - passed: false, - }, - AttributeValue::Int64(value) => Argument { - ty: ArgType::Scalar(ElementType::Int64), - name, - value: Some(Data::Int64(value)), - passed: false, - }, - AttributeValue::Int64s(values) => Argument { - ty: ArgType::Tensor(TensorType { - dim: 1, - elem_type: ElementType::Int64, - shape: Some(vec![values.len()]), - }), - name, - value: Some(Data::Int64s(values)), - passed: false, - }, - AttributeValue::String(value) => Argument { - ty: ArgType::Scalar(ElementType::String), - name, - value: Some(Data::String(value)), - passed: false, - }, - AttributeValue::Strings(values) => Argument { - ty: ArgType::Tensor(TensorType { - dim: 1, - elem_type: ElementType::String, - shape: Some(vec![values.len()]), - }), - name, - value: Some(Data::Strings(values)), - passed: false, - }, - AttributeValue::Tensor(tensor) => { - if tensor.dim == 0 { - // Convert zero dim tensor to scalar - if let Some(data) = tensor.data { - Argument { - ty: ArgType::Scalar(tensor.elem_type), - name, - value: Some(data.into_scalar()), - passed: false, - } - } else { - Argument { - ty: ArgType::Scalar(tensor.elem_type), - name, - value: None, - passed: false, + fn from(attr: AttributeValue) -> Argument { + // "" is used as a placeholder for the name + let name = "".to_string(); + + match attr { + AttributeValue::Float32(value) => Argument { + ty: ArgType::Scalar(ElementType::Float32), + name, + value: Some(Data::Float32(value)), + passed: false, + }, + AttributeValue::Float32s(values) => Argument { + ty: ArgType::Tensor(TensorType { + dim: 1, + elem_type: ElementType::Float32, + shape: Some(vec![values.len()]), + }), + name, + value: Some(Data::Float32s(values)), + passed: false, + }, + AttributeValue::Int64(value) => Argument { + ty: ArgType::Scalar(ElementType::Int64), + name, + value: Some(Data::Int64(value)), + passed: false, + }, + AttributeValue::Int64s(values) => Argument { + ty: ArgType::Tensor(TensorType { + dim: 1, + elem_type: ElementType::Int64, + shape: Some(vec![values.len()]), + }), + name, + value: Some(Data::Int64s(values)), + passed: false, + }, + AttributeValue::String(value) => Argument { + ty: ArgType::Scalar(ElementType::String), + name, + value: Some(Data::String(value)), + passed: false, + }, + AttributeValue::Strings(values) => Argument { + ty: ArgType::Tensor(TensorType { + dim: 1, + elem_type: ElementType::String, + shape: Some(vec![values.len()]), + }), + name, + value: Some(Data::Strings(values)), + passed: false, + }, + AttributeValue::Tensor(tensor) => { + if tensor.dim == 0 { + // Convert zero dim tensor to scalar + if let Some(data) = tensor.data { + Argument { + ty: ArgType::Scalar(tensor.elem_type), + name, + value: Some(data.into_scalar()), + passed: false, + } + } else { + Argument { + ty: ArgType::Scalar(tensor.elem_type), + name, + value: None, + passed: false, + } + } + } else { + // Convert tensor to argument + Argument { + ty: ArgType::Tensor(TensorType { + dim: tensor.dim, + elem_type: tensor.elem_type, + shape: tensor.shape, + }), + name, + value: tensor.data, + passed: false, + } + } } - } - } else { - // Convert tensor to argument - Argument { - ty: ArgType::Tensor(TensorType { - dim: tensor.dim, - elem_type: tensor.elem_type, - shape: tensor.shape, - }), - name, - value: tensor.data, - passed: false, - } + _ => panic!("Unsupported attribute type"), } - } - _ => panic!("Unsupported attribute type"), } - } } impl Argument { - pub fn into_tensor(self) -> Option { - if let ArgType::Tensor(tensor_type) = self.ty { - Some(Tensor { - elem_type: tensor_type.elem_type, - dim: tensor_type.dim, - data: self.value, - shape: tensor_type.shape, - }) - } else { - None - } - } + pub fn into_tensor(self) -> Option { + if let ArgType::Tensor(tensor_type) = self.ty { + Some(Tensor { + elem_type: tensor_type.elem_type, + dim: tensor_type.dim, + data: self.value, + shape: tensor_type.shape, + }) + } else { + None + } + } } diff --git a/burn-import/src/onnx/node_remap.rs b/burn-import/src/onnx/node_remap.rs index 7b773f02aa..c87059280c 100644 --- a/burn-import/src/onnx/node_remap.rs +++ b/burn-import/src/onnx/node_remap.rs @@ -3,33 +3,33 @@ use super::ir::{AttributeValue, Node, NodeType}; /// Remap node type using kernel shape pub fn remap_node_with_kernel_shape(node: &mut Node, new_node_type: F) where - F: FnOnce(&Vec) -> NodeType, + F: FnOnce(&Vec) -> NodeType, { - if let AttributeValue::Int64s(ints) = node.attrs.get("kernel_shape").unwrap() { - node.node_type = new_node_type(ints); - } else { - panic!("kernel_shape is not an int64s"); - } + if let AttributeValue::Int64s(ints) = node.attrs.get("kernel_shape").unwrap() { + node.node_type = new_node_type(ints); + } else { + panic!("kernel_shape is not an int64s"); + } } /// Remap node type to a more specific one pub fn remap_node_type(node: &mut Node) { - match node.node_type { - NodeType::Conv => remap_node_with_kernel_shape(node, |ints| match ints.len() { - 1 => NodeType::Conv1d, - 2 => NodeType::Conv2d, - _ => panic!("Only conv 1d and 2d are supported"), - }), - NodeType::MaxPool => remap_node_with_kernel_shape(node, |ints| match ints.len() { - 1 => NodeType::MaxPool1d, - 2 => NodeType::MaxPool2d, - _ => panic!("Only max_pool 1d and 2d are supported"), - }), - NodeType::AveragePool => remap_node_with_kernel_shape(node, |ints| match ints.len() { - 1 => NodeType::AveragePool1d, - 2 => NodeType::AveragePool2d, - _ => panic!("Only avg_pool 1d and 2d are supported"), - }), - _ => (), - } + match node.node_type { + NodeType::Conv => remap_node_with_kernel_shape(node, |ints| match ints.len() { + 1 => NodeType::Conv1d, + 2 => NodeType::Conv2d, + _ => panic!("Only conv 1d and 2d are supported"), + }), + NodeType::MaxPool => remap_node_with_kernel_shape(node, |ints| match ints.len() { + 1 => NodeType::MaxPool1d, + 2 => NodeType::MaxPool2d, + _ => panic!("Only max_pool 1d and 2d are supported"), + }), + NodeType::AveragePool => remap_node_with_kernel_shape(node, |ints| match ints.len() { + 1 => NodeType::AveragePool1d, + 2 => NodeType::AveragePool2d, + _ => panic!("Only avg_pool 1d and 2d are supported"), + }), + _ => (), + } } diff --git a/burn-import/src/onnx/op_configuration.rs b/burn-import/src/onnx/op_configuration.rs index cfcb69ec6d..e9bf018781 100644 --- a/burn-import/src/onnx/op_configuration.rs +++ b/burn-import/src/onnx/op_configuration.rs @@ -1,8 +1,8 @@ use burn::nn::{ - conv::Conv1dConfig, - conv::Conv2dConfig, - pool::{AvgPool2dConfig, MaxPool2dConfig}, - BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d, + conv::Conv1dConfig, + conv::Conv2dConfig, + pool::{AvgPool2dConfig, MaxPool2dConfig}, + BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d, }; use crate::onnx::ir::Data; @@ -11,404 +11,404 @@ use super::ir::{ArgType, Node}; /// Create a Conv1dConfig from the attributes of the node pub fn conv1d_config(curr: &Node) -> Conv1dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1]; - let mut pads = vec![0, 0]; - let mut dilations = vec![1]; - let mut group: i64 = 1; - - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { - weight - } else { - panic!("Conv1d: weight tensor must be present"); - }; - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let shape = weight.shape.clone().unwrap(); - let channels_in = shape[1]; - let channels_out = shape[0]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64(), - _ => {} - } - } - - let padding = padding_config_1d(&pads); - - Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize) - .with_stride(strides[0] as usize) - .with_dilation(dilations[0] as usize) - .with_groups(group as usize) - .with_bias(bias) - .with_padding(padding) + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1]; + let mut pads = vec![0, 0]; + let mut dilations = vec![1]; + let mut group: i64 = 1; + + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { + weight + } else { + panic!("Conv1d: weight tensor must be present"); + }; + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let shape = weight.shape.clone().unwrap(); + let channels_in = shape[1]; + let channels_out = shape[0]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64(), + _ => {} + } + } + + let padding = padding_config_1d(&pads); + + Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize) + .with_stride(strides[0] as usize) + .with_dilation(dilations[0] as usize) + .with_groups(group as usize) + .with_bias(bias) + .with_padding(padding) } /// Create a Conv2dConfig from the attributes of the node pub fn conv2d_config(curr: &Node) -> Conv2dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - let mut group: i64 = 1; - - // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] - let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { - weight - } else { - panic!("Conv1d: weight tensor must be present"); - }; - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let shape = weight.shape.clone().unwrap(); - let channels: [usize; 2] = [shape[1], shape[0]]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64(), - _ => {} - } - } - - let padding = padding_config(&pads); - - Conv2dConfig::new( - channels, - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([strides[0] as usize, strides[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group as usize) - .with_bias(bias) - .with_padding(padding) + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + let mut group: i64 = 1; + + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { + weight + } else { + panic!("Conv1d: weight tensor must be present"); + }; + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let shape = weight.shape.clone().unwrap(); + let channels: [usize; 2] = [shape[1], shape[0]]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64(), + _ => {} + } + } + + let padding = padding_config(&pads); + + Conv2dConfig::new( + channels, + [kernel_shape[0] as usize, kernel_shape[1] as usize], + ) + .with_stride([strides[0] as usize, strides[1] as usize]) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) + .with_groups(group as usize) + .with_bias(bias) + .with_padding(padding) } /// Create a MaxPool2dConfig from the attributes of the node pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - _ => {} + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + _ => {} + } } - } - let padding = padding_config(&pads); + let padding = padding_config(&pads); - MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) + MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) } /// Create a AvgPool2dConfig from the attributes of the node pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut count_include_pad: i64 = 0; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "count_include_pad" => count_include_pad = value.clone().into_i64(), - _ => {} + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut count_include_pad: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "count_include_pad" => count_include_pad = value.clone().into_i64(), + _ => {} + } } - } - let padding = padding_config(&pads); + let padding = padding_config(&pads); - if count_include_pad == 1 && padding != PaddingConfig2d::Valid { - todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636"); - } + if count_include_pad == 1 && padding != PaddingConfig2d::Valid { + todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636"); + } - AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) + AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) } /// Create a FlattenConfig from the attributes of the node pub fn flatten_config(curr: &Node) -> (usize, usize) { - // the begin dimension is the first dimension (Default: 1 per ONNX spec) - let mut start_dim: i64 = 1; - - // check if the node has only one input - if curr.inputs.len() != 1 { - panic!( - "Flatten: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match curr.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // check if the input tensor has at least 2 dimensions - if tensor.dim < 2 { - panic!( - "Flatten: input tensor must have at least 2 dimensions (got {:?})", - tensor.dim - ); - } - - // the end dimension is the last dimension - let end_dim = tensor.dim - 1; - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => start_dim = value.clone().into_i64(), - _ => {} - } - } - - // if beg_dim is negative, it is counted from the end - if start_dim < 0 { - start_dim += tensor.dim as i64; - } - - (start_dim as usize, end_dim) + // the begin dimension is the first dimension (Default: 1 per ONNX spec) + let mut start_dim: i64 = 1; + + // check if the node has only one input + if curr.inputs.len() != 1 { + panic!( + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // check if the input tensor has at least 2 dimensions + if tensor.dim < 2 { + panic!( + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.dim + ); + } + + // the end dimension is the last dimension + let end_dim = tensor.dim - 1; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "axis" => start_dim = value.clone().into_i64(), + _ => {} + } + } + + // if beg_dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.dim as i64; + } + + (start_dim as usize, end_dim) } /// Create a GatherConfig from the attributes of the node pub fn gather_config(curr: &Node) -> usize { - // Default: 0 per ONNX spec - let mut dim: i64 = 0; - - // check if the node has only one input - if curr.inputs.len() != 2 { - panic!("Gather: index tensor must be present"); - } - - // extract the shape of the input tensor - let tensor = match curr.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => dim = value.clone().into_i64(), - _ => {} - } - } - - // if dim is negative, it is counted from the end - if dim < 0 { - dim += tensor.dim as i64; - } - - dim as usize + // Default: 0 per ONNX spec + let mut dim: i64 = 0; + + // check if the node has only one input + if curr.inputs.len() != 2 { + panic!("Gather: index tensor must be present"); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "axis" => dim = value.clone().into_i64(), + _ => {} + } + } + + // if dim is negative, it is counted from the end + if dim < 0 { + dim += tensor.dim as i64; + } + + dim as usize } /// Create a LinearConfig from the attributes of the node pub fn linear_config(node: &Node) -> LinearConfig { - if node.inputs.len() < 2 { - panic!("Linear: missing weight tensor"); - } - - // extract the shape of the weight tensor - let weight = if let ArgType::Tensor(ref weight) = node.inputs[1].ty { - weight - } else { - panic!("Linear: weight tensor must be present"); - }; - - // check if the weight tensor has at least 2 dimensions - if weight.dim < 2 { - panic!( - "Linear: weight tensor must have at least 2 dimensions (got {:?})", - weight.dim - ); - } - - let shape = weight.shape.clone().unwrap(); - let (in_size, out_size) = (shape[0], shape[1]); - - // check if the bias is present - let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); - - LinearConfig::new(in_size, out_size).with_bias(bias) + if node.inputs.len() < 2 { + panic!("Linear: missing weight tensor"); + } + + // extract the shape of the weight tensor + let weight = if let ArgType::Tensor(ref weight) = node.inputs[1].ty { + weight + } else { + panic!("Linear: weight tensor must be present"); + }; + + // check if the weight tensor has at least 2 dimensions + if weight.dim < 2 { + panic!( + "Linear: weight tensor must have at least 2 dimensions (got {:?})", + weight.dim + ); + } + + let shape = weight.shape.clone().unwrap(); + let (in_size, out_size) = (shape[0], shape[1]); + + // check if the bias is present + let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); + + LinearConfig::new(in_size, out_size).with_bias(bias) } /// Create a DropoutConfig from an attribute and state of the node pub fn dropout_config(node: &Node) -> DropoutConfig { - // Opset 7 and older store probability as an attribute - if node.attrs.contains_key("ratio") { - let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); - return DropoutConfig::new(prob as f64); - } - - if node.inputs.len() < 2 { - panic!("Dropout configuration must have at least 2 inputs"); - } - - let ratio = node.inputs[1] - .value - .clone() - .expect("Dropout ratio must be passed in the second input") - .into_scalar(); - - let prob = match ratio { - Data::Float16(ratio) => f64::from(f32::from(ratio)), - Data::Float32(ratio) => ratio as f64, - Data::Float64(ratio) => ratio, - _ => panic!("Dropout ratio must be a float"), - }; - - DropoutConfig::new(prob) + // Opset 7 and older store probability as an attribute + if node.attrs.contains_key("ratio") { + let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); + return DropoutConfig::new(prob as f64); + } + + if node.inputs.len() < 2 { + panic!("Dropout configuration must have at least 2 inputs"); + } + + let ratio = node.inputs[1] + .value + .clone() + .expect("Dropout ratio must be passed in the second input") + .into_scalar(); + + let prob = match ratio { + Data::Float16(ratio) => f64::from(f32::from(ratio)), + Data::Float32(ratio) => ratio as f64, + Data::Float64(ratio) => ratio, + _ => panic!("Dropout ratio must be a float"), + }; + + DropoutConfig::new(prob) } /// Create log_softmax config from the attributes of the node pub fn log_softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "LogSoftmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.dim as i64; - } - - axis as usize + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "LogSoftmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.dim as i64; + } + + axis as usize } /// Create softmax config from the attributes of the node pub fn softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Softmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.dim as i64; - } - - axis as usize + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Softmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.dim as i64; + } + + axis as usize } /// Create concat config from the attributes of the node pub fn concat_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // extract the shape of the input tensor + let tensor = match node.inputs.get(0).unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } } - } - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.dim as i64; - } + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.dim as i64; + } - axis as usize + axis as usize } /// Create a BatchNormConfig from the attributes of the node pub fn batch_norm_config(node: &Node) -> BatchNormConfig { - // extract the shape of the weight tensor - let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { - tensor_type - } else { - panic!("BatchNorm: weight tensor must be present"); - }; - - let num_features: usize = tensor_type.shape.clone().unwrap()[0]; - - let mut epsilon = 0f32; - let mut momentum = 0f32; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "momentum" => momentum = value.clone().into_f32(), - "epsilon" => epsilon = value.clone().into_f32(), - _ => {} + // extract the shape of the weight tensor + let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { + tensor_type + } else { + panic!("BatchNorm: weight tensor must be present"); + }; + + let num_features: usize = tensor_type.shape.clone().unwrap()[0]; + + let mut epsilon = 0f32; + let mut momentum = 0f32; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "momentum" => momentum = value.clone().into_f32(), + "epsilon" => epsilon = value.clone().into_f32(), + _ => {} + } } - } - BatchNormConfig::new(num_features) - .with_epsilon(epsilon as f64) - .with_momentum(momentum as f64) + BatchNormConfig::new(num_features) + .with_epsilon(epsilon as f64) + .with_momentum(momentum as f64) } /// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. @@ -431,110 +431,110 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig { /// This function is used when the padding is specified as a list of integers, /// and not used when the padding is specified as a string, e.g. "SAME_UPPER". fn padding_config(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == top && top == right && right == bottom && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } else if left == top && top == right && right == bottom && bottom == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig2d::Valid + } else if left == right && top == bottom { + // i.e [2, 3, 2, 3] + PaddingConfig2d::Explicit(left as usize, top as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } } pub fn reshape_config(node: &Node) -> Vec { - let mut allowzero = 0; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "allowzero" => allowzero = value.clone().into_i64(), - _ => {} - } - } - - // Burn does not support zero size shape (0 means false in ONNX) - // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) - if allowzero != 0 { - panic!("Zero shape size is not supported"); - } - - if node.inputs.len() != 2 || node.inputs[1].value.is_none() { - panic!("Reshape: shape tensor must be present"); - } - - let input_value = &node.inputs[1].value; - match &node.inputs[1].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.dim, 1, "Reshape: shape tensor must be 1D"); - - if let Some(Data::Int64s(shape)) = input_value.as_ref() { - shape.clone() - } else { - panic!("Tensor data type must be int64") - } - } - _ => panic!("Only tensor input is valid for shape"), - } + let mut allowzero = 0; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "allowzero" => allowzero = value.clone().into_i64(), + _ => {} + } + } + + // Burn does not support zero size shape (0 means false in ONNX) + // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) + if allowzero != 0 { + panic!("Zero shape size is not supported"); + } + + if node.inputs.len() != 2 || node.inputs[1].value.is_none() { + panic!("Reshape: shape tensor must be present"); + } + + let input_value = &node.inputs[1].value; + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.dim, 1, "Reshape: shape tensor must be 1D"); + + if let Some(Data::Int64s(shape)) = input_value.as_ref() { + shape.clone() + } else { + panic!("Tensor data type must be int64") + } + } + _ => panic!("Only tensor input is valid for shape"), + } } pub fn clip_config(node: &Node) -> (Option, Option) { - let mut min_result: Option = None; - let mut max_result: Option = None; - - // For Clip Opset 6+ , the min and max values are attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "min" => { - let min = value.clone().into_f32() as f64; - min_result = Some(min); - } - "max" => { - let max = value.clone().into_f32(); - max_result = Some(max as f64); - } - _ => {} - } - } - - // For Clip Opset 11+ , the min and max values are inputs - // Get the min and max values from the input values - if min_result.is_none() && max_result.is_none() { - let min = &node.inputs[1].value; - let max = &node.inputs[2].value; - - if min_result.is_none() && min.is_some() { - let min = min.clone().unwrap().into_scalar(); - min_result = match min { - Data::Float16(min) => Some(f32::from(min) as f64), - Data::Float32(min) => Some(min as f64), - Data::Float64(min) => Some(min), - _ => panic!("Clip: only float min is supported"), - }; - } - - if max_result.is_none() && max.is_some() { - let max = max.clone().unwrap().into_scalar(); - max_result = match max { - Data::Float16(max) => Some(f32::from(max) as f64), - Data::Float32(max) => Some(max as f64), - Data::Float64(max) => Some(max), - _ => panic!("Clip: only float max is supported"), - }; - } - } - - if min_result.is_none() && max_result.is_none() { - panic!("Clip: min and max values must be either attributes or inputs"); - } - - (min_result, max_result) + let mut min_result: Option = None; + let mut max_result: Option = None; + + // For Clip Opset 6+ , the min and max values are attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "min" => { + let min = value.clone().into_f32() as f64; + min_result = Some(min); + } + "max" => { + let max = value.clone().into_f32(); + max_result = Some(max as f64); + } + _ => {} + } + } + + // For Clip Opset 11+ , the min and max values are inputs + // Get the min and max values from the input values + if min_result.is_none() && max_result.is_none() { + let min = &node.inputs[1].value; + let max = &node.inputs[2].value; + + if min_result.is_none() && min.is_some() { + let min = min.clone().unwrap().into_scalar(); + min_result = match min { + Data::Float16(min) => Some(f32::from(min) as f64), + Data::Float32(min) => Some(min as f64), + Data::Float64(min) => Some(min), + _ => panic!("Clip: only float min is supported"), + }; + } + + if max_result.is_none() && max.is_some() { + let max = max.clone().unwrap().into_scalar(); + max_result = match max { + Data::Float16(max) => Some(f32::from(max) as f64), + Data::Float32(max) => Some(max as f64), + Data::Float64(max) => Some(max), + _ => panic!("Clip: only float max is supported"), + }; + } + } + + if min_result.is_none() && max_result.is_none() { + panic!("Clip: min and max values must be either attributes or inputs"); + } + + (min_result, max_result) } /// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. @@ -557,20 +557,20 @@ pub fn clip_config(node: &Node) -> (Option, Option) { /// This function is used when the padding is specified as a list of integers, /// and not used when the padding is specified as a string, e.g. "SAME_UPPER". fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { - let [left, right] = [pads[0], pads[1]]; - - if left < 0 || right < 0 { - panic!("Negative pad values are not supported"); - } else if left != right { - panic!("Asymmetric padding is not supported"); - } else if left == right && right == 0 { - // i.e. [0, 0] - PaddingConfig1d::Valid - } else if left == right { - // i.e. [2, 2] - PaddingConfig1d::Explicit(left as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } + let [left, right] = [pads[0], pads[1]]; + + if left < 0 || right < 0 { + panic!("Negative pad values are not supported"); + } else if left != right { + panic!("Asymmetric padding is not supported"); + } else if left == right && right == 0 { + // i.e. [0, 0] + PaddingConfig1d::Valid + } else if left == right { + // i.e. [2, 2] + PaddingConfig1d::Explicit(left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } } diff --git a/burn-import/src/onnx/proto_conversion.rs b/burn-import/src/onnx/proto_conversion.rs index 7dccdb50ba..1fe5aac54a 100644 --- a/burn-import/src/onnx/proto_conversion.rs +++ b/burn-import/src/onnx/proto_conversion.rs @@ -4,11 +4,11 @@ use crate::onnx::ir::TensorType; use super::ir::Dim; use super::ir::{ - ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor, + ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor, }; use super::protos::{ - attribute_proto::AttributeType, tensor_proto::DataType, tensor_shape_proto::dimension::Value, - type_proto, AttributeProto, NodeProto, TensorProto, TensorShapeProto, ValueInfoProto, + attribute_proto::AttributeType, tensor_proto::DataType, tensor_shape_proto::dimension::Value, + type_proto, AttributeProto, NodeProto, TensorProto, TensorShapeProto, ValueInfoProto, }; use bytemuck::cast_slice; @@ -17,244 +17,246 @@ use protobuf::Enum; /// Error type for parsing ONNX model #[derive(Debug)] pub enum ParseError { - VariantNotFound, + VariantNotFound, } /// Convert a vector of AttributeProto to a HashMap of AttributeValue impl TryFrom for Tensor { - type Error = ParseError; - fn try_from(tensor: TensorProto) -> Result { - let (elem_type, data) = match DataType::from_i32(tensor.data_type).unwrap() { - DataType::FLOAT => ( - ElementType::Float32, - // Convert the raw data to a vector of floats - if !tensor.raw_data.is_empty() { - Data::Float32s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Float32s(tensor.float_data) - }, - ), - DataType::INT16 => { - // TODO : Add support for int16 by converting to int32 - todo!("Add support for int16"); - } - DataType::INT32 => ( - ElementType::Int32, - // Convert the raw data to a vector of ints - if !tensor.raw_data.is_empty() { - Data::Int32s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Int32s(tensor.int32_data) - }, - ), - DataType::INT64 => ( - ElementType::Int64, - // Convert the raw data to a vector of ints - if !tensor.raw_data.is_empty() { - Data::Int64s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Int64s(tensor.int64_data) - }, - ), - DataType::DOUBLE => ( - ElementType::Float64, - // Convert the raw data to a vector of floats - if !tensor.raw_data.is_empty() { - Data::Float64s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { - Data::Float64s(tensor.double_data) - }, - ), - DataType::BOOL => (ElementType::Bool, { - assert!(!tensor.raw_data.is_empty()); - Data::Bools(tensor.raw_data.iter().map(|x| *x != 0).collect()) - }), - // TODO : Add more types - _ => { - return Err(ParseError::VariantNotFound); - } - }; - let shape = convert_shape(tensor.dims); - - Ok(Tensor { - elem_type, - dim: shape.len(), - shape: Some(shape), - data: Some(data), - }) - } + type Error = ParseError; + fn try_from(tensor: TensorProto) -> Result { + let (elem_type, data) = match DataType::from_i32(tensor.data_type).unwrap() { + DataType::FLOAT => ( + ElementType::Float32, + // Convert the raw data to a vector of floats + if !tensor.raw_data.is_empty() { + Data::Float32s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Float32s(tensor.float_data) + }, + ), + DataType::INT16 => { + // TODO : Add support for int16 by converting to int32 + todo!("Add support for int16"); + } + DataType::INT32 => ( + ElementType::Int32, + // Convert the raw data to a vector of ints + if !tensor.raw_data.is_empty() { + Data::Int32s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Int32s(tensor.int32_data) + }, + ), + DataType::INT64 => ( + ElementType::Int64, + // Convert the raw data to a vector of ints + if !tensor.raw_data.is_empty() { + Data::Int64s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Int64s(tensor.int64_data) + }, + ), + DataType::DOUBLE => ( + ElementType::Float64, + // Convert the raw data to a vector of floats + if !tensor.raw_data.is_empty() { + Data::Float64s(cast_slice(&tensor.raw_data[..]).to_vec()) + } else { + Data::Float64s(tensor.double_data) + }, + ), + DataType::BOOL => (ElementType::Bool, { + assert!(!tensor.raw_data.is_empty()); + Data::Bools(tensor.raw_data.iter().map(|x| *x != 0).collect()) + }), + // TODO : Add more types + _ => { + return Err(ParseError::VariantNotFound); + } + }; + let shape = convert_shape(tensor.dims); + + Ok(Tensor { + elem_type, + dim: shape.len(), + shape: Some(shape), + data: Some(data), + }) + } } impl TryFrom for Vec { - type Error = ParseError; - fn try_from(shape: TensorShapeProto) -> Result, Self::Error> { - let mut result = Vec::new(); + type Error = ParseError; + fn try_from(shape: TensorShapeProto) -> Result, Self::Error> { + let mut result = Vec::new(); - for dim in shape.dim { - if let Value::DimValue(value) = dim.value.unwrap() { - result.push(value as usize); - } - } + for dim in shape.dim { + if let Value::DimValue(value) = dim.value.unwrap() { + result.push(value as usize); + } + } - Ok(result) - } + Ok(result) + } } /// Convert a vector of AttributeProto to a HashMap of AttributeValue impl TryFrom<&type_proto::Tensor> for Tensor { - type Error = ParseError; - fn try_from(tensor: &type_proto::Tensor) -> Result { - let elem_type = match DataType::from_i32(tensor.elem_type).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - DataType::BOOL => ElementType::Bool, - - // TODO : Add more types - _ => { - return Err(ParseError::VariantNotFound); - } - }; - - let shape_proto = tensor.shape.clone().unwrap(); - let shape: Vec = shape_proto.try_into().unwrap(); - - Ok(Tensor { - elem_type, - dim: shape.len(), - shape: Some(shape), - data: None, - }) - } + type Error = ParseError; + fn try_from(tensor: &type_proto::Tensor) -> Result { + let elem_type = match DataType::from_i32(tensor.elem_type).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + DataType::BOOL => ElementType::Bool, + + // TODO : Add more types + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + let shape_proto = tensor.shape.clone().unwrap(); + let shape: Vec = shape_proto.try_into().unwrap(); + + Ok(Tensor { + elem_type, + dim: shape.len(), + shape: Some(shape), + data: None, + }) + } } fn convert_vec_tensor_proto(tensors: Vec) -> Result, ParseError> { - let mut result = Vec::new(); - for tensor in tensors { - result.push(Tensor::try_from(tensor)?); - } - Ok(result) + let mut result = Vec::new(); + for tensor in tensors { + result.push(Tensor::try_from(tensor)?); + } + Ok(result) } /// Convert a vector of AttributeProto to a HashMap of AttributeValue impl TryFrom for AttributeValue { - type Error = ParseError; - - fn try_from(attr: AttributeProto) -> Result { - let value = match attr.type_.unwrap() { - AttributeType::FLOAT => AttributeValue::Float32(attr.f), - AttributeType::INT => AttributeValue::Int64(attr.i), - AttributeType::STRING => AttributeValue::String(to_string(attr.s)), - - // warning: tensor can be empty TODO: check if it is empty - AttributeType::TENSOR => AttributeValue::Tensor(Tensor::try_from(attr.t.unwrap())?), - - // Graph is not supported for now - // AttributeType::GRAPH => AttributeValue::Graph(attr.g), - AttributeType::FLOATS => AttributeValue::Float32s(attr.floats), - AttributeType::INTS => AttributeValue::Int64s(attr.ints), - AttributeType::STRINGS => AttributeValue::Strings(to_string_vec(attr.strings)), - AttributeType::TENSORS => AttributeValue::Tensors(convert_vec_tensor_proto(attr.tensors)?), - // AttributeType::GRAPHS => AttributeValue::Graphs(attr.graphs), - // AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(attr.sparse_tensors), - // AttributeType::SPARSE_TENSOR => AttributeValue::SparseTensor(attr.sparse_tensor), - _ => { - return Err(ParseError::VariantNotFound); - } - }; - - Ok(value) - } + type Error = ParseError; + + fn try_from(attr: AttributeProto) -> Result { + let value = match attr.type_.unwrap() { + AttributeType::FLOAT => AttributeValue::Float32(attr.f), + AttributeType::INT => AttributeValue::Int64(attr.i), + AttributeType::STRING => AttributeValue::String(to_string(attr.s)), + + // warning: tensor can be empty TODO: check if it is empty + AttributeType::TENSOR => AttributeValue::Tensor(Tensor::try_from(attr.t.unwrap())?), + + // Graph is not supported for now + // AttributeType::GRAPH => AttributeValue::Graph(attr.g), + AttributeType::FLOATS => AttributeValue::Float32s(attr.floats), + AttributeType::INTS => AttributeValue::Int64s(attr.ints), + AttributeType::STRINGS => AttributeValue::Strings(to_string_vec(attr.strings)), + AttributeType::TENSORS => { + AttributeValue::Tensors(convert_vec_tensor_proto(attr.tensors)?) + } + // AttributeType::GRAPHS => AttributeValue::Graphs(attr.graphs), + // AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(attr.sparse_tensors), + // AttributeType::SPARSE_TENSOR => AttributeValue::SparseTensor(attr.sparse_tensor), + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + Ok(value) + } } /// Convert a vector of AttributeProto to a HashMap of AttributeValue pub fn convert_vec_attrs_proto(attrs: Vec) -> Attributes { - let mut result = Attributes::new(); - for attr in attrs { - result.insert(attr.name.clone(), AttributeValue::try_from(attr).unwrap()); - } - result + let mut result = Attributes::new(); + for attr in attrs { + result.insert(attr.name.clone(), AttributeValue::try_from(attr).unwrap()); + } + result } pub fn convert_node_proto(node: &NodeProto) -> Node { - let name = node.name.clone(); + let name = node.name.clone(); - log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str()); + log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str()); - let inputs = node.input.clone().into_iter().map(Argument::new).collect(); + let inputs = node.input.clone().into_iter().map(Argument::new).collect(); - let outputs = node.output.clone().into_iter().map(Argument::new).collect(); + let outputs = node.output.clone().into_iter().map(Argument::new).collect(); - let attrs = convert_vec_attrs_proto(node.attribute.clone()); + let attrs = convert_vec_attrs_proto(node.attribute.clone()); - let node_type = NodeType::from_str(node.op_type.as_str()).expect("Unknown node type"); + let node_type = NodeType::from_str(node.op_type.as_str()).expect("Unknown node type"); - Node { - node_type, - name, - inputs, - outputs, - attrs, - } + Node { + node_type, + name, + inputs, + outputs, + attrs, + } } fn to_string(bytes: Vec) -> String { - from_utf8(bytes.as_slice()).unwrap().to_string() + from_utf8(bytes.as_slice()).unwrap().to_string() } fn to_string_vec(bytes: Vec>) -> Vec { - bytes.iter().map(|b| to_string(b.clone())).collect() + bytes.iter().map(|b| to_string(b.clone())).collect() } fn convert_shape(shape: Vec) -> Vec { - shape.iter().map(|s| *s as usize).collect() + shape.iter().map(|s| *s as usize).collect() } impl TryFrom for Argument { - type Error = ParseError; - - fn try_from(value: ValueInfoProto) -> Result { - let name = value.name.clone(); - let proto_type = value.type_.unwrap(); - - if !proto_type.has_tensor_type() { - panic!("Unsupported argument type {:?}", proto_type); + type Error = ParseError; + + fn try_from(value: ValueInfoProto) -> Result { + let name = value.name.clone(); + let proto_type = value.type_.unwrap(); + + if !proto_type.has_tensor_type() { + panic!("Unsupported argument type {:?}", proto_type); + } + + let tensor_proto = proto_type.tensor_type(); + + let elem_type = match DataType::from_i32(tensor_proto.elem_type).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + DataType::BOOL => ElementType::Bool, + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + let tensor_type = TensorType { + dim: tensor_proto.shape.dim.len(), + elem_type, + shape: Some( + tensor_proto + .shape + .dim + .iter() + .map(|x| x.dim_value() as Dim) + .collect(), + ), + }; + + let ty = ArgType::Tensor(tensor_type); + + Ok(Argument { + ty, + name, + value: None, + passed: false, + }) } - - let tensor_proto = proto_type.tensor_type(); - - let elem_type = match DataType::from_i32(tensor_proto.elem_type).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - DataType::BOOL => ElementType::Bool, - _ => { - return Err(ParseError::VariantNotFound); - } - }; - - let tensor_type = TensorType { - dim: tensor_proto.shape.dim.len(), - elem_type, - shape: Some( - tensor_proto - .shape - .dim - .iter() - .map(|x| x.dim_value() as Dim) - .collect(), - ), - }; - - let ty = ArgType::Tensor(tensor_type); - - Ok(Argument { - ty, - name, - value: None, - passed: false, - }) - } } diff --git a/burn-import/src/onnx/protos/mod.rs b/burn-import/src/onnx/protos/mod.rs index b18e3c0908..328e850e76 100644 --- a/burn-import/src/onnx/protos/mod.rs +++ b/burn-import/src/onnx/protos/mod.rs @@ -1,5 +1,5 @@ mod inner { - include!(concat!(env!("OUT_DIR"), "/onnx-protos/mod.rs")); + include!(concat!(env!("OUT_DIR"), "/onnx-protos/mod.rs")); } pub use inner::onnx::*; diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index b2889e3ff9..4bf912bd44 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -1,55 +1,56 @@ use std::{ - env, - fs::{self, create_dir_all}, - path::{Path, PathBuf}, + env, + fs::{self, create_dir_all}, + path::{Path, PathBuf}, }; use burn::{ - record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings}, - tensor::{DataSerialize, Element}, + record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings}, + tensor::{DataSerialize, Element}, }; use crate::{ - burn::{ - graph::BurnGraph, - node::{ - avg_pool2d::AvgPool2dNode, - batch_norm::BatchNormNode, - binary::BinaryNode, - clip::ClipNode, - concat::ConcatNode, - constant::{ConstantNode, ConstantValue, TensorValue}, - conv1d::Conv1dNode, - conv2d::Conv2dNode, - dropout::DropoutNode, - gather::GatherNode, - global_avg_pool::GlobalAvgPoolNode, - linear::LinearNode, - matmul::MatmulNode, - max_pool2d::MaxPool2dNode, - reshape::ReshapeNode, - unary::UnaryNode, + burn::{ + graph::BurnGraph, + node::{ + avg_pool2d::AvgPool2dNode, + batch_norm::BatchNormNode, + binary::BinaryNode, + clip::ClipNode, + concat::ConcatNode, + constant::{ConstantNode, ConstantValue, TensorValue}, + conv1d::Conv1dNode, + conv2d::Conv2dNode, + dropout::DropoutNode, + gather::GatherNode, + global_avg_pool::GlobalAvgPoolNode, + linear::LinearNode, + matmul::MatmulNode, + max_pool2d::MaxPool2dNode, + reshape::ReshapeNode, + unary::UnaryNode, + }, + ScalarKind, ScalarType, TensorKind, TensorType, Type, }, - ScalarKind, ScalarType, TensorKind, TensorType, Type, - }, - format_tokens, - logger::init_log, - onnx::{ - from_onnx::convert_constant_value, - ir::{Node, NodeType}, - op_configuration::{ - batch_norm_config, conv1d_config, conv2d_config, flatten_config, gather_config, - linear_config, log_softmax_config, max_pool2d_config, + format_tokens, + logger::init_log, + onnx::{ + from_onnx::convert_constant_value, + ir::{Node, NodeType}, + op_configuration::{ + batch_norm_config, conv1d_config, conv2d_config, flatten_config, gather_config, + linear_config, log_softmax_config, max_pool2d_config, + }, }, - }, }; use super::{ - from_onnx::parse_onnx, - ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph}, - op_configuration::{ - avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, softmax_config, - }, + from_onnx::parse_onnx, + ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph}, + op_configuration::{ + avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, + softmax_config, + }, }; pub use crate::burn::graph::RecordType; @@ -57,542 +58,547 @@ pub use crate::burn::graph::RecordType; /// Generate code and states from `.onnx` files and save them to the `out_dir`. #[derive(Debug, Default)] pub struct ModelGen { - out_dir: Option, - /// List of onnx files to generate source code from. - inputs: Vec, - development: bool, - half_precision: bool, - record_type: RecordType, - embed_states: bool, + out_dir: Option, + /// List of onnx files to generate source code from. + inputs: Vec, + development: bool, + half_precision: bool, + record_type: RecordType, + embed_states: bool, } impl ModelGen { - /// Create a new `ModelGen`. - pub fn new() -> Self { - init_log().ok(); // Error when init multiple times are ignored. - Self::default() - } - - /// Set output directory. - pub fn out_dir(&mut self, out_dir: &str) -> &mut Self { - self.out_dir = Some(Path::new(out_dir).into()); - self - } - - /// Add input file. - pub fn input(&mut self, input: &str) -> &mut Self { - self.inputs.push(input.into()); - self - } - - /// Set development mode. - /// - /// If this is set to true, the generated model will be saved as `.graph.txt` files and model - /// states will be saved as `.json` file. - pub fn development(&mut self, development: bool) -> &mut Self { - self.development = development; - self - } - - /// Run code generation. - /// - /// This function is intended to be called from `build.rs` script. - pub fn run_from_script(&self) { - self.run(true); - } - - /// Run code generation. - /// - /// This function is intended to be called from CLI. - pub fn run_from_cli(&self) { - self.run(false); - } - - /// Specify parameter precision to be saved. - /// - /// # Arguments - /// - /// * `half_precision` - If true, half precision is saved. Otherwise, full precision is saved. - pub fn half_precision(&mut self, half_precision: bool) -> &mut Self { - self.half_precision = half_precision; - self - } - - /// Specify the type of the record to be saved. - /// - /// # Arguments - /// - /// * `record_type` - The type of the record to be saved. - pub fn record_type(&mut self, record_type: RecordType) -> &mut Self { - self.record_type = record_type; - self - } - - /// Specify whether to embed states in the generated code. - /// - /// # Arguments - /// - /// * `embed_states` - If true, states are embedded in the generated code. Otherwise, states are - /// saved as a separate file. - pub fn embed_states(&mut self, embed_states: bool) -> &mut Self { - self.embed_states = embed_states; - self - } - - /// Run code generation. - fn run(&self, is_build_script: bool) { - log::info!("Starting to convert ONNX to Burn"); - - // prepend the out_dir to the cargo_out_dir if this is a build script - let out_dir = if is_build_script { - let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set"); - let mut path = PathBuf::from(cargo_out_dir); - - // // Append the out_dir to the cargo_out_dir - path.push(self.out_dir.clone().unwrap()); - path - } else { - self.out_dir.as_ref().expect("out_dir is not set").clone() - }; - - log::debug!("Output directory: {:?}", out_dir); - - create_dir_all(&out_dir).unwrap(); - - for input in self.inputs.iter() { - let file_name = input.file_stem().unwrap(); - let out_file: PathBuf = out_dir.join(file_name); - - log::info!("Converting {:?}", input); - log::debug!("Input file name: {:?}", file_name); - log::debug!("Output file: {:?}", out_file); - - self.generate_model(input, out_file); - } - - log::info!("Finished converting ONNX to Burn"); - } - - /// Generate model source code and model state. - fn generate_model(&self, input: &PathBuf, out_file: PathBuf) { - log::info!("Generating model from {:?}", input); - log::debug!("Development mode: {:?}", self.development); - log::debug!("Output file: {:?}", out_file); - - let graph = parse_onnx(input.as_ref()); - - if self.development { - // export the graph - let debug_graph = format!("{:#?}", graph); - let graph_file = out_file.with_extension("graph.txt"); - log::debug!("Writing debug graph file: {:?}", graph_file); - fs::write(graph_file, debug_graph).unwrap(); - } - - let new_fn = true; - let blank_space = true; - let top_comment = Some(format!("Generated from ONNX {input:?} by burn-import")); - - let code = if self.half_precision { - graph - .into_burn::() - .with_record(out_file.clone(), self.record_type, self.embed_states) - .with_new_fn(new_fn) - .with_blank_space(blank_space) - .with_top_comment(top_comment) - .codegen() - } else { - graph - .into_burn::() - .with_record(out_file.clone(), self.record_type, self.embed_states) - .with_new_fn(new_fn) - .with_blank_space(blank_space) - .with_top_comment(top_comment) - .codegen() - }; - - let code_str = format_tokens(code); - fs::write(out_file.with_extension("rs"), code_str).unwrap(); - - log::info!("Model generated"); - } + /// Create a new `ModelGen`. + pub fn new() -> Self { + init_log().ok(); // Error when init multiple times are ignored. + Self::default() + } + + /// Set output directory. + pub fn out_dir(&mut self, out_dir: &str) -> &mut Self { + self.out_dir = Some(Path::new(out_dir).into()); + self + } + + /// Add input file. + pub fn input(&mut self, input: &str) -> &mut Self { + self.inputs.push(input.into()); + self + } + + /// Set development mode. + /// + /// If this is set to true, the generated model will be saved as `.graph.txt` files and model + /// states will be saved as `.json` file. + pub fn development(&mut self, development: bool) -> &mut Self { + self.development = development; + self + } + + /// Run code generation. + /// + /// This function is intended to be called from `build.rs` script. + pub fn run_from_script(&self) { + self.run(true); + } + + /// Run code generation. + /// + /// This function is intended to be called from CLI. + pub fn run_from_cli(&self) { + self.run(false); + } + + /// Specify parameter precision to be saved. + /// + /// # Arguments + /// + /// * `half_precision` - If true, half precision is saved. Otherwise, full precision is saved. + pub fn half_precision(&mut self, half_precision: bool) -> &mut Self { + self.half_precision = half_precision; + self + } + + /// Specify the type of the record to be saved. + /// + /// # Arguments + /// + /// * `record_type` - The type of the record to be saved. + pub fn record_type(&mut self, record_type: RecordType) -> &mut Self { + self.record_type = record_type; + self + } + + /// Specify whether to embed states in the generated code. + /// + /// # Arguments + /// + /// * `embed_states` - If true, states are embedded in the generated code. Otherwise, states are + /// saved as a separate file. + pub fn embed_states(&mut self, embed_states: bool) -> &mut Self { + self.embed_states = embed_states; + self + } + + /// Run code generation. + fn run(&self, is_build_script: bool) { + log::info!("Starting to convert ONNX to Burn"); + + // prepend the out_dir to the cargo_out_dir if this is a build script + let out_dir = if is_build_script { + let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set"); + let mut path = PathBuf::from(cargo_out_dir); + + // // Append the out_dir to the cargo_out_dir + path.push(self.out_dir.clone().unwrap()); + path + } else { + self.out_dir.as_ref().expect("out_dir is not set").clone() + }; + + log::debug!("Output directory: {:?}", out_dir); + + create_dir_all(&out_dir).unwrap(); + + for input in self.inputs.iter() { + let file_name = input.file_stem().unwrap(); + let out_file: PathBuf = out_dir.join(file_name); + + log::info!("Converting {:?}", input); + log::debug!("Input file name: {:?}", file_name); + log::debug!("Output file: {:?}", out_file); + + self.generate_model(input, out_file); + } + + log::info!("Finished converting ONNX to Burn"); + } + + /// Generate model source code and model state. + fn generate_model(&self, input: &PathBuf, out_file: PathBuf) { + log::info!("Generating model from {:?}", input); + log::debug!("Development mode: {:?}", self.development); + log::debug!("Output file: {:?}", out_file); + + let graph = parse_onnx(input.as_ref()); + + if self.development { + // export the graph + let debug_graph = format!("{:#?}", graph); + let graph_file = out_file.with_extension("graph.txt"); + log::debug!("Writing debug graph file: {:?}", graph_file); + fs::write(graph_file, debug_graph).unwrap(); + } + + let new_fn = true; + let blank_space = true; + let top_comment = Some(format!("Generated from ONNX {input:?} by burn-import")); + + let code = if self.half_precision { + graph + .into_burn::() + .with_record(out_file.clone(), self.record_type, self.embed_states) + .with_new_fn(new_fn) + .with_blank_space(blank_space) + .with_top_comment(top_comment) + .codegen() + } else { + graph + .into_burn::() + .with_record(out_file.clone(), self.record_type, self.embed_states) + .with_new_fn(new_fn) + .with_blank_space(blank_space) + .with_top_comment(top_comment) + .codegen() + }; + + let code_str = format_tokens(code); + fs::write(out_file.with_extension("rs"), code_str).unwrap(); + + log::info!("Model generated"); + } } impl ONNXGraph { - /// Converts ONNX graph to Burn graph. - pub fn into_burn(self) -> BurnGraph { - let mut graph = BurnGraph::::default(); - - for node in self.nodes { - match node.node_type { - NodeType::Add => graph.register(Self::add_conversion(node)), - NodeType::Sub => graph.register(Self::sub_conversion(node)), - NodeType::Mul => graph.register(Self::mul_conversion(node)), - NodeType::Div => graph.register(Self::div_conversion(node)), - NodeType::Equal => graph.register(Self::equal_conversion(node)), - NodeType::Erf => graph.register(Self::erf_conversion(node)), - NodeType::Clip => graph.register(Self::clip_conversion(node)), - NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), - NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), - NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), - NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), - NodeType::MatMul => graph.register(Self::matmul_conversion(node)), - NodeType::Linear => graph.register(Self::linear_conversion::(node)), - NodeType::BatchNormalization => graph.register(Self::batch_norm_conversion::(node)), - NodeType::Relu => graph.register(Self::relu_conversion(node)), - NodeType::Flatten => graph.register(Self::flatten_conversion(node)), - NodeType::GatherElements => graph.register(Self::gather_conversion(node)), - NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)), - NodeType::Softmax => graph.register(Self::softmax_conversion(node)), - NodeType::Tanh => graph.register(Self::tanh_conversion(node)), - NodeType::Constant => graph.register(Self::constant_conversion::(node)), - NodeType::Reshape => graph.register(Self::reshape_conversion(node)), - NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), - NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), - NodeType::Transpose => graph.register(Self::transpose_conversion(node)), - NodeType::Concat => graph.register(Self::concat_conversion(node)), - NodeType::Cast => graph.register(Self::cast_conversion(node)), - NodeType::Dropout => graph.register(Self::dropout_conversion(node)), - NodeType::GlobalAveragePool => graph.register(Self::global_avg_pool_conversion(node)), - _ => panic!("Unsupported node conversion {}", node.node_type), - } - } - - // Get input and output names - let input_names = self - .inputs - .iter() - .map(|input| input.name.clone()) - .collect::>(); - let output_names = self - .outputs - .iter() - .map(|output| output.name.clone()) - .collect::>(); - - // Register inputs and outputs with the graph - graph.register_input_output(input_names, output_names); - - graph - } - - fn constant_conversion(node: Node) -> ConstantNode { - let output = node.outputs.get(0).unwrap(); - - let attr = convert_constant_value(&node); - - let const_value = - match attr.ty { - ArgType::Tensor(tensor) => { - // Treat tensor with dim 0 as scalar - if tensor.dim == 0 { - panic!("Constant tensor with dim 0 should have been converted to scalar.") - } else { - let kind: TensorKind = tensor.elem_type.clone().into(); - let dim = tensor.dim; - let name = node.name.clone(); - let shape = tensor.shape.clone(); - - let tensor_value = - match tensor.elem_type { - // TODO Review how double precision should be supported - ElementType::Float32 | ElementType::Float64 => TensorValue::Float( - serialize_data::(attr.value.unwrap(), tensor.shape.unwrap()), - ), - ElementType::Int32 | ElementType::Int64 => { - TensorValue::Int(serialize_data::( - attr.value.unwrap(), - tensor.shape.unwrap(), - )) + /// Converts ONNX graph to Burn graph. + pub fn into_burn(self) -> BurnGraph { + let mut graph = BurnGraph::::default(); + + for node in self.nodes { + match node.node_type { + NodeType::Add => graph.register(Self::add_conversion(node)), + NodeType::Sub => graph.register(Self::sub_conversion(node)), + NodeType::Mul => graph.register(Self::mul_conversion(node)), + NodeType::Div => graph.register(Self::div_conversion(node)), + NodeType::Equal => graph.register(Self::equal_conversion(node)), + NodeType::Erf => graph.register(Self::erf_conversion(node)), + NodeType::Clip => graph.register(Self::clip_conversion(node)), + NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), + NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), + NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), + NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), + NodeType::MatMul => graph.register(Self::matmul_conversion(node)), + NodeType::Linear => graph.register(Self::linear_conversion::(node)), + NodeType::BatchNormalization => { + graph.register(Self::batch_norm_conversion::(node)) } - // TODO support Bool tensor when it is supported by Burn - _ => panic!("Unsupported constant tensor type: {:?} ", tensor.elem_type), - }; - - ConstantValue::Tensor(TensorType::new(name, dim, kind, shape), tensor_value) - } + NodeType::Relu => graph.register(Self::relu_conversion(node)), + NodeType::Flatten => graph.register(Self::flatten_conversion(node)), + NodeType::GatherElements => graph.register(Self::gather_conversion(node)), + NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)), + NodeType::Softmax => graph.register(Self::softmax_conversion(node)), + NodeType::Tanh => graph.register(Self::tanh_conversion(node)), + NodeType::Constant => graph.register(Self::constant_conversion::(node)), + NodeType::Reshape => graph.register(Self::reshape_conversion(node)), + NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), + NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), + NodeType::Transpose => graph.register(Self::transpose_conversion(node)), + NodeType::Concat => graph.register(Self::concat_conversion(node)), + NodeType::Cast => graph.register(Self::cast_conversion(node)), + NodeType::Dropout => graph.register(Self::dropout_conversion(node)), + NodeType::GlobalAveragePool => { + graph.register(Self::global_avg_pool_conversion(node)) + } + _ => panic!("Unsupported node conversion {}", node.node_type), + } } - ArgType::Scalar(elem_type) => match elem_type { - ElementType::Float64 => ConstantValue::Float64(attr.value.unwrap().into_f64()), - ElementType::Float32 => ConstantValue::Float32(attr.value.unwrap().into_f32()), - ElementType::Int32 => ConstantValue::Int32(attr.value.unwrap().into_i32()), - ElementType::Int64 => ConstantValue::Int64(attr.value.unwrap().into_i64()), - ElementType::Bool => ConstantValue::Bool(attr.value.unwrap().into_bool()), - _ => panic!("Unsupported constant tensor type: {:?} ", elem_type), - }, - ArgType::Shape(_) => panic!("Shape is not supported as constant value."), - }; - ConstantNode::new(node.name.clone(), const_value, output.to_type()) - } + // Get input and output names + let input_names = self + .inputs + .iter() + .map(|input| input.name.clone()) + .collect::>(); + let output_names = self + .outputs + .iter() + .map(|output| output.name.clone()) + .collect::>(); + + // Register inputs and outputs with the graph + graph.register_input_output(input_names, output_names); + + graph + } + + fn constant_conversion(node: Node) -> ConstantNode { + let output = node.outputs.get(0).unwrap(); + + let attr = convert_constant_value(&node); + + let const_value = match attr.ty { + ArgType::Tensor(tensor) => { + // Treat tensor with dim 0 as scalar + if tensor.dim == 0 { + panic!("Constant tensor with dim 0 should have been converted to scalar.") + } else { + let kind: TensorKind = tensor.elem_type.clone().into(); + let dim = tensor.dim; + let name = node.name.clone(); + let shape = tensor.shape.clone(); + + let tensor_value = match tensor.elem_type { + // TODO Review how double precision should be supported + ElementType::Float32 | ElementType::Float64 => { + TensorValue::Float(serialize_data::( + attr.value.unwrap(), + tensor.shape.unwrap(), + )) + } + ElementType::Int32 | ElementType::Int64 => { + TensorValue::Int(serialize_data::( + attr.value.unwrap(), + tensor.shape.unwrap(), + )) + } + // TODO support Bool tensor when it is supported by Burn + _ => panic!("Unsupported constant tensor type: {:?} ", tensor.elem_type), + }; + + ConstantValue::Tensor(TensorType::new(name, dim, kind, shape), tensor_value) + } + } + ArgType::Scalar(elem_type) => match elem_type { + ElementType::Float64 => ConstantValue::Float64(attr.value.unwrap().into_f64()), + ElementType::Float32 => ConstantValue::Float32(attr.value.unwrap().into_f32()), + ElementType::Int32 => ConstantValue::Int32(attr.value.unwrap().into_i32()), + ElementType::Int64 => ConstantValue::Int64(attr.value.unwrap().into_i64()), + ElementType::Bool => ConstantValue::Bool(attr.value.unwrap().into_bool()), + _ => panic!("Unsupported constant tensor type: {:?} ", elem_type), + }, + ArgType::Shape(_) => panic!("Shape is not supported as constant value."), + }; + + ConstantNode::new(node.name.clone(), const_value, output.to_type()) + } + + fn add_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + BinaryNode::add(lhs, rhs, output) + } + + fn sub_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + BinaryNode::sub(lhs, rhs, output) + } + + fn mul_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + BinaryNode::mul(lhs, rhs, output) + } + + fn div_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + BinaryNode::div(lhs, rhs, output) + } + + fn matmul_conversion(node: Node) -> MatmulNode { + let lhs = node.inputs.get(0).unwrap().to_tensor_type(); + let rhs = node.inputs.get(1).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + + MatmulNode::new(lhs, rhs, output) + } + + fn equal_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + BinaryNode::equal(lhs, rhs, output) + } + + fn erf_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::erf(input, output) + } + + fn relu_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::relu(input, output) + } + + fn flatten_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + let (start_dim, end_dim) = flatten_config(&node); + + UnaryNode::flatten(input, output, start_dim, end_dim) + } + + fn gather_conversion(node: Node) -> GatherNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let index = node.inputs.get(1).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let dim = gather_config(&node); + + GatherNode::new(input, index, output, dim) + } + + fn transpose_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - fn add_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + UnaryNode::transpose(input, output) + } + + fn cast_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::cast(input, output) + } - BinaryNode::add(lhs, rhs, output) - } + fn reshape_conversion(node: Node) -> ReshapeNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let shape = reshape_config(&node); - fn sub_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + ReshapeNode::new(input, output, shape) + } + + fn clip_conversion(node: Node) -> ClipNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let (min, max) = clip_config(&node); - BinaryNode::sub(lhs, rhs, output) - } + ClipNode::new(input, output, min, max) + } - fn mul_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + fn sigmoid_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - BinaryNode::mul(lhs, rhs, output) - } + UnaryNode::sigmoid(input, output) + } - fn div_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + fn reciprocal_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - BinaryNode::div(lhs, rhs, output) - } + UnaryNode::reciprocal(input, output) + } - fn matmul_conversion(node: Node) -> MatmulNode { - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + fn log_softmax_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + let dim = log_softmax_config(&node); - MatmulNode::new(lhs, rhs, output) - } + UnaryNode::log_softmax(input, output, dim) + } - fn equal_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + fn softmax_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + let dim = softmax_config(&node); - BinaryNode::equal(lhs, rhs, output) - } + UnaryNode::softmax(input, output, dim) + } - fn erf_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + fn tanh_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - UnaryNode::erf(input, output) - } + UnaryNode::tanh(input, output) + } - fn relu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + fn concat_conversion(node: Node) -> ConcatNode { + let inputs = node + .inputs + .iter() + .map(|input| input.to_tensor_type()) + .collect(); - UnaryNode::relu(input, output) - } + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let dim = concat_config(&node); - fn flatten_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - let (start_dim, end_dim) = flatten_config(&node); + ConcatNode::new(inputs, output, dim) + } - UnaryNode::flatten(input, output, start_dim, end_dim) - } + fn linear_conversion(node: Node) -> LinearNode { + let name = &node.name; + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = linear_config(&node); - fn gather_conversion(node: Node) -> GatherNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let index = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let dim = gather_config(&node); + let weight = extract_data_serialize::(1, &node).expect("Weight is required"); - GatherNode::new(input, index, output, dim) - } + let bias = extract_data_serialize::(2, &node); - fn transpose_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + LinearNode::new(name, input, output, weight, bias, config) + } - UnaryNode::transpose(input, output) - } + fn dropout_conversion(node: Node) -> DropoutNode { + let name = &node.name; + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = dropout_config(&node); - fn cast_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + DropoutNode::new(name, input, output, config) + } - UnaryNode::cast(input, output) - } + fn batch_norm_conversion(node: Node) -> BatchNormNode { + let config = batch_norm_config(&node); + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let dim = input.dim - 2; + + let gamma = extract_data_serialize::(1, &node).expect("Gamma is required"); + let beta = extract_data_serialize::(2, &node).expect("Beta is required"); + let running_mean = + extract_data_serialize::(3, &node).expect("Running mean is required"); + let running_var = + extract_data_serialize::(4, &node).expect("Running var is required"); + + let name = &node.name; + + BatchNormNode::new( + dim, + name, + input, + output, + gamma, + beta, + running_mean, + running_var, + config, + ) + } - fn reshape_conversion(node: Node) -> ReshapeNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let shape = reshape_config(&node); + fn conv1d_conversion(node: Node) -> Conv1dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = conv1d_config(&node); - ReshapeNode::new(input, output, shape) - } + let bias = node.inputs.len() == 3; + let weight = extract_data_serialize::(1, &node).unwrap(); + let bias = match bias { + true => extract_data_serialize::(2, &node), + false => None, + }; - fn clip_conversion(node: Node) -> ClipNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let (min, max) = clip_config(&node); - - ClipNode::new(input, output, min, max) - } - - fn sigmoid_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::sigmoid(input, output) - } - - fn reciprocal_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::reciprocal(input, output) - } - - fn log_softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - let dim = log_softmax_config(&node); - - UnaryNode::log_softmax(input, output, dim) - } - - fn softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - let dim = softmax_config(&node); - - UnaryNode::softmax(input, output, dim) - } - - fn tanh_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); - - UnaryNode::tanh(input, output) - } - - fn concat_conversion(node: Node) -> ConcatNode { - let inputs = node - .inputs - .iter() - .map(|input| input.to_tensor_type()) - .collect(); - - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let dim = concat_config(&node); - - ConcatNode::new(inputs, output, dim) - } - - fn linear_conversion(node: Node) -> LinearNode { - let name = &node.name; - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = linear_config(&node); - - let weight = extract_data_serialize::(1, &node).expect("Weight is required"); - - let bias = extract_data_serialize::(2, &node); - - LinearNode::new(name, input, output, weight, bias, config) - } - - fn dropout_conversion(node: Node) -> DropoutNode { - let name = &node.name; - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = dropout_config(&node); - - DropoutNode::new(name, input, output, config) - } - - fn batch_norm_conversion(node: Node) -> BatchNormNode { - let config = batch_norm_config(&node); - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let dim = input.dim - 2; - - let gamma = extract_data_serialize::(1, &node).expect("Gamma is required"); - let beta = extract_data_serialize::(2, &node).expect("Beta is required"); - let running_mean = - extract_data_serialize::(3, &node).expect("Running mean is required"); - let running_var = - extract_data_serialize::(4, &node).expect("Running var is required"); - - let name = &node.name; - - BatchNormNode::new( - dim, - name, - input, - output, - gamma, - beta, - running_mean, - running_var, - config, - ) - } - - fn conv1d_conversion(node: Node) -> Conv1dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = conv1d_config(&node); - - let bias = node.inputs.len() == 3; - let weight = extract_data_serialize::(1, &node).unwrap(); - let bias = match bias { - true => extract_data_serialize::(2, &node), - false => None, - }; - - let name = &node.name; - Conv1dNode::::new(name, input, output, weight, bias, config) - } - - fn conv2d_conversion(node: Node) -> Conv2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = conv2d_config(&node); - - let bias = node.inputs.len() == 3; - let weight = extract_data_serialize::(1, &node).unwrap(); - let bias = match bias { - true => extract_data_serialize::(2, &node), - false => None, - }; - - let name = &node.name; - Conv2dNode::::new(name, input, output, weight, bias, config) - } - - fn max_pool2d_conversion(node: Node) -> MaxPool2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = max_pool2d_config(&node); - - let name = &node.name; - MaxPool2dNode::new(name, input, output, config) - } - - fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let config = avg_pool2d_config(&node); - - let name = &node.name; - AvgPool2dNode::new(name, input, output, config) - } - - fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - - let name = &node.name; - - GlobalAvgPoolNode::new(name, input, output) - } + let name = &node.name; + Conv1dNode::::new(name, input, output, weight, bias, config) + } + + fn conv2d_conversion(node: Node) -> Conv2dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = conv2d_config(&node); + + let bias = node.inputs.len() == 3; + let weight = extract_data_serialize::(1, &node).unwrap(); + let bias = match bias { + true => extract_data_serialize::(2, &node), + false => None, + }; + + let name = &node.name; + Conv2dNode::::new(name, input, output, weight, bias, config) + } + + fn max_pool2d_conversion(node: Node) -> MaxPool2dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = max_pool2d_config(&node); + + let name = &node.name; + MaxPool2dNode::new(name, input, output, config) + } + + fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let config = avg_pool2d_config(&node); + + let name = &node.name; + AvgPool2dNode::new(name, input, output, config) + } + + fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + + let name = &node.name; + + GlobalAvgPoolNode::new(name, input, output) + } } /// Extract data from node states and convert it to `DataSerialize`. @@ -603,108 +609,108 @@ impl ONNXGraph { /// * `node` - The node where value are stored. #[track_caller] fn extract_data_serialize(input_index: usize, node: &Node) -> Option> { - if node.inputs.is_empty() { - return None; - } - - let input = node.inputs.get(input_index); - input?; - let input = input.unwrap(); - input.value.as_ref()?; - let ty = input.ty.clone(); - - match ty { - ArgType::Tensor(tensor_type) => { - let value = input.value.as_ref().expect("Value to be provided.").clone(); - - Some(serialize_data( - value.clone(), - tensor_type.shape.unwrap().clone(), - )) - } - _ => panic!("Unsupported serialization type"), - } + if node.inputs.is_empty() { + return None; + } + + let input = node.inputs.get(input_index); + input?; + let input = input.unwrap(); + input.value.as_ref()?; + let ty = input.ty.clone(); + + match ty { + ArgType::Tensor(tensor_type) => { + let value = input.value.as_ref().expect("Value to be provided.").clone(); + + Some(serialize_data( + value.clone(), + tensor_type.shape.unwrap().clone(), + )) + } + _ => panic!("Unsupported serialization type"), + } } /// Convert data to `DataSerialize`. fn serialize_data(data: Data, shape: Vec) -> DataSerialize { - match data { - Data::Float16s(val) => DataSerialize::new(val, shape).convert(), - Data::Float32s(val) => DataSerialize::new(val, shape).convert(), - Data::Float64s(val) => DataSerialize::new(val, shape).convert(), - Data::Int32s(val) => DataSerialize::new(val, shape).convert(), - Data::Int64s(val) => DataSerialize::new(val, shape).convert(), - // TODO support Bool tensor when it is supported by Burn - _ => panic!("Unsupported tensor element type"), - } + match data { + Data::Float16s(val) => DataSerialize::new(val, shape).convert(), + Data::Float32s(val) => DataSerialize::new(val, shape).convert(), + Data::Float64s(val) => DataSerialize::new(val, shape).convert(), + Data::Int32s(val) => DataSerialize::new(val, shape).convert(), + Data::Int64s(val) => DataSerialize::new(val, shape).convert(), + // TODO support Bool tensor when it is supported by Burn + _ => panic!("Unsupported tensor element type"), + } } impl Argument { - pub fn to_tensor_type(&self) -> TensorType { - match &self.ty { - ArgType::Tensor(ir::TensorType { - elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64, - dim, - .. - }) => TensorType::new_float(self.name.clone(), *dim), - ArgType::Tensor(ir::TensorType { - elem_type: ElementType::Int32 | ElementType::Int64, - dim, - .. - }) => TensorType::new_int(self.name.clone(), *dim), - _ => panic!("Can't transform to tensor."), - } - } - - pub fn to_type(&self) -> Type { - match &self.ty { - ArgType::Tensor(tensor) => { - // Treat tensor with dim 0 as scalar - if tensor.dim == 0 { - Type::Scalar(ScalarType::new( - self.name.clone(), - ScalarKind::from(&tensor.elem_type), - )) - } else { - let kind: TensorKind = tensor.elem_type.clone().into(); - let dim = tensor.dim; - let name = self.name.clone(); - let shape = tensor.shape.clone(); - Type::Tensor(TensorType::new(name, dim, kind, shape)) + pub fn to_tensor_type(&self) -> TensorType { + match &self.ty { + ArgType::Tensor(ir::TensorType { + elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64, + dim, + .. + }) => TensorType::new_float(self.name.clone(), *dim), + ArgType::Tensor(ir::TensorType { + elem_type: ElementType::Int32 | ElementType::Int64, + dim, + .. + }) => TensorType::new_int(self.name.clone(), *dim), + _ => panic!("Can't transform to tensor."), } - } + } - ArgType::Scalar(elem_type) => { - Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into())) - } - ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), + pub fn to_type(&self) -> Type { + match &self.ty { + ArgType::Tensor(tensor) => { + // Treat tensor with dim 0 as scalar + if tensor.dim == 0 { + Type::Scalar(ScalarType::new( + self.name.clone(), + ScalarKind::from(&tensor.elem_type), + )) + } else { + let kind: TensorKind = tensor.elem_type.clone().into(); + let dim = tensor.dim; + let name = self.name.clone(); + let shape = tensor.shape.clone(); + Type::Tensor(TensorType::new(name, dim, kind, shape)) + } + } + + ArgType::Scalar(elem_type) => { + Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into())) + } + ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), + } } - } } impl From<&ElementType> for ScalarKind { - fn from(elem_type: &ElementType) -> Self { - match elem_type { - ElementType::Float32 => ScalarKind::Float32, - ElementType::Float64 => ScalarKind::Float64, - ElementType::Int32 => ScalarKind::Int32, - ElementType::Int64 => ScalarKind::Int64, - ElementType::Bool => ScalarKind::Bool, - ElementType::String => panic!("String tensor unsupported"), - ElementType::Float16 => panic!("Float16 tensor unsupported"), - } - } + fn from(elem_type: &ElementType) -> Self { + match elem_type { + ElementType::Float32 => ScalarKind::Float32, + ElementType::Float64 => ScalarKind::Float64, + ElementType::Int32 => ScalarKind::Int32, + ElementType::Int64 => ScalarKind::Int64, + ElementType::Bool => ScalarKind::Bool, + ElementType::String => panic!("String tensor unsupported"), + ElementType::Float16 => panic!("Float16 tensor unsupported"), + } + } } impl From for TensorKind { - fn from(elem_type: ElementType) -> Self { - match elem_type { - ElementType::Float32 => TensorKind::Float, - ElementType::Float64 => TensorKind::Float, - ElementType::Int32 => TensorKind::Int, - ElementType::Int64 => TensorKind::Int, - ElementType::Bool => TensorKind::Bool, - _ => panic!("Unsupported tensor type"), - } - } + fn from(elem_type: ElementType) -> Self { + match elem_type { + ElementType::Float32 => TensorKind::Float, + ElementType::Float64 => TensorKind::Float, + ElementType::Int32 => TensorKind::Int, + ElementType::Int64 => TensorKind::Int, + ElementType::Bool => TensorKind::Bool, + _ => panic!("Unsupported tensor type"), + } + } } diff --git a/burn-ndarray/build.rs b/burn-ndarray/build.rs index d70cd753ac..dcb4354ca6 100644 --- a/burn-ndarray/build.rs +++ b/burn-ndarray/build.rs @@ -1,6 +1,6 @@ fn main() { - // https://github.com/rust-ndarray/ndarray/issues/1197 - if cfg!(feature = "blas-accelerate") { - println!("cargo:rustc-link-lib=framework=Accelerate"); - } + // https://github.com/rust-ndarray/ndarray/issues/1197 + if cfg!(feature = "blas-accelerate") { + println!("cargo:rustc-link-lib=framework=Accelerate"); + } } diff --git a/burn-ndarray/src/backend.rs b/burn-ndarray/src/backend.rs index 01e629bde2..3137ccd893 100644 --- a/burn-ndarray/src/backend.rs +++ b/burn-ndarray/src/backend.rs @@ -11,14 +11,14 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); /// The device type for the ndarray backend. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NdArrayDevice { - /// The CPU device. - Cpu, + /// The CPU device. + Cpu, } impl Default for NdArrayDevice { - fn default() -> Self { - Self::Cpu - } + fn default() -> Self { + Self::Cpu + } } /// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. @@ -27,33 +27,33 @@ impl Default for NdArrayDevice { /// `wasm`, `arm`, and `x86`. #[derive(Clone, Copy, Default, Debug)] pub struct NdArray { - phantom: PhantomData, + phantom: PhantomData, } impl Backend for NdArray { - type Device = NdArrayDevice; - type FullPrecisionElem = f32; - type FullPrecisionBackend = NdArray; + type Device = NdArrayDevice; + type FullPrecisionElem = f32; + type FullPrecisionBackend = NdArray; - type TensorPrimitive = NdArrayTensor; - type FloatElem = E; + type TensorPrimitive = NdArrayTensor; + type FloatElem = E; - type IntTensorPrimitive = NdArrayTensor; - type IntElem = i64; + type IntTensorPrimitive = NdArrayTensor; + type IntElem = i64; - type BoolTensorPrimitive = NdArrayTensor; + type BoolTensorPrimitive = NdArrayTensor; - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn name() -> String { - String::from("ndarray") - } + fn name() -> String { + String::from("ndarray") + } - fn seed(seed: u64) { - let rng = StdRng::seed_from_u64(seed); - let mut seed = SEED.lock().unwrap(); - *seed = Some(rng); - } + fn seed(seed: u64) { + let rng = StdRng::seed_from_u64(seed); + let mut seed = SEED.lock().unwrap(); + *seed = Some(rng); + } } diff --git a/burn-ndarray/src/element.rs b/burn-ndarray/src/element.rs index 35381ae93e..be08c76222 100644 --- a/burn-ndarray/src/element.rs +++ b/burn-ndarray/src/element.rs @@ -6,147 +6,147 @@ use ndarray::LinalgScalar; /// A float element for ndarray backend. pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar where - Self: Sized, + Self: Sized, { } /// A general element for ndarray backend. pub trait NdArrayElement: - Element - + ndarray::LinalgScalar - + ndarray::ScalarOperand - + ExpElement - + num_traits::FromPrimitive - + core::ops::AddAssign - + core::cmp::PartialEq - + core::cmp::PartialOrd + Element + + ndarray::LinalgScalar + + ndarray::ScalarOperand + + ExpElement + + num_traits::FromPrimitive + + core::ops::AddAssign + + core::cmp::PartialEq + + core::cmp::PartialOrd { } /// A element for ndarray backend that supports exp ops. pub trait ExpElement { - fn exp_elem(self) -> Self; - fn log_elem(self) -> Self; - fn log1p_elem(self) -> Self; - fn powf_elem(self, value: f32) -> Self; - fn powi_elem(self, value: i32) -> Self; - fn sqrt_elem(self) -> Self; - fn abs_elem(self) -> Self; - fn int_abs_elem(self) -> Self; + fn exp_elem(self) -> Self; + fn log_elem(self) -> Self; + fn log1p_elem(self) -> Self; + fn powf_elem(self, value: f32) -> Self; + fn powi_elem(self, value: i32) -> Self; + fn sqrt_elem(self) -> Self; + fn abs_elem(self) -> Self; + fn int_abs_elem(self) -> Self; } impl FloatNdArrayElement for f64 {} impl FloatNdArrayElement for f32 {} macro_rules! make_elem { - ( + ( double $ty:ty ) => { - impl NdArrayElement for $ty {} - - impl ExpElement for $ty { - #[inline(always)] - fn exp_elem(self) -> Self { - exp(self as f64) as $ty - } - - #[inline(always)] - fn log_elem(self) -> Self { - log(self as f64) as $ty - } - - #[inline(always)] - fn log1p_elem(self) -> Self { - log1p(self as f64) as $ty - } - - #[inline(always)] - fn powf_elem(self, value: f32) -> Self { - pow(self as f64, value.into()) as $ty - } - - #[inline(always)] - fn powi_elem(self, value: i32) -> Self { - #[cfg(feature = "std")] - let val = f64::powi(self as f64, value) as $ty; - - #[cfg(not(feature = "std"))] - let val = Self::powf_elem(self, value as f32); - - val - } - - #[inline(always)] - fn sqrt_elem(self) -> Self { - sqrt(self as f64) as $ty - } - - #[inline(always)] - fn abs_elem(self) -> Self { - fabs(self as f64) as $ty - } - - #[inline(always)] - fn int_abs_elem(self) -> Self { - (self as i64).abs() as $ty - } - } - }; - ( + impl NdArrayElement for $ty {} + + impl ExpElement for $ty { + #[inline(always)] + fn exp_elem(self) -> Self { + exp(self as f64) as $ty + } + + #[inline(always)] + fn log_elem(self) -> Self { + log(self as f64) as $ty + } + + #[inline(always)] + fn log1p_elem(self) -> Self { + log1p(self as f64) as $ty + } + + #[inline(always)] + fn powf_elem(self, value: f32) -> Self { + pow(self as f64, value.into()) as $ty + } + + #[inline(always)] + fn powi_elem(self, value: i32) -> Self { + #[cfg(feature = "std")] + let val = f64::powi(self as f64, value) as $ty; + + #[cfg(not(feature = "std"))] + let val = Self::powf_elem(self, value as f32); + + val + } + + #[inline(always)] + fn sqrt_elem(self) -> Self { + sqrt(self as f64) as $ty + } + + #[inline(always)] + fn abs_elem(self) -> Self { + fabs(self as f64) as $ty + } + + #[inline(always)] + fn int_abs_elem(self) -> Self { + (self as i64).abs() as $ty + } + } + }; + ( single $ty:ty ) => { - impl NdArrayElement for $ty {} - - impl ExpElement for $ty { - #[inline(always)] - fn exp_elem(self) -> Self { - expf(self as f32) as $ty - } - - #[inline(always)] - fn log_elem(self) -> Self { - logf(self as f32) as $ty - } - - #[inline(always)] - fn log1p_elem(self) -> Self { - log1pf(self as f32) as $ty - } - - #[inline(always)] - fn powf_elem(self, value: f32) -> Self { - powf(self as f32, value.into()) as $ty - } - - #[inline(always)] - fn powi_elem(self, value: i32) -> Self { - #[cfg(feature = "std")] - let val = f32::powi(self as f32, value) as $ty; - - #[cfg(not(feature = "std"))] - let val = Self::powf_elem(self, value as f32); - - val - } - - #[inline(always)] - fn sqrt_elem(self) -> Self { - sqrtf(self as f32) as $ty - } - - #[inline(always)] - fn abs_elem(self) -> Self { - fabsf(self as f32) as $ty - } - - #[inline(always)] - fn int_abs_elem(self) -> Self { - (self as i32).abs() as $ty - } - } - }; + impl NdArrayElement for $ty {} + + impl ExpElement for $ty { + #[inline(always)] + fn exp_elem(self) -> Self { + expf(self as f32) as $ty + } + + #[inline(always)] + fn log_elem(self) -> Self { + logf(self as f32) as $ty + } + + #[inline(always)] + fn log1p_elem(self) -> Self { + log1pf(self as f32) as $ty + } + + #[inline(always)] + fn powf_elem(self, value: f32) -> Self { + powf(self as f32, value.into()) as $ty + } + + #[inline(always)] + fn powi_elem(self, value: i32) -> Self { + #[cfg(feature = "std")] + let val = f32::powi(self as f32, value) as $ty; + + #[cfg(not(feature = "std"))] + let val = Self::powf_elem(self, value as f32); + + val + } + + #[inline(always)] + fn sqrt_elem(self) -> Self { + sqrtf(self as f32) as $ty + } + + #[inline(always)] + fn abs_elem(self) -> Self { + fabsf(self as f32) as $ty + } + + #[inline(always)] + fn int_abs_elem(self) -> Self { + (self as i32).abs() as $ty + } + } + }; } make_elem!(double f64); diff --git a/burn-ndarray/src/lib.rs b/burn-ndarray/src/lib.rs index 1f4b9d3790..0c85644116 100644 --- a/burn-ndarray/src/lib.rs +++ b/burn-ndarray/src/lib.rs @@ -7,9 +7,9 @@ extern crate derive_new; #[cfg(any( - feature = "blas-netlib", - feature = "blas-openblas", - feature = "blas-openblas-system", + feature = "blas-netlib", + feature = "blas-openblas", + feature = "blas-openblas-system", ))] extern crate blas_src; @@ -29,14 +29,14 @@ extern crate alloc; #[cfg(test)] mod tests { - type TestBackend = crate::NdArray; - type TestTensor = burn_tensor::Tensor; - type TestTensorInt = burn_tensor::Tensor; - use alloc::format; - use alloc::vec; + type TestBackend = crate::NdArray; + type TestTensor = burn_tensor::Tensor; + type TestTensorInt = burn_tensor::Tensor; + use alloc::format; + use alloc::vec; - burn_tensor::testgen_all!(); + burn_tensor::testgen_all!(); - #[cfg(feature = "std")] - burn_autodiff::testgen_all!(); + #[cfg(feature = "std")] + burn_autodiff::testgen_all!(); } diff --git a/burn-ndarray/src/ops/activations.rs b/burn-ndarray/src/ops/activations.rs index 4a6c33eee7..40d1e16337 100644 --- a/burn-ndarray/src/ops/activations.rs +++ b/burn-ndarray/src/ops/activations.rs @@ -2,16 +2,16 @@ use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray}; use burn_tensor::{ops::ActivationOps, ElementConversion}; impl ActivationOps for NdArray { - fn relu(tensor: NdArrayTensor) -> NdArrayTensor { - let zero = 0.elem(); - let array = tensor - .array - .mapv_into(|elem| match elem < zero { - true => zero, - false => elem, - }) - .into_shared(); + fn relu(tensor: NdArrayTensor) -> NdArrayTensor { + let zero = 0.elem(); + let array = tensor + .array + .mapv_into(|elem| match elem < zero { + true => zero, + false => elem, + }) + .into_shared(); - NdArrayTensor::new(array) - } + NdArrayTensor::new(array) + } } diff --git a/burn-ndarray/src/ops/adaptive_avgpool.rs b/burn-ndarray/src/ops/adaptive_avgpool.rs index dca187c160..1e91aa227e 100644 --- a/burn-ndarray/src/ops/adaptive_avgpool.rs +++ b/burn-ndarray/src/ops/adaptive_avgpool.rs @@ -1,101 +1,103 @@ use crate::{ - element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, - tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, + tensor::NdArrayTensor, }; use burn_tensor::ElementConversion; use ndarray::Array4; pub(crate) fn adaptive_avg_pool2d( - x: NdArrayTensor, - output_size: [usize; 2], + x: NdArrayTensor, + output_size: [usize; 2], ) -> NdArrayTensor { - let [batch_size, channels, input_height, input_width] = x.shape().dims; - - let x = x.array; - let mut output = Array4::from_elem( - (batch_size, channels, output_size[0], output_size[1]), - 0.elem(), - ); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - for h in 0..output_size[0] { - for w in 0..output_size[1] { - let ih_start = start_index(h, output_size[0], input_height); - let ih_end = end_index(h, output_size[0], input_height); - let iw_start = start_index(w, output_size[1], input_width); - let iw_end = end_index(w, output_size[1], input_width); - - let mut sum_val: E = 0.elem(); - - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - sum_val += x[[b, c, ih, iw]]; + let [batch_size, channels, input_height, input_width] = x.shape().dims; + + let x = x.array; + let mut output = Array4::from_elem( + (batch_size, channels, output_size[0], output_size[1]), + 0.elem(), + ); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + for h in 0..output_size[0] { + for w in 0..output_size[1] { + let ih_start = start_index(h, output_size[0], input_height); + let ih_end = end_index(h, output_size[0], input_height); + let iw_start = start_index(w, output_size[1], input_width); + let iw_end = end_index(w, output_size[1], input_width); + + let mut sum_val: E = 0.elem(); + + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + sum_val += x[[b, c, ih, iw]]; + } + } + + let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); + output[[b, c, h, w]] = sum_val / count.elem(); + } } - } + }) + }); - let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); - output[[b, c, h, w]] = sum_val / count.elem(); - } - } - }) - }); - - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } pub(crate) fn adaptive_avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, + x: NdArrayTensor, + grad: NdArrayTensor, ) -> NdArrayTensor { - let [_, _, input_height, input_width] = x.shape().dims; - let [batch_size, channels, output_height, output_width] = grad.shape().dims; - - let mut output_grad = - Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output_grad = unsafe_shared_out.get(); - for oh in 0..output_height { - for ow in 0..output_width { - let ih_start = start_index(oh, output_height, input_height); - let ih_end = end_index(oh, output_height, input_height); - - let iw_start = start_index(ow, output_width, input_width); - let iw_end = end_index(ow, output_width, input_width); - - let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); - - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - output_grad[[b, c, ih, iw]] += grad.array[[b, c, oh, ow]] / count.elem(); + let [_, _, input_height, input_width] = x.shape().dims; + let [batch_size, channels, output_height, output_width] = grad.shape().dims; + + let mut output_grad = + Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output_grad = unsafe_shared_out.get(); + for oh in 0..output_height { + for ow in 0..output_width { + let ih_start = start_index(oh, output_height, input_height); + let ih_end = end_index(oh, output_height, input_height); + + let iw_start = start_index(ow, output_width, input_width); + let iw_end = end_index(ow, output_width, input_width); + + let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); + + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + output_grad[[b, c, ih, iw]] += + grad.array[[b, c, oh, ow]] / count.elem(); + } + } + } } - } - } - } - }) - }); + }) + }); - NdArrayTensor::new(output_grad.into_dyn().into_shared()) + NdArrayTensor::new(output_grad.into_dyn().into_shared()) } fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { - libm::floorf((output_size_index as f32 * input_size as f32) / output_size as f32) as usize + libm::floorf((output_size_index as f32 * input_size as f32) / output_size as f32) as usize } fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { - let index = - libm::ceilf(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32) as usize; + let index = + libm::ceilf(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32) + as usize; - usize::min(index, input_size) + usize::min(index, input_size) } diff --git a/burn-ndarray/src/ops/avgpool.rs b/burn-ndarray/src/ops/avgpool.rs index b0778dcf84..680c4e1175 100644 --- a/burn-ndarray/src/ops/avgpool.rs +++ b/burn-ndarray/src/ops/avgpool.rs @@ -1,134 +1,135 @@ use crate::{ - element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, - tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, + tensor::NdArrayTensor, }; use burn_tensor::ElementConversion; use ndarray::Array4; pub(crate) fn avg_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> NdArrayTensor { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - - let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1; - let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1; - - let x = x.array; - - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem()); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut sum_val: E = 0.elem(); - let mut count: E = 0.elem(); - - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let ih = oh * stride_height + kh; - let iw = ow * stride_width + kw; - - if ih >= x_height + padding_height - || iw >= x_width + padding_width - || ih < padding_height - || iw < padding_width - { - continue; - } - - let ih = ih - padding_height; - let iw = iw - padding_width; - - count += 1.elem(); - sum_val += x[[b, c, ih, iw]]; + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + + let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1; + let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1; + + let x = x.array; + + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem()); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut sum_val: E = 0.elem(); + let mut count: E = 0.elem(); + + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let ih = oh * stride_height + kh; + let iw = ow * stride_width + kw; + + if ih >= x_height + padding_height + || iw >= x_width + padding_width + || ih < padding_height + || iw < padding_width + { + continue; + } + + let ih = ih - padding_height; + let iw = iw - padding_width; + + count += 1.elem(); + sum_val += x[[b, c, ih, iw]]; + } + } + + if count_include_pad { + count = ((kernel_height * kernel_width) as i32).elem(); + } + + output[[b, c, oh, ow]] = sum_val / count; + } } - } - - if count_include_pad { - count = ((kernel_height * kernel_width) as i32).elem(); - } - - output[[b, c, oh, ow]] = sum_val / count; - } - } - }) - }); + }) + }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } pub(crate) fn avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: NdArrayTensor, + grad: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> NdArrayTensor { - let [kernel_height, kernel_width] = kernel_size; - let [stride_height, stride_width] = stride; - let [padding_height, padding_width] = padding; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - let [_batch_size, _channels, out_height, out_width] = grad.shape().dims; + let [kernel_height, kernel_width] = kernel_size; + let [stride_height, stride_width] = stride; + let [padding_height, padding_width] = padding; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + let [_batch_size, _channels, out_height, out_width] = grad.shape().dims; - let grad = grad.array; + let grad = grad.array; - let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem()); - let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); + let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem()); + let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; - let output_grad = unsafe_shared_grad.get(); + let output_grad = unsafe_shared_grad.get(); - for oh in 0..out_height { - for ow in 0..out_width { - let ih_start = oh * stride_height; - let iw_start = ow * stride_width; + for oh in 0..out_height { + for ow in 0..out_width { + let ih_start = oh * stride_height; + let iw_start = ow * stride_width; - let ih_end = ih_start + kernel_height; - let iw_end = iw_start + kernel_width; + let ih_end = ih_start + kernel_height; + let iw_end = iw_start + kernel_width; - let ih_start = usize::max(ih_start, padding_height); - let iw_start = usize::max(iw_start, padding_width); + let ih_start = usize::max(ih_start, padding_height); + let iw_start = usize::max(iw_start, padding_width); - let ih_end = usize::min(ih_end, x_height + padding_height); - let iw_end = usize::min(iw_end, x_width + padding_width); + let ih_end = usize::min(ih_end, x_height + padding_height); + let iw_end = usize::min(iw_end, x_width + padding_width); - let count = match count_include_pad { - true => kernel_width * kernel_height, - false => (ih_end - ih_start) * (iw_end - iw_start), - }; + let count = match count_include_pad { + true => kernel_width * kernel_height, + false => (ih_end - ih_start) * (iw_end - iw_start), + }; - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - let ih = ih - padding_height; - let iw = iw - padding_width; + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + let ih = ih - padding_height; + let iw = iw - padding_width; - output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / (count as i32).elem(); + output_grad[[b, c, ih, iw]] += + grad[[b, c, oh, ow]] / (count as i32).elem(); + } + } + } } - } - } - } - }) - }); + }) + }); - NdArrayTensor::new(output_grad.into_dyn().into_shared()) + NdArrayTensor::new(output_grad.into_dyn().into_shared()) } diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index cf2b2f5cd2..7b2b9f96b9 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -16,457 +16,478 @@ use crate::ops::macros::{keepdim, mean_dim, sum_dim}; use crate::{reshape, tensor::NdArrayTensor}; pub struct NdArrayOps { - e: PhantomData, + e: PhantomData, } pub(crate) struct NdArrayMathOps { - e: PhantomData, + e: PhantomData, } impl NdArrayOps where - E: Copy, + E: Copy, { - pub fn slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - let slices = Self::to_slice_args::(ranges); - let array = tensor.array.slice_move(slices.as_slice()).into_shared(); - - NdArrayTensor { array } - } - - pub fn slice_assign( - tensor: NdArrayTensor, - ranges: [Range; D2], - value: NdArrayTensor, - ) -> NdArrayTensor { - let slices = Self::to_slice_args::(ranges); - let mut array = tensor.array.into_owned(); - array.slice_mut(slices.as_slice()).assign(&value.array); - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - reshape!( - ty E, - shape shape, - array tensor.array, - d D2 - ) - } - - pub fn cat(tensors: Vec>, dim: usize) -> NdArrayTensor { - let arrays: Vec> = - tensors.iter().map(|t| t.array.view()).collect(); - let array = ndarray::concatenate(Axis(dim), &arrays) - .unwrap() - .into_shared(); - - NdArrayTensor { array } - } - - fn to_slice_args( - ranges: [Range; D2], - ) -> [SliceInfoElem; D1] { - let mut slices = [SliceInfoElem::NewAxis; D1]; - for i in 0..D1 { - if i >= D2 { - slices[i] = SliceInfoElem::Slice { - start: 0, - end: None, - step: 1, - } - } else { - slices[i] = SliceInfoElem::Slice { - start: ranges[i].start as isize, - end: Some(ranges[i].end as isize), - step: 1, + pub fn slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + let slices = Self::to_slice_args::(ranges); + let array = tensor.array.slice_move(slices.as_slice()).into_shared(); + + NdArrayTensor { array } + } + + pub fn slice_assign( + tensor: NdArrayTensor, + ranges: [Range; D2], + value: NdArrayTensor, + ) -> NdArrayTensor { + let slices = Self::to_slice_args::(ranges); + let mut array = tensor.array.into_owned(); + array.slice_mut(slices.as_slice()).assign(&value.array); + let array = array.into_shared(); + + NdArrayTensor { array } + } + + pub fn reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + reshape!( + ty E, + shape shape, + array tensor.array, + d D2 + ) + } + + pub fn cat( + tensors: Vec>, + dim: usize, + ) -> NdArrayTensor { + let arrays: Vec> = + tensors.iter().map(|t| t.array.view()).collect(); + let array = ndarray::concatenate(Axis(dim), &arrays) + .unwrap() + .into_shared(); + + NdArrayTensor { array } + } + + fn to_slice_args( + ranges: [Range; D2], + ) -> [SliceInfoElem; D1] { + let mut slices = [SliceInfoElem::NewAxis; D1]; + for i in 0..D1 { + if i >= D2 { + slices[i] = SliceInfoElem::Slice { + start: 0, + end: None, + step: 1, + } + } else { + slices[i] = SliceInfoElem::Slice { + start: ranges[i].start as isize, + end: Some(ranges[i].end as isize), + step: 1, + } + } } - } + slices } - slices - } - pub fn swap_dims( - tensor: NdArrayTensor, - dim1: usize, - dim2: usize, - ) -> NdArrayTensor { - let mut array = tensor.array; - array.swap_axes(dim1, dim2); - - NdArrayTensor::new(array) - } + pub fn swap_dims( + tensor: NdArrayTensor, + dim1: usize, + dim2: usize, + ) -> NdArrayTensor { + let mut array = tensor.array; + array.swap_axes(dim1, dim2); + + NdArrayTensor::new(array) + } } impl NdArrayMathOps where - E: Copy + NdArrayElement, + E: Copy + NdArrayElement, { - pub fn add( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = &lhs.array + &rhs.array; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array + rhs; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn sub( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = lhs.array - rhs.array; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array - rhs; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn mul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = lhs.array * rhs.array; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array * rhs; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn div( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let array = lhs.array / rhs.array; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array / rhs; - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn recip(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.map(|x| 1.elem::() / *x); - let array = array.into_shared(); - - NdArrayTensor { array } - } - - pub fn mean(tensor: NdArrayTensor) -> NdArrayTensor { - let data = Data::from([tensor.array.mean().unwrap()]); - NdArrayTensor::from_data(data) - } - - pub fn sum(tensor: NdArrayTensor) -> NdArrayTensor { - let data = Data::from([tensor.array.sum()]); - NdArrayTensor::from_data(data) - } - - pub fn mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - match D { - 1 => keepdim!(0, dim, tensor, mean), - 2 => keepdim!(1, dim, tensor, mean), - 3 => keepdim!(2, dim, tensor, mean), - 4 => keepdim!(3, dim, tensor, mean), - 5 => keepdim!(4, dim, tensor, mean), - 6 => keepdim!(5, dim, tensor, mean), - _ => panic!("Dim not supported {D}"), - } - } - - pub fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - match D { - 1 => keepdim!(0, dim, tensor, sum), - 2 => keepdim!(1, dim, tensor, sum), - 3 => keepdim!(2, dim, tensor, sum), - 4 => keepdim!(3, dim, tensor, sum), - 5 => keepdim!(4, dim, tensor, sum), - 6 => keepdim!(5, dim, tensor, sum), - _ => panic!("Dim not supported {D}"), - } - } - - pub fn gather( - dim: usize, - mut tensor: NdArrayTensor, - mut indices: NdArrayTensor, - ) -> NdArrayTensor { - if dim != D - 1 { - tensor.array.swap_axes(D - 1, dim); - indices.array.swap_axes(D - 1, dim); + pub fn add( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = &lhs.array + &rhs.array; + let array = array.into_shared(); + + NdArrayTensor { array } } - let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape()); - let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indices.dims[D - 1]); - let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); - let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; - let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; - let mut output = Array2::zeros((batch_size, size_index)); + pub fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array + rhs; + let array = array.into_shared(); - for b in 0..batch_size { - let indices = indices.slice(s!(b, ..)); + NdArrayTensor { array } + } - for (i, index) in indices.iter().enumerate() { - output[[b, i]] = tensor[[b, *index as usize]]; - } + pub fn sub( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = lhs.array - rhs.array; + let array = array.into_shared(); + + NdArrayTensor { array } } - let mut output = NdArrayOps::reshape( - NdArrayTensor::::new(output.into_shared().into_dyn()), - shape_indices, - ); + pub fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array - rhs; + let array = array.into_shared(); - if dim != D - 1 { - output.array.swap_axes(D - 1, dim); + NdArrayTensor { array } } - output - } + pub fn mul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = lhs.array * rhs.array; + let array = array.into_shared(); - pub fn scatter( - dim: usize, - mut tensor: NdArrayTensor, - mut indices: NdArrayTensor, - mut value: NdArrayTensor, - ) -> NdArrayTensor { - if dim != D - 1 { - tensor.array.swap_axes(D - 1, dim); - indices.array.swap_axes(D - 1, dim); - value.array.swap_axes(D - 1, dim); + NdArrayTensor { array } } - let (shape_tensor, shape_indices, shape_value) = - (tensor.shape(), indices.shape(), value.shape()); - let (size_tensor, size_index, size_value) = ( - shape_tensor.dims[D - 1], - shape_indices.dims[D - 1], - shape_value.dims[D - 1], - ); - let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); + pub fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array * rhs; + let array = array.into_shared(); - if shape_value != shape_indices { - panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indices.dims, shape_value.dims); + NdArrayTensor { array } } - let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; - let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array; - let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; + pub fn div( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let array = lhs.array / rhs.array; + let array = array.into_shared(); - for b in 0..batch_size { - let indices = indices.slice(s!(b, ..)); + NdArrayTensor { array } + } + + pub fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array / rhs; + let array = array.into_shared(); - for (i, index) in indices.iter().enumerate() { - let index = *index as usize; - tensor[[b, index]] += value[[b, i]]; - } + NdArrayTensor { array } } - let mut output = NdArrayOps::reshape( - NdArrayTensor::::new(tensor.into_shared().into_dyn()), - shape_tensor, - ); - if dim != D - 1 { - output.array.swap_axes(D - 1, dim); + pub fn recip(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.map(|x| 1.elem::() / *x); + let array = array.into_shared(); + + NdArrayTensor { array } } - output - } - pub fn mask_where( - tensor: NdArrayTensor, - mask: NdArrayTensor, - source: NdArrayTensor, - ) -> NdArrayTensor { - let mask_mul_4tensor = mask.array.mapv(|x| match x { - true => 0.elem(), - false => 1.elem(), - }); - let mask_mul_4source = mask.array.mapv(|x| match x { - true => 1.elem(), - false => 0.elem(), - }); - let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source); + pub fn mean(tensor: NdArrayTensor) -> NdArrayTensor { + let data = Data::from([tensor.array.mean().unwrap()]); + NdArrayTensor::from_data(data) + } - NdArrayTensor::new(array) - } + pub fn sum(tensor: NdArrayTensor) -> NdArrayTensor { + let data = Data::from([tensor.array.sum()]); + NdArrayTensor::from_data(data) + } - pub fn mask_fill( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: E, - ) -> NdArrayTensor { - let mask_mul = mask.array.mapv(|x| match x { - true => 0.elem(), - false => 1.elem(), - }); - let mask_add = mask.array.mapv(|x| match x { - true => value, - false => 0.elem(), - }); - let array = (tensor.array * mask_mul) + mask_add; + pub fn mean_dim( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + match D { + 1 => keepdim!(0, dim, tensor, mean), + 2 => keepdim!(1, dim, tensor, mean), + 3 => keepdim!(2, dim, tensor, mean), + 4 => keepdim!(3, dim, tensor, mean), + 5 => keepdim!(4, dim, tensor, mean), + 6 => keepdim!(5, dim, tensor, mean), + _ => panic!("Dim not supported {D}"), + } + } + + pub fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + match D { + 1 => keepdim!(0, dim, tensor, sum), + 2 => keepdim!(1, dim, tensor, sum), + 3 => keepdim!(2, dim, tensor, sum), + 4 => keepdim!(3, dim, tensor, sum), + 5 => keepdim!(4, dim, tensor, sum), + 6 => keepdim!(5, dim, tensor, sum), + _ => panic!("Dim not supported {D}"), + } + } + + pub fn gather( + dim: usize, + mut tensor: NdArrayTensor, + mut indices: NdArrayTensor, + ) -> NdArrayTensor { + if dim != D - 1 { + tensor.array.swap_axes(D - 1, dim); + indices.array.swap_axes(D - 1, dim); + } + let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape()); + let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indices.dims[D - 1]); + let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); + + let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; + let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; + let mut output = Array2::zeros((batch_size, size_index)); + + for b in 0..batch_size { + let indices = indices.slice(s!(b, ..)); + + for (i, index) in indices.iter().enumerate() { + output[[b, i]] = tensor[[b, *index as usize]]; + } + } + + let mut output = NdArrayOps::reshape( + NdArrayTensor::::new(output.into_shared().into_dyn()), + shape_indices, + ); + + if dim != D - 1 { + output.array.swap_axes(D - 1, dim); + } - NdArrayTensor::new(array) - } + output + } + + pub fn scatter( + dim: usize, + mut tensor: NdArrayTensor, + mut indices: NdArrayTensor, + mut value: NdArrayTensor, + ) -> NdArrayTensor { + if dim != D - 1 { + tensor.array.swap_axes(D - 1, dim); + indices.array.swap_axes(D - 1, dim); + value.array.swap_axes(D - 1, dim); + } - fn gather_batch_size(shape_tensor: &Shape, shape_indices: &Shape) -> usize { - let mut batch_size = 1; + let (shape_tensor, shape_indices, shape_value) = + (tensor.shape(), indices.shape(), value.shape()); + let (size_tensor, size_index, size_value) = ( + shape_tensor.dims[D - 1], + shape_indices.dims[D - 1], + shape_value.dims[D - 1], + ); + let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); + + if shape_value != shape_indices { + panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indices.dims, shape_value.dims); + } - for i in 0..D - 1 { - if shape_tensor.dims[i] != shape_indices.dims[i] { - panic!( + let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; + let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array; + let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; + + for b in 0..batch_size { + let indices = indices.slice(s!(b, ..)); + + for (i, index) in indices.iter().enumerate() { + let index = *index as usize; + tensor[[b, index]] += value[[b, i]]; + } + } + + let mut output = NdArrayOps::reshape( + NdArrayTensor::::new(tensor.into_shared().into_dyn()), + shape_tensor, + ); + if dim != D - 1 { + output.array.swap_axes(D - 1, dim); + } + output + } + + pub fn mask_where( + tensor: NdArrayTensor, + mask: NdArrayTensor, + source: NdArrayTensor, + ) -> NdArrayTensor { + let mask_mul_4tensor = mask.array.mapv(|x| match x { + true => 0.elem(), + false => 1.elem(), + }); + let mask_mul_4source = mask.array.mapv(|x| match x { + true => 1.elem(), + false => 0.elem(), + }); + let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source); + + NdArrayTensor::new(array) + } + + pub fn mask_fill( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: E, + ) -> NdArrayTensor { + let mask_mul = mask.array.mapv(|x| match x { + true => 0.elem(), + false => 1.elem(), + }); + let mask_add = mask.array.mapv(|x| match x { + true => value, + false => 0.elem(), + }); + let array = (tensor.array * mask_mul) + mask_add; + + NdArrayTensor::new(array) + } + + fn gather_batch_size( + shape_tensor: &Shape, + shape_indices: &Shape, + ) -> usize { + let mut batch_size = 1; + + for i in 0..D - 1 { + if shape_tensor.dims[i] != shape_indices.dims[i] { + panic!( "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims ); - } - batch_size *= shape_indices.dims[i]; + } + batch_size *= shape_indices.dims[i]; + } + + batch_size + } + + pub fn select( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + ) -> NdArrayTensor { + let array = tensor.array.select( + Axis(dim), + &indices + .array + .into_iter() + .map(|i| i as usize) + .collect::>(), + ); + + NdArrayTensor::new(array.into_shared()) } - batch_size - } + pub fn select_assign( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + let mut output_array = tensor.array.into_owned(); - pub fn select( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - ) -> NdArrayTensor { - let array = tensor.array.select( - Axis(dim), - &indices - .array - .into_iter() - .map(|i| i as usize) - .collect::>(), - ); - - NdArrayTensor::new(array.into_shared()) - } - - pub fn select_assign( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - let mut output_array = tensor.array.into_owned(); + for (index_value, index) in indices.array.into_iter().enumerate() { + let mut view = output_array.index_axis_mut(Axis(dim), index as usize); + let value = value.array.index_axis(Axis(dim), index_value); - for (index_value, index) in indices.array.into_iter().enumerate() { - let mut view = output_array.index_axis_mut(Axis(dim), index as usize); - let value = value.array.index_axis(Axis(dim), index_value); + view.zip_mut_with(&value, |a, b| *a += *b); + } - view.zip_mut_with(&value, |a, b| *a += *b); + NdArrayTensor::new(output_array.into_shared()) + } + pub fn argmax( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + arg(tensor, dim, CmpType::Max) } - NdArrayTensor::new(output_array.into_shared()) - } - pub fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - arg(tensor, dim, CmpType::Max) - } + pub fn argmin( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + arg(tensor, dim, CmpType::Min) + } - pub fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - arg(tensor, dim, CmpType::Min) - } + pub fn clamp_min( + mut tensor: NdArrayTensor, + min: E, + ) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x < min { + true => min, + false => x, + }); - pub fn clamp_min(mut tensor: NdArrayTensor, min: E) -> NdArrayTensor { - tensor.array.mapv_inplace(|x| match x < min { - true => min, - false => x, - }); - - tensor - } + tensor + } - pub fn clamp_max(mut tensor: NdArrayTensor, max: E) -> NdArrayTensor { - tensor.array.mapv_inplace(|x| match x > max { - true => max, - false => x, - }); + pub fn clamp_max( + mut tensor: NdArrayTensor, + max: E, + ) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x > max { + true => max, + false => x, + }); - tensor - } - - pub fn clamp( - mut tensor: NdArrayTensor, - min: E, - max: E, - ) -> NdArrayTensor { - tensor.array.mapv_inplace(|x| match x < min { - true => min, - false => match x > max { - true => max, - false => x, - }, - }); + tensor + } - tensor - } + pub fn clamp( + mut tensor: NdArrayTensor, + min: E, + max: E, + ) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x < min { + true => min, + false => match x > max { + true => max, + false => x, + }, + }); + + tensor + } } enum CmpType { - Min, - Max, + Min, + Max, } fn arg( - tensor: NdArrayTensor, - dim: usize, - cmp: CmpType, + tensor: NdArrayTensor, + dim: usize, + cmp: CmpType, ) -> NdArrayTensor { - let mut reshape = tensor.array.shape().to_vec(); - reshape[dim] = 1; - - let output = tensor.array.map_axis(Axis(dim), |arr| { - // Find the min/max value in the array, and return its index. - let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { - let cmp = match cmp { - CmpType::Min => e < &acc.0, - CmpType::Max => e > &acc.0, - }; - - if cmp { - (*e, idx) - } else { - acc - } + let mut reshape = tensor.array.shape().to_vec(); + reshape[dim] = 1; + + let output = tensor.array.map_axis(Axis(dim), |arr| { + // Find the min/max value in the array, and return its index. + let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { + let cmp = match cmp { + CmpType::Min => e < &acc.0, + CmpType::Max => e > &acc.0, + }; + + if cmp { + (*e, idx) + } else { + acc + } + }); + + idx as i64 }); - idx as i64 - }); + let output = output.into_shape(Dim(reshape.as_slice())).unwrap(); - let output = output.into_shape(Dim(reshape.as_slice())).unwrap(); - - NdArrayTensor { - array: output.into_shared(), - } + NdArrayTensor { + array: output.into_shared(), + } } diff --git a/burn-ndarray/src/ops/bool_tensor.rs b/burn-ndarray/src/ops/bool_tensor.rs index 8d2e163674..adddb12621 100644 --- a/burn-ndarray/src/ops/bool_tensor.rs +++ b/burn-ndarray/src/ops/bool_tensor.rs @@ -16,116 +16,116 @@ use burn_tensor::{backend::Backend, Data, Shape}; use super::NdArrayOps; impl BoolTensorOps for NdArray { - fn bool_from_data( - data: Data, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - NdArrayTensor::from_data(data) - } - - fn bool_shape( - tensor: & as Backend>::BoolTensorPrimitive, - ) -> Shape { - tensor.shape() - } - - fn bool_into_data( - tensor: as Backend>::BoolTensorPrimitive, - ) -> Reader> { - let shape = tensor.shape(); - let values = tensor.array.into_iter().collect(); - - Reader::Concrete(Data::new(values, shape)) - } - - fn bool_to_device( - tensor: NdArrayTensor, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - tensor - } - - fn bool_reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - NdArrayOps::reshape(tensor, shape) - } - - fn bool_slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - NdArrayOps::slice(tensor, ranges) - } - - fn bool_into_int( - tensor: as Backend>::BoolTensorPrimitive, - ) -> NdArrayTensor { - let data = Self::bool_into_data(tensor) - .read_sync() - .expect("Always sync with ndarray"); - NdArray::::int_from_data(data.convert(), &NdArrayDevice::Cpu) - } - - fn bool_device( - _tensor: & as Backend>::BoolTensorPrimitive, - ) -> as Backend>::Device { - NdArrayDevice::Cpu - } - - fn bool_empty( - shape: Shape, - _device: & as Backend>::Device, - ) -> as Backend>::BoolTensorPrimitive { - let values = vec![false; shape.num_elements()]; - NdArrayTensor::from_data(Data::new(values, shape)) - } - - fn bool_slice_assign( - tensor: as Backend>::BoolTensorPrimitive, - ranges: [Range; D2], - value: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::BoolTensorPrimitive { - NdArrayOps::slice_assign(tensor, ranges, value) - } - - fn bool_cat( - tensors: Vec< as Backend>::BoolTensorPrimitive>, - dim: usize, - ) -> as Backend>::BoolTensorPrimitive { - NdArrayOps::cat(tensors, dim) - } - - fn bool_equal( - lhs: as Backend>::BoolTensorPrimitive, - rhs: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::BoolTensorPrimitive { - let mut array = lhs.array; - array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b); - - NdArrayTensor { array } - } - - fn bool_not( - tensor: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::BoolTensorPrimitive { - let array = tensor.array.mapv(|a| !a).into_shared(); - NdArrayTensor { array } - } - - fn bool_into_float( - tensor: as Backend>::BoolTensorPrimitive, - ) -> as Backend>::TensorPrimitive { - let array = tensor.array.mapv(|a| (a as i32).elem()).into_shared(); - NdArrayTensor { array } - } - - fn bool_swap_dims( - tensor: as Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::BoolTensorPrimitive { - NdArrayOps::swap_dims(tensor, dim1, dim2) - } + fn bool_from_data( + data: Data, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn bool_shape( + tensor: & as Backend>::BoolTensorPrimitive, + ) -> Shape { + tensor.shape() + } + + fn bool_into_data( + tensor: as Backend>::BoolTensorPrimitive, + ) -> Reader> { + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + + Reader::Concrete(Data::new(values, shape)) + } + + fn bool_to_device( + tensor: NdArrayTensor, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + tensor + } + + fn bool_reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + NdArrayOps::reshape(tensor, shape) + } + + fn bool_slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + NdArrayOps::slice(tensor, ranges) + } + + fn bool_into_int( + tensor: as Backend>::BoolTensorPrimitive, + ) -> NdArrayTensor { + let data = Self::bool_into_data(tensor) + .read_sync() + .expect("Always sync with ndarray"); + NdArray::::int_from_data(data.convert(), &NdArrayDevice::Cpu) + } + + fn bool_device( + _tensor: & as Backend>::BoolTensorPrimitive, + ) -> as Backend>::Device { + NdArrayDevice::Cpu + } + + fn bool_empty( + shape: Shape, + _device: & as Backend>::Device, + ) -> as Backend>::BoolTensorPrimitive { + let values = vec![false; shape.num_elements()]; + NdArrayTensor::from_data(Data::new(values, shape)) + } + + fn bool_slice_assign( + tensor: as Backend>::BoolTensorPrimitive, + ranges: [Range; D2], + value: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + NdArrayOps::slice_assign(tensor, ranges, value) + } + + fn bool_cat( + tensors: Vec< as Backend>::BoolTensorPrimitive>, + dim: usize, + ) -> as Backend>::BoolTensorPrimitive { + NdArrayOps::cat(tensors, dim) + } + + fn bool_equal( + lhs: as Backend>::BoolTensorPrimitive, + rhs: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + let mut array = lhs.array; + array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b); + + NdArrayTensor { array } + } + + fn bool_not( + tensor: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + let array = tensor.array.mapv(|a| !a).into_shared(); + NdArrayTensor { array } + } + + fn bool_into_float( + tensor: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::TensorPrimitive { + let array = tensor.array.mapv(|a| (a as i32).elem()).into_shared(); + NdArrayTensor { array } + } + + fn bool_swap_dims( + tensor: as Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::BoolTensorPrimitive { + NdArrayOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-ndarray/src/ops/conv.rs b/burn-ndarray/src/ops/conv.rs index bd300bc5a2..1d4fcc259c 100644 --- a/burn-ndarray/src/ops/conv.rs +++ b/burn-ndarray/src/ops/conv.rs @@ -1,243 +1,252 @@ use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions}, - ElementConversion, + ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions}, + ElementConversion, }; use ndarray::{s, Array3, Array4, ArrayView2, ArrayViewMut2, Axis, Dim}; use crate::{ - element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d, run_par, - sharing::UnsafeSharedRef, tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d, + run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; #[inline(always)] fn conv2d_mad_inner( - mut output: ArrayViewMut2, - x: ArrayView2, - k: E, - k_xy: (usize, usize), - out_xy: (usize, usize), - stride: (usize, usize), - dilation: (usize, usize), + mut output: ArrayViewMut2, + x: ArrayView2, + k: E, + k_xy: (usize, usize), + out_xy: (usize, usize), + stride: (usize, usize), + dilation: (usize, usize), ) { - let (kh, kw) = k_xy; - let (out_width, out_height) = out_xy; - let (stride_width, stride_height) = stride; - let (dilation_width, dilation_height) = dilation; - - for oh in 0..out_height { - // Construct a sub-slice view of the input row. - // This is done upfront so that rustc does not have to emit bounds checks - // in the hot loop below. - let ir = x - .row(oh * stride_height + kh * dilation_height) - .to_slice() - .unwrap(); - - // Ditto. Construct a sub-slice view of the output row, and explicitly specify - // the bounds upfront as 0..out_width so that rustc can make the assumption - // that all accesses are in-bounds in the below loop. - let mut or = output.row_mut(oh); - let or = &mut or.as_slice_mut().unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - let iw = (ow * stride_width) + (kw * dilation_width); - or[ow] += ir[iw] * k; + let (kh, kw) = k_xy; + let (out_width, out_height) = out_xy; + let (stride_width, stride_height) = stride; + let (dilation_width, dilation_height) = dilation; + + for oh in 0..out_height { + // Construct a sub-slice view of the input row. + // This is done upfront so that rustc does not have to emit bounds checks + // in the hot loop below. + let ir = x + .row(oh * stride_height + kh * dilation_height) + .to_slice() + .unwrap(); + + // Ditto. Construct a sub-slice view of the output row, and explicitly specify + // the bounds upfront as 0..out_width so that rustc can make the assumption + // that all accesses are in-bounds in the below loop. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + let iw = (ow * stride_width) + (kw * dilation_width); + or[ow] += ir[iw] * k; + } } - } } pub(crate) fn conv2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvOptions<2>, + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvOptions<2>, ) -> NdArrayTensor { - let [dilation_height, dilation_width] = options.dilation; - let [padding_height, padding_width] = options.padding; - let [stride_height, stride_width] = options.stride; - let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; - let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims; - - let out_height = calculate_conv_output_size( - kernel_height, - stride_height, - padding_height, - dilation_height, - in_height, - ); - let out_width = calculate_conv_output_size( - kernel_width, - stride_width, - padding_width, - dilation_width, - in_width, - ); - - let x = apply_padding_4d(x, options.padding, 0i32.elem()).array; - - // Convert inputs from dynamic indexes to static to improve perf. - let x = x.into_dimensionality::().unwrap(); - let weights = weight.array.into_dimensionality::().unwrap(); - - let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); - - run_par!(|| { - iter_par!(output.axis_iter_mut(Axis(0))) - .enumerate() - .for_each( - #[inline(never)] - |(k, mut output)| { - let b = k / out_channels; - let oc = k % out_channels; - let g = k % options.groups; - - for ic in (in_channels * g)..(in_channels * (g + 1)) { - let weight_ic = ic - (g * in_channels); - - let x = x.slice(s![b, ic, .., ..]); - let k = weights.slice(s![oc, weight_ic, .., ..]); - - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let k = k[[kh, kw]]; - - // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization - // in the case that the stride/dilation is 1. - #[allow(clippy::if_same_then_else)] - if (1, 1, 1, 1) == (stride_width, stride_height, dilation_width, dilation_height) { - conv2d_mad_inner( - output.view_mut(), - x.view(), - k, - (kh, kw), - (out_width, out_height), - (stride_width, stride_height), - (dilation_width, dilation_height), - ); - } else { - conv2d_mad_inner( - output.view_mut(), - x.view(), - k, - (kh, kw), - (out_width, out_height), - (stride_width, stride_height), - (dilation_width, dilation_height), - ); - } - } - } - } - - if let Some(bias) = &bias { - let bias = bias.array[oc]; - - for oh in 0..out_height { - // Get a mutable slice reference to the row we're looping over. - // We explicitly define the bounds to 0..out_width so that rustc can make - // the assumption that all accesses are in-bounds. - let mut or = output.row_mut(oh); - let or = &mut or.as_slice_mut().unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - or[ow] += bias; - } - } - } - }, - ); - }); - - let output = output - .into_shape([batch_size, out_channels, out_height, out_width]) - .unwrap() - .into_dyn() - .into_shared(); - - NdArrayTensor::new(output) + let [dilation_height, dilation_width] = options.dilation; + let [padding_height, padding_width] = options.padding; + let [stride_height, stride_width] = options.stride; + let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; + let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims; + + let out_height = calculate_conv_output_size( + kernel_height, + stride_height, + padding_height, + dilation_height, + in_height, + ); + let out_width = calculate_conv_output_size( + kernel_width, + stride_width, + padding_width, + dilation_width, + in_width, + ); + + let x = apply_padding_4d(x, options.padding, 0i32.elem()).array; + + // Convert inputs from dynamic indexes to static to improve perf. + let x = x.into_dimensionality::().unwrap(); + let weights = weight.array.into_dimensionality::().unwrap(); + + let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); + + run_par!(|| { + iter_par!(output.axis_iter_mut(Axis(0))) + .enumerate() + .for_each( + #[inline(never)] + |(k, mut output)| { + let b = k / out_channels; + let oc = k % out_channels; + let g = k % options.groups; + + for ic in (in_channels * g)..(in_channels * (g + 1)) { + let weight_ic = ic - (g * in_channels); + + let x = x.slice(s![b, ic, .., ..]); + let k = weights.slice(s![oc, weight_ic, .., ..]); + + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let k = k[[kh, kw]]; + + // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization + // in the case that the stride/dilation is 1. + #[allow(clippy::if_same_then_else)] + if (1, 1, 1, 1) + == ( + stride_width, + stride_height, + dilation_width, + dilation_height, + ) + { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } else { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } + } + } + } + + if let Some(bias) = &bias { + let bias = bias.array[oc]; + + for oh in 0..out_height { + // Get a mutable slice reference to the row we're looping over. + // We explicitly define the bounds to 0..out_width so that rustc can make + // the assumption that all accesses are in-bounds. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + or[ow] += bias; + } + } + } + }, + ); + }); + + let output = output + .into_shape([batch_size, out_channels, out_height, out_width]) + .unwrap() + .into_dyn() + .into_shared(); + + NdArrayTensor::new(output) } pub(crate) fn conv_transpose2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvTransposeOptions<2>, + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvTransposeOptions<2>, ) -> NdArrayTensor { - let [dilation_height, dilation_width] = options.dilation; - let [padding_height, padding_width] = options.padding; - let [stride_height, stride_width] = options.stride; - let [out_padding_height, out_padding_width] = options.padding_out; - let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; - let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims; - - let out_height = - (in_height - 1) * stride_height + dilation_height * (kernel_height - 1) + out_padding_height - - 2 * padding_height - + 1; - let out_width = - (in_width - 1) * stride_width + dilation_width * (kernel_width - 1) + out_padding_width - - 2 * padding_width - + 1; - - let x = x.array; - let mut output = Array4::zeros(Dim([ - batch_size, - out_channels * options.groups, - out_height, - out_width, - ])); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { - let b = k / (out_channels * options.groups); - let oc = k % out_channels; - let g = k % options.groups; - - let output = unsafe_shared_out.get(); - - let oc_out = oc + (out_channels * g); - let ic_start = g * (in_channels / options.groups); - let ic_end = ic_start + in_channels / options.groups; - - for ic in ic_start..ic_end { - for ih in 0..in_height { - for iw in 0..in_width { - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let oh = ih * stride_height + kh * dilation_height; - let ow = iw * stride_width + kw * dilation_width; - - if oh >= out_height + padding_height - || ow >= out_width + padding_width - || oh < padding_height - || ow < padding_width - { - continue; + let [dilation_height, dilation_width] = options.dilation; + let [padding_height, padding_width] = options.padding; + let [stride_height, stride_width] = options.stride; + let [out_padding_height, out_padding_width] = options.padding_out; + let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; + let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims; + + let out_height = (in_height - 1) * stride_height + + dilation_height * (kernel_height - 1) + + out_padding_height + - 2 * padding_height + + 1; + let out_width = + (in_width - 1) * stride_width + dilation_width * (kernel_width - 1) + out_padding_width + - 2 * padding_width + + 1; + + let x = x.array; + let mut output = Array4::zeros(Dim([ + batch_size, + out_channels * options.groups, + out_height, + out_width, + ])); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { + let b = k / (out_channels * options.groups); + let oc = k % out_channels; + let g = k % options.groups; + + let output = unsafe_shared_out.get(); + + let oc_out = oc + (out_channels * g); + let ic_start = g * (in_channels / options.groups); + let ic_end = ic_start + in_channels / options.groups; + + for ic in ic_start..ic_end { + for ih in 0..in_height { + for iw in 0..in_width { + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let oh = ih * stride_height + kh * dilation_height; + let ow = iw * stride_width + kw * dilation_width; + + if oh >= out_height + padding_height + || ow >= out_width + padding_width + || oh < padding_height + || ow < padding_width + { + continue; + } + + let oh = oh - padding_height; + let ow = ow - padding_width; + + output[[b, oc_out, oh, ow]] += + x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]]; + } + } + } } - - let oh = oh - padding_height; - let ow = ow - padding_width; - - output[[b, oc_out, oh, ow]] += x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]]; - } } - } - } - } - if let Some(bias) = &bias { - for oh in 0..out_height { - for ow in 0..out_width { - output[[b, oc_out, oh, ow]] += bias.array[oc_out]; - } - } - } + if let Some(bias) = &bias { + for oh in 0..out_height { + for ow in 0..out_width { + output[[b, oc_out, oh, ow]] += bias.array[oc_out]; + } + } + } + }); }); - }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } diff --git a/burn-ndarray/src/ops/int_tensor.rs b/burn-ndarray/src/ops/int_tensor.rs index 88b9989315..fb6adb5517 100644 --- a/burn-ndarray/src/ops/int_tensor.rs +++ b/burn-ndarray/src/ops/int_tensor.rs @@ -19,350 +19,362 @@ use burn_tensor::{backend::Backend, Data, Shape}; use super::{NdArrayMathOps, NdArrayOps}; impl IntTensorOps for NdArray { - fn int_from_data( - data: Data, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - NdArrayTensor::from_data(data) - } - - fn int_shape(tensor: &NdArrayTensor) -> Shape { - tensor.shape() - } - - fn int_into_data(tensor: NdArrayTensor) -> Reader> { - let shape = tensor.shape(); - let values = tensor.array.into_iter().collect(); - - Reader::Concrete(Data::new(values, shape)) - } - - fn int_to_device( - tensor: NdArrayTensor, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - tensor - } - - fn int_reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - NdArrayOps::reshape(tensor, shape) - } - - fn int_slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - NdArrayOps::slice(tensor, ranges) - } - - fn int_device( - _tensor: &NdArrayTensor, - ) -> as Backend>::Device { - NdArrayDevice::Cpu - } - - fn int_empty( - shape: Shape, - _device: & as Backend>::Device, - ) -> NdArrayTensor { - let values = vec![0; shape.num_elements()]; - NdArrayTensor::from_data(Data::new(values, shape)) - } - - fn int_mask_where( - tensor: NdArrayTensor, - mask: NdArrayTensor, - source: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mask_where(tensor, mask, source) - } - - fn int_mask_fill( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: i64, - ) -> NdArrayTensor { - NdArrayMathOps::mask_fill(tensor, mask, value) - } - - fn int_slice_assign( - tensor: NdArrayTensor, - ranges: [Range; D2], - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayOps::slice_assign(tensor, ranges, value) - } - - fn int_cat( - tensors: Vec>, - dim: usize, - ) -> NdArrayTensor { - NdArrayOps::cat(tensors, dim) - } - - fn int_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - - Self::int_equal_elem(tensor, 0) - } - - fn int_equal_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a == rhs).into_shared(); - NdArrayTensor { array } - } - - fn int_greater( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_greater_elem(tensor, 0) - } - - fn int_greater_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a > rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_greater_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_greater_equal_elem(tensor, 0) - } - - fn int_greater_equal_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a >= rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_lower( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_lower_elem(tensor, 0) - } - - fn int_lower_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a < rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_lower_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = Self::int_sub(lhs, rhs); - Self::int_lower_equal_elem(tensor, 0) - } - - fn int_lower_equal_elem( - lhs: NdArrayTensor, - rhs: i64, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a <= rhs).into_shared(); - NdArrayTensor::new(array) - } - - fn int_add( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::add(lhs, rhs) - } - - fn int_add_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { - NdArrayMathOps::add_scalar(lhs, rhs) - } - - fn int_sub( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::sub(lhs, rhs) - } - - fn int_sub_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { - NdArrayMathOps::sub_scalar(lhs, rhs) - } - - fn int_mul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mul(lhs, rhs) - } - - fn int_mul_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { - NdArrayMathOps::mul_scalar(lhs, rhs) - } - - fn int_div( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::div(lhs, rhs) - } - - fn int_div_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { - NdArrayMathOps::div_scalar(lhs, rhs) - } - - fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor { - Self::int_mul_scalar(tensor, -1) - } - - fn int_zeros( - shape: Shape, - device: & as Backend>::Device, - ) -> NdArrayTensor { - Self::int_from_data(Data::zeros(shape), device) - } - - fn int_ones( - shape: Shape, - device: & as Backend>::Device, - ) -> NdArrayTensor { - Self::int_from_data(Data::ones(shape), device) - } - - fn int_full( - shape: Shape, - fill_value: i64, - device: & as Backend>::Device, - ) -> NdArrayTensor { - Self::int_from_data(Data::full(shape, fill_value), device) - } - - fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::sum(tensor) - } - - fn int_sum_dim( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::sum_dim(tensor, dim) - } - - fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::mean(tensor) - } - - fn int_mean_dim( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::mean_dim(tensor, dim) - } - - fn int_gather( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select(tensor, dim, indices) - } - - fn int_select_assign( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select_assign(tensor, dim, indices, value) - } - fn int_argmax( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::argmax(tensor, dim) - } - - fn int_argmin( - tensor: NdArrayTensor, - dim: usize, - ) -> NdArrayTensor { - NdArrayMathOps::argmin(tensor, dim) - } - - fn int_clamp_min( - tensor: NdArrayTensor, - min: i64, - ) -> NdArrayTensor { - NdArrayMathOps::clamp_min(tensor, min) - } - - fn int_clamp_max( - tensor: NdArrayTensor, - max: i64, - ) -> NdArrayTensor { - NdArrayMathOps::clamp_max(tensor, max) - } - - fn int_clamp( - tensor: NdArrayTensor, - min: i64, - max: i64, - ) -> NdArrayTensor { - NdArrayMathOps::clamp(tensor, min, max) - } - - fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn int_into_float( - tensor: as Backend>::IntTensorPrimitive, - ) -> as Backend>::TensorPrimitive { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - NdArrayTensor { array } - } - - fn int_swap_dims( - tensor: as Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::IntTensorPrimitive { - NdArrayOps::swap_dims(tensor, dim1, dim2) - } + fn int_from_data( + data: Data, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn int_shape(tensor: &NdArrayTensor) -> Shape { + tensor.shape() + } + + fn int_into_data(tensor: NdArrayTensor) -> Reader> { + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + + Reader::Concrete(Data::new(values, shape)) + } + + fn int_to_device( + tensor: NdArrayTensor, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + tensor + } + + fn int_reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + NdArrayOps::reshape(tensor, shape) + } + + fn int_slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + NdArrayOps::slice(tensor, ranges) + } + + fn int_device( + _tensor: &NdArrayTensor, + ) -> as Backend>::Device { + NdArrayDevice::Cpu + } + + fn int_empty( + shape: Shape, + _device: & as Backend>::Device, + ) -> NdArrayTensor { + let values = vec![0; shape.num_elements()]; + NdArrayTensor::from_data(Data::new(values, shape)) + } + + fn int_mask_where( + tensor: NdArrayTensor, + mask: NdArrayTensor, + source: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mask_where(tensor, mask, source) + } + + fn int_mask_fill( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: i64, + ) -> NdArrayTensor { + NdArrayMathOps::mask_fill(tensor, mask, value) + } + + fn int_slice_assign( + tensor: NdArrayTensor, + ranges: [Range; D2], + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayOps::slice_assign(tensor, ranges, value) + } + + fn int_cat( + tensors: Vec>, + dim: usize, + ) -> NdArrayTensor { + NdArrayOps::cat(tensors, dim) + } + + fn int_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + + Self::int_equal_elem(tensor, 0) + } + + fn int_equal_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a == rhs).into_shared(); + NdArrayTensor { array } + } + + fn int_greater( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_greater_elem(tensor, 0) + } + + fn int_greater_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a > rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_greater_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_greater_equal_elem(tensor, 0) + } + + fn int_greater_equal_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a >= rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_lower( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_lower_elem(tensor, 0) + } + + fn int_lower_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a < rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_lower_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = Self::int_sub(lhs, rhs); + Self::int_lower_equal_elem(tensor, 0) + } + + fn int_lower_equal_elem( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a <= rhs).into_shared(); + NdArrayTensor::new(array) + } + + fn int_add( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::add(lhs, rhs) + } + + fn int_add_scalar( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + NdArrayMathOps::add_scalar(lhs, rhs) + } + + fn int_sub( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::sub(lhs, rhs) + } + + fn int_sub_scalar( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + NdArrayMathOps::sub_scalar(lhs, rhs) + } + + fn int_mul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mul(lhs, rhs) + } + + fn int_mul_scalar( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + NdArrayMathOps::mul_scalar(lhs, rhs) + } + + fn int_div( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::div(lhs, rhs) + } + + fn int_div_scalar( + lhs: NdArrayTensor, + rhs: i64, + ) -> NdArrayTensor { + NdArrayMathOps::div_scalar(lhs, rhs) + } + + fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor { + Self::int_mul_scalar(tensor, -1) + } + + fn int_zeros( + shape: Shape, + device: & as Backend>::Device, + ) -> NdArrayTensor { + Self::int_from_data(Data::zeros(shape), device) + } + + fn int_ones( + shape: Shape, + device: & as Backend>::Device, + ) -> NdArrayTensor { + Self::int_from_data(Data::ones(shape), device) + } + + fn int_full( + shape: Shape, + fill_value: i64, + device: & as Backend>::Device, + ) -> NdArrayTensor { + Self::int_from_data(Data::full(shape, fill_value), device) + } + + fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::sum(tensor) + } + + fn int_sum_dim( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::sum_dim(tensor, dim) + } + + fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::mean(tensor) + } + + fn int_mean_dim( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::mean_dim(tensor, dim) + } + + fn int_gather( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select(tensor, dim, indices) + } + + fn int_select_assign( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select_assign(tensor, dim, indices, value) + } + fn int_argmax( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::argmax(tensor, dim) + } + + fn int_argmin( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::argmin(tensor, dim) + } + + fn int_clamp_min( + tensor: NdArrayTensor, + min: i64, + ) -> NdArrayTensor { + NdArrayMathOps::clamp_min(tensor, min) + } + + fn int_clamp_max( + tensor: NdArrayTensor, + max: i64, + ) -> NdArrayTensor { + NdArrayMathOps::clamp_max(tensor, max) + } + + fn int_clamp( + tensor: NdArrayTensor, + min: i64, + max: i64, + ) -> NdArrayTensor { + NdArrayMathOps::clamp(tensor, min, max) + } + + fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn int_into_float( + tensor: as Backend>::IntTensorPrimitive, + ) -> as Backend>::TensorPrimitive { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + NdArrayTensor { array } + } + + fn int_swap_dims( + tensor: as Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::IntTensorPrimitive { + NdArrayOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-ndarray/src/ops/macros.rs b/burn-ndarray/src/ops/macros.rs index 6895d8fc5d..b92b37b82d 100644 --- a/burn-ndarray/src/ops/macros.rs +++ b/burn-ndarray/src/ops/macros.rs @@ -1,26 +1,26 @@ macro_rules! keepdim { - ( + ( $D:expr, $dim:expr, $self:expr, mean ) => {{ - let tensor: NdArrayTensor = mean_dim($self.clone(), $dim); - let mut shape = $self.shape(); - shape.dims[$dim] = 1; - NdArrayOps::reshape(tensor.clone(), shape) - }}; - ( + let tensor: NdArrayTensor = mean_dim($self.clone(), $dim); + let mut shape = $self.shape(); + shape.dims[$dim] = 1; + NdArrayOps::reshape(tensor.clone(), shape) + }}; + ( $D:expr, $dim:expr, $self:expr, sum ) => {{ - let tensor: NdArrayTensor = sum_dim($self.clone(), $dim); - let mut shape = $self.shape(); - shape.dims[$dim] = 1; - NdArrayOps::reshape(tensor, shape) - }}; + let tensor: NdArrayTensor = sum_dim($self.clone(), $dim); + let mut shape = $self.shape(); + shape.dims[$dim] = 1; + NdArrayOps::reshape(tensor, shape) + }}; } pub(crate) use keepdim; @@ -29,19 +29,19 @@ use ndarray::Axis; use crate::{element::NdArrayElement, tensor::NdArrayTensor}; pub(crate) fn mean_dim( - tensor: NdArrayTensor, - dim: usize, + tensor: NdArrayTensor, + dim: usize, ) -> NdArrayTensor { - let array = tensor.array.mean_axis(Axis(dim)).unwrap().into_shared(); + let array = tensor.array.mean_axis(Axis(dim)).unwrap().into_shared(); - NdArrayTensor { array } + NdArrayTensor { array } } pub(crate) fn sum_dim( - tensor: NdArrayTensor, - dim: usize, + tensor: NdArrayTensor, + dim: usize, ) -> NdArrayTensor { - let array = tensor.array.sum_axis(Axis(dim)).into_shared(); + let array = tensor.array.sum_axis(Axis(dim)).into_shared(); - NdArrayTensor { array } + NdArrayTensor { array } } diff --git a/burn-ndarray/src/ops/matmul.rs b/burn-ndarray/src/ops/matmul.rs index 4db10005ed..185a7567dd 100644 --- a/burn-ndarray/src/ops/matmul.rs +++ b/burn-ndarray/src/ops/matmul.rs @@ -5,101 +5,107 @@ use burn_tensor::{ops::TensorOps, Shape}; use ndarray::s; pub(crate) fn matmul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, + lhs: NdArrayTensor, + rhs: NdArrayTensor, ) -> NdArrayTensor where - E: FloatNdArrayElement, + E: FloatNdArrayElement, { - let shape_ori_lhs = lhs.shape(); - let shape_ori_rhs = rhs.shape(); + let shape_ori_lhs = lhs.shape(); + let shape_ori_rhs = rhs.shape(); - let lhs = reshape(lhs); - let rhs = reshape(rhs); + let lhs = reshape(lhs); + let rhs = reshape(rhs); - let [batch_size_lhs, m, _] = lhs.shape().dims; - let [batch_size_rhs, _, n] = rhs.shape().dims; + let [batch_size_lhs, m, _] = lhs.shape().dims; + let [batch_size_rhs, _, n] = rhs.shape().dims; - let mut shape_out = match batch_size_lhs > batch_size_rhs { - true => shape_ori_lhs, - false => shape_ori_rhs, - }; - shape_out.dims[D - 2] = m; - shape_out.dims[D - 1] = n; + let mut shape_out = match batch_size_lhs > batch_size_rhs { + true => shape_ori_lhs, + false => shape_ori_rhs, + }; + shape_out.dims[D - 2] = m; + shape_out.dims[D - 1] = n; - let out = general_matmul(lhs, rhs); + let out = general_matmul(lhs, rhs); - NdArray::::reshape(out, shape_out) + NdArray::::reshape(out, shape_out) } fn general_matmul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, + lhs: NdArrayTensor, + rhs: NdArrayTensor, ) -> NdArrayTensor { - run_par!(|| { - let [batch_size_lhs, m, _] = lhs.shape().dims; - let [batch_size_rhs, k, n] = rhs.shape().dims; - let batch_size = usize::max(batch_size_rhs, batch_size_lhs); - - if batch_size_lhs > batch_size && batch_size_lhs != 1 { - panic!("Broadcast on multiple dimensions is not yet supported"); - } - - if batch_size_rhs > batch_size && batch_size_rhs != 1 { - panic!("Broadcast on multiple dimensions is not yet supported"); - } - - let alpha: E = 1.0.elem(); - let beta: E = 0.0.elem(); - - let mut out_array = ndarray::Array3::::zeros((batch_size, m, n)); - let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array); - - let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap(); - let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap(); - - iter_range_par!(0, batch_size).for_each(|b| { - let lhs_slice = match batch_size_lhs == 1 { - true => lhs_array.slice(s!(0, .., ..)), - false => lhs_array.slice(s!(b, .., ..)), - }; - let rhs_slice = match batch_size_rhs == 1 { - true => rhs_array.slice(s!(0, .., ..)), - false => rhs_array.slice(s!(b, .., ..)), - }; - - unsafe { - let mut out_slice = unsafe_shared_out_array.get().slice_mut(s!(b, .., ..)); - - ndarray::linalg::general_mat_mul(alpha, &lhs_slice, &rhs_slice, beta, &mut out_slice); - } - }); - - NdArrayTensor::new(out_array.into_shared().into_dyn()) - }) + run_par!(|| { + let [batch_size_lhs, m, _] = lhs.shape().dims; + let [batch_size_rhs, k, n] = rhs.shape().dims; + let batch_size = usize::max(batch_size_rhs, batch_size_lhs); + + if batch_size_lhs > batch_size && batch_size_lhs != 1 { + panic!("Broadcast on multiple dimensions is not yet supported"); + } + + if batch_size_rhs > batch_size && batch_size_rhs != 1 { + panic!("Broadcast on multiple dimensions is not yet supported"); + } + + let alpha: E = 1.0.elem(); + let beta: E = 0.0.elem(); + + let mut out_array = ndarray::Array3::::zeros((batch_size, m, n)); + let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array); + + let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap(); + let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap(); + + iter_range_par!(0, batch_size).for_each(|b| { + let lhs_slice = match batch_size_lhs == 1 { + true => lhs_array.slice(s!(0, .., ..)), + false => lhs_array.slice(s!(b, .., ..)), + }; + let rhs_slice = match batch_size_rhs == 1 { + true => rhs_array.slice(s!(0, .., ..)), + false => rhs_array.slice(s!(b, .., ..)), + }; + + unsafe { + let mut out_slice = unsafe_shared_out_array.get().slice_mut(s!(b, .., ..)); + + ndarray::linalg::general_mat_mul( + alpha, + &lhs_slice, + &rhs_slice, + beta, + &mut out_slice, + ); + } + }); + + NdArrayTensor::new(out_array.into_shared().into_dyn()) + }) } fn reshape( - tensor: NdArrayTensor, + tensor: NdArrayTensor, ) -> NdArrayTensor { - let shape = tensor.shape(); + let shape = tensor.shape(); - if D < 2 { - NdArray::::reshape(tensor, Shape::new([1, 1, shape.dims[0]])) - } else { - let batch_size = batch_size(&shape); - let size0 = shape.dims[D - 2]; - let size1 = shape.dims[D - 1]; + if D < 2 { + NdArray::::reshape(tensor, Shape::new([1, 1, shape.dims[0]])) + } else { + let batch_size = batch_size(&shape); + let size0 = shape.dims[D - 2]; + let size1 = shape.dims[D - 1]; - NdArray::::reshape(tensor, Shape::new([batch_size, size0, size1])) - } + NdArray::::reshape(tensor, Shape::new([batch_size, size0, size1])) + } } fn batch_size(shape: &Shape) -> usize { - let mut num_batch = 1; - for i in 0..D - 2 { - num_batch *= shape.dims[i]; - } + let mut num_batch = 1; + for i in 0..D - 2 { + num_batch *= shape.dims[i]; + } - num_batch + num_batch } diff --git a/burn-ndarray/src/ops/maxpool.rs b/burn-ndarray/src/ops/maxpool.rs index 7546271205..948c942932 100644 --- a/burn-ndarray/src/ops/maxpool.rs +++ b/burn-ndarray/src/ops/maxpool.rs @@ -1,181 +1,183 @@ use crate::{ - element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par, - sharing::UnsafeSharedRef, tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par, + sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; use burn_tensor::ElementConversion; use ndarray::Array4; pub(crate) fn max_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> NdArrayTensor { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - let inf = (-f32::INFINITY).elem::(); + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + let inf = (-f32::INFINITY).elem::(); - let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) - / stride_height) - + 1; - let out_width = - ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; + let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) + / stride_height) + + 1; + let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) + / stride_width) + + 1; - let x = apply_padding_4d(x, padding, inf).array; + let x = apply_padding_4d(x, padding, inf).array; - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; - let output = unsafe_shared_out.get(); + let output = unsafe_shared_out.get(); - for oh in 0..out_height { - for ow in 0..out_width { - let mut max_val = inf; + for oh in 0..out_height { + for ow in 0..out_width { + let mut max_val = inf; - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height; + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height; - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width; + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width; - let val = x[[b, c, ih, iw]]; + let val = x[[b, c, ih, iw]]; - if val > max_val { - max_val = val; - } - } - } + if val > max_val { + max_val = val; + } + } + } - output[[b, c, oh, ow]] = max_val; - } - } - }) - }); + output[[b, c, oh, ow]] = max_val; + } + } + }) + }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } pub(crate) fn max_pool2d_with_indices( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (NdArrayTensor, NdArrayTensor) { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape().dims; - let inf = (-f32::INFINITY).elem::(); - - let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) - / stride_height) - + 1; - let out_width = - ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; - - let x = apply_padding_4d(x, padding, inf).array; - - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); - let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - let indices = unsafe_shared_indices.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut max_val = inf; - let mut index = 0; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width; - let val = x[[b, c, ih, iw]]; - - if val > max_val { - max_val = val; - - let ih = ih as i64 - padding_height as i64; - let iw = iw as i64 - padding_width as i64; - - index = ih * x_height as i64 + iw; - } + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape().dims; + let inf = (-f32::INFINITY).elem::(); + + let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) + / stride_height) + + 1; + let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) + / stride_width) + + 1; + + let x = apply_padding_4d(x, padding, inf).array; + + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); + let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + let indices = unsafe_shared_indices.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut max_val = inf; + let mut index = 0; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width; + let val = x[[b, c, ih, iw]]; + + if val > max_val { + max_val = val; + + let ih = ih as i64 - padding_height as i64; + let iw = iw as i64 - padding_width as i64; + + index = ih * x_height as i64 + iw; + } + } + } + + output[[b, c, oh, ow]] = max_val; + indices[[b, c, oh, ow]] = index; + } } - } - - output[[b, c, oh, ow]] = max_val; - indices[[b, c, oh, ow]] = index; - } - } - }) - }); + }) + }); - let output = NdArrayTensor::new(output.into_dyn().into_shared()); - let indices = NdArrayTensor::new(indices.into_dyn().into_shared()); + let output = NdArrayTensor::new(output.into_dyn().into_shared()); + let indices = NdArrayTensor::new(indices.into_dyn().into_shared()); - (output, indices) + (output, indices) } pub(crate) fn max_pool2d_backward( - x: NdArrayTensor, - _kernel_size: [usize; 2], - _stride: [usize; 2], - _padding: [usize; 2], - _dilation: [usize; 2], - output_grad: NdArrayTensor, - indices: NdArrayTensor, + x: NdArrayTensor, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _dilation: [usize; 2], + output_grad: NdArrayTensor, + indices: NdArrayTensor, ) -> NdArrayTensor { - let [_batch_size, _channels, height, width] = output_grad.shape().dims; - let [batch_size, channels, height_x, width_x] = x.shape().dims; + let [_batch_size, _channels, height, width] = output_grad.shape().dims; + let [batch_size, channels, height_x, width_x] = x.shape().dims; - let output_grad = output_grad.array; - let indices = indices.array; + let output_grad = output_grad.array; + let indices = indices.array; - let mut output = Array4::zeros((batch_size, channels, height_x, width_x)); + let mut output = Array4::zeros((batch_size, channels, height_x, width_x)); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; - let output = unsafe_shared_out.get(); + let output = unsafe_shared_out.get(); - for h in 0..height { - for w in 0..width { - let index = indices[[b, c, h, w]]; - let grad = output_grad[[b, c, h, w]]; + for h in 0..height { + for w in 0..width { + let index = indices[[b, c, h, w]]; + let grad = output_grad[[b, c, h, w]]; - let index_h = index as usize / width_x; - let index_w = index as usize % width_x; + let index_h = index as usize / width_x; + let index_w = index as usize % width_x; - output[[b, c, index_h, index_w]] += grad; - } - } + output[[b, c, index_h, index_w]] += grad; + } + } + }); }); - }); - NdArrayTensor::new(output.into_dyn().into_shared()) + NdArrayTensor::new(output.into_dyn().into_shared()) } diff --git a/burn-ndarray/src/ops/module.rs b/burn-ndarray/src/ops/module.rs index 5d7f63378e..119f13657a 100644 --- a/burn-ndarray/src/ops/module.rs +++ b/burn-ndarray/src/ops/module.rs @@ -1,102 +1,102 @@ use super::{ - adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, - avgpool::{avg_pool2d, avg_pool2d_backward}, - conv::{conv2d, conv_transpose2d}, - maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, + adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, + avgpool::{avg_pool2d, avg_pool2d_backward}, + conv::{conv2d, conv_transpose2d}, + maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, }; use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray}; use burn_tensor::ops::*; impl ModuleOps for NdArray { - fn conv2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> NdArrayTensor { - conv2d(x, weight, bias, options) - } + fn conv2d( + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> NdArrayTensor { + conv2d(x, weight, bias, options) + } - fn conv_transpose2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> NdArrayTensor { - conv_transpose2d(x, weight, bias, options) - } + fn conv_transpose2d( + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> NdArrayTensor { + conv_transpose2d(x, weight, bias, options) + } - fn avg_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> NdArrayTensor { - avg_pool2d(x, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> NdArrayTensor { + avg_pool2d(x, kernel_size, stride, padding, count_include_pad) + } - fn avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> NdArrayTensor { - avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d_backward( + x: NdArrayTensor, + grad: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> NdArrayTensor { + avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) + } - fn max_pool2d( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> NdArrayTensor { - max_pool2d(x, kernel_size, stride, padding, dilation) - } + fn max_pool2d( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> NdArrayTensor { + max_pool2d(x, kernel_size, stride, padding, dilation) + } - fn max_pool2d_with_indices( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); + fn max_pool2d_with_indices( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); - MaxPool2dWithIndices::new(output, indices) - } + MaxPool2dWithIndices::new(output, indices) + } - fn max_pool2d_with_indices_backward( - x: NdArrayTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: NdArrayTensor, - indices: NdArrayTensor, - ) -> MaxPool2dBackward> { - MaxPool2dBackward::new(max_pool2d_backward( - x, - kernel_size, - stride, - padding, - dilation, - output_grad, - indices, - )) - } + fn max_pool2d_with_indices_backward( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: NdArrayTensor, + indices: NdArrayTensor, + ) -> MaxPool2dBackward> { + MaxPool2dBackward::new(max_pool2d_backward( + x, + kernel_size, + stride, + padding, + dilation, + output_grad, + indices, + )) + } - fn adaptive_avg_pool2d(x: NdArrayTensor, output_size: [usize; 2]) -> NdArrayTensor { - adaptive_avg_pool2d(x, output_size) - } + fn adaptive_avg_pool2d(x: NdArrayTensor, output_size: [usize; 2]) -> NdArrayTensor { + adaptive_avg_pool2d(x, output_size) + } - fn adaptive_avg_pool2d_backward( - x: NdArrayTensor, - grad: NdArrayTensor, - ) -> NdArrayTensor { - adaptive_avg_pool2d_backward(x, grad) - } + fn adaptive_avg_pool2d_backward( + x: NdArrayTensor, + grad: NdArrayTensor, + ) -> NdArrayTensor { + adaptive_avg_pool2d_backward(x, grad) + } } diff --git a/burn-ndarray/src/ops/padding.rs b/burn-ndarray/src/ops/padding.rs index a072f3e77b..790a291053 100644 --- a/burn-ndarray/src/ops/padding.rs +++ b/burn-ndarray/src/ops/padding.rs @@ -3,31 +3,31 @@ use burn_tensor::ops::TensorOps; use ndarray::Array4; pub(crate) fn apply_padding_4d( - x: NdArrayTensor, - padding: [usize; 2], - elem: E, + x: NdArrayTensor, + padding: [usize; 2], + elem: E, ) -> NdArrayTensor { - let [batch_size, input_channels, height, width] = x.shape().dims; - let [padding_height, padding_width] = padding; - let padded_height = height + 2 * padding_height; - let padded_width = width + 2 * padding_width; + let [batch_size, input_channels, height, width] = x.shape().dims; + let [padding_height, padding_width] = padding; + let padded_height = height + 2 * padding_height; + let padded_width = width + 2 * padding_width; - let x_new = Array4::from_elem( - (batch_size, input_channels, padded_height, padded_width), - elem, - ); - let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn()); + let x_new = Array4::from_elem( + (batch_size, input_channels, padded_height, padded_width), + elem, + ); + let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn()); - x_new = NdArray::slice_assign( - x_new, - [ - 0..batch_size, - 0..input_channels, - padding_height..(height + padding_height), - padding_width..width + padding_width, - ], - x, - ); + x_new = NdArray::slice_assign( + x_new, + [ + 0..batch_size, + 0..input_channels, + padding_height..(height + padding_height), + padding_width..width + padding_width, + ], + x, + ); - x_new + x_new } diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index 6367aa7ae3..0ac3f20226 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -21,419 +21,422 @@ use libm::{cos, erf, sin, tanh}; use num_traits::Float; impl TensorOps for NdArray { - fn from_data(data: Data, _device: &NdArrayDevice) -> NdArrayTensor { - NdArrayTensor::from_data(data) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: &NdArrayDevice, - ) -> NdArrayTensor { - let mut seed = SEED.lock().unwrap(); - let mut rng = if let Some(rng_seeded) = seed.as_ref() { - rng_seeded.clone() - } else { - get_seeded_rng() - }; - let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device); - *seed = Some(rng); - tensor - } - - fn shape(tensor: &NdArrayTensor) -> Shape { - tensor.shape() - } - - fn into_data( - tensor: NdArrayTensor, - ) -> Reader as Backend>::FloatElem, D>> { - let shape = tensor.shape(); - let values = tensor.array.into_iter().collect(); - - Reader::Concrete(Data::new(values, shape)) - } - - fn device(_tensor: &NdArrayTensor) -> NdArrayDevice { - NdArrayDevice::Cpu - } - - fn to_device( - tensor: NdArrayTensor, - _device: &NdArrayDevice, - ) -> NdArrayTensor { - tensor - } - - fn empty( - shape: Shape, - device: & as Backend>::Device, - ) -> NdArrayTensor { - NdArray::::zeros(shape, device) - } - - fn add( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::add(lhs, rhs) - } - - fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::add_scalar(lhs, rhs) - } - - fn sub( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::sub(lhs, rhs) - } - - fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::sub_scalar(lhs, rhs) - } - - fn mul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mul(lhs, rhs) - } - - fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::mul_scalar(lhs, rhs) - } - - fn div( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::div(lhs, rhs) - } - - fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - NdArrayMathOps::div_scalar(lhs, rhs) - } - - fn matmul( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - matmul(lhs, rhs) - } - - fn neg(tensor: NdArrayTensor) -> NdArrayTensor { - Self::mul_scalar(tensor, (-1f32).elem::()) - } - - fn recip(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::recip(tensor) - } - - fn swap_dims( - tensor: NdArrayTensor, - dim1: usize, - dim2: usize, - ) -> NdArrayTensor { - NdArrayOps::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: NdArrayTensor, - shape: Shape, - ) -> NdArrayTensor { - NdArrayOps::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::gather(dim, tensor, indices) - } - - fn scatter( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::scatter(dim, tensor, indices, value) - } - - fn select( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select(tensor, dim, indices) - } - - fn select_assign( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::select_assign(tensor, dim, indices, value) - } - - fn slice( - tensor: NdArrayTensor, - ranges: [Range; D2], - ) -> NdArrayTensor { - NdArrayOps::slice(tensor, ranges) - } - - fn slice_assign( - tensor: NdArrayTensor, - ranges: [Range; D2], - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayOps::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayMathOps::mask_where(tensor, mask, value) - } - - fn mask_fill( - tensor: NdArrayTensor, - mask: NdArrayTensor, - value: E, - ) -> NdArrayTensor { - NdArrayMathOps::mask_fill(tensor, mask, value) - } - - fn equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - - Self::equal_elem(tensor, zero) - } - - fn equal_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a == rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn greater( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::greater_elem(tensor, zero) - } - - fn greater_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a > rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn greater_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::greater_equal_elem(tensor, zero) - } - - fn greater_equal_elem( - lhs: NdArrayTensor, - rhs: E, - ) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a >= rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn lower( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::lower_elem(tensor, zero) - } - - fn lower_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a < rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn lower_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - ) -> NdArrayTensor { - let tensor = NdArray::::sub(lhs, rhs); - let zero = 0.elem(); - Self::lower_equal_elem(tensor, zero) - } - - fn lower_equal_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { - let array = lhs.array.mapv(|a| a <= rhs).into_shared(); - - NdArrayTensor::new(array) - } - - fn detach(tensor: NdArrayTensor) -> NdArrayTensor { - tensor - } - - fn mean(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::mean(tensor) - } - - fn sum(tensor: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::sum(tensor) - } - - fn mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::mean_dim(tensor, dim) - } - - fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::sum_dim(tensor, dim) - } - - fn to_full_precision(tensor: &NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn from_full_precision(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::argmax(tensor, dim) - } - - fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - NdArrayMathOps::argmin(tensor, dim) - } - - fn exp(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn log(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn log1p(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn powf(tensor: NdArrayTensor, value: f32) -> NdArrayTensor { - let array = if value == 2.0 { - // Happens often and is faster. - tensor.array.mapv_into(|a| a * a).into_shared() - } else if value.floor() == value { - // Is faster then powf - tensor - .array - .mapv_into(|a| a.powi_elem(value as i32)) - .into_shared() - } else { - // Default - tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared() - }; - - NdArrayTensor::new(array) - } - - fn sqrt(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn abs(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn cos(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| cos(a.to_f64().unwrap()).elem()) - .into_shared(); + fn from_data(data: Data, _device: &NdArrayDevice) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: &NdArrayDevice, + ) -> NdArrayTensor { + let mut seed = SEED.lock().unwrap(); + let mut rng = if let Some(rng_seeded) = seed.as_ref() { + rng_seeded.clone() + } else { + get_seeded_rng() + }; + let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device); + *seed = Some(rng); + tensor + } + + fn shape(tensor: &NdArrayTensor) -> Shape { + tensor.shape() + } + + fn into_data( + tensor: NdArrayTensor, + ) -> Reader as Backend>::FloatElem, D>> { + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + + Reader::Concrete(Data::new(values, shape)) + } + + fn device(_tensor: &NdArrayTensor) -> NdArrayDevice { + NdArrayDevice::Cpu + } + + fn to_device( + tensor: NdArrayTensor, + _device: &NdArrayDevice, + ) -> NdArrayTensor { + tensor + } + + fn empty( + shape: Shape, + device: & as Backend>::Device, + ) -> NdArrayTensor { + NdArray::::zeros(shape, device) + } + + fn add( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::add(lhs, rhs) + } + + fn add_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::add_scalar(lhs, rhs) + } + + fn sub( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::sub(lhs, rhs) + } + + fn sub_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::sub_scalar(lhs, rhs) + } + + fn mul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mul(lhs, rhs) + } + + fn mul_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::mul_scalar(lhs, rhs) + } + + fn div( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::div(lhs, rhs) + } + + fn div_scalar(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + NdArrayMathOps::div_scalar(lhs, rhs) + } + + fn matmul( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + matmul(lhs, rhs) + } + + fn neg(tensor: NdArrayTensor) -> NdArrayTensor { + Self::mul_scalar(tensor, (-1f32).elem::()) + } + + fn recip(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::recip(tensor) + } + + fn swap_dims( + tensor: NdArrayTensor, + dim1: usize, + dim2: usize, + ) -> NdArrayTensor { + NdArrayOps::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: NdArrayTensor, + shape: Shape, + ) -> NdArrayTensor { + NdArrayOps::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::gather(dim, tensor, indices) + } + + fn scatter( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::scatter(dim, tensor, indices, value) + } + + fn select( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select(tensor, dim, indices) + } + + fn select_assign( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::select_assign(tensor, dim, indices, value) + } + + fn slice( + tensor: NdArrayTensor, + ranges: [Range; D2], + ) -> NdArrayTensor { + NdArrayOps::slice(tensor, ranges) + } + + fn slice_assign( + tensor: NdArrayTensor, + ranges: [Range; D2], + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayOps::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::mask_where(tensor, mask, value) + } + + fn mask_fill( + tensor: NdArrayTensor, + mask: NdArrayTensor, + value: E, + ) -> NdArrayTensor { + NdArrayMathOps::mask_fill(tensor, mask, value) + } + + fn equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + + Self::equal_elem(tensor, zero) + } + + fn equal_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a == rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn greater( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::greater_elem(tensor, zero) + } + + fn greater_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a > rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn greater_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::greater_equal_elem(tensor, zero) + } + + fn greater_equal_elem( + lhs: NdArrayTensor, + rhs: E, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a >= rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn lower( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::lower_elem(tensor, zero) + } + + fn lower_elem(lhs: NdArrayTensor, rhs: E) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a < rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn lower_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + ) -> NdArrayTensor { + let tensor = NdArray::::sub(lhs, rhs); + let zero = 0.elem(); + Self::lower_equal_elem(tensor, zero) + } + + fn lower_equal_elem( + lhs: NdArrayTensor, + rhs: E, + ) -> NdArrayTensor { + let array = lhs.array.mapv(|a| a <= rhs).into_shared(); + + NdArrayTensor::new(array) + } + + fn detach(tensor: NdArrayTensor) -> NdArrayTensor { + tensor + } + + fn mean(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::mean(tensor) + } + + fn sum(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::sum(tensor) + } + + fn mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::mean_dim(tensor, dim) + } + + fn sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::sum_dim(tensor, dim) + } + + fn to_full_precision(tensor: &NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn from_full_precision(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::argmax(tensor, dim) + } + + fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + NdArrayMathOps::argmin(tensor, dim) + } + + fn exp(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn log(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn log1p(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn powf(tensor: NdArrayTensor, value: f32) -> NdArrayTensor { + let array = if value == 2.0 { + // Happens often and is faster. + tensor.array.mapv_into(|a| a * a).into_shared() + } else if value.floor() == value { + // Is faster then powf + tensor + .array + .mapv_into(|a| a.powi_elem(value as i32)) + .into_shared() + } else { + // Default + tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared() + }; + + NdArrayTensor::new(array) + } + + fn sqrt(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn abs(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn cos(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| cos(a.to_f64().unwrap()).elem()) + .into_shared(); + + NdArrayTensor::new(array) + } - NdArrayTensor::new(array) - } + fn sin(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| sin(a.to_f64().unwrap()).elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + + fn tanh(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| tanh(a.to_f64().unwrap()).elem()) + .into_shared(); - fn sin(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| sin(a.to_f64().unwrap()).elem()) - .into_shared(); - - NdArrayTensor::new(array) - } - - fn tanh(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| tanh(a.to_f64().unwrap()).elem()) - .into_shared(); - - NdArrayTensor::new(array) - } - - fn erf(tensor: NdArrayTensor) -> NdArrayTensor { - let array = tensor - .array - .mapv_into(|a| erf(a.to_f64().unwrap()).elem()) - .into_shared(); - - NdArrayTensor::new(array) - } - - fn cat(tensors: Vec>, dim: usize) -> NdArrayTensor { - NdArrayOps::cat(tensors, dim) - } - - fn clamp_min(tensor: NdArrayTensor, min: E) -> NdArrayTensor { - NdArrayMathOps::clamp_min(tensor, min) - } - - fn clamp_max(tensor: NdArrayTensor, max: E) -> NdArrayTensor { - NdArrayMathOps::clamp_max(tensor, max) - } - - fn clamp(tensor: NdArrayTensor, min: E, max: E) -> NdArrayTensor { - NdArrayMathOps::clamp(tensor, min, max) - } - - fn into_int( - tensor: as Backend>::TensorPrimitive, - ) -> as Backend>::IntTensorPrimitive { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - NdArrayTensor { array } - } + NdArrayTensor::new(array) + } + + fn erf(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| erf(a.to_f64().unwrap()).elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + + fn cat(tensors: Vec>, dim: usize) -> NdArrayTensor { + NdArrayOps::cat(tensors, dim) + } + + fn clamp_min(tensor: NdArrayTensor, min: E) -> NdArrayTensor { + NdArrayMathOps::clamp_min(tensor, min) + } + + fn clamp_max(tensor: NdArrayTensor, max: E) -> NdArrayTensor { + NdArrayMathOps::clamp_max(tensor, max) + } + + fn clamp(tensor: NdArrayTensor, min: E, max: E) -> NdArrayTensor { + NdArrayMathOps::clamp(tensor, min, max) + } + + fn into_int( + tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + NdArrayTensor { array } + } } diff --git a/burn-ndarray/src/parallel.rs b/burn-ndarray/src/parallel.rs index 5debb0f1b5..8229bfb396 100644 --- a/burn-ndarray/src/parallel.rs +++ b/burn-ndarray/src/parallel.rs @@ -1,51 +1,51 @@ /// Macro for running a function in parallel. #[macro_export(local_inner_macros)] macro_rules! run_par { - ( + ( $func:expr ) => {{ - #[cfg(feature = "std")] - use rayon::prelude::*; + #[cfg(feature = "std")] + use rayon::prelude::*; - #[cfg(feature = "std")] - #[allow(clippy::redundant_closure_call)] - let output = rayon::scope(|_| $func()); + #[cfg(feature = "std")] + #[allow(clippy::redundant_closure_call)] + let output = rayon::scope(|_| $func()); - #[cfg(not(feature = "std"))] - let output = $func(); + #[cfg(not(feature = "std"))] + let output = $func(); - output - }}; + output + }}; } /// Macro for iterating in parallel. #[macro_export(local_inner_macros)] macro_rules! iter_par { - ( + ( $iter:expr ) => {{ - #[cfg(feature = "std")] - let output = $iter.into_par_iter(); + #[cfg(feature = "std")] + let output = $iter.into_par_iter(); - #[cfg(not(feature = "std"))] - let output = $iter; + #[cfg(not(feature = "std"))] + let output = $iter; - output - }}; + output + }}; } /// Macro for iterating over a range in parallel. #[macro_export(local_inner_macros)] macro_rules! iter_range_par { - ( + ( $start:expr, $end:expr ) => {{ - #[cfg(feature = "std")] - let output = ($start..$end).into_par_iter(); + #[cfg(feature = "std")] + let output = ($start..$end).into_par_iter(); - #[cfg(not(feature = "std"))] - let output = ($start..$end); + #[cfg(not(feature = "std"))] + let output = ($start..$end); - output - }}; + output + }}; } diff --git a/burn-ndarray/src/sharing.rs b/burn-ndarray/src/sharing.rs index be99c0999e..073f3d3872 100644 --- a/burn-ndarray/src/sharing.rs +++ b/burn-ndarray/src/sharing.rs @@ -2,18 +2,18 @@ use core::cell::UnsafeCell; /// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439). pub(crate) struct UnsafeSharedRef<'a, T> { - cell: UnsafeCell<&'a mut T>, + cell: UnsafeCell<&'a mut T>, } unsafe impl<'a, T> Sync for UnsafeSharedRef<'a, T> {} impl<'a, T> UnsafeSharedRef<'a, T> { - pub fn new(data: &'a mut T) -> Self { - Self { - cell: UnsafeCell::new(data), + pub fn new(data: &'a mut T) -> Self { + Self { + cell: UnsafeCell::new(data), + } + } + pub unsafe fn get(&self) -> &'a mut T { + unsafe { core::ptr::read(self.cell.get()) } } - } - pub unsafe fn get(&self) -> &'a mut T { - unsafe { core::ptr::read(self.cell.get()) } - } } diff --git a/burn-ndarray/src/tensor.rs b/burn-ndarray/src/tensor.rs index ae05aa9e47..db99bc87e6 100644 --- a/burn-ndarray/src/tensor.rs +++ b/burn-ndarray/src/tensor.rs @@ -5,52 +5,52 @@ use ndarray::{ArcArray, Array, Dim, IxDyn}; /// Tensor primitive used by the [ndarray backend](crate::NdArray). #[derive(new, Debug, Clone)] pub struct NdArrayTensor { - /// Dynamic array that contains the data of type E. - pub array: ArcArray, + /// Dynamic array that contains the data of type E. + pub array: ArcArray, } impl NdArrayTensor { - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.array.shape().to_vec()) - } + pub(crate) fn shape(&self) -> Shape { + Shape::from(self.array.shape().to_vec()) + } } #[cfg(test)] mod utils { - use super::*; - use crate::element::FloatNdArrayElement; + use super::*; + use crate::element::FloatNdArrayElement; - impl NdArrayTensor - where - E: Default + Clone, - { - pub(crate) fn into_data(self) -> Data + impl NdArrayTensor where - E: FloatNdArrayElement, + E: Default + Clone, { - let shape = self.shape(); - let values = self.array.into_iter().collect(); - - Data::new(values, shape) + pub(crate) fn into_data(self) -> Data + where + E: FloatNdArrayElement, + { + let shape = self.shape(); + let values = self.array.into_iter().collect(); + + Data::new(values, shape) + } } - } } /// Converts a slice of usize to a typed dimension. #[macro_export(local_inner_macros)] macro_rules! to_typed_dims { - ( + ( $n:expr, $dims:expr, justdim ) => {{ - let mut dims = [0; $n]; - for i in 0..$n { - dims[i] = $dims[i]; - } - let dim: Dim<[usize; $n]> = Dim(dims); - dim - }}; + let mut dims = [0; $n]; + for i in 0..$n { + dims[i] = $dims[i]; + } + let dim: Dim<[usize; $n]> = Dim(dims); + dim + }}; } /// Reshapes an array into a tensor. @@ -101,82 +101,82 @@ macro_rules! reshape { impl NdArrayTensor where - E: Default + Clone, + E: Default + Clone, { - /// Create a new [ndarray tensor](NdArrayTensor) from [data](Data). - pub fn from_data(data: Data) -> NdArrayTensor { - let shape = data.shape.clone(); - let to_array = |data: Data| Array::from_iter(data.value).into_shared(); - let array = to_array(data); - - reshape!( - ty E, - shape shape, - array array, - d D - ) - } + /// Create a new [ndarray tensor](NdArrayTensor) from [data](Data). + pub fn from_data(data: Data) -> NdArrayTensor { + let shape = data.shape.clone(); + let to_array = |data: Data| Array::from_iter(data.value).into_shared(); + let array = to_array(data); + + reshape!( + ty E, + shape shape, + array array, + d D + ) + } } #[cfg(test)] mod tests { - use super::*; - use burn_common::rand::get_seeded_rng; - use burn_tensor::Distribution; - - #[test] - fn should_support_into_and_from_data_1d() { - let data_expected = Data::::random( - Shape::new([3]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_2d() { - let data_expected = Data::::random( - Shape::new([2, 3]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_3d() { - let data_expected = Data::::random( - Shape::new([2, 3, 4]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_4d() { - let data_expected = Data::::random( - Shape::new([2, 3, 4, 2]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_common::rand::get_seeded_rng; + use burn_tensor::Distribution; + + #[test] + fn should_support_into_and_from_data_1d() { + let data_expected = Data::::random( + Shape::new([3]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_2d() { + let data_expected = Data::::random( + Shape::new([2, 3]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_3d() { + let data_expected = Data::::random( + Shape::new([2, 3, 4]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_4d() { + let data_expected = Data::::random( + Shape::new([2, 3, 4, 2]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-no-std-tests/src/conv.rs b/burn-no-std-tests/src/conv.rs index 21d4cc67d8..6f17949409 100644 --- a/burn-no-std-tests/src/conv.rs +++ b/burn-no-std-tests/src/conv.rs @@ -1,48 +1,48 @@ // Originally copied from the burn/examples/mnist package use burn::{ - config::Config, - module::Module, - nn, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, }; #[derive(Module, Debug)] pub struct ConvBlock { - conv: nn::conv::Conv2d, - pool: nn::pool::MaxPool2d, - activation: nn::GELU, + conv: nn::conv::Conv2d, + pool: nn::pool::MaxPool2d, + activation: nn::GELU, } #[derive(Config)] pub struct ConvBlockConfig { - channels: [usize; 2], - #[config(default = "[3, 3]")] - kernel_size: [usize; 2], + channels: [usize; 2], + #[config(default = "[3, 3]")] + kernel_size: [usize; 2], } impl ConvBlock { - pub fn new(config: &ConvBlockConfig) -> Self { - let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size) - .with_padding(nn::PaddingConfig2d::Same) - .init(); - let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size) - .with_padding(nn::PaddingConfig2d::Same) - .init(); - let activation = nn::GELU::new(); + pub fn new(config: &ConvBlockConfig) -> Self { + let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size) + .with_padding(nn::PaddingConfig2d::Same) + .init(); + let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size) + .with_padding(nn::PaddingConfig2d::Same) + .init(); + let activation = nn::GELU::new(); - Self { - conv, - pool, - activation, + Self { + conv, + pool, + activation, + } } - } - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.conv.forward(input.clone()); - let x = self.pool.forward(x); - let x = self.activation.forward(x); + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.conv.forward(input.clone()); + let x = self.pool.forward(x); + let x = self.activation.forward(x); - (x + input) / 2.0 - } + (x + input) / 2.0 + } } diff --git a/burn-no-std-tests/src/mlp.rs b/burn-no-std-tests/src/mlp.rs index 1ef28fd496..ec8f189718 100644 --- a/burn-no-std-tests/src/mlp.rs +++ b/burn-no-std-tests/src/mlp.rs @@ -3,65 +3,65 @@ use alloc::vec::Vec; use burn::{ - config::Config, - module::Module, - nn, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, }; /// Configuration to create a [Multilayer Perceptron](Mlp) layer. #[derive(Config)] pub struct MlpConfig { - /// The number of layers. - #[config(default = 3)] - pub num_layers: usize, - /// The dropout rate. - #[config(default = 0.5)] - pub dropout: f64, - /// The size of each layer. - #[config(default = 256)] - pub d_model: usize, + /// The number of layers. + #[config(default = 3)] + pub num_layers: usize, + /// The dropout rate. + #[config(default = 0.5)] + pub dropout: f64, + /// The size of each layer. + #[config(default = 256)] + pub d_model: usize, } /// Multilayer Perceptron module. #[derive(Module, Debug)] pub struct Mlp { - linears: Vec>, - dropout: nn::Dropout, - activation: nn::ReLU, + linears: Vec>, + dropout: nn::Dropout, + activation: nn::ReLU, } impl Mlp { - /// Create the module from the given configuration. - pub fn new(config: &MlpConfig) -> Self { - let mut linears = Vec::with_capacity(config.num_layers); + /// Create the module from the given configuration. + pub fn new(config: &MlpConfig) -> Self { + let mut linears = Vec::with_capacity(config.num_layers); - for _ in 0..config.num_layers { - linears.push(nn::LinearConfig::new(config.d_model, config.d_model).init()); - } + for _ in 0..config.num_layers { + linears.push(nn::LinearConfig::new(config.d_model, config.d_model).init()); + } - Self { - linears, - dropout: nn::DropoutConfig::new(0.3).init(), - activation: nn::ReLU::new(), + Self { + linears, + dropout: nn::DropoutConfig::new(0.3).init(), + activation: nn::ReLU::new(), + } } - } - /// Applies the forward pass on the input tensor. - /// - /// # Shapes - /// - /// - input: `[batch_size, d_model]` - /// - output: `[batch_size, d_model]` - pub fn forward(&self, input: Tensor) -> Tensor { - let mut x = input; + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[batch_size, d_model]` + /// - output: `[batch_size, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let mut x = input; - for linear in self.linears.iter() { - x = linear.forward(x); - x = self.dropout.forward(x); - x = self.activation.forward(x); - } + for linear in self.linears.iter() { + x = linear.forward(x); + x = self.dropout.forward(x); + x = self.activation.forward(x); + } - x - } + x + } } diff --git a/burn-no-std-tests/src/model.rs b/burn-no-std-tests/src/model.rs index 6363a4ae21..7028bdb6b3 100644 --- a/burn-no-std-tests/src/model.rs +++ b/burn-no-std-tests/src/model.rs @@ -1,66 +1,66 @@ // Originally copied from the burn/examples/mnist package use crate::{ - conv::{ConvBlock, ConvBlockConfig}, - mlp::{Mlp, MlpConfig}, + conv::{ConvBlock, ConvBlockConfig}, + mlp::{Mlp, MlpConfig}, }; use burn::{ - config::Config, - module::Module, - nn, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, }; #[derive(Config)] pub struct MnistConfig { - #[config(default = 42)] - pub seed: u64, + #[config(default = 42)] + pub seed: u64, - pub mlp: MlpConfig, + pub mlp: MlpConfig, - #[config(default = 784)] - pub input_size: usize, + #[config(default = 784)] + pub input_size: usize, - #[config(default = 10)] - pub output_size: usize, + #[config(default = 10)] + pub output_size: usize, } #[derive(Module, Debug)] pub struct Model { - mlp: Mlp, - conv: ConvBlock, - input: nn::Linear, - output: nn::Linear, - num_classes: usize, + mlp: Mlp, + conv: ConvBlock, + input: nn::Linear, + output: nn::Linear, + num_classes: usize, } impl Model { - pub fn new(config: &MnistConfig) -> Self { - let mlp = Mlp::new(&config.mlp); - let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init(); - let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init(); - let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1])); + pub fn new(config: &MnistConfig) -> Self { + let mlp = Mlp::new(&config.mlp); + let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init(); + let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init(); + let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1])); - Self { - mlp, - conv, - output, - input, - num_classes: config.output_size, + Self { + mlp, + conv, + output, + input, + num_classes: config.output_size, + } } - } - pub fn forward(&self, input: Tensor) -> Tensor { - let [batch_size, height, width] = input.dims(); + pub fn forward(&self, input: Tensor) -> Tensor { + let [batch_size, height, width] = input.dims(); - let x = input.reshape([batch_size, 1, height, width]).detach(); - let x = self.conv.forward(x); - let x = x.reshape([batch_size, height * width]); + let x = input.reshape([batch_size, 1, height, width]).detach(); + let x = self.conv.forward(x); + let x = x.reshape([batch_size, height * width]); - let x = self.input.forward(x); - let x = self.mlp.forward(x); + let x = self.input.forward(x); + let x = self.mlp.forward(x); - self.output.forward(x) - } + self.output.forward(x) + } } diff --git a/burn-no-std-tests/tests/integration_test.rs b/burn-no-std-tests/tests/integration_test.rs index 2907909cc2..6f6558cea9 100644 --- a/burn-no-std-tests/tests/integration_test.rs +++ b/burn-no-std-tests/tests/integration_test.rs @@ -8,23 +8,23 @@ use burn_ndarray::NdArray; #[test] fn test_mnist_model_with_random_input() { - type Backend = NdArray; + type Backend = NdArray; - // Model configurations - let mlp_config = MlpConfig::new(); - let mnist_config = MnistConfig::new(mlp_config); - let mnist_model: Model = Model::new(&mnist_config); + // Model configurations + let mlp_config = MlpConfig::new(); + let mnist_config = MnistConfig::new(mlp_config); + let mnist_model: Model = Model::new(&mnist_config); - // Pass a fixed seed for random, otherwise a build generated random seed is used - Backend::seed(mnist_config.seed); + // Pass a fixed seed for random, otherwise a build generated random seed is used + Backend::seed(mnist_config.seed); - // Some random input - let input_shape = [1, 28, 28]; - let input = Tensor::::random(input_shape, Default); + // Some random input + let input_shape = [1, 28, 28]; + let input = Tensor::::random(input_shape, Default); - // Run through the model - let output = mnist_model.forward(input); + // Run through the model + let output = mnist_model.forward(input); - assert_eq!(output.shape().dims, [1, 10]); - assert!(output.to_data().value.into_iter().all(|x| x <= 1.0)); + assert_eq!(output.shape().dims, [1, 10]); + assert!(output.to_data().value.into_iter().all(|x| x <= 1.0)); } diff --git a/burn-tch/src/backend.rs b/burn-tch/src/backend.rs index 6a70c96914..3e62d40305 100644 --- a/burn-tch/src/backend.rs +++ b/burn-tch/src/backend.rs @@ -19,46 +19,46 @@ use burn_tensor::backend::Backend; /// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan /// ``` pub enum LibTorchDevice { - /// CPU device. - Cpu, + /// CPU device. + Cpu, - /// Cuda device with the given index. The index is the index of the Cuda device in the list of - /// all Cuda devices found on the system. - Cuda(usize), + /// Cuda device with the given index. The index is the index of the Cuda device in the list of + /// all Cuda devices found on the system. + Cuda(usize), - /// Metal Performance Shaders device. - Mps, + /// Metal Performance Shaders device. + Mps, - /// Vulkan device. - Vulkan, + /// Vulkan device. + Vulkan, } impl From for tch::Device { - fn from(device: LibTorchDevice) -> Self { - match device { - LibTorchDevice::Cpu => tch::Device::Cpu, - LibTorchDevice::Cuda(num) => tch::Device::Cuda(num), - LibTorchDevice::Mps => tch::Device::Mps, - LibTorchDevice::Vulkan => tch::Device::Vulkan, + fn from(device: LibTorchDevice) -> Self { + match device { + LibTorchDevice::Cpu => tch::Device::Cpu, + LibTorchDevice::Cuda(num) => tch::Device::Cuda(num), + LibTorchDevice::Mps => tch::Device::Mps, + LibTorchDevice::Vulkan => tch::Device::Vulkan, + } } - } } impl From for LibTorchDevice { - fn from(device: tch::Device) -> Self { - match device { - tch::Device::Cpu => LibTorchDevice::Cpu, - tch::Device::Cuda(num) => LibTorchDevice::Cuda(num), - tch::Device::Mps => LibTorchDevice::Mps, - tch::Device::Vulkan => LibTorchDevice::Vulkan, + fn from(device: tch::Device) -> Self { + match device { + tch::Device::Cpu => LibTorchDevice::Cpu, + tch::Device::Cuda(num) => LibTorchDevice::Cuda(num), + tch::Device::Mps => LibTorchDevice::Mps, + tch::Device::Vulkan => LibTorchDevice::Vulkan, + } } - } } impl Default for LibTorchDevice { - fn default() -> Self { - Self::Cpu - } + fn default() -> Self { + Self::Cpu + } } /// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations. @@ -70,39 +70,39 @@ impl Default for LibTorchDevice { /// Refer to the [tch] crate for more information. #[derive(Clone, Copy, Default, Debug)] pub struct LibTorch { - _e: E, + _e: E, } impl Backend for LibTorch { - type Device = LibTorchDevice; - type FullPrecisionElem = f32; - type FullPrecisionBackend = LibTorch; + type Device = LibTorchDevice; + type FullPrecisionElem = f32; + type FullPrecisionBackend = LibTorch; - type TensorPrimitive = TchTensor; - type FloatElem = E; + type TensorPrimitive = TchTensor; + type FloatElem = E; - type IntTensorPrimitive = TchTensor; - type IntElem = i64; + type IntTensorPrimitive = TchTensor; + type IntElem = i64; - type BoolTensorPrimitive = TchTensor; + type BoolTensorPrimitive = TchTensor; - fn seed(seed: u64) { - tch::manual_seed(seed as i64); - } + fn seed(seed: u64) { + tch::manual_seed(seed as i64); + } - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn name() -> String { - "tch".to_string() - } + fn name() -> String { + "tch".to_string() + } - fn sync(device: &Self::Device) { - if let LibTorchDevice::Cuda(index) = device { - tch::Cuda::synchronize(*index as i64); - } else if let LibTorchDevice::Mps = device { - panic!("Can't sync MPS device") + fn sync(device: &Self::Device) { + if let LibTorchDevice::Cuda(index) = device { + tch::Cuda::synchronize(*index as i64); + } else if let LibTorchDevice::Mps = device { + panic!("Can't sync MPS device") + } } - } } diff --git a/burn-tch/src/lib.rs b/burn-tch/src/lib.rs index 0c1e130b39..f7580bf641 100644 --- a/burn-tch/src/lib.rs +++ b/burn-tch/src/lib.rs @@ -14,12 +14,12 @@ pub use tensor::*; #[cfg(test)] mod tests { - extern crate alloc; + extern crate alloc; - type TestBackend = crate::LibTorch; - type TestTensor = burn_tensor::Tensor; - type TestTensorInt = burn_tensor::Tensor; + type TestBackend = crate::LibTorch; + type TestTensor = burn_tensor::Tensor; + type TestTensorInt = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!(); + burn_autodiff::testgen_all!(); } diff --git a/burn-tch/src/ops/activation.rs b/burn-tch/src/ops/activation.rs index cf834a2368..386e87c493 100644 --- a/burn-tch/src/ops/activation.rs +++ b/burn-tch/src/ops/activation.rs @@ -2,24 +2,24 @@ use crate::{element::TchElement, LibTorch, TchTensor}; use burn_tensor::ops::ActivationOps; impl ActivationOps for LibTorch { - fn relu(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()) - } + fn relu(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()) + } - fn gelu(tensor: TchTensor) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.gelu_("none"), - |tensor| tensor.gelu("none"), - ) - } + fn gelu(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.gelu_("none"), + |tensor| tensor.gelu("none"), + ) + } - fn gelu_backward( - tensor: TchTensor, - grad: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none"); + fn gelu_backward( + tensor: TchTensor, + grad: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none"); - TchTensor::from_existing(tensor, storage) - } + TchTensor::from_existing(tensor, storage) + } } diff --git a/burn-tch/src/ops/base.rs b/burn-tch/src/ops/base.rs index 215d4c8b45..d2cd0d4dc3 100644 --- a/burn-tch/src/ops/base.rs +++ b/burn-tch/src/ops/base.rs @@ -5,405 +5,408 @@ use crate::{TchShape, TchTensor}; use std::{marker::PhantomData, ops::Range}; pub struct TchOps { - e: PhantomData, + e: PhantomData, } impl TchOps { - pub fn reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - let shape_tch: TchShape = shape.into(); - - TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage) - } - - pub fn repeat( - tensor: TchTensor, - dim: usize, - times: usize, - ) -> TchTensor { - let mut dims = [1; D]; - dims[dim] = times as i64; - let tensor = tch::Tensor::repeat(&tensor.tensor, dims); - TchTensor::new(tensor) - } - - pub fn slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - let storage = tensor.storage.clone(); - let mut tensor = tensor.tensor.shallow_clone(); - - for (i, index) in ranges.iter().enumerate().take(D2) { - let start = index.start as i64; - let length = (index.end - index.start) as i64; - tensor = tensor.narrow(i as i64, start, length); - } - - TchTensor::from_existing(tensor, storage) - } - - pub fn slice_assign( - tensor: TchTensor, - ranges: [Range; D2], - value: TchTensor, - ) -> TchTensor { - let tensor_original = tensor.tensor.copy(); - let tch_shape = TchShape::from(tensor.shape()); - - let mut tensor = tensor_original.view_(tch_shape.dims); - - for (i, index) in ranges.into_iter().enumerate().take(D2) { - let start = index.start as i64; - let length = (index.end - index.start) as i64; - - tensor = tensor.narrow(i as i64, start, length); - } - - tensor.copy_(&value.tensor); - - TchTensor::new(tensor_original) - } - - pub fn gather( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false); - - TchTensor::from_existing(tensor, storage) - } - - pub fn scatter( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor - .tensor - .scatter_add(dim as i64, &indices.tensor, &value.tensor); - - TchTensor::from_existing(tensor, storage) - } - - pub fn index_select_dim( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - ) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor); - - TchTensor::from_existing(tensor, storage) - } - - pub fn select_assign( - tensor: TchTensor, - dim: usize, - indices_tensor: TchTensor, - value: TchTensor, - ) -> TchTensor { - let mut indices = Vec::with_capacity(D); - for _ in 0..D { - indices.push(None); - } - indices[dim] = Some(indices_tensor.tensor); - - tensor.unary_ops( - |mut tensor| tensor.index_put_(&indices, &value.tensor, true), - |tensor| tensor.index_put(&indices, &value.tensor, true), - ) - } - - pub fn cat(tensors: Vec>, dim: usize) -> TchTensor { - let tensors: Vec = tensors - .into_iter() - .map(|t| t.tensor.shallow_clone()) - .collect(); - let tensor = tch::Tensor::cat(&tensors, dim as i64); - - TchTensor::new(tensor) - } - - pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.eq_tensor(rhs), - ) - } - - pub fn equal_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool), - |tensor| tensor.eq(rhs.clone().into()), - ) - } - - pub fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.greater_tensor(rhs), - ) - } - - pub fn greater_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool), - |tensor| tensor.greater(rhs.clone().into()), - ) - } - - pub fn greater_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.greater_equal_tensor(rhs), - ) - } - - pub fn greater_equal_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| { - tensor - .greater_equal_(rhs.clone().into()) - .to_kind(tch::Kind::Bool) - }, - |tensor| tensor.greater_equal(rhs.clone().into()), - ) - } - - pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.less_tensor(rhs), - ) - } - - pub fn lower_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool), - |tensor| tensor.less(rhs.clone().into()), - ) - } - - pub fn lower_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool), - |lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.less_equal_tensor(rhs), - ) - } - - pub fn lower_equal_elem + Clone>( - lhs: TchTensor, - rhs: S, - ) -> TchTensor { - lhs.unary_ops( - |mut tensor| { - tensor - .less_equal_(rhs.clone().into()) - .to_kind(tch::Kind::Bool) - }, - |tensor| tensor.less_equal(rhs.clone().into()), - ) - } - - pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_add_(rhs).unwrap(), - |lhs, rhs| rhs.f_add_(lhs).unwrap(), - |lhs, rhs| lhs.f_add(rhs).unwrap(), - ) - } - - pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_sub_(rhs).unwrap(), - |lhs, rhs| lhs.f_sub(rhs).unwrap(), - |lhs, rhs| lhs.f_sub(rhs).unwrap(), - ) - } - - pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_mul_(rhs).unwrap(), - |lhs, rhs| rhs.f_mul_(lhs).unwrap(), - |lhs, rhs| lhs.f_mul(rhs).unwrap(), - ) - } - - pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchTensor::binary_ops_tensor( - lhs, - rhs, - |lhs, rhs| lhs.f_div_(rhs).unwrap(), - |lhs, rhs| lhs.f_div(rhs).unwrap(), - |lhs, rhs| lhs.f_div(rhs).unwrap(), - ) - } - - pub fn mean(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.mean(E::KIND); - TchTensor::new(tensor) - } - - pub fn sum(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.sum(E::KIND); - TchTensor::new(tensor) - } - - pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchTensor::from_existing( - tensor - .tensor - .mean_dim(Some([dim as i64].as_slice()), true, E::KIND), - tensor.storage, - ) - } - - pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchTensor::from_existing( - tensor - .tensor - .sum_dim_intlist(Some([dim as i64].as_slice()), true, E::KIND), - tensor.storage, - ) - } - - pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.argmax(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.argmin(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn max_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - let storage = tensor.storage.clone(); - let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true); - - let tensor = TchTensor::from_existing(tensor, storage); - let indices = TchTensor::new(indices); - - (tensor, indices) - } - - pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { - let storage = tensor.storage.clone(); - let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true); - - TchTensor::from_existing(tensor, storage) - } - - pub fn min_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - let storage = tensor.storage.clone(); - let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true); - - let tensor = TchTensor::from_existing(tensor, storage); - let indices = TchTensor::new(indices); - - (tensor, indices) - } - - pub fn clamp_min + Clone + Copy>( - tensor: TchTensor, - min: S, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.clamp_min_(min), - |tensor| tensor.clamp_min(min), - ) - } - - pub fn clamp_max + Clone + Copy>( - tensor: TchTensor, - max: S, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.clamp_max_(max), - |tensor| tensor.clamp_max(max), - ) - } - - pub fn clamp + Clone + Copy>( - tensor: TchTensor, - min: S, - max: S, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.clamp_(min, max), - |tensor| tensor.clamp(min, max), - ) - } - - pub fn swap_dims( - tensor: TchTensor, - dim1: usize, - dim2: usize, - ) -> TchTensor { - let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64); - TchTensor::new(tensor) - } + pub fn reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + let shape_tch: TchShape = shape.into(); + + TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage) + } + + pub fn repeat( + tensor: TchTensor, + dim: usize, + times: usize, + ) -> TchTensor { + let mut dims = [1; D]; + dims[dim] = times as i64; + let tensor = tch::Tensor::repeat(&tensor.tensor, dims); + TchTensor::new(tensor) + } + + pub fn slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + let storage = tensor.storage.clone(); + let mut tensor = tensor.tensor.shallow_clone(); + + for (i, index) in ranges.iter().enumerate().take(D2) { + let start = index.start as i64; + let length = (index.end - index.start) as i64; + tensor = tensor.narrow(i as i64, start, length); + } + + TchTensor::from_existing(tensor, storage) + } + + pub fn slice_assign( + tensor: TchTensor, + ranges: [Range; D2], + value: TchTensor, + ) -> TchTensor { + let tensor_original = tensor.tensor.copy(); + let tch_shape = TchShape::from(tensor.shape()); + + let mut tensor = tensor_original.view_(tch_shape.dims); + + for (i, index) in ranges.into_iter().enumerate().take(D2) { + let start = index.start as i64; + let length = (index.end - index.start) as i64; + + tensor = tensor.narrow(i as i64, start, length); + } + + tensor.copy_(&value.tensor); + + TchTensor::new(tensor_original) + } + + pub fn gather( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false); + + TchTensor::from_existing(tensor, storage) + } + + pub fn scatter( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor + .tensor + .scatter_add(dim as i64, &indices.tensor, &value.tensor); + + TchTensor::from_existing(tensor, storage) + } + + pub fn index_select_dim( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + ) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor); + + TchTensor::from_existing(tensor, storage) + } + + pub fn select_assign( + tensor: TchTensor, + dim: usize, + indices_tensor: TchTensor, + value: TchTensor, + ) -> TchTensor { + let mut indices = Vec::with_capacity(D); + for _ in 0..D { + indices.push(None); + } + indices[dim] = Some(indices_tensor.tensor); + + tensor.unary_ops( + |mut tensor| tensor.index_put_(&indices, &value.tensor, true), + |tensor| tensor.index_put(&indices, &value.tensor, true), + ) + } + + pub fn cat(tensors: Vec>, dim: usize) -> TchTensor { + let tensors: Vec = tensors + .into_iter() + .map(|t| t.tensor.shallow_clone()) + .collect(); + let tensor = tch::Tensor::cat(&tensors, dim as i64); + + TchTensor::new(tensor) + } + + pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.eq_tensor(rhs), + ) + } + + pub fn equal_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool), + |tensor| tensor.eq(rhs.clone().into()), + ) + } + + pub fn greater( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.greater_tensor(rhs), + ) + } + + pub fn greater_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool), + |tensor| tensor.greater(rhs.clone().into()), + ) + } + + pub fn greater_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.greater_equal_tensor(rhs), + ) + } + + pub fn greater_equal_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| { + tensor + .greater_equal_(rhs.clone().into()) + .to_kind(tch::Kind::Bool) + }, + |tensor| tensor.greater_equal(rhs.clone().into()), + ) + } + + pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.less_tensor(rhs), + ) + } + + pub fn lower_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool), + |tensor| tensor.less(rhs.clone().into()), + ) + } + + pub fn lower_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool), + |lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool), + |lhs, rhs| lhs.less_equal_tensor(rhs), + ) + } + + pub fn lower_equal_elem + Clone>( + lhs: TchTensor, + rhs: S, + ) -> TchTensor { + lhs.unary_ops( + |mut tensor| { + tensor + .less_equal_(rhs.clone().into()) + .to_kind(tch::Kind::Bool) + }, + |tensor| tensor.less_equal(rhs.clone().into()), + ) + } + + pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_add_(rhs).unwrap(), + |lhs, rhs| rhs.f_add_(lhs).unwrap(), + |lhs, rhs| lhs.f_add(rhs).unwrap(), + ) + } + + pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_sub_(rhs).unwrap(), + |lhs, rhs| lhs.f_sub(rhs).unwrap(), + |lhs, rhs| lhs.f_sub(rhs).unwrap(), + ) + } + + pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_mul_(rhs).unwrap(), + |lhs, rhs| rhs.f_mul_(lhs).unwrap(), + |lhs, rhs| lhs.f_mul(rhs).unwrap(), + ) + } + + pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_div_(rhs).unwrap(), + |lhs, rhs| lhs.f_div(rhs).unwrap(), + |lhs, rhs| lhs.f_div(rhs).unwrap(), + ) + } + + pub fn mean(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.mean(E::KIND); + TchTensor::new(tensor) + } + + pub fn sum(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.sum(E::KIND); + TchTensor::new(tensor) + } + + pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchTensor::from_existing( + tensor + .tensor + .mean_dim(Some([dim as i64].as_slice()), true, E::KIND), + tensor.storage, + ) + } + + pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchTensor::from_existing( + tensor + .tensor + .sum_dim_intlist(Some([dim as i64].as_slice()), true, E::KIND), + tensor.storage, + ) + } + + pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.argmax(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.argmin(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn max_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + let storage = tensor.storage.clone(); + let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true); + + let tensor = TchTensor::from_existing(tensor, storage); + let indices = TchTensor::new(indices); + + (tensor, indices) + } + + pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { + let storage = tensor.storage.clone(); + let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true); + + TchTensor::from_existing(tensor, storage) + } + + pub fn min_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + let storage = tensor.storage.clone(); + let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true); + + let tensor = TchTensor::from_existing(tensor, storage); + let indices = TchTensor::new(indices); + + (tensor, indices) + } + + pub fn clamp_min + Clone + Copy>( + tensor: TchTensor, + min: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.clamp_min_(min), + |tensor| tensor.clamp_min(min), + ) + } + + pub fn clamp_max + Clone + Copy>( + tensor: TchTensor, + max: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.clamp_max_(max), + |tensor| tensor.clamp_max(max), + ) + } + + pub fn clamp + Clone + Copy>( + tensor: TchTensor, + min: S, + max: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.clamp_(min, max), + |tensor| tensor.clamp(min, max), + ) + } + + pub fn swap_dims( + tensor: TchTensor, + dim1: usize, + dim2: usize, + ) -> TchTensor { + let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64); + TchTensor::new(tensor) + } } diff --git a/burn-tch/src/ops/bool_tensor.rs b/burn-tch/src/ops/bool_tensor.rs index c919ed5ffc..76abdfadc4 100644 --- a/burn-tch/src/ops/bool_tensor.rs +++ b/burn-tch/src/ops/bool_tensor.rs @@ -4,111 +4,114 @@ use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Reader, Shape}; use std::ops::Range; impl BoolTensorOps for LibTorch { - fn bool_from_data( - data: Data, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::from_data(data, (*device).into()) - } - - fn bool_shape(tensor: &TchTensor) -> Shape { - tensor.shape() - } - - fn bool_repeat( - tensor: TchTensor, - dim: usize, - times: usize, - ) -> TchTensor { - TchOps::repeat(tensor, dim, times) - } - - fn bool_into_data(tensor: TchTensor) -> Reader> { - let shape = Self::bool_shape(&tensor); - let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); - let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); - - Reader::Concrete(Data::new(values.unwrap(), shape)) - } - - fn bool_to_device( - tensor: TchTensor, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) - } - - fn bool_reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - TchOps::reshape(tensor, shape) - } - - fn bool_device(tensor: &TchTensor) -> LibTorchDevice { - tensor.tensor.device().into() - } - - fn bool_empty( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let tensor = tch::Tensor::empty( - shape.dims.map(|a| a as i64), - (tch::Kind::Bool, (*device).into()), - ); - - TchTensor::new(tensor) - } - - fn bool_slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - TchOps::slice(tensor, ranges) - } - - fn bool_slice_assign( - tensor: TchTensor, - ranges: [std::ops::Range; D2], - value: TchTensor, - ) -> TchTensor { - TchOps::slice_assign(tensor, ranges, value) - } - - fn bool_cat(tensors: Vec>, dim: usize) -> TchTensor { - TchOps::cat(tensors, dim) - } - - fn bool_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::equal(lhs, rhs) - } - - fn bool_not(tensor: TchTensor) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool), - |tensor| tensor.eq(0), - ) - } - - fn bool_into_int(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(tch::Kind::Int64); - TchTensor::new(tensor) - } - - fn bool_into_float(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(E::KIND); - TchTensor::new(tensor) - } - - fn bool_swap_dims( - tensor: as Backend>::BoolTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::BoolTensorPrimitive { - TchOps::swap_dims(tensor, dim1, dim2) - } + fn bool_from_data( + data: Data, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::from_data(data, (*device).into()) + } + + fn bool_shape(tensor: &TchTensor) -> Shape { + tensor.shape() + } + + fn bool_repeat( + tensor: TchTensor, + dim: usize, + times: usize, + ) -> TchTensor { + TchOps::repeat(tensor, dim, times) + } + + fn bool_into_data(tensor: TchTensor) -> Reader> { + let shape = Self::bool_shape(&tensor); + let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); + + Reader::Concrete(Data::new(values.unwrap(), shape)) + } + + fn bool_to_device( + tensor: TchTensor, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::new(tensor.tensor.to((*device).into())) + } + + fn bool_reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + TchOps::reshape(tensor, shape) + } + + fn bool_device(tensor: &TchTensor) -> LibTorchDevice { + tensor.tensor.device().into() + } + + fn bool_empty( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let tensor = tch::Tensor::empty( + shape.dims.map(|a| a as i64), + (tch::Kind::Bool, (*device).into()), + ); + + TchTensor::new(tensor) + } + + fn bool_slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + TchOps::slice(tensor, ranges) + } + + fn bool_slice_assign( + tensor: TchTensor, + ranges: [std::ops::Range; D2], + value: TchTensor, + ) -> TchTensor { + TchOps::slice_assign(tensor, ranges, value) + } + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> TchTensor { + TchOps::cat(tensors, dim) + } + + fn bool_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::equal(lhs, rhs) + } + + fn bool_not(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool), + |tensor| tensor.eq(0), + ) + } + + fn bool_into_int(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(tch::Kind::Int64); + TchTensor::new(tensor) + } + + fn bool_into_float(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(E::KIND); + TchTensor::new(tensor) + } + + fn bool_swap_dims( + tensor: as Backend>::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::BoolTensorPrimitive { + TchOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs index bb6448f5d6..16eb9af05c 100644 --- a/burn-tch/src/ops/int_tensor.rs +++ b/burn-tch/src/ops/int_tensor.rs @@ -7,365 +7,385 @@ use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor}; use super::TchOps; impl IntTensorOps for LibTorch { - fn int_from_data( - data: Data, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::from_data(data, (*device).into()) - } - - fn int_shape(tensor: &TchTensor) -> Shape { - tensor.shape() - } - - fn int_repeat( - tensor: TchTensor, - dim: usize, - times: usize, - ) -> TchTensor { - TchOps::repeat(tensor, dim, times) - } - - fn int_into_data(tensor: TchTensor) -> Reader> { - let shape = Self::int_shape(&tensor); - let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); - let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); - - Reader::Concrete(Data::new(values.unwrap(), shape)) - } - - fn int_to_device( - tensor: TchTensor, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) - } - - fn int_reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - TchOps::reshape(tensor, shape) - } - - fn int_device(tensor: &TchTensor) -> LibTorchDevice { - tensor.tensor.device().into() - } - - fn int_empty( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let tensor = tch::Tensor::empty( - shape.dims.map(|a| a as i64), - (tch::Kind::Int64, (*device).into()), - ); - - TchTensor::new(tensor) - } - - fn int_slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - TchOps::slice(tensor, ranges) - } - - fn int_slice_assign( - tensor: TchTensor, - ranges: [std::ops::Range; D2], - value: TchTensor, - ) -> TchTensor { - TchOps::slice_assign(tensor, ranges, value) - } - - fn int_cat(tensors: Vec>, dim: usize) -> TchTensor { - TchOps::cat(tensors, dim) - } - - fn int_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::equal(lhs, rhs) - } - - fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor { - TchOps::equal_elem(lhs, rhs) - } - - fn int_greater( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::greater(lhs, rhs) - } - - fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor { - TchOps::greater_elem(lhs, rhs) - } - - fn int_greater_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::greater_equal(lhs, rhs) - } - - fn int_greater_equal_elem( - lhs: TchTensor, - rhs: i64, - ) -> TchTensor { - TchOps::greater_equal_elem(lhs, rhs) - } - - fn int_lower( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::lower(lhs, rhs) - } - - fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor { - TchOps::lower_elem(lhs, rhs) - } - - fn int_lower_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::lower_equal(lhs, rhs) - } - - fn int_lower_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor { - TchOps::lower_equal_elem(lhs, rhs) - } - - fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::add(lhs, rhs) - } - - fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), - |tensor| tensor.f_add_scalar(rhs).unwrap(), - ) - } - - fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::sub(lhs, rhs) - } - - fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), - |tensor| tensor.f_sub_scalar(rhs).unwrap(), - ) - } - - fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::mul(lhs, rhs) - } - - fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - lhs.unary_ops( - |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), - |tensor| tensor.f_mul_scalar(rhs).unwrap(), - ) - } - - fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::div(lhs, rhs) - } - - fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - let lhs: TchTensor = TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false)); - let output: TchTensor = lhs.unary_ops( - |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), - |tensor| tensor.f_div_scalar(rhs).unwrap(), - ); - TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) - } - - fn int_neg(tensor: TchTensor) -> TchTensor { - Self::int_mul_scalar(tensor, -1) - } - - fn int_zeros( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device))) - } - - fn int_ones( - shape: Shape, - device: & as Backend>::Device, - ) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device))) - } - - fn int_full( - shape: Shape, - fill_value: i64, - device: & as Backend>::Device, - ) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::full( - shape.dims, - fill_value, - (tch::Kind::Int64, device), - )) - } - - fn int_sum(tensor: TchTensor) -> TchTensor { - TchOps::sum(tensor) - } - - fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::sum_dim(tensor, dim) - } - - fn int_mean(tensor: TchTensor) -> TchTensor { - let tensor: TchTensor = - TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); - let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor); - - TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) - } - - fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { - let tensor: TchTensor = - TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); - - let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor); - - TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) - } - - fn int_gather( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - ) -> TchTensor { - TchOps::gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - ) -> TchTensor { - TchOps::index_select_dim(tensor, dim, indices) - } - - fn int_select_assign( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::select_assign(tensor, dim, indices, value) - } - - fn int_mask_where( - tensor: TchTensor, - mask: TchTensor, - source: TchTensor, - ) -> TchTensor { - TchTensor::binary_ops_tensor( - tensor, - source, - |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), - |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), - |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), - ) - } - - fn int_mask_fill( - tensor: TchTensor, - mask: TchTensor, - value: i64, - ) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), - |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), - ) - } - - fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmax(tensor, dim) - } - - fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmin(tensor, dim) - } - - fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::max_dim(tensor, dim) - } - - fn int_max_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::max_dim_with_indices(tensor, dim) - } - - fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::min_dim(tensor, dim) - } - - fn int_min_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::min_dim_with_indices(tensor, dim) - } - - fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor { - TchOps::clamp_min(tensor, min) - } - - fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor { - TchOps::clamp_max(tensor, max) - } - - fn int_clamp(tensor: TchTensor, min: i64, max: i64) -> TchTensor { - TchOps::clamp(tensor, min, max) - } - - fn int_abs(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) - } - - fn int_into_float(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(E::KIND); - TchTensor::new(tensor) - } - - fn int_swap_dims( - tensor: as Backend>::IntTensorPrimitive, - dim1: usize, - dim2: usize, - ) -> as Backend>::IntTensorPrimitive { - TchOps::swap_dims(tensor, dim1, dim2) - } + fn int_from_data( + data: Data, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::from_data(data, (*device).into()) + } + + fn int_shape(tensor: &TchTensor) -> Shape { + tensor.shape() + } + + fn int_repeat( + tensor: TchTensor, + dim: usize, + times: usize, + ) -> TchTensor { + TchOps::repeat(tensor, dim, times) + } + + fn int_into_data(tensor: TchTensor) -> Reader> { + let shape = Self::int_shape(&tensor); + let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); + + Reader::Concrete(Data::new(values.unwrap(), shape)) + } + + fn int_to_device( + tensor: TchTensor, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::new(tensor.tensor.to((*device).into())) + } + + fn int_reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + TchOps::reshape(tensor, shape) + } + + fn int_device(tensor: &TchTensor) -> LibTorchDevice { + tensor.tensor.device().into() + } + + fn int_empty( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let tensor = tch::Tensor::empty( + shape.dims.map(|a| a as i64), + (tch::Kind::Int64, (*device).into()), + ); + + TchTensor::new(tensor) + } + + fn int_slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + TchOps::slice(tensor, ranges) + } + + fn int_slice_assign( + tensor: TchTensor, + ranges: [std::ops::Range; D2], + value: TchTensor, + ) -> TchTensor { + TchOps::slice_assign(tensor, ranges, value) + } + + fn int_cat(tensors: Vec>, dim: usize) -> TchTensor { + TchOps::cat(tensors, dim) + } + + fn int_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::equal(lhs, rhs) + } + + fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor { + TchOps::equal_elem(lhs, rhs) + } + + fn int_greater( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::greater(lhs, rhs) + } + + fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor { + TchOps::greater_elem(lhs, rhs) + } + + fn int_greater_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::greater_equal(lhs, rhs) + } + + fn int_greater_equal_elem( + lhs: TchTensor, + rhs: i64, + ) -> TchTensor { + TchOps::greater_equal_elem(lhs, rhs) + } + + fn int_lower( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::lower(lhs, rhs) + } + + fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor { + TchOps::lower_elem(lhs, rhs) + } + + fn int_lower_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::lower_equal(lhs, rhs) + } + + fn int_lower_equal_elem( + lhs: TchTensor, + rhs: i64, + ) -> TchTensor { + TchOps::lower_equal_elem(lhs, rhs) + } + + fn int_add( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::add(lhs, rhs) + } + + fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), + |tensor| tensor.f_add_scalar(rhs).unwrap(), + ) + } + + fn int_sub( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::sub(lhs, rhs) + } + + fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), + |tensor| tensor.f_sub_scalar(rhs).unwrap(), + ) + } + + fn int_mul( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::mul(lhs, rhs) + } + + fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + lhs.unary_ops( + |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), + |tensor| tensor.f_mul_scalar(rhs).unwrap(), + ) + } + + fn int_div( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::div(lhs, rhs) + } + + fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { + let lhs: TchTensor = + TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false)); + let output: TchTensor = lhs.unary_ops( + |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), + |tensor| tensor.f_div_scalar(rhs).unwrap(), + ); + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) + } + + fn int_neg(tensor: TchTensor) -> TchTensor { + Self::int_mul_scalar(tensor, -1) + } + + fn int_zeros( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device))) + } + + fn int_ones( + shape: Shape, + device: & as Backend>::Device, + ) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device))) + } + + fn int_full( + shape: Shape, + fill_value: i64, + device: & as Backend>::Device, + ) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::full( + shape.dims, + fill_value, + (tch::Kind::Int64, device), + )) + } + + fn int_sum(tensor: TchTensor) -> TchTensor { + TchOps::sum(tensor) + } + + fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::sum_dim(tensor, dim) + } + + fn int_mean(tensor: TchTensor) -> TchTensor { + let tensor: TchTensor = + TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); + let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor); + + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) + } + + fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { + let tensor: TchTensor = + TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); + + let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor); + + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) + } + + fn int_gather( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + ) -> TchTensor { + TchOps::gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + ) -> TchTensor { + TchOps::index_select_dim(tensor, dim, indices) + } + + fn int_select_assign( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::select_assign(tensor, dim, indices, value) + } + + fn int_mask_where( + tensor: TchTensor, + mask: TchTensor, + source: TchTensor, + ) -> TchTensor { + TchTensor::binary_ops_tensor( + tensor, + source, + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + ) + } + + fn int_mask_fill( + tensor: TchTensor, + mask: TchTensor, + value: i64, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), + |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), + ) + } + + fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmax(tensor, dim) + } + + fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmin(tensor, dim) + } + + fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::max_dim(tensor, dim) + } + + fn int_max_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::max_dim_with_indices(tensor, dim) + } + + fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::min_dim(tensor, dim) + } + + fn int_min_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::min_dim_with_indices(tensor, dim) + } + + fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor { + TchOps::clamp_min(tensor, min) + } + + fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor { + TchOps::clamp_max(tensor, max) + } + + fn int_clamp( + tensor: TchTensor, + min: i64, + max: i64, + ) -> TchTensor { + TchOps::clamp(tensor, min, max) + } + + fn int_abs(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) + } + + fn int_into_float(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(E::KIND); + TchTensor::new(tensor) + } + + fn int_swap_dims( + tensor: as Backend>::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> as Backend>::IntTensorPrimitive { + TchOps::swap_dims(tensor, dim1, dim2) + } } diff --git a/burn-tch/src/ops/module.rs b/burn-tch/src/ops/module.rs index db522187d3..c4b0e156ca 100644 --- a/burn-tch/src/ops/module.rs +++ b/burn-tch/src/ops/module.rs @@ -1,286 +1,286 @@ use crate::{element::TchElement, LibTorch, TchTensor}; use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, - ModuleOps, + ConvOptions, ConvTransposeOptions, MaxPool1dWithIndices, MaxPool2dBackward, + MaxPool2dWithIndices, ModuleOps, }; impl ModuleOps for LibTorch { - fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor { - let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false); - - TchTensor::new(tensor) - } - - fn embedding_backward( - weights: TchTensor, - output: TchTensor, - indices: TchTensor, - ) -> TchTensor { - let [n_embedding, _d_model] = weights.shape().dims; - let tensor = tch::Tensor::embedding_backward( - &output.tensor, - &indices.tensor, - n_embedding as i64, - -1, - false, - false, - ); - - TchTensor::new(tensor) - } - - fn conv1d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> TchTensor { - let tensor = tch::Tensor::conv1d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.dilation.map(|i| i as i64), - options.groups as i64, - ); - - TchTensor::new(tensor) - } - - fn conv2d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> TchTensor { - let tensor = tch::Tensor::conv2d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.dilation.map(|i| i as i64), - options.groups as i64, - ); - - TchTensor::new(tensor) - } - - fn conv_transpose2d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> TchTensor { - let tensor = tch::Tensor::conv_transpose2d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.padding_out.map(|i| i as i64), - options.groups as i64, - options.dilation.map(|i| i as i64), - ); - - TchTensor::new(tensor) - } - - fn conv_transpose1d( - x: TchTensor, - weight: TchTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> TchTensor { - let tensor = tch::Tensor::conv_transpose1d( - &x.tensor, - &weight.tensor, - bias.map(|t| t.tensor), - options.stride.map(|i| i as i64), - options.padding.map(|i| i as i64), - options.padding_out.map(|i| i as i64), - options.groups as i64, - options.dilation.map(|i| i as i64), - ); - - TchTensor::new(tensor) - } - - fn avg_pool1d( - x: TchTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> TchTensor { - let tensor = tch::Tensor::avg_pool1d( - &x.tensor, - [kernel_size as i64], - [stride as i64], - [padding as i64], - false, - count_include_pad, - ); - - TchTensor::new(tensor) - } - fn avg_pool2d( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> TchTensor { - let tensor = tch::Tensor::avg_pool2d( - &x.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - false, - count_include_pad, - None, - ); - - TchTensor::new(tensor) - } - - fn avg_pool2d_backward( - x: TchTensor, - grad: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> TchTensor { - let tensor = tch::Tensor::avg_pool2d_backward( - &x.tensor, - &grad.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - false, - count_include_pad, - None, - ); - - TchTensor::new(tensor) - } - - fn max_pool1d( - x: TchTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> TchTensor { - let tensor = tch::Tensor::max_pool1d( - &x.tensor, - kernel_size as i64, - stride as i64, - padding as i64, - dilation as i64, - false, - ); - - TchTensor::new(tensor) - } - - fn max_pool1d_with_indices( - x: TchTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices> { - let (tensor, indices) = tch::Tensor::max_pool1d_with_indices( - &x.tensor, - kernel_size as i64, - stride as i64, - padding as i64, - dilation as i64, - false, - ); - - MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) - } - - fn max_pool2d( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> TchTensor { - let tensor = tch::Tensor::max_pool2d( - &x.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - [dilation[0] as i64, dilation[1] as i64], - false, - ); - - TchTensor::new(tensor) - } - - fn max_pool2d_with_indices( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - let (tensor, indices) = tch::Tensor::max_pool2d_with_indices( - &x.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - [dilation[0] as i64, dilation[1] as i64], - false, - ); - - MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) - } - - fn max_pool2d_with_indices_backward( - x: TchTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: TchTensor, - indices: TchTensor, - ) -> MaxPool2dBackward> { - let grad = tch::Tensor::max_pool2d_with_indices_backward( - &x.tensor, - &output_grad.tensor, - [kernel_size[0] as i64, kernel_size[1] as i64], - [stride[0] as i64, stride[1] as i64], - [padding[0] as i64, padding[1] as i64], - [dilation[0] as i64, dilation[1] as i64], - false, - &indices.tensor, - ); - - MaxPool2dBackward::new(TchTensor::new(grad)) - } - - fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor { - let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64)); - - TchTensor::new(tensor) - } - - fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor { - let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor); - - TchTensor::new(tensor) - } - - fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor { - let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64); - - TchTensor::new(tensor) - } + fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor { + let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false); + + TchTensor::new(tensor) + } + + fn embedding_backward( + weights: TchTensor, + output: TchTensor, + indices: TchTensor, + ) -> TchTensor { + let [n_embedding, _d_model] = weights.shape().dims; + let tensor = tch::Tensor::embedding_backward( + &output.tensor, + &indices.tensor, + n_embedding as i64, + -1, + false, + false, + ); + + TchTensor::new(tensor) + } + + fn conv1d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> TchTensor { + let tensor = tch::Tensor::conv1d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.dilation.map(|i| i as i64), + options.groups as i64, + ); + + TchTensor::new(tensor) + } + + fn conv2d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> TchTensor { + let tensor = tch::Tensor::conv2d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.dilation.map(|i| i as i64), + options.groups as i64, + ); + + TchTensor::new(tensor) + } + + fn conv_transpose2d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> TchTensor { + let tensor = tch::Tensor::conv_transpose2d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.padding_out.map(|i| i as i64), + options.groups as i64, + options.dilation.map(|i| i as i64), + ); + + TchTensor::new(tensor) + } + + fn conv_transpose1d( + x: TchTensor, + weight: TchTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> TchTensor { + let tensor = tch::Tensor::conv_transpose1d( + &x.tensor, + &weight.tensor, + bias.map(|t| t.tensor), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.padding_out.map(|i| i as i64), + options.groups as i64, + options.dilation.map(|i| i as i64), + ); + + TchTensor::new(tensor) + } + + fn avg_pool1d( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool1d( + &x.tensor, + [kernel_size as i64], + [stride as i64], + [padding as i64], + false, + count_include_pad, + ); + + TchTensor::new(tensor) + } + fn avg_pool2d( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool2d( + &x.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + false, + count_include_pad, + None, + ); + + TchTensor::new(tensor) + } + + fn avg_pool2d_backward( + x: TchTensor, + grad: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool2d_backward( + &x.tensor, + &grad.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + false, + count_include_pad, + None, + ); + + TchTensor::new(tensor) + } + + fn max_pool1d( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> TchTensor { + let tensor = tch::Tensor::max_pool1d( + &x.tensor, + kernel_size as i64, + stride as i64, + padding as i64, + dilation as i64, + false, + ); + + TchTensor::new(tensor) + } + + fn max_pool1d_with_indices( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices> { + let (tensor, indices) = tch::Tensor::max_pool1d_with_indices( + &x.tensor, + kernel_size as i64, + stride as i64, + padding as i64, + dilation as i64, + false, + ); + + MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) + } + + fn max_pool2d( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> TchTensor { + let tensor = tch::Tensor::max_pool2d( + &x.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [dilation[0] as i64, dilation[1] as i64], + false, + ); + + TchTensor::new(tensor) + } + + fn max_pool2d_with_indices( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + let (tensor, indices) = tch::Tensor::max_pool2d_with_indices( + &x.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [dilation[0] as i64, dilation[1] as i64], + false, + ); + + MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) + } + + fn max_pool2d_with_indices_backward( + x: TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: TchTensor, + indices: TchTensor, + ) -> MaxPool2dBackward> { + let grad = tch::Tensor::max_pool2d_with_indices_backward( + &x.tensor, + &output_grad.tensor, + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [dilation[0] as i64, dilation[1] as i64], + false, + &indices.tensor, + ); + + MaxPool2dBackward::new(TchTensor::new(grad)) + } + + fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor { + let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64)); + + TchTensor::new(tensor) + } + + fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor { + let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor); + + TchTensor::new(tensor) + } + + fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor { + let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64); + + TchTensor::new(tensor) + } } diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index 25c0a3ac17..2326b8a76f 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -1,438 +1,445 @@ use super::TchOps; use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor}; use burn_tensor::{ - backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape, + backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape, }; use std::ops::Range; impl TensorOps for LibTorch { - fn from_data(data: Data, device: &LibTorchDevice) -> TchTensor { - TchTensor::from_data(data, (*device).into()) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: &LibTorchDevice, - ) -> TchTensor { - match distribution { - Distribution::Default => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor - .mut_ops(|tensor| tensor.rand_like_out(tensor)) - .unwrap() - } - Distribution::Bernoulli(prob) => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor - .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap()) - .unwrap() - } - Distribution::Uniform(from, to) => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor - .mut_ops(|tensor| tensor.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap())) - .unwrap() - } - Distribution::Normal(mean, std) => { - let mut tensor = TchTensor::::empty(shape, *device); - tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap() - } - } - } - - fn arange(range: Range, device: &LibTorchDevice) -> TchTensor { - let device: tch::Device = (*device).into(); - let mut tensor = tch::Tensor::arange( - range.end as i64 - range.start as i64, - (tch::Kind::Int64, device), - ); - - if range.start != 0 { - tensor = tensor.f_add_scalar_(range.start as i64).unwrap(); - } - - TchTensor::new(tensor) - } - - fn repeat(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { - TchOps::repeat(tensor, dim, times) - } - - fn zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device))) - } - - fn ones(shape: Shape, device: &LibTorchDevice) -> TchTensor { - let shape = TchShape::from(shape); - let device: tch::Device = (*device).into(); - - TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device))) - } - - fn shape(tensor: & as Backend>::TensorPrimitive) -> Shape { - tensor.shape() - } - - fn into_data( - tensor: as Backend>::TensorPrimitive, - ) -> Reader as Backend>::FloatElem, D>> { - let shape = Self::shape(&tensor); - let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()])); - let values: Result, tch::TchError> = tensor.tensor.try_into(); - - Reader::Concrete(Data::new(values.unwrap(), shape)) - } - - fn device(tensor: &TchTensor) -> LibTorchDevice { - tensor.tensor.device().into() - } - - fn to_device( - tensor: TchTensor, - device: &LibTorchDevice, - ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) - } - - fn empty( - shape: Shape, - device: & as Backend>::Device, - ) -> as Backend>::TensorPrimitive { - let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into())); - - TchTensor::new(tensor) - } - - fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::add(lhs, rhs) - } - - fn add_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), - |tensor| tensor.f_add_scalar(rhs).unwrap(), - ) - } - - fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::sub(lhs, rhs) - } - - fn sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), - |tensor| tensor.f_sub_scalar(rhs).unwrap(), - ) - } - - fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::mul(lhs, rhs) - } - - fn mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), - |tensor| tensor.f_mul_scalar(rhs).unwrap(), - ) - } - - fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::div(lhs, rhs) - } - - fn div_scalar(lhs: TchTensor, rhs: E) -> TchTensor { - let rhs: f64 = rhs.elem(); - - lhs.unary_ops( - |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), - |tensor| tensor.f_div_scalar(rhs).unwrap(), - ) - } - - fn matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - let tensor = lhs.tensor.matmul(&rhs.tensor); - TchTensor::new(tensor) - } - - fn neg(tensor: TchTensor) -> TchTensor { - Self::mul_scalar(tensor, (-1f32).elem::()) - } - - fn recip(tensor: TchTensor) -> TchTensor { - TchTensor::new(tensor.tensor.reciprocal()) - } - - fn swap_dims( - tensor: TchTensor, - dim1: usize, - dim2: usize, - ) -> TchTensor { - TchOps::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: TchTensor, - shape: Shape, - ) -> TchTensor { - TchOps::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - ) -> TchTensor { - TchOps::gather(dim, tensor, indices) - } - - fn scatter( - dim: usize, - tensor: TchTensor, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::scatter(dim, tensor, indices, value) - } - - fn select( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - ) -> TchTensor { - TchOps::index_select_dim(tensor, dim, indices) - } - - fn select_assign( - tensor: TchTensor, - dim: usize, - indices: TchTensor, - value: TchTensor, - ) -> TchTensor { - TchOps::select_assign(tensor, dim, indices, value) - } - - fn slice( - tensor: TchTensor, - ranges: [Range; D2], - ) -> TchTensor { - TchOps::slice(tensor, ranges) - } - - fn slice_assign( - tensor: TchTensor, - ranges: [Range; D2], - value: TchTensor, - ) -> as Backend>::TensorPrimitive { - TchOps::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: TchTensor, - mask: TchTensor, - value: TchTensor, - ) -> TchTensor { - let output = value.tensor.where_self(&mask.tensor, &tensor.tensor); - - TchTensor::new(output) - } - - fn mask_fill( - tensor: TchTensor, - mask: TchTensor, - value: E, - ) -> TchTensor { - let value: f64 = value.elem(); - - tensor.unary_ops( - |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), - |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), - ) - } - - fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::equal(lhs, rhs) - } - - fn equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::equal_elem(lhs, rhs.elem::()) - } - - fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::greater(lhs, rhs) - } - - fn greater_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::greater_elem(lhs, rhs.elem::()) - } - - fn greater_equal( - lhs: TchTensor, - rhs: TchTensor, - ) -> TchTensor { - TchOps::greater_equal(lhs, rhs) - } - - fn greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::greater_equal_elem(lhs, rhs.elem::()) - } - - fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::lower(lhs, rhs) - } - - fn lower_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::lower_elem(lhs, rhs.elem::()) - } - - fn lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { - TchOps::lower_equal(lhs, rhs) - } - - fn lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { - TchOps::lower_equal_elem(lhs, rhs.elem::()) - } - - fn mean(tensor: TchTensor) -> TchTensor { - TchOps::mean(tensor) - } - - fn sum(tensor: TchTensor) -> TchTensor { - TchOps::sum(tensor) - } - - fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::mean_dim(tensor, dim) - } - - fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::sum_dim(tensor, dim) - } - - fn to_full_precision(tensor: &TchTensor) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.to_kind(tch::Kind::Float); - - TchTensor::from_existing(tensor, storage) - } - - fn from_full_precision(tensor: TchTensor) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.to_kind(E::KIND); - - TchTensor::from_existing(tensor, storage) - } - - fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmax(tensor, dim) - } - - fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::argmin(tensor, dim) - } - - fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::max_dim(tensor, dim) - } - - fn max_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::max_dim_with_indices(tensor, dim) - } - - fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::min_dim(tensor, dim) - } - - fn min_dim_with_indices( - tensor: TchTensor, - dim: usize, - ) -> (TchTensor, TchTensor) { - TchOps::min_dim_with_indices(tensor, dim) - } - - fn exp(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp()) - } - - fn log(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log()) - } - - fn log1p(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p()) - } - - fn powf(tensor: TchTensor, value: f32) -> TchTensor { - tensor.unary_ops( - |mut tensor| tensor.f_pow_(value as f64).unwrap(), - |tensor| tensor.pow_tensor_scalar(value as f64), - ) - } - - fn sqrt(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt()) - } - - fn abs(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) - } - - fn cos(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos()) - } - - fn sin(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin()) - } - - fn tanh(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) - } - - fn erf(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) - } - - fn cat(tensors: Vec>, dim: usize) -> TchTensor { - TchOps::cat(tensors, dim) - } - - fn clamp_min( - tensor: TchTensor, - min: E, - ) -> as Backend>::TensorPrimitive { - TchOps::clamp_min(tensor, min.elem::()) - } - - fn clamp_max( - tensor: as Backend>::TensorPrimitive, - max: as Backend>::FloatElem, - ) -> as Backend>::TensorPrimitive { - TchOps::clamp_max(tensor, max.elem::()) - } - - fn clamp( - tensor: as Backend>::TensorPrimitive, - min: as Backend>::FloatElem, - max: as Backend>::FloatElem, - ) -> as Backend>::TensorPrimitive { - TchOps::clamp(tensor, min.elem::(), max.elem::()) - } - - fn into_int(tensor: TchTensor) -> TchTensor { - let tensor = tensor.tensor.to_kind(tch::Kind::Int64); - TchTensor::new(tensor) - } + fn from_data(data: Data, device: &LibTorchDevice) -> TchTensor { + TchTensor::from_data(data, (*device).into()) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: &LibTorchDevice, + ) -> TchTensor { + match distribution { + Distribution::Default => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor + .mut_ops(|tensor| tensor.rand_like_out(tensor)) + .unwrap() + } + Distribution::Bernoulli(prob) => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor + .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap()) + .unwrap() + } + Distribution::Uniform(from, to) => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor + .mut_ops(|tensor| tensor.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap())) + .unwrap() + } + Distribution::Normal(mean, std) => { + let mut tensor = TchTensor::::empty(shape, *device); + tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap() + } + } + } + + fn arange(range: Range, device: &LibTorchDevice) -> TchTensor { + let device: tch::Device = (*device).into(); + let mut tensor = tch::Tensor::arange( + range.end as i64 - range.start as i64, + (tch::Kind::Int64, device), + ); + + if range.start != 0 { + tensor = tensor.f_add_scalar_(range.start as i64).unwrap(); + } + + TchTensor::new(tensor) + } + + fn repeat( + tensor: TchTensor, + dim: usize, + times: usize, + ) -> TchTensor { + TchOps::repeat(tensor, dim, times) + } + + fn zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device))) + } + + fn ones(shape: Shape, device: &LibTorchDevice) -> TchTensor { + let shape = TchShape::from(shape); + let device: tch::Device = (*device).into(); + + TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device))) + } + + fn shape(tensor: & as Backend>::TensorPrimitive) -> Shape { + tensor.shape() + } + + fn into_data( + tensor: as Backend>::TensorPrimitive, + ) -> Reader as Backend>::FloatElem, D>> { + let shape = Self::shape(&tensor); + let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.try_into(); + + Reader::Concrete(Data::new(values.unwrap(), shape)) + } + + fn device(tensor: &TchTensor) -> LibTorchDevice { + tensor.tensor.device().into() + } + + fn to_device( + tensor: TchTensor, + device: &LibTorchDevice, + ) -> TchTensor { + TchTensor::new(tensor.tensor.to((*device).into())) + } + + fn empty( + shape: Shape, + device: & as Backend>::Device, + ) -> as Backend>::TensorPrimitive { + let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into())); + + TchTensor::new(tensor) + } + + fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::add(lhs, rhs) + } + + fn add_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), + |tensor| tensor.f_add_scalar(rhs).unwrap(), + ) + } + + fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::sub(lhs, rhs) + } + + fn sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), + |tensor| tensor.f_sub_scalar(rhs).unwrap(), + ) + } + + fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::mul(lhs, rhs) + } + + fn mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), + |tensor| tensor.f_mul_scalar(rhs).unwrap(), + ) + } + + fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::div(lhs, rhs) + } + + fn div_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + let rhs: f64 = rhs.elem(); + + lhs.unary_ops( + |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), + |tensor| tensor.f_div_scalar(rhs).unwrap(), + ) + } + + fn matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + let tensor = lhs.tensor.matmul(&rhs.tensor); + TchTensor::new(tensor) + } + + fn neg(tensor: TchTensor) -> TchTensor { + Self::mul_scalar(tensor, (-1f32).elem::()) + } + + fn recip(tensor: TchTensor) -> TchTensor { + TchTensor::new(tensor.tensor.reciprocal()) + } + + fn swap_dims( + tensor: TchTensor, + dim1: usize, + dim2: usize, + ) -> TchTensor { + TchOps::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: TchTensor, + shape: Shape, + ) -> TchTensor { + TchOps::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + ) -> TchTensor { + TchOps::gather(dim, tensor, indices) + } + + fn scatter( + dim: usize, + tensor: TchTensor, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::scatter(dim, tensor, indices, value) + } + + fn select( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + ) -> TchTensor { + TchOps::index_select_dim(tensor, dim, indices) + } + + fn select_assign( + tensor: TchTensor, + dim: usize, + indices: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::select_assign(tensor, dim, indices, value) + } + + fn slice( + tensor: TchTensor, + ranges: [Range; D2], + ) -> TchTensor { + TchOps::slice(tensor, ranges) + } + + fn slice_assign( + tensor: TchTensor, + ranges: [Range; D2], + value: TchTensor, + ) -> as Backend>::TensorPrimitive { + TchOps::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: TchTensor, + mask: TchTensor, + value: TchTensor, + ) -> TchTensor { + let output = value.tensor.where_self(&mask.tensor, &tensor.tensor); + + TchTensor::new(output) + } + + fn mask_fill( + tensor: TchTensor, + mask: TchTensor, + value: E, + ) -> TchTensor { + let value: f64 = value.elem(); + + tensor.unary_ops( + |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), + |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), + ) + } + + fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::equal(lhs, rhs) + } + + fn equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::equal_elem(lhs, rhs.elem::()) + } + + fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::greater(lhs, rhs) + } + + fn greater_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::greater_elem(lhs, rhs.elem::()) + } + + fn greater_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::greater_equal(lhs, rhs) + } + + fn greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::greater_equal_elem(lhs, rhs.elem::()) + } + + fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchOps::lower(lhs, rhs) + } + + fn lower_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::lower_elem(lhs, rhs.elem::()) + } + + fn lower_equal( + lhs: TchTensor, + rhs: TchTensor, + ) -> TchTensor { + TchOps::lower_equal(lhs, rhs) + } + + fn lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + TchOps::lower_equal_elem(lhs, rhs.elem::()) + } + + fn mean(tensor: TchTensor) -> TchTensor { + TchOps::mean(tensor) + } + + fn sum(tensor: TchTensor) -> TchTensor { + TchOps::sum(tensor) + } + + fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::mean_dim(tensor, dim) + } + + fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::sum_dim(tensor, dim) + } + + fn to_full_precision(tensor: &TchTensor) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.to_kind(tch::Kind::Float); + + TchTensor::from_existing(tensor, storage) + } + + fn from_full_precision(tensor: TchTensor) -> TchTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.to_kind(E::KIND); + + TchTensor::from_existing(tensor, storage) + } + + fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmax(tensor, dim) + } + + fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::argmin(tensor, dim) + } + + fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::max_dim(tensor, dim) + } + + fn max_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::max_dim_with_indices(tensor, dim) + } + + fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::min_dim(tensor, dim) + } + + fn min_dim_with_indices( + tensor: TchTensor, + dim: usize, + ) -> (TchTensor, TchTensor) { + TchOps::min_dim_with_indices(tensor, dim) + } + + fn exp(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp()) + } + + fn log(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log()) + } + + fn log1p(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p()) + } + + fn powf(tensor: TchTensor, value: f32) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_pow_(value as f64).unwrap(), + |tensor| tensor.pow_tensor_scalar(value as f64), + ) + } + + fn sqrt(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt()) + } + + fn abs(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) + } + + fn cos(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos()) + } + + fn sin(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin()) + } + + fn tanh(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) + } + + fn erf(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) + } + + fn cat(tensors: Vec>, dim: usize) -> TchTensor { + TchOps::cat(tensors, dim) + } + + fn clamp_min( + tensor: TchTensor, + min: E, + ) -> as Backend>::TensorPrimitive { + TchOps::clamp_min(tensor, min.elem::()) + } + + fn clamp_max( + tensor: as Backend>::TensorPrimitive, + max: as Backend>::FloatElem, + ) -> as Backend>::TensorPrimitive { + TchOps::clamp_max(tensor, max.elem::()) + } + + fn clamp( + tensor: as Backend>::TensorPrimitive, + min: as Backend>::FloatElem, + max: as Backend>::FloatElem, + ) -> as Backend>::TensorPrimitive { + TchOps::clamp(tensor, min.elem::(), max.elem::()) + } + + fn into_int(tensor: TchTensor) -> TchTensor { + let tensor = tensor.tensor.to_kind(tch::Kind::Int64); + TchTensor::new(tensor) + } } diff --git a/burn-tch/src/tensor.rs b/burn-tch/src/tensor.rs index 1f3e73b9c0..1ca6f44f36 100644 --- a/burn-tch/src/tensor.rs +++ b/burn-tch/src/tensor.rs @@ -9,61 +9,61 @@ pub type StorageRef = Rc<*mut c_void>; /// A tensor that uses the tch backend. #[derive(Debug, PartialEq)] pub struct TchTensor { - /// Handle to the tensor. Call methods on this field. - pub tensor: tch::Tensor, - /// The tensor's storage - pub storage: StorageRef, - phantom: PhantomData, + /// Handle to the tensor. Call methods on this field. + pub tensor: tch::Tensor, + /// The tensor's storage + pub storage: StorageRef, + phantom: PhantomData, } impl TchTensor { - /// Create a new tensor. - /// - /// Note that if the tensor was created from an operation that may reuse the same tensor - /// storage as the parent, you should use [from_existing](TchTensor::from_existing) - /// instead. - pub fn new(tensor: tch::Tensor) -> Self { - let data = Rc::new(tensor.data_ptr()); - - Self { - tensor, - phantom: PhantomData, - storage: data, + /// Create a new tensor. + /// + /// Note that if the tensor was created from an operation that may reuse the same tensor + /// storage as the parent, you should use [from_existing](TchTensor::from_existing) + /// instead. + pub fn new(tensor: tch::Tensor) -> Self { + let data = Rc::new(tensor.data_ptr()); + + Self { + tensor, + phantom: PhantomData, + storage: data, + } } - } - - /// Create a tensor that was created from an operation executed on a parent tensor. - /// - /// If the child tensor shared the same storage as its parent, it will be cloned, effectively - /// tracking how much tensors point to the same memory space. - pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self { - let storage_child = tensor.data_ptr(); - - let storage = match storage_child == *storage_parent { - true => storage_parent.clone(), - false => Rc::new(storage_child), - }; - - Self { - tensor, - storage, - phantom: PhantomData, + + /// Create a tensor that was created from an operation executed on a parent tensor. + /// + /// If the child tensor shared the same storage as its parent, it will be cloned, effectively + /// tracking how much tensors point to the same memory space. + pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self { + let storage_child = tensor.data_ptr(); + + let storage = match storage_child == *storage_parent { + true => storage_parent.clone(), + false => Rc::new(storage_child), + }; + + Self { + tensor, + storage, + phantom: PhantomData, + } } - } } impl std::ops::Add for TchTensor { - type Output = Self; + type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - LibTorch::add(self, rhs) - } + fn add(self, rhs: Self) -> Self::Output { + LibTorch::add(self, rhs) + } } impl TchTensor { - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.tensor.size()) - } + pub(crate) fn shape(&self) -> Shape { + Shape::from(self.tensor.size()) + } } // This is safe since we don't use autodiff from LibTorch. @@ -73,209 +73,209 @@ unsafe impl Send for TchTensor {} unsafe impl Sync for TchTensor {} impl TchTensor { - /// Execute an operation on a tensor if the data can be reused. - pub fn mut_ops< - F: Fn(&mut tch::Tensor) -> tch::Tensor, - EOut: tch::kind::Element, - const D_OUT: usize, - >( - &mut self, - func: F, - ) -> Option> { - if Rc::strong_count(&self.storage) > 1 { - return None; + /// Execute an operation on a tensor if the data can be reused. + pub fn mut_ops< + F: Fn(&mut tch::Tensor) -> tch::Tensor, + EOut: tch::kind::Element, + const D_OUT: usize, + >( + &mut self, + func: F, + ) -> Option> { + if Rc::strong_count(&self.storage) > 1 { + return None; + } + + let data = self.storage.clone(); + Some(TchTensor::from_existing(func(&mut self.tensor), data)) } + /// Execute a unary ops reusing the tensor data if possible. + pub fn unary_ops( + self, + fown: FOwn, + fref: FRef, + ) -> TchTensor + where + FOwn: Fn(tch::Tensor) -> tch::Tensor, + FRef: Fn(&tch::Tensor) -> tch::Tensor, + { + if Rc::strong_count(&self.storage) > 1 { + return TchTensor::from_existing(fref(&self.tensor), self.storage); + } - let data = self.storage.clone(); - Some(TchTensor::from_existing(func(&mut self.tensor), data)) - } - /// Execute a unary ops reusing the tensor data if possible. - pub fn unary_ops( - self, - fown: FOwn, - fref: FRef, - ) -> TchTensor - where - FOwn: Fn(tch::Tensor) -> tch::Tensor, - FRef: Fn(&tch::Tensor) -> tch::Tensor, - { - if Rc::strong_count(&self.storage) > 1 { - return TchTensor::from_existing(fref(&self.tensor), self.storage); + TchTensor::from_existing(fown(self.tensor), self.storage) } - TchTensor::from_existing(fown(self.tensor), self.storage) - } - - /// Execute a binary ops reusing the tensor data if possible. - pub fn binary_ops_tensor( - mut lhs: Self, - mut rhs: Self, - flmut: FLMut, - frmut: FRMut, - fref: FRef, - ) -> TchTensor - where - FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor, - FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor, - FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor, - { - let lhs_num_elems = lhs.shape().num_elements(); - let rhs_num_elems = rhs.shape().num_elements(); - - let safe_mut_lhs = lhs_num_elems > rhs_num_elems; - let safe_mut_rhs = rhs_num_elems > lhs_num_elems; - - if safe_mut_lhs { - if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) { - return output; - } - } + /// Execute a binary ops reusing the tensor data if possible. + pub fn binary_ops_tensor( + mut lhs: Self, + mut rhs: Self, + flmut: FLMut, + frmut: FRMut, + fref: FRef, + ) -> TchTensor + where + FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor, + FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor, + FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor, + { + let lhs_num_elems = lhs.shape().num_elements(); + let rhs_num_elems = rhs.shape().num_elements(); - if safe_mut_rhs { - if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) { - return output; - } - } + let safe_mut_lhs = lhs_num_elems > rhs_num_elems; + let safe_mut_rhs = rhs_num_elems > lhs_num_elems; + + if safe_mut_lhs { + if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) { + return output; + } + } - let storage = lhs.storage; - let tensor = fref(&lhs.tensor, &rhs.tensor); + if safe_mut_rhs { + if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) { + return output; + } + } - TchTensor::from_existing(tensor, storage) - } + let storage = lhs.storage; + let tensor = fref(&lhs.tensor, &rhs.tensor); + + TchTensor::from_existing(tensor, storage) + } } impl Clone for TchTensor { - fn clone(&self) -> Self { - Self { - tensor: self.tensor.shallow_clone(), - phantom: PhantomData, - storage: self.storage.clone(), + fn clone(&self) -> Self { + Self { + tensor: self.tensor.shallow_clone(), + phantom: PhantomData, + storage: self.storage.clone(), + } } - } } /// A shape that can be used by LibTorch. pub struct TchShape { - /// The shape's dimensions. - pub dims: [i64; D], + /// The shape's dimensions. + pub dims: [i64; D], } impl From> for TchShape { - fn from(shape: Shape) -> Self { - let mut dims = [0; D]; - for (i, dim) in dims.iter_mut().enumerate().take(D) { - *dim = shape.dims[i] as i64; + fn from(shape: Shape) -> Self { + let mut dims = [0; D]; + for (i, dim) in dims.iter_mut().enumerate().take(D) { + *dim = shape.dims[i] as i64; + } + TchShape { dims } } - TchShape { dims } - } } impl TchTensor { - /// Creates a new tensor from a shape and a device. - /// - /// # Arguments - /// - /// * `data` - The tensor's data. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// A new tensor. - pub fn from_data(data: Data, device: tch::Device) -> Self { - let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device); - let shape_tch = TchShape::from(data.shape); - let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND); - - Self::new(tensor) - } + /// Creates a new tensor from a shape and a device. + /// + /// # Arguments + /// + /// * `data` - The tensor's data. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// A new tensor. + pub fn from_data(data: Data, device: tch::Device) -> Self { + let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device); + let shape_tch = TchShape::from(data.shape); + let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND); + + Self::new(tensor) + } } #[cfg(test)] mod utils { - use super::*; - use crate::{backend::LibTorch, element::TchElement}; - - impl TchTensor { - pub(crate) fn into_data(self) -> Data - where - P: tch::kind::Element, - { - as TensorOps>>::into_data(self).read() + use super::*; + use crate::{backend::LibTorch, element::TchElement}; + + impl TchTensor { + pub(crate) fn into_data(self) -> Data + where + P: tch::kind::Element, + { + as TensorOps>>::into_data(self).read() + } } - } } impl TchTensor { - /// Creates an empty tensor from a shape and a device. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// A new empty tensor. - pub fn empty(shape: Shape, device: LibTorchDevice) -> Self { - let shape_tch = TchShape::from(shape); - let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into())); - - Self::new(tensor) - } + /// Creates an empty tensor from a shape and a device. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// A new empty tensor. + pub fn empty(shape: Shape, device: LibTorchDevice) -> Self { + let shape_tch = TchShape::from(shape); + let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into())); + + Self::new(tensor) + } } #[cfg(test)] mod tests { - use super::*; - use burn_tensor::{Distribution, Tensor}; - use rand::prelude::StdRng; - use rand::SeedableRng; - - #[test] - fn should_support_into_and_from_data_1d() { - let data_expected = Data::::random( - Shape::new([3]), - Distribution::Default, - &mut StdRng::from_entropy(), - ); - let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_2d() { - let data_expected = Data::::random( - Shape::new([2, 3]), - Distribution::Default, - &mut StdRng::from_entropy(), - ); - let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_not_update_inplace_after_reshape() { - let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); - let tensor_2 = tensor_1.clone(); - - let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0); - - assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); - } - - #[test] - fn should_not_update_inplace_after_slice() { - let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); - let tensor_2 = tensor_1.clone(); - - let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0); - - assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); - } + use super::*; + use burn_tensor::{Distribution, Tensor}; + use rand::prelude::StdRng; + use rand::SeedableRng; + + #[test] + fn should_support_into_and_from_data_1d() { + let data_expected = Data::::random( + Shape::new([3]), + Distribution::Default, + &mut StdRng::from_entropy(), + ); + let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_2d() { + let data_expected = Data::::random( + Shape::new([2, 3]), + Distribution::Default, + &mut StdRng::from_entropy(), + ); + let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_not_update_inplace_after_reshape() { + let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); + let tensor_2 = tensor_1.clone(); + + let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0); + + assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); + } + + #[test] + fn should_not_update_inplace_after_slice() { + let tensor_1 = Tensor::, 1>::from_floats([4.0, 4.0]); + let tensor_2 = tensor_1.clone(); + + let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0); + + assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value); + } } diff --git a/burn-tensor-testgen/src/lib.rs b/burn-tensor-testgen/src/lib.rs index 5ddbcc2a8e..d67d13114a 100644 --- a/burn-tensor-testgen/src/lib.rs +++ b/burn-tensor-testgen/src/lib.rs @@ -4,22 +4,22 @@ use quote::{format_ident, quote}; #[allow(missing_docs)] #[proc_macro_attribute] pub fn testgen(attr: TokenStream, item: TokenStream) -> TokenStream { - let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item); - let attr: proc_macro2::TokenStream = proc_macro2::TokenStream::from(attr); - let macro_ident = format_ident!("testgen_{}", attr.to_string()); + let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item); + let attr: proc_macro2::TokenStream = proc_macro2::TokenStream::from(attr); + let macro_ident = format_ident!("testgen_{}", attr.to_string()); - let macro_gen = quote! { - #[macro_export] - macro_rules! #macro_ident { - () => { - mod #attr { - use super::*; + let macro_gen = quote! { + #[macro_export] + macro_rules! #macro_ident { + () => { + mod #attr { + use super::*; - #item - } - }; - } - }; + #item + } + }; + } + }; - macro_gen.into() + macro_gen.into() } diff --git a/burn-tensor/src/tensor/activation/base.rs b/burn-tensor/src/tensor/activation/base.rs index bcc3f2593e..944de1cab6 100644 --- a/burn-tensor/src/tensor/activation/base.rs +++ b/burn-tensor/src/tensor/activation/base.rs @@ -5,12 +5,12 @@ use crate::{ElementPrecision, Precision}; /// Applies the rectified linear unit function. pub fn relu(tensor: Tensor) -> Tensor { - tensor.relu() + tensor.relu() } /// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf). pub fn gelu(tensor: Tensor) -> Tensor { - Tensor::from_primitive(B::gelu(tensor.primitive)) + Tensor::from_primitive(B::gelu(tensor.primitive)) } /// Applies the softmax function on the input tensor along the given dimension. @@ -22,13 +22,13 @@ pub fn gelu(tensor: Tensor) -> Tensor { /// The dimension argument `dim` specifies the dimension along which the function will be computed. /// It must in the range of `0` and `D-1`. pub fn softmax(tensor: Tensor, dim: usize) -> Tensor { - check!(TensorCheck::dim_ops::("softmax", dim)); + check!(TensorCheck::dim_ops::("softmax", dim)); - let tensor = tensor.clone() - tensor.detach().max_dim(dim); - let tensor = tensor.exp(); - let tensor_tmp = tensor.clone().sum_dim(dim); + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor = tensor.exp(); + let tensor_tmp = tensor.clone().sum_dim(dim); - tensor.div(tensor_tmp) + tensor.div(tensor_tmp) } /// Applies the "quiet softmax" function on the input tensor along the given dimension. @@ -42,13 +42,13 @@ pub fn softmax(tensor: Tensor, dim: usize) -> /// The dimension argument `dim` specifies the dimension along which the function will be computed. /// It must in the range of `0` and `D-1`. pub fn quiet_softmax(tensor: Tensor, dim: usize) -> Tensor { - check!(TensorCheck::dim_ops::("softmax", dim)); + check!(TensorCheck::dim_ops::("softmax", dim)); - let tensor = tensor.clone() - tensor.detach().max_dim(dim); - let tensor = tensor.exp(); - let tensor_tmp = tensor.clone().sum_dim(dim); + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor = tensor.exp(); + let tensor_tmp = tensor.clone().sum_dim(dim); - tensor.div(tensor_tmp + 1) + tensor.div(tensor_tmp + 1) } /// Applies the log softmax function on the input tensor along the given dimension. @@ -60,37 +60,37 @@ pub fn quiet_softmax(tensor: Tensor, dim: usiz /// The dimension argument `dim` specifies the dimension along which the function will be computed. /// It must in the range of `0` and `D-1`. pub fn log_softmax(tensor: Tensor, dim: usize) -> Tensor { - check!(TensorCheck::dim_ops::("log softmax", dim)); + check!(TensorCheck::dim_ops::("log softmax", dim)); - let tensor = tensor.clone() - tensor.detach().max_dim(dim); - let tensor_tmp = tensor.clone().exp().sum_dim(dim).log(); + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor_tmp = tensor.clone().exp().sum_dim(dim).log(); - tensor.sub(tensor_tmp) + tensor.sub(tensor_tmp) } /// Applies the sigmoid function. pub fn sigmoid(tensor: Tensor) -> Tensor { - log_sigmoid(tensor).exp() + log_sigmoid(tensor).exp() } /// Applies the log sigmoid function. pub fn log_sigmoid(tensor: Tensor) -> Tensor { - match B::FloatElem::precision() { - Precision::Half => { - let tensor_full = tensor.to_full_precision(); - let tensor_tmp = tensor_full.neg().exp().add_scalar(1.0_f32).log().neg(); - Tensor::from_full_precision(tensor_tmp) + match B::FloatElem::precision() { + Precision::Half => { + let tensor_full = tensor.to_full_precision(); + let tensor_tmp = tensor_full.neg().exp().add_scalar(1.0_f32).log().neg(); + Tensor::from_full_precision(tensor_tmp) + } + _ => tensor.neg().exp().add_scalar(1.0_f32).log().neg(), } - _ => tensor.neg().exp().add_scalar(1.0_f32).log().neg(), - } } /// Applies the silu function pub fn silu(tensor: Tensor) -> Tensor { - tensor.clone().mul(sigmoid(tensor)) + tensor.clone().mul(sigmoid(tensor)) } /// Applies the tanh function pub fn tanh(tensor: Tensor) -> Tensor { - tensor.tanh() + tensor.tanh() } diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index b80bd023f1..3a12b834e3 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -13,658 +13,658 @@ use burn_common::{reader::Reader, stub::Mutex}; use core::{fmt::Debug, ops::Range}; use crate::{ - backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind, + backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind, }; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] pub struct Tensor where - B: Backend, - K: TensorKind, + B: Backend, + K: TensorKind, { - pub(crate) primitive: K::Primitive, + pub(crate) primitive: K::Primitive, } impl Tensor where - B: Backend, - K: BasicOps, + B: Backend, + K: BasicOps, { - /// Converts the tensor into a primitive tensor. - pub fn into_primitive(self) -> K::Primitive { - self.primitive - } - - /// Converts from a primitive tensor into a tensor. - pub fn from_primitive(tensor: K::Primitive) -> Self { - Self::new(tensor) - } - - /// Create an empty tensor of the given shape. - pub fn empty>>(shape: S) -> Self { - Self::empty_device(shape, &B::Device::default()) - } - - /// Create an empty tensor of the given shape. - pub fn empty_device>>(shape: S, device: &B::Device) -> Self { - Self::new(K::empty(shape.into(), device)) - } - - /// Returns the dimensions of the current tensor. - /// - /// Equivalent to `tensor.shape().dims`. - pub fn dims(&self) -> [usize; D] { - Self::shape(self).dims - } - - /// Returns the shape of the current tensor. - pub fn shape(&self) -> Shape { - K::shape(&self.primitive) - } - - /// Reshape the tensor to have the given shape. - /// - /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]` - /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12]. - /// - /// A `0` in the shape instructs to keep the current dimension from the original tensor, - /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4]. - /// This is useful when reshaping tensors with unknown dimensions and combining with `-1` - /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor - /// with [1, 3, 4] dimensions to [1, 12]. - /// - /// # Arguments - /// - `shape`: The new shape of the tensor. - /// - /// # Panics - /// - If the tensor contains more than one `-1` in the shape. - /// - If the tensor contains values that are not positive (other than -1). - /// - If the shape does not match the number of elements of the original shape. - /// - /// # Example - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let tensor = Tensor::::ones([2, 3, 4]); - /// // Given a 3D tensor with dimensions (2, 3, 4), reshape it to (2, 12) - /// let reshaped_tensor: Tensor:: = tensor.reshape([2, -1]); - /// // The resulting tensor will have dimensions (2, 12). - /// println!("{:?}", reshaped_tensor.shape()); - /// } - /// ``` - pub fn reshape>(self, shape: S) -> Tensor { - // Convert reshape args to shape - let shape = shape.into_shape(&self); - Tensor::new(K::reshape::(self.primitive, shape)) - } - - /// Transpose the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - pub fn transpose(self) -> Tensor { - Tensor::new(K::transpose(self.primitive)) - } - - /// Swaps two dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { - Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) - } - - /// Flatten the tensor along a given range of dimensions. - /// - /// This function collapses the specified range of dimensions into a single dimension, - /// effectively flattening the tensor in that range. - /// - /// # Arguments - /// - /// - `start_dim`: The starting dimension of the range to be flattened. - /// - `end_dim`: The ending dimension of the range to be flattened (inclusive). - /// - /// # Type Parameters - /// - /// - `D2`: The resulting number of dimensions in the flattened tensor. - /// - /// # Returns - /// - /// A new `Tensor` instance with the specified range of dimensions flattened. - /// - /// # Example - /// - /// ```rust - /// - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 4])); - /// - /// // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2: - /// let flattened_tensor: Tensor:: = tensor.flatten(1, 2); - /// - /// // The resulting tensor will have dimensions (2, 12). - /// println!("{:?}", flattened_tensor.shape()); - /// } - /// - /// ``` - pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { - check!(TensorCheck::flatten::(start_dim, end_dim)); - - let current_dims = self.shape().dims; - let mut new_dims: [usize; D2] = [0; D2]; - let mut flatten_dims = 1; - - for i in current_dims[start_dim..=end_dim].iter() { - flatten_dims *= i; - } - - new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]); - new_dims[start_dim] = flatten_dims; - new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]); - - Tensor::new(K::reshape::(self.primitive, new_dims.into())) - } - - /// Squeeze the tensor along the given dimension, removing the specified dimension - /// of size one, and effectively reducing the rank of the tensor by one. - /// - /// # Arguments - /// - /// - `dim`: The dimension to be squeezed. - /// - /// # Type Parameters - /// - /// - 'D2': The resulting number of dimensions in the squeezed tensor. - /// - /// # Returns - /// - /// A new `Tensor` instance with the specified dimenension removed. - /// - /// # Example - /// - /// ```rust - /// - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 1, 4])); - /// - /// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1 - /// let squeezed_tensor: Tensor:: = tensor.squeeze(1); - /// - /// // Resulting tensor will have dimensions (2, 4) - /// println!("{:?}", squeezed_tensor.shape()); - /// } - /// ``` - pub fn squeeze(self, dim: usize) -> Tensor { - check!(TensorCheck::squeeze::(dim, &self.shape().dims)); - - let current_dims = self.shape().dims; - let mut new_dims: [usize; D2] = [0; D2]; - - new_dims[..dim].copy_from_slice(¤t_dims[..dim]); - new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]); - - Tensor::new(K::reshape::(self.primitive, new_dims.into())) - } - - /// Unsqueeze the current tensor. Create new dimensions to fit the given size. - /// - /// If the output size is higher than the current tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([3, 3])); - /// let tensor = tensor.unsqueeze::<4>(); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [1, 1, 3, 3] } - /// } - /// ``` - pub fn unsqueeze(self) -> Tensor { - check!(TensorCheck::unsqueeze::()); - - let mut dims = [1; D2]; - let num_ones = D2 - D; - let shape = self.shape(); - - dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]); - - let shape = Shape::new(dims); - self.reshape(shape) - } - - /// Returns a tensor containing the elements selected from the given ranges. - /// - /// # Panics - /// - /// If a range exceeds the number of elements on a dimension. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); - /// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]); - /// println!("{:?}", tensor_slices.dims()); // [1, 3, 2] - /// - /// } - /// ``` - pub fn slice(self, ranges: [core::ops::Range; D2]) -> Self { - check!(TensorCheck::slice(&self.shape(), &ranges)); - Self::new(K::slice(self.primitive, ranges)) - } - - /// Returns a copy of the current tensor with the selected elements changed to the new ones at - /// the selected indices. - /// - /// # Panics - /// - /// - If a range exceeds the number of elements on a dimension. - /// - If the given values don't match the given ranges. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let tensor = Tensor::::ones([2, 3, 3]); - /// let values = Tensor::::zeros([1, 1, 1]); - /// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values); - /// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3] - /// } - /// ``` - pub fn slice_assign( - self, - ranges: [core::ops::Range; D2], - values: Self, - ) -> Self { - check!(TensorCheck::slice_assign( - &self.shape(), - &values.shape(), - &ranges - )); - Self::new(K::slice_assign(self.primitive, ranges, values.primitive)) - } - - /// Returns the device of the current tensor. - pub fn device(&self) -> B::Device { - K::device(&self.primitive) - } - - /// Returns a new tensor on the given device. - pub fn to_device(self, device: &B::Device) -> Self { - Self::new(K::to_device(self.primitive, device)) - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Returns the data of the current tensor. - pub async fn into_data(self) -> Data { - K::into_data(self.primitive).read().await - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Returns the data of the current tensor. - pub fn into_data(self) -> Data { - K::into_data(self.primitive).read() - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Returns the data of the current tensor. - pub async fn to_data(&self) -> Data { - K::into_data(self.primitive.clone()).read().await - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Returns the data of the current tensor without taking ownership. - pub fn to_data(&self) -> Data { - Self::into_data(self.clone()) - } - - /// Create a tensor from the given data. - pub fn from_data(data: T) -> Self - where - T: Into>, - { - Self::from_data_device(data, &B::Device::default()) - } - - /// Create a tensor from the given data on the given device. - pub fn from_data_device(data: T, device: &B::Device) -> Self - where - T: Into>, - { - Self::new(K::from_data(data.into(), device)) - } - - /// Repeat the tensor along the given dimension. - /// - /// # Panics - /// - /// If the selected dimension more than one item. - pub fn repeat(self, dim: usize, times: usize) -> Self { - Self::new(K::repeat(self.primitive, dim, times)) - } - - /// Applies element wise equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); - K::equal(self.primitive, other.primitive) - } - - /// Concatenates all tensors into a new one along the given dimension. - /// - /// # Panics - /// - /// If all tensors don't have the same shape. - pub fn cat(tensors: Vec, dim: usize) -> Self { - check!(TensorCheck::cat(&tensors, dim)); - - Self::new(K::cat( - tensors.into_iter().map(|vector| vector.primitive).collect(), - dim, - )) - } - - /// Iterate over slices of tensors alongside a given dimension. - /// - /// # Panics - /// - /// Given dimension is less than tensor rank. - /// - /// # Returns - /// - /// A tensor iterator. - pub fn iter_dim(self, dim: usize) -> DimIter { - check!(TensorCheck::dim_ops::("iter_dim", dim)); - DimIter::new(self, dim) - } + /// Converts the tensor into a primitive tensor. + pub fn into_primitive(self) -> K::Primitive { + self.primitive + } + + /// Converts from a primitive tensor into a tensor. + pub fn from_primitive(tensor: K::Primitive) -> Self { + Self::new(tensor) + } + + /// Create an empty tensor of the given shape. + pub fn empty>>(shape: S) -> Self { + Self::empty_device(shape, &B::Device::default()) + } + + /// Create an empty tensor of the given shape. + pub fn empty_device>>(shape: S, device: &B::Device) -> Self { + Self::new(K::empty(shape.into(), device)) + } + + /// Returns the dimensions of the current tensor. + /// + /// Equivalent to `tensor.shape().dims`. + pub fn dims(&self) -> [usize; D] { + Self::shape(self).dims + } + + /// Returns the shape of the current tensor. + pub fn shape(&self) -> Shape { + K::shape(&self.primitive) + } + + /// Reshape the tensor to have the given shape. + /// + /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]` + /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12]. + /// + /// A `0` in the shape instructs to keep the current dimension from the original tensor, + /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4]. + /// This is useful when reshaping tensors with unknown dimensions and combining with `-1` + /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor + /// with [1, 3, 4] dimensions to [1, 12]. + /// + /// # Arguments + /// - `shape`: The new shape of the tensor. + /// + /// # Panics + /// - If the tensor contains more than one `-1` in the shape. + /// - If the tensor contains values that are not positive (other than -1). + /// - If the shape does not match the number of elements of the original shape. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let tensor = Tensor::::ones([2, 3, 4]); + /// // Given a 3D tensor with dimensions (2, 3, 4), reshape it to (2, 12) + /// let reshaped_tensor: Tensor:: = tensor.reshape([2, -1]); + /// // The resulting tensor will have dimensions (2, 12). + /// println!("{:?}", reshaped_tensor.shape()); + /// } + /// ``` + pub fn reshape>(self, shape: S) -> Tensor { + // Convert reshape args to shape + let shape = shape.into_shape(&self); + Tensor::new(K::reshape::(self.primitive, shape)) + } + + /// Transpose the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + pub fn transpose(self) -> Tensor { + Tensor::new(K::transpose(self.primitive)) + } + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { + Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) + } + + /// Flatten the tensor along a given range of dimensions. + /// + /// This function collapses the specified range of dimensions into a single dimension, + /// effectively flattening the tensor in that range. + /// + /// # Arguments + /// + /// - `start_dim`: The starting dimension of the range to be flattened. + /// - `end_dim`: The ending dimension of the range to be flattened (inclusive). + /// + /// # Type Parameters + /// + /// - `D2`: The resulting number of dimensions in the flattened tensor. + /// + /// # Returns + /// + /// A new `Tensor` instance with the specified range of dimensions flattened. + /// + /// # Example + /// + /// ```rust + /// + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 4])); + /// + /// // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2: + /// let flattened_tensor: Tensor:: = tensor.flatten(1, 2); + /// + /// // The resulting tensor will have dimensions (2, 12). + /// println!("{:?}", flattened_tensor.shape()); + /// } + /// + /// ``` + pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { + check!(TensorCheck::flatten::(start_dim, end_dim)); + + let current_dims = self.shape().dims; + let mut new_dims: [usize; D2] = [0; D2]; + let mut flatten_dims = 1; + + for i in current_dims[start_dim..=end_dim].iter() { + flatten_dims *= i; + } + + new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]); + new_dims[start_dim] = flatten_dims; + new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]); + + Tensor::new(K::reshape::(self.primitive, new_dims.into())) + } + + /// Squeeze the tensor along the given dimension, removing the specified dimension + /// of size one, and effectively reducing the rank of the tensor by one. + /// + /// # Arguments + /// + /// - `dim`: The dimension to be squeezed. + /// + /// # Type Parameters + /// + /// - 'D2': The resulting number of dimensions in the squeezed tensor. + /// + /// # Returns + /// + /// A new `Tensor` instance with the specified dimenension removed. + /// + /// # Example + /// + /// ```rust + /// + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 1, 4])); + /// + /// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1 + /// let squeezed_tensor: Tensor:: = tensor.squeeze(1); + /// + /// // Resulting tensor will have dimensions (2, 4) + /// println!("{:?}", squeezed_tensor.shape()); + /// } + /// ``` + pub fn squeeze(self, dim: usize) -> Tensor { + check!(TensorCheck::squeeze::(dim, &self.shape().dims)); + + let current_dims = self.shape().dims; + let mut new_dims: [usize; D2] = [0; D2]; + + new_dims[..dim].copy_from_slice(¤t_dims[..dim]); + new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]); + + Tensor::new(K::reshape::(self.primitive, new_dims.into())) + } + + /// Unsqueeze the current tensor. Create new dimensions to fit the given size. + /// + /// If the output size is higher than the current tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([3, 3])); + /// let tensor = tensor.unsqueeze::<4>(); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [1, 1, 3, 3] } + /// } + /// ``` + pub fn unsqueeze(self) -> Tensor { + check!(TensorCheck::unsqueeze::()); + + let mut dims = [1; D2]; + let num_ones = D2 - D; + let shape = self.shape(); + + dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]); + + let shape = Shape::new(dims); + self.reshape(shape) + } + + /// Returns a tensor containing the elements selected from the given ranges. + /// + /// # Panics + /// + /// If a range exceeds the number of elements on a dimension. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]); + /// println!("{:?}", tensor_slices.dims()); // [1, 3, 2] + /// + /// } + /// ``` + pub fn slice(self, ranges: [core::ops::Range; D2]) -> Self { + check!(TensorCheck::slice(&self.shape(), &ranges)); + Self::new(K::slice(self.primitive, ranges)) + } + + /// Returns a copy of the current tensor with the selected elements changed to the new ones at + /// the selected indices. + /// + /// # Panics + /// + /// - If a range exceeds the number of elements on a dimension. + /// - If the given values don't match the given ranges. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let tensor = Tensor::::ones([2, 3, 3]); + /// let values = Tensor::::zeros([1, 1, 1]); + /// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values); + /// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3] + /// } + /// ``` + pub fn slice_assign( + self, + ranges: [core::ops::Range; D2], + values: Self, + ) -> Self { + check!(TensorCheck::slice_assign( + &self.shape(), + &values.shape(), + &ranges + )); + Self::new(K::slice_assign(self.primitive, ranges, values.primitive)) + } + + /// Returns the device of the current tensor. + pub fn device(&self) -> B::Device { + K::device(&self.primitive) + } + + /// Returns a new tensor on the given device. + pub fn to_device(self, device: &B::Device) -> Self { + Self::new(K::to_device(self.primitive, device)) + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Returns the data of the current tensor. + pub async fn into_data(self) -> Data { + K::into_data(self.primitive).read().await + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Returns the data of the current tensor. + pub fn into_data(self) -> Data { + K::into_data(self.primitive).read() + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Returns the data of the current tensor. + pub async fn to_data(&self) -> Data { + K::into_data(self.primitive.clone()).read().await + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Returns the data of the current tensor without taking ownership. + pub fn to_data(&self) -> Data { + Self::into_data(self.clone()) + } + + /// Create a tensor from the given data. + pub fn from_data(data: T) -> Self + where + T: Into>, + { + Self::from_data_device(data, &B::Device::default()) + } + + /// Create a tensor from the given data on the given device. + pub fn from_data_device(data: T, device: &B::Device) -> Self + where + T: Into>, + { + Self::new(K::from_data(data.into(), device)) + } + + /// Repeat the tensor along the given dimension. + /// + /// # Panics + /// + /// If the selected dimension more than one item. + pub fn repeat(self, dim: usize, times: usize) -> Self { + Self::new(K::repeat(self.primitive, dim, times)) + } + + /// Applies element wise equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); + K::equal(self.primitive, other.primitive) + } + + /// Concatenates all tensors into a new one along the given dimension. + /// + /// # Panics + /// + /// If all tensors don't have the same shape. + pub fn cat(tensors: Vec, dim: usize) -> Self { + check!(TensorCheck::cat(&tensors, dim)); + + Self::new(K::cat( + tensors.into_iter().map(|vector| vector.primitive).collect(), + dim, + )) + } + + /// Iterate over slices of tensors alongside a given dimension. + /// + /// # Panics + /// + /// Given dimension is less than tensor rank. + /// + /// # Returns + /// + /// A tensor iterator. + pub fn iter_dim(self, dim: usize) -> DimIter { + check!(TensorCheck::dim_ops::("iter_dim", dim)); + DimIter::new(self, dim) + } } /// Iterator given by (Tensor::iter_dim). pub struct DimIter where - B: Backend, - K: BasicOps, + B: Backend, + K: BasicOps, { - counter: usize, - dim: usize, - end_idx: usize, - ranges: [Range; D], - tensor: Tensor, + counter: usize, + dim: usize, + end_idx: usize, + ranges: [Range; D], + tensor: Tensor, } impl> Iterator for DimIter { - type Item = Tensor; - - fn next(&mut self) -> Option { - let res = if self.counter < self.end_idx { - let mut ranges = self.ranges.clone(); - ranges[self.dim] = self.counter..(self.counter + 1); - let slice = self.tensor.clone().slice(ranges); - Some(slice) - } else { - None - }; - self.counter += 1; - res - } + type Item = Tensor; + + fn next(&mut self) -> Option { + let res = if self.counter < self.end_idx { + let mut ranges = self.ranges.clone(); + ranges[self.dim] = self.counter..(self.counter + 1); + let slice = self.tensor.clone().slice(ranges); + Some(slice) + } else { + None + }; + self.counter += 1; + res + } } impl> DimIter { - fn new(tensor: Tensor, dim: usize) -> Self { - let dims = tensor.dims(); - let ranges = dims - .iter() - .map(|&dim| 0..dim) - .collect::>>(); - let ranges: [Range; D] = ranges.try_into().unwrap(); - Self { - end_idx: dims[dim], - ranges, - counter: 0, - dim, - tensor, - } - } + fn new(tensor: Tensor, dim: usize) -> Self { + let dims = tensor.dims(); + let ranges = dims + .iter() + .map(|&dim| 0..dim) + .collect::>>(); + let ranges: [Range; D] = ranges.try_into().unwrap(); + Self { + end_idx: dims[dim], + ranges, + counter: 0, + dim, + tensor, + } + } } impl Tensor where - B: Backend, - K: BasicOps, - >::Elem: Debug, + B: Backend, + K: BasicOps, + >::Elem: Debug, { - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - #[inline] - fn push_newline_indent(acc: &mut String, indent: usize) { - acc.push('\n'); - for _ in 0..indent { - acc.push(' '); - } - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn fmt_inner_tensor( - &self, - acc: &mut String, - depth: usize, - multi_index: &mut [usize], - range: (usize, usize), - ) { - let (start, end) = range; - for i in start..end { - if i > 0 { - acc.push_str(", "); - } - multi_index[depth] = i; - let range: [core::ops::Range; D] = - core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1); - - let elem = &self.clone().slice(range).into_data().value[0]; - acc.push_str(&format!("{elem:?}")); - } - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn fmt_outer_tensor( - &self, - acc: &mut String, - depth: usize, - multi_index: &mut [usize], - print_options: &PrintOptions, - summarize: bool, - range: (usize, usize), - ) { - let (start, end) = range; - for i in start..end { - if i > start { - acc.push(','); - Self::push_newline_indent(acc, depth + 1); - } - acc.push('['); - multi_index[depth] = i; - self.display_recursive(acc, depth + 1, multi_index, print_options, summarize); - acc.push(']'); - } - } - - /// Recursively formats the tensor data for display and appends it to the provided accumulator string. - /// - /// This function is designed to work with tensors of any dimensionality. - /// It traverses the tensor dimensions recursively, converting the elements - /// to strings and appending them to the accumulator string with the - /// appropriate formatting. - /// - /// # Arguments - /// - /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output. - /// * `depth` - The current depth of the tensor dimensions being processed. - /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn display_recursive( - &self, - acc: &mut String, - depth: usize, - multi_index: &mut [usize], - print_options: &PrintOptions, - summarize: bool, - ) { - let edge_items = print_options.edge_items; - - if depth == 0 { - acc.push('['); - } - - if depth == self.dims().len() - 1 { - // if we are at the innermost dimension, just push its elements into the accumulator - if summarize && self.dims()[depth] > 2 * edge_items { - // print the starting `edge_items` elements - self.fmt_inner_tensor(acc, depth, multi_index, (0, edge_items)); - acc.push_str(", ..."); - // print the last `edge_items` elements - self.fmt_inner_tensor( - acc, - depth, - multi_index, - (self.dims()[depth] - edge_items, self.dims()[depth]), - ); - } else { - // print all the elements - self.fmt_inner_tensor(acc, depth, multi_index, (0, self.dims()[depth])); - } - } else { - // otherwise, iterate through the current dimension and recursively display the inner tensors - if summarize && self.dims()[depth] > 2 * edge_items { - self.fmt_outer_tensor( - acc, - depth, - multi_index, - print_options, - summarize, - (0, edge_items), - ); - - acc.push(','); - Self::push_newline_indent(acc, depth + 1); - acc.push_str("..."); - Self::push_newline_indent(acc, depth + 1); - - self.fmt_outer_tensor( - acc, - depth, - multi_index, - print_options, - summarize, - (self.dims()[depth] - edge_items, self.dims()[depth]), - ); - } else { - self.fmt_outer_tensor( - acc, - depth, - multi_index, - print_options, - summarize, - (0, self.dims()[depth]), - ); - } - } - - if depth == 0 { - acc.push(']'); - } - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + #[inline] + fn push_newline_indent(acc: &mut String, indent: usize) { + acc.push('\n'); + for _ in 0..indent { + acc.push(' '); + } + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn fmt_inner_tensor( + &self, + acc: &mut String, + depth: usize, + multi_index: &mut [usize], + range: (usize, usize), + ) { + let (start, end) = range; + for i in start..end { + if i > 0 { + acc.push_str(", "); + } + multi_index[depth] = i; + let range: [core::ops::Range; D] = + core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1); + + let elem = &self.clone().slice(range).into_data().value[0]; + acc.push_str(&format!("{elem:?}")); + } + } + + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn fmt_outer_tensor( + &self, + acc: &mut String, + depth: usize, + multi_index: &mut [usize], + print_options: &PrintOptions, + summarize: bool, + range: (usize, usize), + ) { + let (start, end) = range; + for i in start..end { + if i > start { + acc.push(','); + Self::push_newline_indent(acc, depth + 1); + } + acc.push('['); + multi_index[depth] = i; + self.display_recursive(acc, depth + 1, multi_index, print_options, summarize); + acc.push(']'); + } + } + + /// Recursively formats the tensor data for display and appends it to the provided accumulator string. + /// + /// This function is designed to work with tensors of any dimensionality. + /// It traverses the tensor dimensions recursively, converting the elements + /// to strings and appending them to the accumulator string with the + /// appropriate formatting. + /// + /// # Arguments + /// + /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output. + /// * `depth` - The current depth of the tensor dimensions being processed. + /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension. + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + fn display_recursive( + &self, + acc: &mut String, + depth: usize, + multi_index: &mut [usize], + print_options: &PrintOptions, + summarize: bool, + ) { + let edge_items = print_options.edge_items; + + if depth == 0 { + acc.push('['); + } + + if depth == self.dims().len() - 1 { + // if we are at the innermost dimension, just push its elements into the accumulator + if summarize && self.dims()[depth] > 2 * edge_items { + // print the starting `edge_items` elements + self.fmt_inner_tensor(acc, depth, multi_index, (0, edge_items)); + acc.push_str(", ..."); + // print the last `edge_items` elements + self.fmt_inner_tensor( + acc, + depth, + multi_index, + (self.dims()[depth] - edge_items, self.dims()[depth]), + ); + } else { + // print all the elements + self.fmt_inner_tensor(acc, depth, multi_index, (0, self.dims()[depth])); + } + } else { + // otherwise, iterate through the current dimension and recursively display the inner tensors + if summarize && self.dims()[depth] > 2 * edge_items { + self.fmt_outer_tensor( + acc, + depth, + multi_index, + print_options, + summarize, + (0, edge_items), + ); + + acc.push(','); + Self::push_newline_indent(acc, depth + 1); + acc.push_str("..."); + Self::push_newline_indent(acc, depth + 1); + + self.fmt_outer_tensor( + acc, + depth, + multi_index, + print_options, + summarize, + (self.dims()[depth] - edge_items, self.dims()[depth]), + ); + } else { + self.fmt_outer_tensor( + acc, + depth, + multi_index, + print_options, + summarize, + (0, self.dims()[depth]), + ); + } + } + + if depth == 0 { + acc.push(']'); + } + } } /// Options for Tensor pretty printing pub struct PrintOptions { - /// number of elements to start summarizing tensor - pub threshold: usize, - /// number of starting elements and ending elements to display - pub edge_items: usize, + /// number of elements to start summarizing tensor + pub threshold: usize, + /// number of starting elements and ending elements to display + pub edge_items: usize, } static PRINT_OPTS: Mutex = Mutex::new(PrintOptions::const_default()); impl PrintOptions { - // We cannot use the default trait as it's not const. - const fn const_default() -> Self { - Self { - threshold: 1000, - edge_items: 3, + // We cannot use the default trait as it's not const. + const fn const_default() -> Self { + Self { + threshold: 1000, + edge_items: 3, + } } - } } /// Set print options pub fn set_print_options(options: PrintOptions) { - *PRINT_OPTS.lock().unwrap() = options + *PRINT_OPTS.lock().unwrap() = options } /// Pretty print tensors impl core::fmt::Display for Tensor where - B: Backend, - B::IntElem: core::fmt::Display, - K: BasicOps, - >::Elem: Debug, + B: Backend, + B::IntElem: core::fmt::Display, + K: BasicOps, + >::Elem: Debug, { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - writeln!(f, "Tensor {{")?; + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + writeln!(f, "Tensor {{")?; - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - { - let po = PRINT_OPTS.lock().unwrap(); - let mut acc = String::new(); - let mut multi_index = vec![0; D]; - let summarize = self.shape().num_elements() > po.threshold; + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + { + let po = PRINT_OPTS.lock().unwrap(); + let mut acc = String::new(); + let mut multi_index = vec![0; D]; + let summarize = self.shape().num_elements() > po.threshold; - self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize); + self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize); - writeln!(f, " data:")?; - write!(f, "{acc}")?; - writeln!(f, ",")?; - } + writeln!(f, " data:")?; + write!(f, "{acc}")?; + writeln!(f, ",")?; + } - writeln!(f, " shape: {:?},", self.dims())?; - writeln!(f, " device: {:?},", self.device())?; - writeln!(f, " backend: {:?},", B::name())?; - writeln!(f, " kind: {:?},", K::name())?; - writeln!(f, " dtype: {:?},", K::elem_type_name())?; - write!(f, "}}") - } + writeln!(f, " shape: {:?},", self.dims())?; + writeln!(f, " device: {:?},", self.device())?; + writeln!(f, " backend: {:?},", B::name())?; + writeln!(f, " kind: {:?},", K::name())?; + writeln!(f, " dtype: {:?},", K::elem_type_name())?; + write!(f, "}}") + } } /// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g. @@ -680,10 +680,10 @@ where pub struct T; impl core::ops::BitXor for Tensor { - type Output = Self; - fn bitxor(self, _: T) -> Self::Output { - self.transpose() - } + type Output = Self; + fn bitxor(self, _: T) -> Self::Output { + self.transpose() + } } /// Trait that list all operations that can be applied on all tensors. @@ -692,646 +692,660 @@ impl core::ops::BitXor for Tensor { /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). pub trait BasicOps: TensorKind { - /// The type of the tensor elements. - type Elem: 'static; - - /// Creates an empty tensor with the given shape. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The empty tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function, - /// which is more high-level and designed for public use. - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive; - - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function, - /// which is more high-level and designed for public use. - fn shape(tensor: &Self::Primitive) -> Shape; - - /// Reshapes the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `shape` - The new shape of the tensor. - /// - /// # Returns - /// - /// The reshaped tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function, - /// which is more high-level and designed for public use. - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive; - - /// Transposes a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn transpose(tensor: Self::Primitive) -> Self::Primitive; - - /// Swaps two dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive; - - /// Select tensor elements corresponding for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges of the elements to select. - /// - /// # Returns - /// - /// The selected elements. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function, - /// which is more high-level and designed for public use. - fn slice( - tensor: Self::Primitive, - range: [Range; D2], - ) -> Self::Primitive; - - /// Assigns the given value to the tensor elements corresponding for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges of the elements to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the assigned values. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function, - /// which is more high-level and designed for public use. - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive; - - /// Returns the device on which the tensor is allocated. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device on which the tensor is allocated. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function, - /// which is more high-level and designed for public use. - fn device(tensor: &Self::Primitive) -> B::Device; - - /// Moves the tensor to the given device. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `device` - The device on which the tensor will be moved. - /// - /// # Returns - /// - /// The tensor on the given device. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function, - /// which is more high-level and designed for public use. - fn to_device( - tensor: Self::Primitive, - device: &B::Device, - ) -> Self::Primitive; - - /// Extracts the data from the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function, - /// which is more high-level and designed for public use. - fn into_data(tensor: Self::Primitive) -> Reader>; - - /// Creates a tensor from the given data. - /// - /// # Arguments - /// - /// * `data` - The data of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function, - /// which is more high-level and designed for public use. - fn from_data(data: Data, device: &B::Device) - -> Self::Primitive; - - /// Repeat the tensor along the given dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension along which the tensor will be repeated. - /// * `times` - The number of times the tensor will be repeated. - /// - /// # Returns - /// - /// The repeated tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function, - /// which is more high-level and designed for public use. - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive; - - /// Concatenates the given tensors along the given dimension. - /// - /// # Arguments - /// - /// * `vectors` - The tensors to concatenate. - /// * `dim` - The dimension along which the tensors will be concatenated. - /// - /// # Returns - /// - /// The concatenated tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function, - /// which is more high-level and designed for public use. - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; - - /// Equates the given tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The tensor of booleans indicating whether the corresponding elements are equal. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function, - /// which is more high-level and designed for public use. - fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor; - - /// Returns the name of the element type. - fn elem_type_name() -> &'static str { - core::any::type_name::() - } + /// The type of the tensor elements. + type Elem: 'static; + + /// Creates an empty tensor with the given shape. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The empty tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function, + /// which is more high-level and designed for public use. + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive; + + /// Returns the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function, + /// which is more high-level and designed for public use. + fn shape(tensor: &Self::Primitive) -> Shape; + + /// Reshapes the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `shape` - The new shape of the tensor. + /// + /// # Returns + /// + /// The reshaped tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function, + /// which is more high-level and designed for public use. + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive; + + /// Transposes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn transpose(tensor: Self::Primitive) -> Self::Primitive; + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive; + + /// Select tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges of the elements to select. + /// + /// # Returns + /// + /// The selected elements. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function, + /// which is more high-level and designed for public use. + fn slice( + tensor: Self::Primitive, + range: [Range; D2], + ) -> Self::Primitive; + + /// Assigns the given value to the tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges of the elements to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the assigned values. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function, + /// which is more high-level and designed for public use. + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive; + + /// Returns the device on which the tensor is allocated. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device on which the tensor is allocated. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function, + /// which is more high-level and designed for public use. + fn device(tensor: &Self::Primitive) -> B::Device; + + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device on which the tensor will be moved. + /// + /// # Returns + /// + /// The tensor on the given device. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function, + /// which is more high-level and designed for public use. + fn to_device( + tensor: Self::Primitive, + device: &B::Device, + ) -> Self::Primitive; + + /// Extracts the data from the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function, + /// which is more high-level and designed for public use. + fn into_data(tensor: Self::Primitive) -> Reader>; + + /// Creates a tensor from the given data. + /// + /// # Arguments + /// + /// * `data` - The data of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function, + /// which is more high-level and designed for public use. + fn from_data( + data: Data, + device: &B::Device, + ) -> Self::Primitive; + + /// Repeat the tensor along the given dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension along which the tensor will be repeated. + /// * `times` - The number of times the tensor will be repeated. + /// + /// # Returns + /// + /// The repeated tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function, + /// which is more high-level and designed for public use. + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive; + + /// Concatenates the given tensors along the given dimension. + /// + /// # Arguments + /// + /// * `vectors` - The tensors to concatenate. + /// * `dim` - The dimension along which the tensors will be concatenated. + /// + /// # Returns + /// + /// The concatenated tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function, + /// which is more high-level and designed for public use. + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; + + /// Equates the given tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The tensor of booleans indicating whether the corresponding elements are equal. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function, + /// which is more high-level and designed for public use. + fn equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Returns the name of the element type. + fn elem_type_name() -> &'static str { + core::any::type_name::() + } } impl BasicOps for Float { - type Elem = B::FloatElem; - - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { - B::empty(shape, device) - } - fn shape(tensor: &Self::Primitive) -> Shape { - B::shape(tensor) - } - - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::reshape(tensor, shape) - } - - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - B::transpose(tensor) - } - - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive { - check!(TensorCheck::swap_dims::(dim1, dim2)); - B::swap_dims(tensor, dim1, dim2) - } - - fn slice( - tensor: Self::Primitive, - ranges: [Range; D2], - ) -> Self::Primitive { - B::slice(tensor, ranges) - } - - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - B::slice_assign(tensor, ranges, value) - } - - fn device(tensor: &Self::Primitive) -> ::Device { - B::device(tensor) - } - - fn to_device( - tensor: Self::Primitive, - device: &::Device, - ) -> Self::Primitive { - B::to_device(tensor, device) - } - - fn into_data(tensor: Self::Primitive) -> Reader> { - B::into_data(tensor) - } - - fn from_data( - data: Data, - device: &B::Device, - ) -> Self::Primitive { - B::from_data(data, device) - } - - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive { - B::repeat(tensor, dim, times) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::cat(vectors, dim) - } - - fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { - Tensor::new(B::equal(lhs, rhs)) - } + type Elem = B::FloatElem; + + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { + B::empty(shape, device) + } + fn shape(tensor: &Self::Primitive) -> Shape { + B::shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + B::reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + B::transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + check!(TensorCheck::swap_dims::(dim1, dim2)); + B::swap_dims(tensor, dim1, dim2) + } + + fn slice( + tensor: Self::Primitive, + ranges: [Range; D2], + ) -> Self::Primitive { + B::slice(tensor, ranges) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + B::slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + B::device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + B::to_device(tensor, device) + } + + fn into_data(tensor: Self::Primitive) -> Reader> { + B::into_data(tensor) + } + + fn from_data( + data: Data, + device: &B::Device, + ) -> Self::Primitive { + B::from_data(data, device) + } + + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + B::repeat(tensor, dim, times) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + B::cat(vectors, dim) + } + + fn equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::equal(lhs, rhs)) + } } impl BasicOps for Int { - type Elem = B::IntElem; - - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { - B::int_empty(shape, device) - } - fn shape(tensor: &Self::Primitive) -> Shape { - B::int_shape(tensor) - } - - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::int_reshape(tensor, shape) - } - - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - B::int_transpose(tensor) - } - - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive { - check!(TensorCheck::swap_dims::(dim1, dim2)); - B::int_swap_dims(tensor, dim1, dim2) - } - - fn slice( - tensor: Self::Primitive, - ranges: [Range; D2], - ) -> Self::Primitive { - B::int_slice(tensor, ranges) - } - - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - B::int_slice_assign(tensor, ranges, value) - } - - fn device(tensor: &Self::Primitive) -> ::Device { - B::int_device(tensor) - } - - fn to_device( - tensor: Self::Primitive, - device: &::Device, - ) -> Self::Primitive { - B::int_to_device(tensor, device) - } - - fn into_data(tensor: Self::Primitive) -> Reader> { - B::int_into_data(tensor) - } - - fn from_data( - data: Data, - device: &B::Device, - ) -> Self::Primitive { - B::int_from_data(data, device) - } - - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive { - B::int_repeat(tensor, dim, times) - } - - fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { - Tensor::new(B::int_equal(lhs, rhs)) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::int_cat(vectors, dim) - } + type Elem = B::IntElem; + + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { + B::int_empty(shape, device) + } + fn shape(tensor: &Self::Primitive) -> Shape { + B::int_shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + B::int_reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + B::int_transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + check!(TensorCheck::swap_dims::(dim1, dim2)); + B::int_swap_dims(tensor, dim1, dim2) + } + + fn slice( + tensor: Self::Primitive, + ranges: [Range; D2], + ) -> Self::Primitive { + B::int_slice(tensor, ranges) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + B::int_slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + B::int_device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + B::int_to_device(tensor, device) + } + + fn into_data(tensor: Self::Primitive) -> Reader> { + B::int_into_data(tensor) + } + + fn from_data( + data: Data, + device: &B::Device, + ) -> Self::Primitive { + B::int_from_data(data, device) + } + + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + B::int_repeat(tensor, dim, times) + } + + fn equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_equal(lhs, rhs)) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + B::int_cat(vectors, dim) + } } impl BasicOps for Bool { - type Elem = bool; - - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { - B::bool_empty(shape, device) - } - fn shape(tensor: &Self::Primitive) -> Shape { - B::bool_shape(tensor) - } - - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::bool_reshape(tensor, shape) - } - - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - B::bool_transpose(tensor) - } - - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive { - check!(TensorCheck::swap_dims::(dim1, dim2)); - B::bool_swap_dims(tensor, dim1, dim2) - } - - fn slice( - tensor: Self::Primitive, - ranges: [Range; D2], - ) -> Self::Primitive { - B::bool_slice(tensor, ranges) - } - - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - B::bool_slice_assign(tensor, ranges, value) - } - - fn device(tensor: &Self::Primitive) -> ::Device { - B::bool_device(tensor) - } - - fn to_device( - tensor: Self::Primitive, - device: &::Device, - ) -> Self::Primitive { - B::bool_to_device(tensor, device) - } - - fn into_data(tensor: Self::Primitive) -> Reader> { - B::bool_into_data(tensor) - } - - fn from_data( - data: Data, - device: &B::Device, - ) -> Self::Primitive { - B::bool_from_data(data, device) - } - - fn repeat( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive { - B::bool_repeat(tensor, dim, times) - } - - fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { - Tensor::new(B::bool_equal(lhs, rhs)) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::bool_cat(vectors, dim) - } + type Elem = bool; + + fn empty(shape: Shape, device: &B::Device) -> Self::Primitive { + B::bool_empty(shape, device) + } + fn shape(tensor: &Self::Primitive) -> Shape { + B::bool_shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + B::bool_reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + B::bool_transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + check!(TensorCheck::swap_dims::(dim1, dim2)); + B::bool_swap_dims(tensor, dim1, dim2) + } + + fn slice( + tensor: Self::Primitive, + ranges: [Range; D2], + ) -> Self::Primitive { + B::bool_slice(tensor, ranges) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + B::bool_slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + B::bool_device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + B::bool_to_device(tensor, device) + } + + fn into_data(tensor: Self::Primitive) -> Reader> { + B::bool_into_data(tensor) + } + + fn from_data( + data: Data, + device: &B::Device, + ) -> Self::Primitive { + B::bool_from_data(data, device) + } + + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + B::bool_repeat(tensor, dim, times) + } + + fn equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::bool_equal(lhs, rhs)) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + B::bool_cat(vectors, dim) + } } /// Trait used for reshape arguments. pub trait ReshapeArgs { - /// Converts to a shape. - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape; + /// Converts to a shape. + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape; } impl ReshapeArgs for Shape { - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape { - check!(TensorCheck::reshape_args_usize(&self, &tensor.shape())); - - self - } + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape { + check!(TensorCheck::reshape_args_usize(&self, &tensor.shape())); + + self + } } impl ReshapeArgs for [usize; D2] { - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape { - let shape = Shape::from(self); + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape { + let shape = Shape::from(self); - check!(TensorCheck::reshape_args_usize(&shape, &tensor.shape())); + check!(TensorCheck::reshape_args_usize(&shape, &tensor.shape())); - shape - } + shape + } } impl ReshapeArgs for [i32; D2] { - fn into_shape>( - self, - tensor: &Tensor, - ) -> Shape { - // Validate the reshape arguments - check!(TensorCheck::reshape_args_i32(&self)); - - // Temporary shape - let mut new_shape: [i32; D2] = [1; D2]; - - // We need to find the index of the 0 dimension and - // replace it with the actual dimension value. - for (i, &s) in self.iter().enumerate() { - if s != 0 { - new_shape[i] = s; - } else { - new_shape[i] = tensor.dims()[i] as i32; - } - } - - // Find the index of the inferred dimension (-1) - let infer_index = new_shape.iter().position(|x| x == &-1); - - // Handle the case where the dimension is inferred (via -1) - if let Some(index) = infer_index { - // Handle the case where the dimension is inferred - let mut product = 1; - for (i, &s) in new_shape.iter().enumerate() { - if i != index { - product *= s; + fn into_shape>( + self, + tensor: &Tensor, + ) -> Shape { + // Validate the reshape arguments + check!(TensorCheck::reshape_args_i32(&self)); + + // Temporary shape + let mut new_shape: [i32; D2] = [1; D2]; + + // We need to find the index of the 0 dimension and + // replace it with the actual dimension value. + for (i, &s) in self.iter().enumerate() { + if s != 0 { + new_shape[i] = s; + } else { + new_shape[i] = tensor.dims()[i] as i32; + } } - } - let product_current = tensor.shape().num_elements() as i32; - - new_shape[index] = product_current / product; - - // Check if the reshape is valid - if product_current % product != 0 { - panic!( - "Cannot reshape tensor of shape {:?} to shape {:?}", - tensor.shape(), - new_shape - ); - } - }; - - // Convert each element to usize - let new_shape: [usize; D2] = new_shape.map(|x| x as usize); - - Shape::from(new_shape) - } + + // Find the index of the inferred dimension (-1) + let infer_index = new_shape.iter().position(|x| x == &-1); + + // Handle the case where the dimension is inferred (via -1) + if let Some(index) = infer_index { + // Handle the case where the dimension is inferred + let mut product = 1; + for (i, &s) in new_shape.iter().enumerate() { + if i != index { + product *= s; + } + } + let product_current = tensor.shape().num_elements() as i32; + + new_shape[index] = product_current / product; + + // Check if the reshape is valid + if product_current % product != 0 { + panic!( + "Cannot reshape tensor of shape {:?} to shape {:?}", + tensor.shape(), + new_shape + ); + } + }; + + // Convert each element to usize + let new_shape: [usize; D2] = new_shape.map(|x| x as usize); + + Shape::from(new_shape) + } } diff --git a/burn-tensor/src/tensor/api/bool.rs b/burn-tensor/src/tensor/api/bool.rs index 4047e64bac..e7d7e97460 100644 --- a/burn-tensor/src/tensor/api/bool.rs +++ b/burn-tensor/src/tensor/api/bool.rs @@ -2,30 +2,30 @@ use crate::{backend::Backend, Bool, Data, Int, Tensor}; impl Tensor where - B: Backend, + B: Backend, { - /// Create a boolean tensor from data. - pub fn from_bool(data: Data) -> Self { - Self::new(B::bool_from_data(data, &B::Device::default())) - } + /// Create a boolean tensor from data. + pub fn from_bool(data: Data) -> Self { + Self::new(B::bool_from_data(data, &B::Device::default())) + } - /// Create a boolean tensor from data on the given device. - pub fn from_bool_device(data: Data, device: &B::Device) -> Self { - Self::new(B::bool_from_data(data, device)) - } + /// Create a boolean tensor from data on the given device. + pub fn from_bool_device(data: Data, device: &B::Device) -> Self { + Self::new(B::bool_from_data(data, device)) + } - /// Convert the bool tensor into an int tensor. - pub fn int(self) -> Tensor { - Tensor::new(B::bool_into_int(self.primitive)) - } + /// Convert the bool tensor into an int tensor. + pub fn int(self) -> Tensor { + Tensor::new(B::bool_into_int(self.primitive)) + } - /// Convert the bool tensor into an float tensor. - pub fn float(self) -> Tensor { - Tensor::new(B::bool_into_float(self.primitive)) - } + /// Convert the bool tensor into an float tensor. + pub fn float(self) -> Tensor { + Tensor::new(B::bool_into_float(self.primitive)) + } - /// Inverses boolean values. - pub fn bool_not(self) -> Self { - Tensor::new(B::bool_not(self.primitive)) - } + /// Inverses boolean values. + pub fn bool_not(self) -> Self { + Tensor::new(B::bool_not(self.primitive)) + } } diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index cf064f8f4d..3d81317513 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -33,59 +33,61 @@ use core::ops::Range; /// implementation might re-implement the same checks, which may result in uncessary code /// duplication. Maybe a combination of both strategies could help to cover all usecases. pub(crate) enum TensorCheck { - Ok, - Failed(FailedTensorCheck), + Ok, + Failed(FailedTensorCheck), } impl TensorCheck { - /// Checks device and shape compatibility for element wise binary operations. - pub(crate) fn binary_ops_ew>( - ops: &str, - lhs: &Tensor, - rhs: &Tensor, - ) -> Self { - Self::Ok - .binary_ops_device(ops, &lhs.device(), &rhs.device()) - .binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape()) - } - - pub(crate) fn into_scalar(shape: &Shape) -> Self { - let mut check = Self::Ok; - - if shape.num_elements() != 1 { - check = check.register( - "Into Scalar", - TensorError::new("Only tensors with 1 element can be converted into scalar.").details( - format!("Current tensor has {} elements", shape.num_elements()), - ), - ); + /// Checks device and shape compatibility for element wise binary operations. + pub(crate) fn binary_ops_ew>( + ops: &str, + lhs: &Tensor, + rhs: &Tensor, + ) -> Self { + Self::Ok + .binary_ops_device(ops, &lhs.device(), &rhs.device()) + .binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape()) } - check - } + pub(crate) fn into_scalar(shape: &Shape) -> Self { + let mut check = Self::Ok; - pub(crate) fn dim_ops(ops: &str, dim: usize) -> Self { - let mut check = Self::Ok; + if shape.num_elements() != 1 { + check = check.register( + "Into Scalar", + TensorError::new("Only tensors with 1 element can be converted into scalar.") + .details(format!( + "Current tensor has {} elements", + shape.num_elements() + )), + ); + } - if dim >= D { - check = check.register( - ops, - TensorError::new("Given dimension is higher than the tensor rank.") - .details(format!("Tensor rank: '{D}', given dimension: '{dim}'.")), - ); + check } - check - } + pub(crate) fn dim_ops(ops: &str, dim: usize) -> Self { + let mut check = Self::Ok; - pub(crate) fn reshape_args_usize( - original: &Shape, - target: &Shape, - ) -> Self { - let mut check = Self::Ok; + if dim >= D { + check = check.register( + ops, + TensorError::new("Given dimension is higher than the tensor rank.") + .details(format!("Tensor rank: '{D}', given dimension: '{dim}'.")), + ); + } + + check + } - if original.num_elements() != target.num_elements() { - check = check.register( + pub(crate) fn reshape_args_usize( + original: &Shape, + target: &Shape, + ) -> Self { + let mut check = Self::Ok; + + if original.num_elements() != target.num_elements() { + check = check.register( "Reshape", TensorError::new( "The given shape doesn't have the same number of elements as the current tensor.", @@ -95,187 +97,195 @@ impl TensorCheck { original.dims, target.dims )), ); + } + + check } - check - } + pub(crate) fn reshape_args_i32(target: &[i32; D]) -> Self { + let mut check = Self::Ok; - pub(crate) fn reshape_args_i32(target: &[i32; D]) -> Self { - let mut check = Self::Ok; + if target.iter().any(|&dim| dim < -1) { + check = check.register( + "Reshape", + TensorError::new( + "The given shape cannot contain negative dimensions (other than -1).", + ) + .details(format!("Target shape: {:?}.", target)), + ); + } - if target.iter().any(|&dim| dim < -1) { - check = check.register( - "Reshape", - TensorError::new("The given shape cannot contain negative dimensions (other than -1).") - .details(format!("Target shape: {:?}.", target)), - ); - } + if target.iter().filter(|&x| x == &-1).count() > 1 { + check = check.register( + "Reshape", + TensorError::new("The given shape cannot contain more than one -1.") + .details(format!("Target shape: {:?}.", target)), + ); + } - if target.iter().filter(|&x| x == &-1).count() > 1 { - check = check.register( - "Reshape", - TensorError::new("The given shape cannot contain more than one -1.") - .details(format!("Target shape: {:?}.", target)), - ); + check } - check - } - - pub(crate) fn flatten( - start_dim: usize, - end_dim: usize, - ) -> Self { - let mut check = Self::Ok; + pub(crate) fn flatten( + start_dim: usize, + end_dim: usize, + ) -> Self { + let mut check = Self::Ok; - if start_dim > end_dim { - check = check.register( - "Flatten", - TensorError::new(format!( - "The start dim ({start_dim}) must be smaller than the end dim ({end_dim})" - )), - ); - } + if start_dim > end_dim { + check = check.register( + "Flatten", + TensorError::new(format!( + "The start dim ({start_dim}) must be smaller than the end dim ({end_dim})" + )), + ); + } - if D2 > D1 { - check = check.register( - "Flatten", - TensorError::new(format!("Result dim ({D2}) must be smaller than ({D1})")), - ); - } + if D2 > D1 { + check = check.register( + "Flatten", + TensorError::new(format!("Result dim ({D2}) must be smaller than ({D1})")), + ); + } - if D1 < end_dim + 1 { - check = check.register( - "Flatten", - TensorError::new(format!( - "The end dim ({end_dim}) must be greater than the tensor dim ({D2})" - )), - ); - } + if D1 < end_dim + 1 { + check = check.register( + "Flatten", + TensorError::new(format!( + "The end dim ({end_dim}) must be greater than the tensor dim ({D2})" + )), + ); + } - if D2 < D1 - (end_dim - start_dim) { - check = check.register( + if D2 < D1 - (end_dim - start_dim) { + check = check.register( "Flatten", TensorError::new(format!( "The destination dimension ({D2}) must be large enough to accommodate the flattening operation." )), ); + } + + check } - check - } + pub(crate) fn squeeze(dim: usize, tensor_dims: &[usize]) -> Self { + let mut check = Self::Ok; + // This should actually be to check that the dimension to squeeze + // has a size of 1 + if tensor_dims[dim] != 1 { + check = check.register( + "Squeeze", + TensorError::new(format!( + "Can't squeeze dimension {} because its size is not 1", + dim + )), + ); + } - pub(crate) fn squeeze(dim: usize, tensor_dims: &[usize]) -> Self { - let mut check = Self::Ok; - // This should actually be to check that the dimension to squeeze - // has a size of 1 - if tensor_dims[dim] != 1 { - check = check.register( - "Squeeze", - TensorError::new(format!( - "Can't squeeze dimension {} because its size is not 1", - dim - )), - ); + check } - check - } + pub(crate) fn unsqueeze() -> Self { + let mut check = Self::Ok; + if D2 < D1 { + check = check.register( + "Unsqueeze", + TensorError::new(format!( + "Can't unsqueeze smaller tensor, got dim {D2}, expected > {D1}" + )), + ); + } - pub(crate) fn unsqueeze() -> Self { - let mut check = Self::Ok; - if D2 < D1 { - check = check.register( - "Unsqueeze", - TensorError::new(format!( - "Can't unsqueeze smaller tensor, got dim {D2}, expected > {D1}" - )), - ); + check } - check - } + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { + let mut check = Self::Ok; - pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { - let mut check = Self::Ok; + if dim1 > D || dim2 > D { + check = check.register( + "Swap Dims", + TensorError::new("The swap dimensions must be smaller than the tensor dimension") + .details(format!( + "Swap dims ({dim1}, {dim2}) on tensor with ({D}) dimensions." + )), + ); + } - if dim1 > D || dim2 > D { - check = check.register( - "Swap Dims", - TensorError::new("The swap dimensions must be smaller than the tensor dimension").details( - format!("Swap dims ({dim1}, {dim2}) on tensor with ({D}) dimensions."), - ), - ); + check } - check - } - - pub(crate) fn matmul(lhs: &Tensor, rhs: &Tensor) -> Self { - let mut check = Self::Ok; + pub(crate) fn matmul( + lhs: &Tensor, + rhs: &Tensor, + ) -> Self { + let mut check = Self::Ok; - check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); + check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); - if D < 2 { - return check; - } + if D < 2 { + return check; + } - let shape_lhs = lhs.shape(); - let shape_rhs = rhs.shape(); + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); - let dim_lhs = shape_lhs.dims[D - 1]; - let dim_rhs = shape_rhs.dims[D - 2]; + let dim_lhs = shape_lhs.dims[D - 1]; + let dim_rhs = shape_rhs.dims[D - 2]; - if dim_lhs != dim_rhs { - check = check.register( - "Matmul", - TensorError::new(format!( + if dim_lhs != dim_rhs { + check = check.register( + "Matmul", + TensorError::new(format!( "The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}." )) - .details(format!( - "Lhs shape {:?}, rhs shape {:?}.", - shape_lhs.dims, shape_rhs.dims - )), - ); - } - - check - } - - pub(crate) fn cat>( - tensors: &[Tensor], - dim: usize, - ) -> Self { - let mut check = Self::Ok; + .details(format!( + "Lhs shape {:?}, rhs shape {:?}.", + shape_lhs.dims, shape_rhs.dims + )), + ); + } - if dim >= D { - check = check.register( - "Cat", - TensorError::new("Can't concatenate tensors on a dim that exceeds the tensors dimension") - .details(format!( - "Trying to concatenate tensors with {D} dimensions on axis {dim}." - )), - ); + check } - if tensors.is_empty() { - return check.register( - "Cat", - TensorError::new("Can't concatenate an empty list of tensors."), - ); - } + pub(crate) fn cat>( + tensors: &[Tensor], + dim: usize, + ) -> Self { + let mut check = Self::Ok; + + if dim >= D { + check = check.register( + "Cat", + TensorError::new( + "Can't concatenate tensors on a dim that exceeds the tensors dimension", + ) + .details(format!( + "Trying to concatenate tensors with {D} dimensions on axis {dim}." + )), + ); + } + + if tensors.is_empty() { + return check.register( + "Cat", + TensorError::new("Can't concatenate an empty list of tensors."), + ); + } - let mut shape_reference = tensors.get(0).unwrap().shape(); - shape_reference.dims[dim] = 1; // We want to check every dims except the one where the - // concatenation happens. + let mut shape_reference = tensors.get(0).unwrap().shape(); + shape_reference.dims[dim] = 1; // We want to check every dims except the one where the + // concatenation happens. - for tensor in tensors { - let mut shape = tensor.shape(); - shape.dims[dim] = 1; // Ignore the concatenate dim. + for tensor in tensors { + let mut shape = tensor.shape(); + shape.dims[dim] = 1; // Ignore the concatenate dim. - if shape_reference != shape { - return check.register( + if shape_reference != shape { + return check.register( "Cat", TensorError::new( "Can't concatenate tensors with different shapes, except for the provided dimension", @@ -286,36 +296,36 @@ impl TensorCheck { tensors.iter().map(Tensor::shape).collect::>() )), ); - } - } + } + } - check - } + check + } - pub(crate) fn slice( - shape: &Shape, - ranges: &[Range; D2], - ) -> Self { - let mut check = Self::Ok; - let n_dims_tensor = D1; - let n_dims_ranges = D2; + pub(crate) fn slice( + shape: &Shape, + ranges: &[Range; D2], + ) -> Self { + let mut check = Self::Ok; + let n_dims_tensor = D1; + let n_dims_ranges = D2; - if n_dims_tensor < n_dims_ranges { - check = check.register("Slice", + if n_dims_tensor < n_dims_ranges { + check = check.register("Slice", TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.") .details( format!( "The ranges array must be smaller or equal to the tensor number of dimensions. \ Tensor number of dimensions: {n_dims_tensor}, ranges array length {n_dims_ranges}." ))); - } + } - for i in 0..usize::min(D1, D2) { - let d_tensor = shape.dims[i]; - let range = ranges.get(i).unwrap(); + for i in 0..usize::min(D1, D2) { + let d_tensor = shape.dims[i]; + let range = ranges.get(i).unwrap(); - if range.end > d_tensor { - check = check.register( + if range.end > d_tensor { + check = check.register( "Slice", TensorError::new( "The provided ranges array has a range that exceeds the current tensor size.", @@ -326,10 +336,10 @@ impl TensorCheck { range.start, range.end, d_tensor, i, shape.dims, ranges, )), ); - } + } - if range.start >= range.end { - check = check.register( + if range.start >= range.end { + check = check.register( "Slice", TensorError::new("The provided range array has a range where the start index is bigger or equal to its end.") .details(format!( @@ -341,21 +351,21 @@ impl TensorCheck { shape.dims, ranges, ))); - } - } + } + } - check - } + check + } - pub(crate) fn slice_assign( - shape: &Shape, - shape_value: &Shape, - ranges: &[Range; D2], - ) -> Self { - let mut check = Self::Ok; + pub(crate) fn slice_assign( + shape: &Shape, + shape_value: &Shape, + ranges: &[Range; D2], + ) -> Self { + let mut check = Self::Ok; - if D1 < D2 { - check = check.register( + if D1 < D2 { + check = check.register( "Slice Assign", TensorError::new( "The provided ranges array has a higher number of dimensions than the current tensor.", @@ -365,15 +375,15 @@ impl TensorCheck { Tensor number of dimensions: {D1}, ranges array length {D2}." )), ); - } + } - for i in 0..usize::min(D1, D2) { - let d_tensor = shape.dims[i]; - let d_tensor_value = shape_value.dims[i]; - let range = ranges.get(i).unwrap(); + for i in 0..usize::min(D1, D2) { + let d_tensor = shape.dims[i]; + let d_tensor_value = shape_value.dims[i]; + let range = ranges.get(i).unwrap(); - if range.end > d_tensor { - check = check.register( + if range.end > d_tensor { + check = check.register( "Range Assign", TensorError::new( "The provided ranges array has a range that exceeds the current tensor size.", @@ -384,10 +394,10 @@ impl TensorCheck { range.start, range.end, d_tensor, i, shape.dims, shape_value.dims, ranges, )), ); - } + } - if range.end - range.start != d_tensor_value { - check = check.register( + if range.end - range.start != d_tensor_value { + check = check.register( "Slice Assign", TensorError::new("The value tensor must match the amount of elements selected with the ranges array") .details(format!( @@ -401,10 +411,10 @@ impl TensorCheck { shape_value.dims, ranges, ))); - } + } - if range.start >= range.end { - check = check.register( + if range.start >= range.end { + check = check.register( "Slice Assign", TensorError::new("The provided ranges array has a range where the start index is bigger or equal to its end.") .details(format!( @@ -417,240 +427,245 @@ impl TensorCheck { shape_value.dims, ranges, ))); - } - } - - check - } - - pub(crate) fn gather( - dim: usize, - shape: &Shape, - shape_indices: &Shape, - ) -> Self { - Self::check_gather_scatter_indices(Self::Ok, "Gather", dim, shape, shape_indices) - } - - pub(crate) fn scatter( - dim: usize, - shape: &Shape, - shape_indices: &Shape, - shape_value: &Shape, - ) -> Self { - let ops = "Scatter"; - let mut check = Self::check_gather_scatter_indices(Self::Ok, ops, dim, shape, shape_indices); - - if shape_indices != shape_value { - check = check.register( - ops, - TensorError::new( - "Indices tensor shape should be the same as the value tensor shape.".to_string(), - ) - .details(format!( - "The shape differs: {:?} != {:?}", - shape_indices.dims, shape_value.dims - )), - ); - } + } + } - check - } + check + } - pub(crate) fn select(dim: usize) -> Self { - Self::check_select_basic::(Self::Ok, "select", dim) - } + pub(crate) fn gather( + dim: usize, + shape: &Shape, + shape_indices: &Shape, + ) -> Self { + Self::check_gather_scatter_indices(Self::Ok, "Gather", dim, shape, shape_indices) + } - pub(crate) fn select_assign(dim: usize) -> Self { - Self::check_select_basic::(Self::Ok, "select_assign", dim) - } + pub(crate) fn scatter( + dim: usize, + shape: &Shape, + shape_indices: &Shape, + shape_value: &Shape, + ) -> Self { + let ops = "Scatter"; + let mut check = + Self::check_gather_scatter_indices(Self::Ok, ops, dim, shape, shape_indices); + + if shape_indices != shape_value { + check = check.register( + ops, + TensorError::new( + "Indices tensor shape should be the same as the value tensor shape." + .to_string(), + ) + .details(format!( + "The shape differs: {:?} != {:?}", + shape_indices.dims, shape_value.dims + )), + ); + } - fn check_select_basic(mut check: Self, ops: &str, dim: usize) -> Self { - if dim > D { - check = check.register( - ops, - TensorError::new(format!( - "Can't index a tensor with ({D}) dimensions on axis ({dim})" - )), - ); + check } - check - } - fn check_gather_scatter_indices( - mut check: Self, - ops: &str, - dim: usize, - shape: &Shape, - shape_indices: &Shape, - ) -> Self { - if dim > D { - check = check.register( - ops, - TensorError::new(format!( - "Can't index a tensor with ({D}) dimensions on axis ({dim})" - )), - ); + pub(crate) fn select(dim: usize) -> Self { + Self::check_select_basic::(Self::Ok, "select", dim) } - for i in 0..D { - if i == dim { - continue; - } + pub(crate) fn select_assign(dim: usize) -> Self { + Self::check_select_basic::(Self::Ok, "select_assign", dim) + } - let tensor_dim_i = shape.dims[i]; - let indices_dim_i = shape_indices.dims[i]; + fn check_select_basic(mut check: Self, ops: &str, dim: usize) -> Self { + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't index a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } - if tensor_dim_i != indices_dim_i { - check = check.register( - ops, - TensorError::new( - "The tensor shape should be the same as the index tensor shape.".to_string(), - ) - .details(format!( - "The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}" - )), - ); - } + check } + fn check_gather_scatter_indices( + mut check: Self, + ops: &str, + dim: usize, + shape: &Shape, + shape_indices: &Shape, + ) -> Self { + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't index a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } - check - } - - /// Checks aggregate dimension such as mean and sum. - pub(crate) fn aggregate_dim(ops: &str, dim: usize) -> Self { - let mut check = Self::Ok; + for i in 0..D { + if i == dim { + continue; + } + + let tensor_dim_i = shape.dims[i]; + let indices_dim_i = shape_indices.dims[i]; + + if tensor_dim_i != indices_dim_i { + check = check.register( + ops, + TensorError::new( + "The tensor shape should be the same as the index tensor shape." + .to_string(), + ) + .details(format!( + "The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}" + )), + ); + } + } - if dim > D { - check = check.register( - ops, - TensorError::new(format!( - "Can't aggregate a tensor with ({D}) dimensions on axis ({dim})" - )), - ); + check } - check - } + /// Checks aggregate dimension such as mean and sum. + pub(crate) fn aggregate_dim(ops: &str, dim: usize) -> Self { + let mut check = Self::Ok; - /// The goal is to minimize the cost of checks when there are no error, but it's way less - /// important when an error occurred, crafting a comprehensive error message is more important - /// than optimizing string manipulation. - fn register(self, ops: &str, error: TensorError) -> Self { - let errors = match self { - Self::Ok => vec![error], - Self::Failed(mut failed) => { - failed.errors.push(error); - failed.errors - } - }; - - Self::Failed(FailedTensorCheck { - ops: ops.to_string(), - errors, - }) - } - - /// Checks if shapes are compatible for element wise operations supporting broadcasting. - pub(crate) fn binary_ops_ew_shape( - self, - ops: &str, - lhs: &Shape, - rhs: &Shape, - ) -> Self { - let mut check = self; - - for i in 0..D { - let d_lhs = lhs.dims[i]; - let d_rhs = rhs.dims[i]; - - if d_lhs != d_rhs { - let is_broadcast = d_lhs == 1 || d_rhs == 1; - - if is_broadcast { - continue; + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't aggregate a tensor with ({D}) dimensions on axis ({dim})" + )), + ); } - check = check.register( - ops, - TensorError::new("The provided tensors have incompatible shapes.").details(format!( + check + } + + /// The goal is to minimize the cost of checks when there are no error, but it's way less + /// important when an error occurred, crafting a comprehensive error message is more important + /// than optimizing string manipulation. + fn register(self, ops: &str, error: TensorError) -> Self { + let errors = match self { + Self::Ok => vec![error], + Self::Failed(mut failed) => { + failed.errors.push(error); + failed.errors + } + }; + + Self::Failed(FailedTensorCheck { + ops: ops.to_string(), + errors, + }) + } + + /// Checks if shapes are compatible for element wise operations supporting broadcasting. + pub(crate) fn binary_ops_ew_shape( + self, + ops: &str, + lhs: &Shape, + rhs: &Shape, + ) -> Self { + let mut check = self; + + for i in 0..D { + let d_lhs = lhs.dims[i]; + let d_rhs = rhs.dims[i]; + + if d_lhs != d_rhs { + let is_broadcast = d_lhs == 1 || d_rhs == 1; + + if is_broadcast { + continue; + } + + check = check.register( + ops, + TensorError::new("The provided tensors have incompatible shapes.").details( + format!( "Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \ Lhs tensor shape {:?}, Rhs tensor shape {:?}.", i, d_lhs, d_rhs, lhs.dims, rhs.dims, - )), - ); - } - } - - check - } - - /// Checks if tensor devices are equal. - fn binary_ops_device( - self, - ops: &str, - lhs: &Device, - rhs: &Device, - ) -> Self { - match lhs != rhs { - true => self.register( - ops, - TensorError::new("The provided tensors are not on the same device.").details(format!( - "Lhs tensor device {lhs:?}, Rhs tensor device {rhs:?}.", - )), - ), - false => self, + ), + ), + ); + } + } + + check + } + + /// Checks if tensor devices are equal. + fn binary_ops_device( + self, + ops: &str, + lhs: &Device, + rhs: &Device, + ) -> Self { + match lhs != rhs { + true => self.register( + ops, + TensorError::new("The provided tensors are not on the same device.").details( + format!("Lhs tensor device {lhs:?}, Rhs tensor device {rhs:?}.",), + ), + ), + false => self, + } } - } } pub(crate) struct FailedTensorCheck { - ops: String, - errors: Vec, + ops: String, + errors: Vec, } impl FailedTensorCheck { - /// Format all the checks into a single message ready to be printed by a [panic](core::panic). - pub(crate) fn format(self) -> String { - self.errors.into_iter().enumerate().fold( - format!( - "=== Tensor Operation Error ===\n Operation: '{}'\n Reason:", - self.ops - ), - |accum, (number, error)| accum + error.format(number + 1).as_str(), - ) + "\n" - } + /// Format all the checks into a single message ready to be printed by a [panic](core::panic). + pub(crate) fn format(self) -> String { + self.errors.into_iter().enumerate().fold( + format!( + "=== Tensor Operation Error ===\n Operation: '{}'\n Reason:", + self.ops + ), + |accum, (number, error)| accum + error.format(number + 1).as_str(), + ) + "\n" + } } struct TensorError { - description: String, - details: Option, + description: String, + details: Option, } impl TensorError { - pub(crate) fn new>(description: S) -> Self { - TensorError { - description: description.into(), - details: None, + pub(crate) fn new>(description: S) -> Self { + TensorError { + description: description.into(), + details: None, + } } - } - pub(crate) fn details>(mut self, details: S) -> Self { - self.details = Some(details.into()); - self - } + pub(crate) fn details>(mut self, details: S) -> Self { + self.details = Some(details.into()); + self + } - fn format(self, number: usize) -> String { - let mut message = format!("\n {number}. "); - message += self.description.as_str(); - message += " "; + fn format(self, number: usize) -> String { + let mut message = format!("\n {number}. "); + message += self.description.as_str(); + message += " "; - if let Some(details) = self.details { - message += details.as_str(); - message += " "; - } + if let Some(details) = self.details { + message += details.as_str(); + message += " "; + } - message - } + message + } } /// We use a macro for all checks, since the panic message file and line number will match the @@ -658,78 +673,78 @@ impl TensorError { /// and line number. #[macro_export(local_inner_macros)] macro_rules! check { - ($check:expr) => { - if let TensorCheck::Failed(check) = $check { - core::panic!("{}", check.format()); - } - }; + ($check:expr) => { + if let TensorCheck::Failed(check) = $check { + core::panic!("{}", check.format()); + } + }; } #[cfg(test)] mod tests { - use super::*; - - #[test] - #[should_panic] - fn reshape_invalid_shape() { - check!(TensorCheck::reshape_args_usize( - &Shape::new([2, 2]), - &Shape::new([1, 3]) - )); - } - - #[test] - fn reshape_valid_shape() { - check!(TensorCheck::reshape_args_usize( - &Shape::new([2, 2]), - &Shape::new([1, 4]) - )); - } - - #[test] - #[should_panic] - fn index_range_exceed_dimension() { - check!(TensorCheck::slice( - &Shape::new([3, 5, 7]), - &[0..2, 0..4, 1..8] - )); - } - - #[test] - #[should_panic] - fn index_range_exceed_number_of_dimensions() { - check!(TensorCheck::slice(&Shape::new([3, 5]), &[0..1, 0..1, 0..1])); - } - - #[test] - #[should_panic] - fn binary_ops_shapes_no_broadcast() { - check!(TensorCheck::binary_ops_ew_shape( - TensorCheck::Ok, - "TestOps", - &Shape::new([3, 5]), - &Shape::new([3, 6]) - )); - } - - #[test] - fn binary_ops_shapes_with_broadcast() { - check!(TensorCheck::binary_ops_ew_shape( - TensorCheck::Ok, - "Test", - &Shape::new([3, 5]), - &Shape::new([1, 5]) - )); - } - - #[test] - #[should_panic] - fn binary_ops_devices() { - check!(TensorCheck::binary_ops_device( - TensorCheck::Ok, - "Test", - &5, // We can pass anything that implements PartialEq as device - &8 - )); - } + use super::*; + + #[test] + #[should_panic] + fn reshape_invalid_shape() { + check!(TensorCheck::reshape_args_usize( + &Shape::new([2, 2]), + &Shape::new([1, 3]) + )); + } + + #[test] + fn reshape_valid_shape() { + check!(TensorCheck::reshape_args_usize( + &Shape::new([2, 2]), + &Shape::new([1, 4]) + )); + } + + #[test] + #[should_panic] + fn index_range_exceed_dimension() { + check!(TensorCheck::slice( + &Shape::new([3, 5, 7]), + &[0..2, 0..4, 1..8] + )); + } + + #[test] + #[should_panic] + fn index_range_exceed_number_of_dimensions() { + check!(TensorCheck::slice(&Shape::new([3, 5]), &[0..1, 0..1, 0..1])); + } + + #[test] + #[should_panic] + fn binary_ops_shapes_no_broadcast() { + check!(TensorCheck::binary_ops_ew_shape( + TensorCheck::Ok, + "TestOps", + &Shape::new([3, 5]), + &Shape::new([3, 6]) + )); + } + + #[test] + fn binary_ops_shapes_with_broadcast() { + check!(TensorCheck::binary_ops_ew_shape( + TensorCheck::Ok, + "Test", + &Shape::new([3, 5]), + &Shape::new([1, 5]) + )); + } + + #[test] + #[should_panic] + fn binary_ops_devices() { + check!(TensorCheck::binary_ops_device( + TensorCheck::Ok, + "Test", + &5, // We can pass anything that implements PartialEq as device + &8 + )); + } } diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index 4115b8538b..d2cededbab 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -12,316 +12,316 @@ use crate::Tensor; impl Tensor where - B: Backend, + B: Backend, { - /// Executes an operation on the tensor and modifies its value. - /// - /// # Notes - /// - /// This won't necessary reuse the same tensor data/buffer, but it should if there is - /// no other reference pointing to the same tensor. - /// - /// Wrapping operations with inplace is not an optimization, it's mainly there if you - /// want to mutate a tensor by using owned operations. A plausible usage would be to - /// update the weights of a mutable model reference. - pub fn inplace Self>(&mut self, func: F) { - let mut tensor_owned = Tensor::empty([0; D]); - core::mem::swap(&mut tensor_owned, self); - - let mut tensor_new = func(tensor_owned); - core::mem::swap(&mut tensor_new, self); - } - - /// Applies element wise exponential operation. - /// - /// `y = e^x` - pub fn exp(self) -> Self { - Self::new(B::exp(self.primitive)) - } - - /// Applies element wise natural log operation *ln*. - /// - /// `y = log(x)` - pub fn log(self) -> Self { - Self::new(B::log(self.primitive)) - } - - /// Applies the natural logarithm of one plus the input tensor, element-wise. - /// - /// `y = log(x+1)` - pub fn log1p(self) -> Self { - Self::new(B::log1p(self.primitive)) - } - - /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise. - /// - /// `y = erf(x)` - pub fn erf(self) -> Self { - Self::new(B::erf(self.primitive)) - } - - /// Applies element wise power operation. - /// - /// `y = x^a` - pub fn powf(self, value: f32) -> Self { - Self::new(B::powf(self.primitive, value)) - } - - /// Applies element wise reciprocal operation. - pub fn recip(self) -> Self { - Self::new(B::recip(self.primitive)) - } - - /// Applies element wise root square operation. - pub fn sqrt(self) -> Self { - Self::new(B::sqrt(self.primitive)) - } - - /// Applies element wise cosine operation. - pub fn cos(self) -> Self { - Self::new(B::cos(self.primitive)) - } - - /// Applies element wise sine operation. - pub fn sin(self) -> Self { - Self::new(B::sin(self.primitive)) - } - - /// Applies element wise hyperbolic tangent operation. - pub fn tanh(self) -> Self { - Self::new(B::tanh(self.primitive)) - } - - /// Create a tensor from floats (f32). - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let _ = Tensor::::from_floats([1.0, 2.0]); - /// let _ = Tensor::::from_floats([[1.0, 2.0], [3.0, 4.0]]); - /// } - /// ``` - pub fn from_floats>>(floats: A) -> Self { - Self::from_data(floats.into().convert()) - } - - /// Returns a new tensor with the same shape and device as the current tensor and the data - /// casted to Integer. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let float_tensor = Tensor::::from_floats([1.0, 2.0]); - /// let int_tensor = float_tensor.int(); - /// } - /// ``` - pub fn int(self) -> Tensor { - Tensor::new(B::into_int(self.primitive)) - } - - /// Returns a new tensor with the same shape and device as the current tensor filled with zeros. - pub fn zeros_like(&self) -> Self { - Tensor::new(B::zeros(self.shape(), &self.device())) - } - - /// Returns a new tensor with the same shape and device as the current tensor filled with ones. - pub fn ones_like(&self) -> Self { - Tensor::new(B::ones(self.shape(), &self.device())) - } - - /// Returns a new tensor with the same shape and device as the current tensor filled random - /// values sampled from the given distribution. - pub fn random_like(&self, distribution: Distribution) -> Self { - Tensor::new(B::random(self.shape(), distribution, &self.device())) - } - - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let one_hot = Tensor::::one_hot(2, 10); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - /// ``` - pub fn one_hot(index: usize, num_classes: usize) -> Self { - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]))) - } - - /// Applies the matrix multiplication operation. - /// - /// `C = AB` - /// - /// # Panics - /// - /// If the two tensors dont' have a compatible shape. - pub fn matmul(self, other: Self) -> Self { - check!(TensorCheck::matmul(&self, &other)); - Self::new(B::matmul(self.primitive, other.primitive)) - } - - /// Calculate the variance along the given dimension. - pub fn var(self, dim: usize) -> Self { - stats::var(self, dim) - } - - /// Calculate the variance along the given dimension without applying the Bessel’s correction. - pub fn var_bias(self, dim: usize) -> Self { - stats::var_bias(self, dim) - } - - /// Calculate the variance along the given dimension and also returns the mean. - pub fn var_mean(self, dim: usize) -> (Self, Self) { - let mean = self.clone().mean_dim(dim); - let var = stats::var_with_mean(self, mean.clone(), dim); - (var, mean) - } - - /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean. - pub fn var_mean_bias(self, dim: usize) -> (Self, Self) { - let mean = self.clone().mean_dim(dim); - let var = stats::var_with_mean_bias(self, mean.clone(), dim); - (var, mean) - } - - /// Create a random tensor of the given shape where each element is sampled from the given - /// distribution. - pub fn random>>(shape: S, distribution: Distribution) -> Self { - let tensor = B::random(shape.into(), distribution, &B::Device::default()); - Self::new(tensor) - } - - /// Create a random tensor of the given shape on the given device where each element is - /// sampled from the given distribution. - pub fn random_device>>( - shape: S, - distribution: Distribution, - device: &B::Device, - ) -> Self { - let tensor = B::random(shape.into(), distribution, device); - Self::new(tensor) - } - /// Returns a tensor with full precision based on the selected backend. - pub fn to_full_precision(&self) -> Tensor { - Tensor::new(B::to_full_precision(&self.primitive)) - } - - /// Returns a tensor on the selected backend from a full precision tensor. - pub fn from_full_precision(tensor: Tensor) -> Self { - Self::new(B::from_full_precision(tensor.primitive)) - } - - /// Detach the current tensor from the autodiff graph. - /// This function does nothing when autodiff is not enabled. - /// This can be used in batchers or elsewhere to ensure that previous operations are not - /// considered in the autodiff graph. - pub fn detach(self) -> Self { - Self::new(B::detach(self.primitive)) - } - - /// Mark the tensor to keep gradients during the backward pass. - /// This function does nothing when autodiff is not enabled. - pub fn require_grad(self) -> Self { - self.set_require_grad(true) - } - - /// Returns true if the tensor requires gradients during the backward pass. - pub fn is_require_grad(&self) -> bool { - B::is_require_grad(&self.primitive) - } - - /// Mark the tensor as tracked or untracked depending on the require grad argument. - /// When tracked, the gradients will be available after the backward pass. - /// - /// This function does nothing when autodiff is not enabled. - pub fn set_require_grad(self, require_grad: bool) -> Self { - Self::new(B::set_require_grad(self.primitive, require_grad)) - } - - /// Applies the relu function to the tensor. - pub(crate) fn relu(self) -> Self { - Self::new(B::relu(self.primitive)) - } - - /// Calculate covaraince matrix between different entries alongside a given dimension. - /// - /// # Arguments - /// - /// * `size` - The size of the square matrix. - /// * `correction_factor` - Is usually 1 for samples and 0 for population. - pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor { - let n = self.dims()[dim]; - let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0); - centered - .clone() - .transpose() - .matmul(centered) - .div_scalar(n as f32 - correction_factor as f32) - } + /// Executes an operation on the tensor and modifies its value. + /// + /// # Notes + /// + /// This won't necessary reuse the same tensor data/buffer, but it should if there is + /// no other reference pointing to the same tensor. + /// + /// Wrapping operations with inplace is not an optimization, it's mainly there if you + /// want to mutate a tensor by using owned operations. A plausible usage would be to + /// update the weights of a mutable model reference. + pub fn inplace Self>(&mut self, func: F) { + let mut tensor_owned = Tensor::empty([0; D]); + core::mem::swap(&mut tensor_owned, self); + + let mut tensor_new = func(tensor_owned); + core::mem::swap(&mut tensor_new, self); + } + + /// Applies element wise exponential operation. + /// + /// `y = e^x` + pub fn exp(self) -> Self { + Self::new(B::exp(self.primitive)) + } + + /// Applies element wise natural log operation *ln*. + /// + /// `y = log(x)` + pub fn log(self) -> Self { + Self::new(B::log(self.primitive)) + } + + /// Applies the natural logarithm of one plus the input tensor, element-wise. + /// + /// `y = log(x+1)` + pub fn log1p(self) -> Self { + Self::new(B::log1p(self.primitive)) + } + + /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise. + /// + /// `y = erf(x)` + pub fn erf(self) -> Self { + Self::new(B::erf(self.primitive)) + } + + /// Applies element wise power operation. + /// + /// `y = x^a` + pub fn powf(self, value: f32) -> Self { + Self::new(B::powf(self.primitive, value)) + } + + /// Applies element wise reciprocal operation. + pub fn recip(self) -> Self { + Self::new(B::recip(self.primitive)) + } + + /// Applies element wise root square operation. + pub fn sqrt(self) -> Self { + Self::new(B::sqrt(self.primitive)) + } + + /// Applies element wise cosine operation. + pub fn cos(self) -> Self { + Self::new(B::cos(self.primitive)) + } + + /// Applies element wise sine operation. + pub fn sin(self) -> Self { + Self::new(B::sin(self.primitive)) + } + + /// Applies element wise hyperbolic tangent operation. + pub fn tanh(self) -> Self { + Self::new(B::tanh(self.primitive)) + } + + /// Create a tensor from floats (f32). + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let _ = Tensor::::from_floats([1.0, 2.0]); + /// let _ = Tensor::::from_floats([[1.0, 2.0], [3.0, 4.0]]); + /// } + /// ``` + pub fn from_floats>>(floats: A) -> Self { + Self::from_data(floats.into().convert()) + } + + /// Returns a new tensor with the same shape and device as the current tensor and the data + /// casted to Integer. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let float_tensor = Tensor::::from_floats([1.0, 2.0]); + /// let int_tensor = float_tensor.int(); + /// } + /// ``` + pub fn int(self) -> Tensor { + Tensor::new(B::into_int(self.primitive)) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled with zeros. + pub fn zeros_like(&self) -> Self { + Tensor::new(B::zeros(self.shape(), &self.device())) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled with ones. + pub fn ones_like(&self) -> Self { + Tensor::new(B::ones(self.shape(), &self.device())) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled random + /// values sampled from the given distribution. + pub fn random_like(&self, distribution: Distribution) -> Self { + Tensor::new(B::random(self.shape(), distribution, &self.device())) + } + + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let one_hot = Tensor::::one_hot(2, 10); + /// println!("{}", one_hot.to_data()); + /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + /// } + /// ``` + pub fn one_hot(index: usize, num_classes: usize) -> Self { + let mut dims = [1; D]; + dims[D - 1] = num_classes; + let shape = Shape::new(dims); + let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); + let tensor = Tensor::zeros(shape); + let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); + ranges[D - 1] = index..index + 1; + + tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]))) + } + + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors dont' have a compatible shape. + pub fn matmul(self, other: Self) -> Self { + check!(TensorCheck::matmul(&self, &other)); + Self::new(B::matmul(self.primitive, other.primitive)) + } + + /// Calculate the variance along the given dimension. + pub fn var(self, dim: usize) -> Self { + stats::var(self, dim) + } + + /// Calculate the variance along the given dimension without applying the Bessel’s correction. + pub fn var_bias(self, dim: usize) -> Self { + stats::var_bias(self, dim) + } + + /// Calculate the variance along the given dimension and also returns the mean. + pub fn var_mean(self, dim: usize) -> (Self, Self) { + let mean = self.clone().mean_dim(dim); + let var = stats::var_with_mean(self, mean.clone(), dim); + (var, mean) + } + + /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean. + pub fn var_mean_bias(self, dim: usize) -> (Self, Self) { + let mean = self.clone().mean_dim(dim); + let var = stats::var_with_mean_bias(self, mean.clone(), dim); + (var, mean) + } + + /// Create a random tensor of the given shape where each element is sampled from the given + /// distribution. + pub fn random>>(shape: S, distribution: Distribution) -> Self { + let tensor = B::random(shape.into(), distribution, &B::Device::default()); + Self::new(tensor) + } + + /// Create a random tensor of the given shape on the given device where each element is + /// sampled from the given distribution. + pub fn random_device>>( + shape: S, + distribution: Distribution, + device: &B::Device, + ) -> Self { + let tensor = B::random(shape.into(), distribution, device); + Self::new(tensor) + } + /// Returns a tensor with full precision based on the selected backend. + pub fn to_full_precision(&self) -> Tensor { + Tensor::new(B::to_full_precision(&self.primitive)) + } + + /// Returns a tensor on the selected backend from a full precision tensor. + pub fn from_full_precision(tensor: Tensor) -> Self { + Self::new(B::from_full_precision(tensor.primitive)) + } + + /// Detach the current tensor from the autodiff graph. + /// This function does nothing when autodiff is not enabled. + /// This can be used in batchers or elsewhere to ensure that previous operations are not + /// considered in the autodiff graph. + pub fn detach(self) -> Self { + Self::new(B::detach(self.primitive)) + } + + /// Mark the tensor to keep gradients during the backward pass. + /// This function does nothing when autodiff is not enabled. + pub fn require_grad(self) -> Self { + self.set_require_grad(true) + } + + /// Returns true if the tensor requires gradients during the backward pass. + pub fn is_require_grad(&self) -> bool { + B::is_require_grad(&self.primitive) + } + + /// Mark the tensor as tracked or untracked depending on the require grad argument. + /// When tracked, the gradients will be available after the backward pass. + /// + /// This function does nothing when autodiff is not enabled. + pub fn set_require_grad(self, require_grad: bool) -> Self { + Self::new(B::set_require_grad(self.primitive, require_grad)) + } + + /// Applies the relu function to the tensor. + pub(crate) fn relu(self) -> Self { + Self::new(B::relu(self.primitive)) + } + + /// Calculate covaraince matrix between different entries alongside a given dimension. + /// + /// # Arguments + /// + /// * `size` - The size of the square matrix. + /// * `correction_factor` - Is usually 1 for samples and 0 for population. + pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor { + let n = self.dims()[dim]; + let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0); + centered + .clone() + .transpose() + .matmul(centered) + .div_scalar(n as f32 - correction_factor as f32) + } } impl Tensor { - /// Backward pass of the tensor. - pub fn backward(&self) -> B::Gradients { - B::backward::(self.primitive.clone()) - } - - /// Get the gradients of a tensor if it exist. - /// - /// Returns a new reference to the same tensor. Therefore the same grad tensor can - /// be accessed multiple times. If you only need to get the gradients one time, - /// consider using [grad_remove](Tensor::grad_remove) for better performance. - pub fn grad(&self, grads: &B::Gradients) -> Option> { - B::grad(&self.primitive, grads).map(Tensor::new) - } - - /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result. - pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option> { - B::grad_remove(&self.primitive, grads).map(Tensor::new) - } - - /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided - /// gradient. - pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor) { - B::grad_replace(&self.primitive, grads, grad.primitive); - } - - /// Returns the inner tensor without the autodiff information. - pub fn inner(self) -> Tensor { - Tensor::new(B::inner(self.primitive)) - } - - /// Convert a tensor to the autodiff backend. - /// - /// # Arguments - /// - /// * `inner` - The tensor to convert. - /// - /// # Returns - /// - /// The tensor converted to the autodiff backend. - pub fn from_inner(inner: Tensor) -> Self { - Self::new(B::from_inner(inner.primitive)) - } + /// Backward pass of the tensor. + pub fn backward(&self) -> B::Gradients { + B::backward::(self.primitive.clone()) + } + + /// Get the gradients of a tensor if it exist. + /// + /// Returns a new reference to the same tensor. Therefore the same grad tensor can + /// be accessed multiple times. If you only need to get the gradients one time, + /// consider using [grad_remove](Tensor::grad_remove) for better performance. + pub fn grad(&self, grads: &B::Gradients) -> Option> { + B::grad(&self.primitive, grads).map(Tensor::new) + } + + /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result. + pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option> { + B::grad_remove(&self.primitive, grads).map(Tensor::new) + } + + /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided + /// gradient. + pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor) { + B::grad_replace(&self.primitive, grads, grad.primitive); + } + + /// Returns the inner tensor without the autodiff information. + pub fn inner(self) -> Tensor { + Tensor::new(B::inner(self.primitive)) + } + + /// Convert a tensor to the autodiff backend. + /// + /// # Arguments + /// + /// * `inner` - The tensor to convert. + /// + /// # Returns + /// + /// The tensor converted to the autodiff backend. + pub fn from_inner(inner: Tensor) -> Self { + Self::new(B::from_inner(inner.primitive)) + } } diff --git a/burn-tensor/src/tensor/api/int.rs b/burn-tensor/src/tensor/api/int.rs index f141c40ca8..395db8c280 100644 --- a/burn-tensor/src/tensor/api/int.rs +++ b/burn-tensor/src/tensor/api/int.rs @@ -3,84 +3,84 @@ use core::ops::Range; impl Tensor where - B: Backend, + B: Backend, { - /// Returns a new integer tensor on the default device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - pub fn arange(range: Range) -> Self { - Tensor::new(B::arange(range, &B::Device::default())) - } + /// Returns a new integer tensor on the default device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + pub fn arange(range: Range) -> Self { + Tensor::new(B::arange(range, &B::Device::default())) + } - /// Returns a new integer tensor on the default device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - /// * `step` - The step between each value. - pub fn arange_step(range: Range, step: usize) -> Self { - Tensor::new(B::arange_step(range, step, &B::Device::default())) - } + /// Returns a new integer tensor on the default device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `step` - The step between each value. + pub fn arange_step(range: Range, step: usize) -> Self { + Tensor::new(B::arange_step(range, step, &B::Device::default())) + } - /// Returns a new integer tensor on the specified device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - /// * `device` - The device to create the tensor on. - pub fn arange_device(range: Range, device: &B::Device) -> Self { - Tensor::new(B::arange(range, device)) - } + /// Returns a new integer tensor on the specified device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `device` - The device to create the tensor on. + pub fn arange_device(range: Range, device: &B::Device) -> Self { + Tensor::new(B::arange(range, device)) + } - /// Returns a new integer tensor on the specified device. - /// - /// # Arguments - /// - /// * `range` - The range of values to generate. - /// * `step` - The step between each value. - pub fn arange_step_device(range: Range, step: usize, device: &B::Device) -> Self { - Tensor::new(B::arange_step(range, step, device)) - } + /// Returns a new integer tensor on the specified device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `step` - The step between each value. + pub fn arange_step_device(range: Range, step: usize, device: &B::Device) -> Self { + Tensor::new(B::arange_step(range, step, device)) + } } impl Tensor where - B: Backend, + B: Backend, { - /// Create a tensor from integers (i32). - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example() { - /// let _x: Tensor = Tensor::from_ints([1, 2]); - /// let _y: Tensor = Tensor::from_ints([[1, 2], [3, 4]]); - /// } - /// ``` - pub fn from_ints>>(ints: A) -> Self { - Self::from_data(ints.into().convert()) - } + /// Create a tensor from integers (i32). + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Int}; + /// + /// fn example() { + /// let _x: Tensor = Tensor::from_ints([1, 2]); + /// let _y: Tensor = Tensor::from_ints([[1, 2], [3, 4]]); + /// } + /// ``` + pub fn from_ints>>(ints: A) -> Self { + Self::from_data(ints.into().convert()) + } - /// Returns a new tensor with the same shape and device as the current tensor and the data - /// casted to Float. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor}; - /// - /// fn example() { - /// let int_tensor = Tensor::::arange(0..5); - /// let float_tensor = int_tensor.float(); - /// } - /// ``` - pub fn float(self) -> Tensor { - Tensor::new(B::int_into_float(self.primitive)) - } + /// Returns a new tensor with the same shape and device as the current tensor and the data + /// casted to Float. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Int, Tensor}; + /// + /// fn example() { + /// let int_tensor = Tensor::::arange(0..5); + /// let float_tensor = int_tensor.float(); + /// } + /// ``` + pub fn float(self) -> Tensor { + Tensor::new(B::int_into_float(self.primitive)) + } } diff --git a/burn-tensor/src/tensor/api/kind.rs b/burn-tensor/src/tensor/api/kind.rs index d7e7cc9eb5..208aa26ef4 100644 --- a/burn-tensor/src/tensor/api/kind.rs +++ b/burn-tensor/src/tensor/api/kind.rs @@ -14,30 +14,30 @@ pub struct Bool; /// A type-level representation of the kind of a tensor. pub trait TensorKind: Clone + core::fmt::Debug { - /// The primitive type of the tensor. - type Primitive: Clone + core::fmt::Debug; + /// The primitive type of the tensor. + type Primitive: Clone + core::fmt::Debug; - /// The name of the tensor kind. - fn name() -> &'static str; + /// The name of the tensor kind. + fn name() -> &'static str; } impl TensorKind for Float { - type Primitive = B::TensorPrimitive; - fn name() -> &'static str { - "Float" - } + type Primitive = B::TensorPrimitive; + fn name() -> &'static str { + "Float" + } } impl TensorKind for Int { - type Primitive = B::IntTensorPrimitive; - fn name() -> &'static str { - "Int" - } + type Primitive = B::IntTensorPrimitive; + fn name() -> &'static str { + "Int" + } } impl TensorKind for Bool { - type Primitive = B::BoolTensorPrimitive; - fn name() -> &'static str { - "Bool" - } + type Primitive = B::BoolTensorPrimitive; + fn name() -> &'static str { + "Bool" + } } diff --git a/burn-tensor/src/tensor/api/numeric.rs b/burn-tensor/src/tensor/api/numeric.rs index 2ee70d39db..fe3d064120 100644 --- a/burn-tensor/src/tensor/api/numeric.rs +++ b/burn-tensor/src/tensor/api/numeric.rs @@ -1,486 +1,486 @@ use crate::{ - backend::Backend, check, check::TensorCheck, BasicOps, Bool, Element, ElementConversion, Float, - Int, Shape, Tensor, TensorKind, + backend::Backend, check, check::TensorCheck, BasicOps, Bool, Element, ElementConversion, Float, + Int, Shape, Tensor, TensorKind, }; impl Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Convert the tensor into a scalar. - /// - /// # Panics - /// - /// If the tensor doesn't have one element. - pub fn into_scalar(self) -> K::Elem { - check!(TensorCheck::into_scalar(&self.shape())); - let data = self.into_data(); - data.value[0] - } - - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Convert the tensor into a scalar. - /// - /// # Panics - /// - /// If the tensor doesn't have one element. - pub async fn into_scalar(self) -> K::Elem { - check!(TensorCheck::into_scalar(&self.shape())); - let data = self.into_data().await; - data.value[0] - } - - /// Applies element wise addition operation. - /// - /// `y = x2 + x1` - #[allow(clippy::should_implement_trait)] - pub fn add(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Add", &self, &other)); - Self::new(K::add(self.primitive, other.primitive)) - } - - /// Applies element wise addition operation with a scalar. - /// - /// `y = x + s` - pub fn add_scalar(self, other: E) -> Self { - Self::new(K::add_scalar(self.primitive, other)) - } - - /// Applies element wise subtraction operation. - /// - /// `y = x2 - x1` - #[allow(clippy::should_implement_trait)] - pub fn sub(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Sub", &self, &other)); - Self::new(K::sub(self.primitive, other.primitive)) - } - - /// Applies element wise subtraction operation with a scalar. - /// - /// `y = x - s` - pub fn sub_scalar(self, other: E) -> Self { - Self::new(K::sub_scalar(self.primitive, other)) - } - - /// Applies element wise division operation. - /// - /// `y = x2 / x1` - #[allow(clippy::should_implement_trait)] - pub fn div(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Div", &self, &other)); - Self::new(K::div(self.primitive, other.primitive)) - } - - /// Applies element wise division operation with a scalar. - /// - /// `y = x / s` - pub fn div_scalar(self, other: E) -> Self { - Self::new(K::div_scalar(self.primitive, other)) - } - /// - /// Applies element wise multiplication operation. - /// - /// `y = x2 * x1` - #[allow(clippy::should_implement_trait)] - pub fn mul(self, other: Self) -> Self { - check!(TensorCheck::binary_ops_ew("Mul", &self, &other)); - Self::new(K::mul(self.primitive, other.primitive)) - } - - /// Applies element wise multiplication operation with a scalar. - /// - /// `y = x * s` - pub fn mul_scalar(self, other: E) -> Self { - Self::new(K::mul_scalar(self.primitive, other)) - } - - /// Switch sign of each element in the tensor. - /// - /// `y = -x` - #[allow(clippy::should_implement_trait)] - pub fn neg(self) -> Self { - Self::new(K::neg(self.primitive)) - } - - /// Create a tensor of the given shape where each element is zero. - pub fn zeros>>(shape: S) -> Self { - Self::zeros_device(shape, &B::Device::default()) - } - - /// Create a tensor of the given shape where each element is zero. - pub fn zeros_device>>(shape: S, device: &B::Device) -> Self { - Self::new(K::zeros(shape.into(), device)) - } - - /// Create a tensor of the given shape where each element is one. - pub fn ones>>(shape: S) -> Self { - Self::ones_device(shape, &B::Device::default()) - } - - /// Create a tensor of the given shape where each element is one. - pub fn ones_device>>(shape: S, device: &B::Device) -> Self { - Self::new(K::ones(shape.into(), device)) - } - - /// Create a tensor of the given shape where each element is equal to the provided value. - pub fn full>, E: ElementConversion>(shape: S, fill_value: E) -> Self { - Self::full_device(shape, fill_value, &B::Device::default()) - } - - /// Create a tensor of the given shape where each element is equal to the provided value. - pub fn full_device>, E: ElementConversion>( - shape: S, - fill_value: E, - device: &B::Device, - ) -> Self { - Self::new(K::full(shape.into(), fill_value, device)) - } - - /// Aggregate all elements in the tensor with the mean operation. - pub fn mean(self) -> Tensor { - Tensor::new(K::mean(self.primitive)) - } - - /// Aggregate all elements in the tensor with the sum operation. - pub fn sum(self) -> Tensor { - Tensor::new(K::sum(self.primitive)) - } - - /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation. - pub fn mean_dim(self, dim: usize) -> Self { - check!(TensorCheck::aggregate_dim::("Mean", dim)); - Self::new(K::mean_dim(self.primitive, dim)) - } - - /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation. - pub fn sum_dim(self, dim: usize) -> Self { - check!(TensorCheck::aggregate_dim::("Sum", dim)); - Self::new(K::sum_dim(self.primitive, dim)) - } - - /// Applies element wise equal comparison and returns a boolean tensor. - pub fn equal_elem(self, other: E) -> Tensor { - K::equal_elem::(self.primitive, other.elem()) - } - - /// Applies element wise greater comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn greater(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Greater", &self, &other)); - K::greater(self.primitive, other.primitive) - } - - /// Applies element wise greater-equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn greater_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other)); - K::greater_equal(self.primitive, other.primitive) - } - - /// Applies element wise lower comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn lower(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Lower", &self, &other)); - K::lower(self.primitive, other.primitive) - } - - /// Applies element wise lower-equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - pub fn lower_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other)); - K::lower_equal(self.primitive, other.primitive) - } - - /// Applies element wise greater comparison and returns a boolean tensor. - pub fn greater_elem(self, other: E) -> Tensor { - K::greater_elem(self.primitive, other.elem()) - } - - /// Applies element wise greater-equal comparison and returns a boolean tensor. - pub fn greater_equal_elem(self, other: E) -> Tensor { - K::greater_equal_elem(self.primitive, other.elem()) - } - - /// Applies element wise lower comparison and returns a boolean tensor. - pub fn lower_elem(self, other: E) -> Tensor { - K::lower_elem(self.primitive, other.elem()) - } - - /// Applies element wise lower-equal comparison and returns a boolean tensor. - pub fn lower_equal_elem(self, other: E) -> Tensor { - K::lower_equal_elem(self.primitive, other.elem()) - } - - /// Update the given tensor with the value tensor where the mask is true. - /// - /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of - /// a scalar. - pub fn mask_where(self, mask: Tensor, value: Self) -> Self { - Self::new(K::mask_where(self.primitive, mask, value.primitive)) - } - - /// Update the given tensor with the value where the mask is true. - /// - /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of - /// a tensor. - pub fn mask_fill(self, mask: Tensor, value: E) -> Self { - Self::new(K::mask_fill(self.primitive, mask, value.elem())) - } - - /// Gather tensor elements corresponding to the given indices from the specified dim. - /// - /// Example using a 3D tensor: - /// - /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0` - /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1` - /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2` - /// - /// # Notes - /// - /// The index tensor should have the same shape as the original tensor except for the dim - /// specified. - pub fn gather(self, dim: usize, indices: Tensor) -> Self { - check!(TensorCheck::gather::( - dim, - &self.shape(), - &indices.shape() - )); - - Self::new(K::gather(dim, self.primitive, indices)) - } - - /// Assign the gathered elements corresponding to the given indices along the specified dimension - /// from the value tensor to the original tensor using sum reduction. - /// - /// Example using a 3D tensor: - /// - /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0` - /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1` - /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2` - /// - /// # Notes - /// - /// The index tensor should have the same shape as the original tensor except for the specified - /// dimension. The value and index tensors should have the same shape. - /// - /// Other references to the input tensor will not be modified by this operation. - pub fn scatter(self, dim: usize, indices: Tensor, values: Self) -> Self { - check!(TensorCheck::scatter::( - dim, - &self.shape(), - &indices.shape(), - &values.shape() - )); - - Self::new(K::scatter(dim, self.primitive, indices, values.primitive)) - } - - /// Select the tensor elements along the given dimension corresponding to the given indices. - /// - /// Example using a 3D tensor: - /// - /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0` - /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1` - /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2` - pub fn select(self, dim: usize, indices: Tensor) -> Self { - check!(TensorCheck::select::(dim)); - Self::new(K::select(self.primitive, dim, indices)) - } - - /// Assign the selected elements along the given dimension corresponding to the given indices - /// from the value tensor to the original tensor using sum reduction. - /// - /// Example using a 3D tensor: - /// - /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0` - /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1` - /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2` - pub fn select_assign( - self, - dim: usize, - indices: Tensor, - values: Tensor, - ) -> Self { - check!(TensorCheck::select_assign::(dim)); - - Self::new(K::select_assign( - self.primitive, - dim, - indices, - values.primitive, - )) - } - - /// Applies the argmax function along the given dimension and returns an integer tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); - /// let tensor = tensor.argmax(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } - /// } - /// ``` - pub fn argmax(self, dim: usize) -> Tensor { - Tensor::new(K::argmax(self.primitive, dim)) - } - - /// Find the maximum value. - pub fn max(self) -> Tensor { - Tensor::new(K::max(self.primitive)) - } - - /// Find the maximum value along the given dimension. - pub fn max_dim(self, dim: usize) -> Tensor { - check!(TensorCheck::aggregate_dim::("Max", dim)); - - Tensor::new(K::max_dim(self.primitive, dim)) - } - - /// Find the maximum value along the given dimension. - /// - /// Also returns the indices. - pub fn max_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { - check!(TensorCheck::aggregate_dim::("Max", dim)); - - let (tensor, index) = K::max_dim_with_indices(self.primitive, dim); - - let tensor = Tensor::new(tensor); - let index = Tensor::new(index); - - (tensor, index) - } - - /// Applies the argmin function along the given dimension and returns an integer tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); - /// let tensor = tensor.argmin(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } - /// } - /// ``` - pub fn argmin(self, dim: usize) -> Tensor { - Tensor::new(K::argmin(self.primitive, dim)) - } - - /// Find the minimum value. - pub fn min(self) -> Tensor { - Tensor::new(K::min(self.primitive)) - } - - /// Find the minimum value along the given dimension. - pub fn min_dim(self, dim: usize) -> Tensor { - check!(TensorCheck::aggregate_dim::("Min", dim)); - Tensor::new(K::min_dim(self.primitive, dim)) - } - - /// Find the minimum value along the given dimension. - /// - /// Also returns the indices. - pub fn min_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { - check!(TensorCheck::aggregate_dim::("Min", dim)); - - let (tensor, index) = K::min_dim_with_indices(self.primitive, dim); - - let tensor = Tensor::new(tensor); - let index = Tensor::new(index); - - (tensor, index) - } - - /// Clamp the tensor between the given min and max values. - /// - /// # Arguments - /// - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped between the given min and max values. - pub fn clamp(self, min: E, max: E) -> Self { - Self::new(K::clamp(self.primitive, min.elem(), max.elem())) - } - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped under the given min value. - pub fn clamp_min(self, min: E) -> Self { - Self::new(K::clamp_min(self.primitive, min.elem())) - } - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped over the given max value. - /// - pub fn clamp_max(self, max: E) -> Self { - Self::new(K::clamp_max(self.primitive, max.elem())) - } - - /// Apply element wise absolute value operation - pub fn abs(self) -> Self { - Self::new(K::abs(self.primitive)) - } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// Convert the tensor into a scalar. + /// + /// # Panics + /// + /// If the tensor doesn't have one element. + pub fn into_scalar(self) -> K::Elem { + check!(TensorCheck::into_scalar(&self.shape())); + let data = self.into_data(); + data.value[0] + } + + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] + /// Convert the tensor into a scalar. + /// + /// # Panics + /// + /// If the tensor doesn't have one element. + pub async fn into_scalar(self) -> K::Elem { + check!(TensorCheck::into_scalar(&self.shape())); + let data = self.into_data().await; + data.value[0] + } + + /// Applies element wise addition operation. + /// + /// `y = x2 + x1` + #[allow(clippy::should_implement_trait)] + pub fn add(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Add", &self, &other)); + Self::new(K::add(self.primitive, other.primitive)) + } + + /// Applies element wise addition operation with a scalar. + /// + /// `y = x + s` + pub fn add_scalar(self, other: E) -> Self { + Self::new(K::add_scalar(self.primitive, other)) + } + + /// Applies element wise subtraction operation. + /// + /// `y = x2 - x1` + #[allow(clippy::should_implement_trait)] + pub fn sub(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Sub", &self, &other)); + Self::new(K::sub(self.primitive, other.primitive)) + } + + /// Applies element wise subtraction operation with a scalar. + /// + /// `y = x - s` + pub fn sub_scalar(self, other: E) -> Self { + Self::new(K::sub_scalar(self.primitive, other)) + } + + /// Applies element wise division operation. + /// + /// `y = x2 / x1` + #[allow(clippy::should_implement_trait)] + pub fn div(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Div", &self, &other)); + Self::new(K::div(self.primitive, other.primitive)) + } + + /// Applies element wise division operation with a scalar. + /// + /// `y = x / s` + pub fn div_scalar(self, other: E) -> Self { + Self::new(K::div_scalar(self.primitive, other)) + } + /// + /// Applies element wise multiplication operation. + /// + /// `y = x2 * x1` + #[allow(clippy::should_implement_trait)] + pub fn mul(self, other: Self) -> Self { + check!(TensorCheck::binary_ops_ew("Mul", &self, &other)); + Self::new(K::mul(self.primitive, other.primitive)) + } + + /// Applies element wise multiplication operation with a scalar. + /// + /// `y = x * s` + pub fn mul_scalar(self, other: E) -> Self { + Self::new(K::mul_scalar(self.primitive, other)) + } + + /// Switch sign of each element in the tensor. + /// + /// `y = -x` + #[allow(clippy::should_implement_trait)] + pub fn neg(self) -> Self { + Self::new(K::neg(self.primitive)) + } + + /// Create a tensor of the given shape where each element is zero. + pub fn zeros>>(shape: S) -> Self { + Self::zeros_device(shape, &B::Device::default()) + } + + /// Create a tensor of the given shape where each element is zero. + pub fn zeros_device>>(shape: S, device: &B::Device) -> Self { + Self::new(K::zeros(shape.into(), device)) + } + + /// Create a tensor of the given shape where each element is one. + pub fn ones>>(shape: S) -> Self { + Self::ones_device(shape, &B::Device::default()) + } + + /// Create a tensor of the given shape where each element is one. + pub fn ones_device>>(shape: S, device: &B::Device) -> Self { + Self::new(K::ones(shape.into(), device)) + } + + /// Create a tensor of the given shape where each element is equal to the provided value. + pub fn full>, E: ElementConversion>(shape: S, fill_value: E) -> Self { + Self::full_device(shape, fill_value, &B::Device::default()) + } + + /// Create a tensor of the given shape where each element is equal to the provided value. + pub fn full_device>, E: ElementConversion>( + shape: S, + fill_value: E, + device: &B::Device, + ) -> Self { + Self::new(K::full(shape.into(), fill_value, device)) + } + + /// Aggregate all elements in the tensor with the mean operation. + pub fn mean(self) -> Tensor { + Tensor::new(K::mean(self.primitive)) + } + + /// Aggregate all elements in the tensor with the sum operation. + pub fn sum(self) -> Tensor { + Tensor::new(K::sum(self.primitive)) + } + + /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation. + pub fn mean_dim(self, dim: usize) -> Self { + check!(TensorCheck::aggregate_dim::("Mean", dim)); + Self::new(K::mean_dim(self.primitive, dim)) + } + + /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation. + pub fn sum_dim(self, dim: usize) -> Self { + check!(TensorCheck::aggregate_dim::("Sum", dim)); + Self::new(K::sum_dim(self.primitive, dim)) + } + + /// Applies element wise equal comparison and returns a boolean tensor. + pub fn equal_elem(self, other: E) -> Tensor { + K::equal_elem::(self.primitive, other.elem()) + } + + /// Applies element wise greater comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn greater(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Greater", &self, &other)); + K::greater(self.primitive, other.primitive) + } + + /// Applies element wise greater-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn greater_equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other)); + K::greater_equal(self.primitive, other.primitive) + } + + /// Applies element wise lower comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn lower(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Lower", &self, &other)); + K::lower(self.primitive, other.primitive) + } + + /// Applies element wise lower-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + pub fn lower_equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other)); + K::lower_equal(self.primitive, other.primitive) + } + + /// Applies element wise greater comparison and returns a boolean tensor. + pub fn greater_elem(self, other: E) -> Tensor { + K::greater_elem(self.primitive, other.elem()) + } + + /// Applies element wise greater-equal comparison and returns a boolean tensor. + pub fn greater_equal_elem(self, other: E) -> Tensor { + K::greater_equal_elem(self.primitive, other.elem()) + } + + /// Applies element wise lower comparison and returns a boolean tensor. + pub fn lower_elem(self, other: E) -> Tensor { + K::lower_elem(self.primitive, other.elem()) + } + + /// Applies element wise lower-equal comparison and returns a boolean tensor. + pub fn lower_equal_elem(self, other: E) -> Tensor { + K::lower_equal_elem(self.primitive, other.elem()) + } + + /// Update the given tensor with the value tensor where the mask is true. + /// + /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of + /// a scalar. + pub fn mask_where(self, mask: Tensor, value: Self) -> Self { + Self::new(K::mask_where(self.primitive, mask, value.primitive)) + } + + /// Update the given tensor with the value where the mask is true. + /// + /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of + /// a tensor. + pub fn mask_fill(self, mask: Tensor, value: E) -> Self { + Self::new(K::mask_fill(self.primitive, mask, value.elem())) + } + + /// Gather tensor elements corresponding to the given indices from the specified dim. + /// + /// Example using a 3D tensor: + /// + /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0` + /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1` + /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2` + /// + /// # Notes + /// + /// The index tensor should have the same shape as the original tensor except for the dim + /// specified. + pub fn gather(self, dim: usize, indices: Tensor) -> Self { + check!(TensorCheck::gather::( + dim, + &self.shape(), + &indices.shape() + )); + + Self::new(K::gather(dim, self.primitive, indices)) + } + + /// Assign the gathered elements corresponding to the given indices along the specified dimension + /// from the value tensor to the original tensor using sum reduction. + /// + /// Example using a 3D tensor: + /// + /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0` + /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1` + /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2` + /// + /// # Notes + /// + /// The index tensor should have the same shape as the original tensor except for the specified + /// dimension. The value and index tensors should have the same shape. + /// + /// Other references to the input tensor will not be modified by this operation. + pub fn scatter(self, dim: usize, indices: Tensor, values: Self) -> Self { + check!(TensorCheck::scatter::( + dim, + &self.shape(), + &indices.shape(), + &values.shape() + )); + + Self::new(K::scatter(dim, self.primitive, indices, values.primitive)) + } + + /// Select the tensor elements along the given dimension corresponding to the given indices. + /// + /// Example using a 3D tensor: + /// + /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0` + /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1` + /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2` + pub fn select(self, dim: usize, indices: Tensor) -> Self { + check!(TensorCheck::select::(dim)); + Self::new(K::select(self.primitive, dim, indices)) + } + + /// Assign the selected elements along the given dimension corresponding to the given indices + /// from the value tensor to the original tensor using sum reduction. + /// + /// Example using a 3D tensor: + /// + /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0` + /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1` + /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2` + pub fn select_assign( + self, + dim: usize, + indices: Tensor, + values: Tensor, + ) -> Self { + check!(TensorCheck::select_assign::(dim)); + + Self::new(K::select_assign( + self.primitive, + dim, + indices, + values.primitive, + )) + } + + /// Applies the argmax function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmax(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } + /// ``` + pub fn argmax(self, dim: usize) -> Tensor { + Tensor::new(K::argmax(self.primitive, dim)) + } + + /// Find the maximum value. + pub fn max(self) -> Tensor { + Tensor::new(K::max(self.primitive)) + } + + /// Find the maximum value along the given dimension. + pub fn max_dim(self, dim: usize) -> Tensor { + check!(TensorCheck::aggregate_dim::("Max", dim)); + + Tensor::new(K::max_dim(self.primitive, dim)) + } + + /// Find the maximum value along the given dimension. + /// + /// Also returns the indices. + pub fn max_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { + check!(TensorCheck::aggregate_dim::("Max", dim)); + + let (tensor, index) = K::max_dim_with_indices(self.primitive, dim); + + let tensor = Tensor::new(tensor); + let index = Tensor::new(index); + + (tensor, index) + } + + /// Applies the argmin function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmin(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } + /// ``` + pub fn argmin(self, dim: usize) -> Tensor { + Tensor::new(K::argmin(self.primitive, dim)) + } + + /// Find the minimum value. + pub fn min(self) -> Tensor { + Tensor::new(K::min(self.primitive)) + } + + /// Find the minimum value along the given dimension. + pub fn min_dim(self, dim: usize) -> Tensor { + check!(TensorCheck::aggregate_dim::("Min", dim)); + Tensor::new(K::min_dim(self.primitive, dim)) + } + + /// Find the minimum value along the given dimension. + /// + /// Also returns the indices. + pub fn min_dim_with_indices(self, dim: usize) -> (Tensor, Tensor) { + check!(TensorCheck::aggregate_dim::("Min", dim)); + + let (tensor, index) = K::min_dim_with_indices(self.primitive, dim); + + let tensor = Tensor::new(tensor); + let index = Tensor::new(index); + + (tensor, index) + } + + /// Clamp the tensor between the given min and max values. + /// + /// # Arguments + /// + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped between the given min and max values. + pub fn clamp(self, min: E, max: E) -> Self { + Self::new(K::clamp(self.primitive, min.elem(), max.elem())) + } + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped under the given min value. + pub fn clamp_min(self, min: E) -> Self { + Self::new(K::clamp_min(self.primitive, min.elem())) + } + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped over the given max value. + /// + pub fn clamp_max(self, max: E) -> Self { + Self::new(K::clamp_max(self.primitive, max.elem())) + } + + /// Apply element wise absolute value operation + pub fn abs(self) -> Self { + Self::new(K::abs(self.primitive)) + } } impl Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - /// Create diagonal matrix. - /// - /// # Arguments - /// - /// * `size` - The size of the square matrix. - pub fn diagonal(size: usize) -> Self { - let indices = Tensor::::arange(0..size).unsqueeze(); - let ones = K::ones([1, size].into(), &B::Device::default()); - let zeros = K::zeros([size, size].into(), &B::Device::default()); - Self::new(K::scatter(0, zeros, indices, ones)) - } + /// Create diagonal matrix. + /// + /// # Arguments + /// + /// * `size` - The size of the square matrix. + pub fn diagonal(size: usize) -> Self { + let indices = Tensor::::arange(0..size).unsqueeze(); + let ones = K::ones([1, size].into(), &B::Device::default()); + let zeros = K::zeros([size, size].into(), &B::Device::default()); + Self::new(K::scatter(0, zeros, indices, ones)) + } } /// Trait that list all operations that can be applied on all numerical tensors. @@ -490,1623 +490,1647 @@ where /// This is an internal trait, use the public API provided by [tensor struct](Tensor). pub trait Numeric: BasicOps where - Self::Elem: Element, + Self::Elem: Element, { - /// Adds two tensors together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The sum of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function, - /// which is more high-level and designed for public use. - fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Adds a scalar to a tensor element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The sum of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function, - /// which is more high-level and designed for public use. - fn add_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Subtracts two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The difference of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function, - /// which is more high-level and designed for public use. - fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Subtracts a scalar from a tensor element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The difference of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function, - /// which is more high-level and designed for public use. - fn sub_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Divides two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The quotient of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function, - /// which is more high-level and designed for public use. - fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Divides a tensor by a scalar element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The quotient of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function, - /// which is more high-level and designed for public use. - fn div_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Multiplies two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The product of the two tensors. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function, - /// which is more high-level and designed for public use. - fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; - - /// Multiplies a tensor by a scalar element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The product of the tensor and the scalar. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function, - /// which is more high-level and designed for public use. - fn mul_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive; - - /// Negates a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to negate. - /// - /// # Returns - /// - /// The negated tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function, - /// which is more high-level and designed for public use. - fn neg(tensor: Self::Primitive) -> Self::Primitive; - - /// Creates a tensor filled with zeros. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor filled with zeros. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function, - /// which is more high-level and designed for public use. - fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive; - - /// Creates a tensor filled with ones. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor filled with ones. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function, - /// which is more high-level and designed for public use. - fn ones(shape: Shape, device: &B::Device) -> Self::Primitive; - - /// Creates a tensor filled with elements equal to the given value. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `fill_value` - The value with which to fill the tensor - /// * `device` - The device on which the tensor will be allocated. - /// - /// # Returns - /// - /// The tensor filled with elements equal to the given value - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function, - /// which is more high-level and designed for public use. - fn full( - shape: Shape, - fill_value: E, - device: &B::Device, - ) -> Self::Primitive; - - /// Sums all the elements of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// - /// # Returns - /// - /// The sum of all the elements of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function, - /// which is more high-level and designed for public use. - fn sum(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Sums all the elements of the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// * `dim` - The dimension along which to sum. - /// - /// # Returns - /// - /// The sum of all the elements of the tensor along the specified dimension. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function, - /// which is more high-level and designed for public use. - fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Computes the mean of all the elements of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// - /// # Returns - /// - /// The mean of all the elements of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function, - /// which is more high-level and designed for public use. - fn mean(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Computes the mean of all the elements of the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// * `dim` - The dimension along which to compute the mean. - /// - /// # Returns - /// - /// The mean of all the elements of the tensor along the specified dimension. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For computing the mean of all the elements of a tensor along a dimension, users should prefer - /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use. - fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Element-wise equality between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding elements of the input tensors are equal, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem) function, - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; - - /// Element-wise greater than comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is greater than the corresponding element - /// of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function, - /// which is more high-level and designed for public use. - fn greater( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Element-wise greater than comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is greater than the right hand side - /// scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than comparison between a tensor and a scalar, users should prefer - /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use. - fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; - - /// Element-wise greater than or equal comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is greater than or equal to the - /// corresponding element of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than or equal comparison between two tensors, users should prefer - /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use. - fn greater_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Element-wise greater than or equal comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is greater than or equal to the right - /// hand side scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer - /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use. - fn greater_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor; - - /// Element-wise less than comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is less than the corresponding element of - /// the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function, - /// which is more high-level and designed for public use. - fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor; - - /// Element-wise less than comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is less than the right hand side scalar, - /// and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than comparison between a tensor and a scalar, users should prefer - /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use. - fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; - - /// Element-wise less than or equal comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is less than or equal to the corresponding - /// element of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than or equal comparison between two tensors, users should prefer - /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use. - fn lower_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; - - /// Element-wise less than or equal comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is less than or equal to the right hand - /// side scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer - /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use. - fn lower_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor; - - /// Selects elements from a tensor based on a boolean mask. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true. - /// * `mask` - The boolean mask to use for selecting elements. - /// * `source` - The tensor to select elements from when the corresponding element of the mask is false. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensors, where each element is taken from the - /// corresponding element of the left hand side tensor if the corresponding element of the mask - /// is true, and from the corresponding element of the right hand side tensor otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For selecting elements from a tensor based on a boolean mask, users should prefer the - /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use. - fn mask_where( - tensor: Self::Primitive, - mask: Tensor, - source: Self::Primitive, - ) -> Self::Primitive; - - /// Fills elements of a tensor based on a boolean mask. - /// - /// # Arguments - /// - /// * `tensor` - The tensor where will be overwritten with the value - /// when the corresponding element of the mask is true. - /// * `mask` - The boolean mask to use for filling elements. - /// * `value` - The value to fill elements with when the corresponding element of the mask is true. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensors, where each element is taken from the - /// corresponding element unmodified if the corresponding element of the mask is false, and - /// filled with the value otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For filling elements of a tensor based on a boolean mask, users should prefer the - /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use. - fn mask_fill( - tensor: Self::Primitive, - mask: Tensor, - value: Self::Elem, - ) -> Self::Primitive; - - /// Gathers elements from a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to gather elements. - /// * `tensor` - The tensor to gather elements from. - /// * `indices` - The indices of the elements to gather. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For gathering elements from a tensor along an axis, users should prefer the - /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use. - fn gather( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - ) -> Self::Primitive; - - /// Scatters elements into a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to scatter elements. - /// * `tensor` - The tensor to scatter elements into. - /// * `indices` - The indices of the elements to scatter. - /// * `values` - The values to scatter into the tensor. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis, - /// except for the elements at the specified indices, which are taken from the corresponding - /// element of the values tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function, - /// which is more high-level and designed for public use. - fn scatter( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive; - - /// Select tensor elements along the given dimension corresponding for the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select elements from. - /// * `dim` - The axis along which to select elements. - /// * `indices` - The indices of the elements to select. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For selecting elements from a tensor along an axis, users should prefer the - /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use. - fn select( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - ) -> Self::Primitive; - - /// Assign the selected elements along the given dimension corresponding to the given indices - /// from the value tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to assign elements to. - /// * `dim` - The axis along which to assign elements. - /// * `indices` - The indices of the elements to assign. - /// * `values` - The values to assign to the tensor. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is taken from the - /// corresponding element of the input tensor at the corresponding index along the specified axis, - /// except for the elements at the specified indices, which are taken from the corresponding - /// element of the values tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For assigning elements to a tensor along an axis, users should prefer the - /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use. - fn select_assign( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive; - - /// Gets the indices of the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the indices of the maximum elements. - /// * `tensor` - The tensor to get the indices of the maximum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the index of the - /// maximum element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use. - fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; - - /// Gets the indices of the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the indices of the minimum elements. - /// * `tensor` - The tensor to get the indices of the minimum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the index of the - /// minimum element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use. - fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A single-element tensor containing the maximum element of the input tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use. - fn max(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements from. - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the maximum element - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use. - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements from. - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape - /// as the input tensor, where each element is the index of the maximum element of the input tensor - /// at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use. - fn max_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, B::IntTensorPrimitive); - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// - /// # Returns - /// - /// A single-element tensor containing the minimum element of the input tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use. - fn min(tensor: Self::Primitive) -> Self::Primitive<1>; - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// * `dim` - The axis along which to get the minimum elements. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the minimum element - /// of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use. - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Gets the minimum elements and indices of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor and corresponding indices, where - /// each element is the minimum element of the input tensor at the corresponding index - /// along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use. - fn min_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, B::IntTensorPrimitive); - - /// Clamp the tensor between the given min and max values. - /// - /// # Arguments - /// - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped between the given min and max values. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor between the given min and max values, users should prefer the - /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use. - fn clamp( - tensor: Self::Primitive, - min: Self::Elem, - max: Self::Elem, - ) -> Self::Primitive; - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped under the given min value. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor under a minimum value, users should prefer the - /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use. - fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive; - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped over the given max value. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor over a maximum value, users should prefer the - /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use. - fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive; - - /// Calculate absolute value on all elements of a tensor - /// - /// # Arguments - /// - /// * `tensor` - The tensor to apply abs to. - /// - /// # Returns - /// - /// A tensor with absolute values. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function, - /// which is more high-level and designed for public use. - fn abs(tensor: Self::Primitive) -> Self::Primitive; + /// Adds two tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The sum of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function, + /// which is more high-level and designed for public use. + fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Adds a scalar to a tensor element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The sum of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function, + /// which is more high-level and designed for public use. + fn add_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The difference of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function, + /// which is more high-level and designed for public use. + fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Subtracts a scalar from a tensor element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The difference of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function, + /// which is more high-level and designed for public use. + fn sub_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Divides two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The quotient of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function, + /// which is more high-level and designed for public use. + fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Divides a tensor by a scalar element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The quotient of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function, + /// which is more high-level and designed for public use. + fn div_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Multiplies two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The product of the two tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function, + /// which is more high-level and designed for public use. + fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Multiplies a tensor by a scalar element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The product of the tensor and the scalar. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function, + /// which is more high-level and designed for public use. + fn mul_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive; + + /// Negates a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to negate. + /// + /// # Returns + /// + /// The negated tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function, + /// which is more high-level and designed for public use. + fn neg(tensor: Self::Primitive) -> Self::Primitive; + + /// Creates a tensor filled with zeros. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor filled with zeros. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function, + /// which is more high-level and designed for public use. + fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive; + + /// Creates a tensor filled with ones. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor filled with ones. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function, + /// which is more high-level and designed for public use. + fn ones(shape: Shape, device: &B::Device) -> Self::Primitive; + + /// Creates a tensor filled with elements equal to the given value. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `fill_value` - The value with which to fill the tensor + /// * `device` - The device on which the tensor will be allocated. + /// + /// # Returns + /// + /// The tensor filled with elements equal to the given value + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function, + /// which is more high-level and designed for public use. + fn full( + shape: Shape, + fill_value: E, + device: &B::Device, + ) -> Self::Primitive; + + /// Sums all the elements of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// + /// # Returns + /// + /// The sum of all the elements of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function, + /// which is more high-level and designed for public use. + fn sum(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Sums all the elements of the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// The sum of all the elements of the tensor along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function, + /// which is more high-level and designed for public use. + fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Computes the mean of all the elements of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// + /// # Returns + /// + /// The mean of all the elements of the tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function, + /// which is more high-level and designed for public use. + fn mean(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Computes the mean of all the elements of the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// * `dim` - The dimension along which to compute the mean. + /// + /// # Returns + /// + /// The mean of all the elements of the tensor along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For computing the mean of all the elements of a tensor along a dimension, users should prefer + /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use. + fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Element-wise equality between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding elements of the input tensors are equal, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem) function, + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; + + /// Element-wise greater than comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is greater than the corresponding element + /// of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function, + /// which is more high-level and designed for public use. + fn greater( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Element-wise greater than comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is greater than the right hand side + /// scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than comparison between a tensor and a scalar, users should prefer + /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use. + fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) + -> Tensor; + + /// Element-wise greater than or equal comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is greater than or equal to the + /// corresponding element of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than or equal comparison between two tensors, users should prefer + /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use. + fn greater_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Element-wise greater than or equal comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is greater than or equal to the right + /// hand side scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer + /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use. + fn greater_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor; + + /// Element-wise less than comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is less than the corresponding element of + /// the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function, + /// which is more high-level and designed for public use. + fn lower( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Element-wise less than comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is less than the right hand side scalar, + /// and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than comparison between a tensor and a scalar, users should prefer + /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use. + fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; + + /// Element-wise less than or equal comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is less than or equal to the corresponding + /// element of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than or equal comparison between two tensors, users should prefer + /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use. + fn lower_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor; + + /// Element-wise less than or equal comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is less than or equal to the right hand + /// side scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer + /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use. + fn lower_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor; + + /// Selects elements from a tensor based on a boolean mask. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true. + /// * `mask` - The boolean mask to use for selecting elements. + /// * `source` - The tensor to select elements from when the corresponding element of the mask is false. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensors, where each element is taken from the + /// corresponding element of the left hand side tensor if the corresponding element of the mask + /// is true, and from the corresponding element of the right hand side tensor otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For selecting elements from a tensor based on a boolean mask, users should prefer the + /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use. + fn mask_where( + tensor: Self::Primitive, + mask: Tensor, + source: Self::Primitive, + ) -> Self::Primitive; + + /// Fills elements of a tensor based on a boolean mask. + /// + /// # Arguments + /// + /// * `tensor` - The tensor where will be overwritten with the value + /// when the corresponding element of the mask is true. + /// * `mask` - The boolean mask to use for filling elements. + /// * `value` - The value to fill elements with when the corresponding element of the mask is true. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensors, where each element is taken from the + /// corresponding element unmodified if the corresponding element of the mask is false, and + /// filled with the value otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For filling elements of a tensor based on a boolean mask, users should prefer the + /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use. + fn mask_fill( + tensor: Self::Primitive, + mask: Tensor, + value: Self::Elem, + ) -> Self::Primitive; + + /// Gathers elements from a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to gather elements. + /// * `tensor` - The tensor to gather elements from. + /// * `indices` - The indices of the elements to gather. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For gathering elements from a tensor along an axis, users should prefer the + /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use. + fn gather( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + ) -> Self::Primitive; + + /// Scatters elements into a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to scatter elements. + /// * `tensor` - The tensor to scatter elements into. + /// * `indices` - The indices of the elements to scatter. + /// * `values` - The values to scatter into the tensor. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis, + /// except for the elements at the specified indices, which are taken from the corresponding + /// element of the values tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function, + /// which is more high-level and designed for public use. + fn scatter( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive; + + /// Select tensor elements along the given dimension corresponding for the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select elements from. + /// * `dim` - The axis along which to select elements. + /// * `indices` - The indices of the elements to select. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For selecting elements from a tensor along an axis, users should prefer the + /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use. + fn select( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + ) -> Self::Primitive; + + /// Assign the selected elements along the given dimension corresponding to the given indices + /// from the value tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to assign elements to. + /// * `dim` - The axis along which to assign elements. + /// * `indices` - The indices of the elements to assign. + /// * `values` - The values to assign to the tensor. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is taken from the + /// corresponding element of the input tensor at the corresponding index along the specified axis, + /// except for the elements at the specified indices, which are taken from the corresponding + /// element of the values tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For assigning elements to a tensor along an axis, users should prefer the + /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use. + fn select_assign( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive; + + /// Gets the indices of the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the indices of the maximum elements. + /// * `tensor` - The tensor to get the indices of the maximum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the index of the + /// maximum element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use. + fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; + + /// Gets the indices of the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the indices of the minimum elements. + /// * `tensor` - The tensor to get the indices of the minimum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the index of the + /// minimum element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use. + fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A single-element tensor containing the maximum element of the input tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use. + fn max(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements from. + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the maximum element + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use. + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements from. + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape + /// as the input tensor, where each element is the index of the maximum element of the input tensor + /// at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use. + fn max_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, B::IntTensorPrimitive); + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// + /// # Returns + /// + /// A single-element tensor containing the minimum element of the input tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use. + fn min(tensor: Self::Primitive) -> Self::Primitive<1>; + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// * `dim` - The axis along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the minimum element + /// of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use. + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Gets the minimum elements and indices of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor and corresponding indices, where + /// each element is the minimum element of the input tensor at the corresponding index + /// along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use. + fn min_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, B::IntTensorPrimitive); + + /// Clamp the tensor between the given min and max values. + /// + /// # Arguments + /// + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped between the given min and max values. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor between the given min and max values, users should prefer the + /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use. + fn clamp( + tensor: Self::Primitive, + min: Self::Elem, + max: Self::Elem, + ) -> Self::Primitive; + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped under the given min value. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor under a minimum value, users should prefer the + /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use. + fn clamp_min(tensor: Self::Primitive, min: Self::Elem) + -> Self::Primitive; + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped over the given max value. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor over a maximum value, users should prefer the + /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use. + fn clamp_max(tensor: Self::Primitive, max: Self::Elem) + -> Self::Primitive; + + /// Calculate absolute value on all elements of a tensor + /// + /// # Arguments + /// + /// * `tensor` - The tensor to apply abs to. + /// + /// # Returns + /// + /// A tensor with absolute values. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function, + /// which is more high-level and designed for public use. + fn abs(tensor: Self::Primitive) -> Self::Primitive; } impl Numeric for Int { - fn add( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_add(lhs, rhs) - } - fn add_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_add_scalar(lhs, rhs.elem()) - } - fn sub( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_sub(lhs, rhs) - } - fn sub_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_sub_scalar(lhs, rhs.elem()) - } - fn div( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_div(lhs, rhs) - } - fn div_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_div_scalar(lhs, rhs.elem()) - } - fn mul( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::int_mul(lhs, rhs) - } - fn mul_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::int_mul_scalar(lhs, rhs.elem()) - } - fn neg(tensor: Self::Primitive) -> Self::Primitive { - B::int_neg(tensor) - } - fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { - B::int_zeros(shape, device) - } - fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { - B::int_ones(shape, device) - } - fn full( - shape: Shape, - fill_value: E, - device: &B::Device, - ) -> Self::Primitive { - B::int_full(shape, fill_value.elem(), device) - } - fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_sum(tensor) - } - fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_sum_dim(tensor, dim) - } - fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_mean(tensor) - } - fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_mean_dim(tensor, dim) - } - - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::int_equal_elem(lhs, rhs)) - } - fn greater( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_greater(lhs, rhs)) - } - - fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::int_greater_elem(lhs, rhs)) - } - - fn greater_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_greater_equal(lhs, rhs)) - } - - fn greater_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::int_greater_equal_elem(lhs, rhs)) - } - - fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { - Tensor::new(B::int_lower(lhs, rhs)) - } - - fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::int_lower_elem(lhs, rhs)) - } - - fn lower_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::int_lower_equal(lhs, rhs)) - } - - fn lower_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::int_lower_equal_elem(lhs, rhs)) - } - - fn mask_where( - tensor: Self::Primitive, - mask: Tensor, - source: Self::Primitive, - ) -> Self::Primitive { - B::int_mask_where(tensor, mask.primitive, source) - } - - fn mask_fill( - tensor: Self::Primitive, - mask: Tensor, - value: Self::Elem, - ) -> Self::Primitive { - B::int_mask_fill(tensor, mask.primitive, value) - } - - fn select( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - ) -> Self::Primitive { - B::int_select(tensor, dim, indices.primitive) - } - - fn select_assign( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::int_select_assign(tensor, dim, indices.primitive, values) - } - fn gather( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - ) -> Self::Primitive { - B::int_gather(dim, tensor, indices.primitive) - } - - fn scatter( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::int_scatter(dim, tensor, indices.primitive, values) - } - - fn argmax( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::int_argmax(tensor, dim) - } - - fn argmin( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::int_argmin(tensor, dim) - } - - fn max(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_max(tensor) - } - - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_max_dim(tensor, dim) - } - - fn max_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::int_max_dim_with_indices(tensor, dim) - } - - fn min(tensor: Self::Primitive) -> Self::Primitive<1> { - B::int_min(tensor) - } - - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::int_min_dim(tensor, dim) - } - - fn min_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::int_min_dim_with_indices(tensor, dim) - } - - fn clamp( - tensor: Self::Primitive, - min: B::IntElem, - max: B::IntElem, - ) -> Self::Primitive { - B::int_clamp(tensor, min, max) - } - - fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive { - B::int_clamp_min(tensor, min) - } - - fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive { - B::int_clamp_max(tensor, max) - } - - fn abs(tensor: Self::Primitive) -> Self::Primitive { - B::int_abs(tensor) - } + fn add( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_add(lhs, rhs) + } + fn add_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_add_scalar(lhs, rhs.elem()) + } + fn sub( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_sub(lhs, rhs) + } + fn sub_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_sub_scalar(lhs, rhs.elem()) + } + fn div( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_div(lhs, rhs) + } + fn div_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_div_scalar(lhs, rhs.elem()) + } + fn mul( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::int_mul(lhs, rhs) + } + fn mul_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::int_mul_scalar(lhs, rhs.elem()) + } + fn neg(tensor: Self::Primitive) -> Self::Primitive { + B::int_neg(tensor) + } + fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { + B::int_zeros(shape, device) + } + fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { + B::int_ones(shape, device) + } + fn full( + shape: Shape, + fill_value: E, + device: &B::Device, + ) -> Self::Primitive { + B::int_full(shape, fill_value.elem(), device) + } + fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_sum(tensor) + } + fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_sum_dim(tensor, dim) + } + fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_mean(tensor) + } + fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_mean_dim(tensor, dim) + } + + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::int_equal_elem(lhs, rhs)) + } + fn greater( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_greater(lhs, rhs)) + } + + fn greater_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::int_greater_elem(lhs, rhs)) + } + + fn greater_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_greater_equal(lhs, rhs)) + } + + fn greater_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::int_greater_equal_elem(lhs, rhs)) + } + + fn lower( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_lower(lhs, rhs)) + } + + fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::int_lower_elem(lhs, rhs)) + } + + fn lower_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::int_lower_equal(lhs, rhs)) + } + + fn lower_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::int_lower_equal_elem(lhs, rhs)) + } + + fn mask_where( + tensor: Self::Primitive, + mask: Tensor, + source: Self::Primitive, + ) -> Self::Primitive { + B::int_mask_where(tensor, mask.primitive, source) + } + + fn mask_fill( + tensor: Self::Primitive, + mask: Tensor, + value: Self::Elem, + ) -> Self::Primitive { + B::int_mask_fill(tensor, mask.primitive, value) + } + + fn select( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + ) -> Self::Primitive { + B::int_select(tensor, dim, indices.primitive) + } + + fn select_assign( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::int_select_assign(tensor, dim, indices.primitive, values) + } + fn gather( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + ) -> Self::Primitive { + B::int_gather(dim, tensor, indices.primitive) + } + + fn scatter( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::int_scatter(dim, tensor, indices.primitive, values) + } + + fn argmax( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::int_argmax(tensor, dim) + } + + fn argmin( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::int_argmin(tensor, dim) + } + + fn max(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_max(tensor) + } + + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_max_dim(tensor, dim) + } + + fn max_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::int_max_dim_with_indices(tensor, dim) + } + + fn min(tensor: Self::Primitive) -> Self::Primitive<1> { + B::int_min(tensor) + } + + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_min_dim(tensor, dim) + } + + fn min_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::int_min_dim_with_indices(tensor, dim) + } + + fn clamp( + tensor: Self::Primitive, + min: B::IntElem, + max: B::IntElem, + ) -> Self::Primitive { + B::int_clamp(tensor, min, max) + } + + fn clamp_min( + tensor: Self::Primitive, + min: B::IntElem, + ) -> Self::Primitive { + B::int_clamp_min(tensor, min) + } + + fn clamp_max( + tensor: Self::Primitive, + max: B::IntElem, + ) -> Self::Primitive { + B::int_clamp_max(tensor, max) + } + + fn abs(tensor: Self::Primitive) -> Self::Primitive { + B::int_abs(tensor) + } } impl Numeric for Float { - fn add( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::add(lhs, rhs) - } - fn add_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::add_scalar(lhs, rhs.elem()) - } - fn sub( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::sub(lhs, rhs) - } - fn sub_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sub_scalar(lhs, rhs.elem()) - } - fn div( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::div(lhs, rhs) - } - fn div_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::div_scalar(lhs, rhs.elem()) - } - fn mul( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> >::Primitive { - B::mul(lhs, rhs) - } - fn mul_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::mul_scalar(lhs, rhs.elem()) - } - fn neg(tensor: Self::Primitive) -> Self::Primitive { - B::neg(tensor) - } - fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { - B::zeros(shape, device) - } - fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { - B::ones(shape, device) - } - fn full( - shape: Shape, - fill_value: E, - device: &B::Device, - ) -> Self::Primitive { - B::full(shape, fill_value.elem(), device) - } - fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { - B::sum(tensor) - } - fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::sum_dim(tensor, dim) - } - fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { - B::mean(tensor) - } - fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::mean_dim(tensor, dim) - } - - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::equal_elem(lhs, rhs)) - } - fn greater( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::greater(lhs, rhs)) - } - - fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::greater_elem(lhs, rhs)) - } - - fn greater_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::greater_equal(lhs, rhs)) - } - - fn greater_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::greater_equal_elem(lhs, rhs)) - } - - fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> Tensor { - Tensor::new(B::lower(lhs, rhs)) - } - - fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::lower_elem(lhs, rhs)) - } - - fn lower_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::lower_equal(lhs, rhs)) - } - - fn lower_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::lower_equal_elem(lhs, rhs)) - } - - fn mask_where( - tensor: Self::Primitive, - mask: Tensor, - source: Self::Primitive, - ) -> Self::Primitive { - B::mask_where(tensor, mask.primitive, source) - } - - fn mask_fill( - tensor: Self::Primitive, - mask: Tensor, - value: Self::Elem, - ) -> Self::Primitive { - B::mask_fill(tensor, mask.primitive, value) - } - - fn select( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - ) -> Self::Primitive { - B::select(tensor, dim, indices.primitive) - } - - fn select_assign( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::select_assign(tensor, dim, indices.primitive, values) - } - - fn gather( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - ) -> Self::Primitive { - B::gather(dim, tensor, indices.primitive) - } - - fn scatter( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::scatter(dim, tensor, indices.primitive, values) - } - - fn argmax( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::argmax(tensor, dim) - } - - fn argmin( - tensor: Self::Primitive, - dim: usize, - ) -> ::IntTensorPrimitive { - B::argmin(tensor, dim) - } - - fn max(tensor: Self::Primitive) -> Self::Primitive<1> { - B::max(tensor) - } - - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::max_dim(tensor, dim) - } - - fn max_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::max_dim_with_indices(tensor, dim) - } - - fn min(tensor: Self::Primitive) -> Self::Primitive<1> { - B::min(tensor) - } - - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::min_dim(tensor, dim) - } - - fn min_dim_with_indices( - tensor: Self::Primitive, - dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - B::min_dim_with_indices(tensor, dim) - } - - fn clamp( - tensor: Self::Primitive, - min: B::FloatElem, - max: B::FloatElem, - ) -> Self::Primitive { - B::clamp(tensor, min, max) - } - - fn clamp_min( - tensor: Self::Primitive, - min: B::FloatElem, - ) -> Self::Primitive { - B::clamp_min(tensor, min) - } - - fn clamp_max( - tensor: Self::Primitive, - max: B::FloatElem, - ) -> Self::Primitive { - B::clamp_max(tensor, max) - } - - fn abs(tensor: Self::Primitive) -> Self::Primitive { - B::abs(tensor) - } + fn add( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::add(lhs, rhs) + } + fn add_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::add_scalar(lhs, rhs.elem()) + } + fn sub( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::sub(lhs, rhs) + } + fn sub_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::sub_scalar(lhs, rhs.elem()) + } + fn div( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::div(lhs, rhs) + } + fn div_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::div_scalar(lhs, rhs.elem()) + } + fn mul( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> >::Primitive { + B::mul(lhs, rhs) + } + fn mul_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::mul_scalar(lhs, rhs.elem()) + } + fn neg(tensor: Self::Primitive) -> Self::Primitive { + B::neg(tensor) + } + fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { + B::zeros(shape, device) + } + fn ones(shape: Shape, device: &B::Device) -> Self::Primitive { + B::ones(shape, device) + } + fn full( + shape: Shape, + fill_value: E, + device: &B::Device, + ) -> Self::Primitive { + B::full(shape, fill_value.elem(), device) + } + fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { + B::sum(tensor) + } + fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::sum_dim(tensor, dim) + } + fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { + B::mean(tensor) + } + fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::mean_dim(tensor, dim) + } + + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::equal_elem(lhs, rhs)) + } + fn greater( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::greater(lhs, rhs)) + } + + fn greater_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::greater_elem(lhs, rhs)) + } + + fn greater_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::greater_equal(lhs, rhs)) + } + + fn greater_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::greater_equal_elem(lhs, rhs)) + } + + fn lower( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::lower(lhs, rhs)) + } + + fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::lower_elem(lhs, rhs)) + } + + fn lower_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + Tensor::new(B::lower_equal(lhs, rhs)) + } + + fn lower_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + Tensor::new(B::lower_equal_elem(lhs, rhs)) + } + + fn mask_where( + tensor: Self::Primitive, + mask: Tensor, + source: Self::Primitive, + ) -> Self::Primitive { + B::mask_where(tensor, mask.primitive, source) + } + + fn mask_fill( + tensor: Self::Primitive, + mask: Tensor, + value: Self::Elem, + ) -> Self::Primitive { + B::mask_fill(tensor, mask.primitive, value) + } + + fn select( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + ) -> Self::Primitive { + B::select(tensor, dim, indices.primitive) + } + + fn select_assign( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::select_assign(tensor, dim, indices.primitive, values) + } + + fn gather( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + ) -> Self::Primitive { + B::gather(dim, tensor, indices.primitive) + } + + fn scatter( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::scatter(dim, tensor, indices.primitive, values) + } + + fn argmax( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::argmax(tensor, dim) + } + + fn argmin( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + B::argmin(tensor, dim) + } + + fn max(tensor: Self::Primitive) -> Self::Primitive<1> { + B::max(tensor) + } + + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::max_dim(tensor, dim) + } + + fn max_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::max_dim_with_indices(tensor, dim) + } + + fn min(tensor: Self::Primitive) -> Self::Primitive<1> { + B::min(tensor) + } + + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::min_dim(tensor, dim) + } + + fn min_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + B::min_dim_with_indices(tensor, dim) + } + + fn clamp( + tensor: Self::Primitive, + min: B::FloatElem, + max: B::FloatElem, + ) -> Self::Primitive { + B::clamp(tensor, min, max) + } + + fn clamp_min( + tensor: Self::Primitive, + min: B::FloatElem, + ) -> Self::Primitive { + B::clamp_min(tensor, min) + } + + fn clamp_max( + tensor: Self::Primitive, + max: B::FloatElem, + ) -> Self::Primitive { + B::clamp_max(tensor, max) + } + + fn abs(tensor: Self::Primitive) -> Self::Primitive { + B::abs(tensor) + } } impl core::ops::Add for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn add(self, rhs: Tensor) -> Self { - Self::add(self, rhs) - } + fn add(self, rhs: Tensor) -> Self { + Self::add(self, rhs) + } } impl core::ops::Add for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn add(self, other: E) -> Self { - Tensor::add_scalar(self, other) - } + fn add(self, other: E) -> Self { + Tensor::add_scalar(self, other) + } } impl core::ops::Sub> for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn sub(self, rhs: Tensor) -> Self { - Tensor::sub(self, rhs) - } + fn sub(self, rhs: Tensor) -> Self { + Tensor::sub(self, rhs) + } } impl core::ops::Sub for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn sub(self, other: E) -> Self { - Tensor::sub_scalar(self, other) - } + fn sub(self, other: E) -> Self { + Tensor::sub_scalar(self, other) + } } impl core::ops::Div> for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn div(self, rhs: Tensor) -> Self { - Tensor::div(self, rhs) - } + fn div(self, rhs: Tensor) -> Self { + Tensor::div(self, rhs) + } } impl core::ops::Div for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn div(self, other: E) -> Self { - Tensor::div_scalar(self, other) - } + fn div(self, other: E) -> Self { + Tensor::div_scalar(self, other) + } } impl core::ops::Mul> for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn mul(self, rhs: Tensor) -> Self { - Tensor::mul(self, rhs) - } + fn mul(self, rhs: Tensor) -> Self { + Tensor::mul(self, rhs) + } } impl core::ops::Mul for Tensor where - E: ElementConversion, - B: Backend, - K: Numeric, - K::Elem: Element, + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn mul(self, other: E) -> Self { - Tensor::mul_scalar(self, other) - } + fn mul(self, other: E) -> Self { + Tensor::mul_scalar(self, other) + } } impl core::ops::Neg for Tensor where - B: Backend, - K: Numeric, - K::Elem: Element, + B: Backend, + K: Numeric, + K::Elem: Element, { - type Output = Self; + type Output = Self; - fn neg(self) -> Self { - Tensor::neg(self) - } + fn neg(self) -> Self { + Tensor::neg(self) + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 7e6141ea63..131855cce8 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -50,144 +50,145 @@ use crate::tensor::Element; /// Most of the documentation for each function can be found on the user API [tensor struct](crate::Tensor). /// For modules, public functions are often created, which can be used by `burn-core` modules. pub trait Backend: - TensorOps - + BoolTensorOps - + IntTensorOps - + ModuleOps - + ActivationOps - + Clone - + Sized - + Default - + Send - + Sync - + core::fmt::Debug - + 'static + TensorOps + + BoolTensorOps + + IntTensorOps + + ModuleOps + + ActivationOps + + Clone + + Sized + + Default + + Send + + Sync + + core::fmt::Debug + + 'static { - /// Device type. - type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; + /// Device type. + type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; - /// Pointer to another backend that have a full precision float element type - type FullPrecisionBackend: Backend; - /// Full precision float element type. - type FullPrecisionElem: Element; + /// Pointer to another backend that have a full precision float element type + type FullPrecisionBackend: Backend; + /// Full precision float element type. + type FullPrecisionElem: Element; - /// Tensor primitive to be used for all float operations. - type TensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; - /// Float element type. - type FloatElem: Element; + /// Tensor primitive to be used for all float operations. + type TensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + /// Float element type. + type FloatElem: Element; - /// Tensor primitive to be used for all int operations. - type IntTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; - /// Int element type. - type IntElem: Element; + /// Tensor primitive to be used for all int operations. + type IntTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + /// Int element type. + type IntElem: Element; - /// Tensor primitive to be used for all bool operations. - type BoolTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + /// Tensor primitive to be used for all bool operations. + type BoolTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; - /// If autodiff is enabled. - fn ad_enabled() -> bool { - false - } + /// If autodiff is enabled. + fn ad_enabled() -> bool { + false + } - /// Name of the backend. - fn name() -> String; + /// Name of the backend. + fn name() -> String; - /// Seed the backend. - fn seed(seed: u64); + /// Seed the backend. + fn seed(seed: u64); - /// Sync the backend, ensure that all computation are finished. - fn sync(_device: &Self::Device) {} + /// Sync the backend, ensure that all computation are finished. + fn sync(_device: &Self::Device) {} } /// Trait that allows a backend to support autodiff. pub trait AutodiffBackend: Backend { - /// The inner backend type. - type InnerBackend: Backend< - Device = Self::Device, - FloatElem = Self::FloatElem, - IntElem = Self::IntElem, - FullPrecisionElem = Self::FullPrecisionElem, - >; - - /// Gradients type. - type Gradients: Send + Sync; - - /// Backward pass. - /// - /// # Arguments - /// - /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed. - /// - /// # Returns - /// - /// The gradients. - fn backward(tensor: FloatTensor) -> Self::Gradients; - - /// Returns the gradients of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to extract the gradients from. - /// - /// # Returns - /// - /// An optional tensor containing the gradient. - fn grad( - tensor: &FloatTensor, - grads: &Self::Gradients, - ) -> Option>; - - /// Pops the gradients of a tensor and returns them. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to pop the gradients from. - /// * `grads` - The gradients. - /// - /// # Returns - /// - /// An optional tensor containing the given gradients. - fn grad_remove( - tensor: &FloatTensor, - grads: &mut Self::Gradients, - ) -> Option>; - - /// Replace the gradients of a tensor with the one provided. - /// - /// If no gradient existed for the provided tensor, register it. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to pop the gradients from. - /// * `grads` - The gradients. - /// * `grad` - The updated grad tensor. - fn grad_replace( - tensor: &FloatTensor, - grads: &mut Self::Gradients, - grad: FloatTensor, - ); - - /// Returns the tensor with inner backend type. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the inner backend tensor for. - /// - /// # Returns - /// - /// The inner backend tensor. - fn inner(tensor: FloatTensor) -> FloatTensor; - - /// Converts the inner backend tensor to the autodiff backend tensor. - /// - /// # Arguments - /// - /// * `tensor` - The inner backend tensor to convert. - /// - /// - /// # Returns - /// - /// The autodiff backend tensor. - fn from_inner(tensor: FloatTensor) - -> FloatTensor; + /// The inner backend type. + type InnerBackend: Backend< + Device = Self::Device, + FloatElem = Self::FloatElem, + IntElem = Self::IntElem, + FullPrecisionElem = Self::FullPrecisionElem, + >; + + /// Gradients type. + type Gradients: Send + Sync; + + /// Backward pass. + /// + /// # Arguments + /// + /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed. + /// + /// # Returns + /// + /// The gradients. + fn backward(tensor: FloatTensor) -> Self::Gradients; + + /// Returns the gradients of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to extract the gradients from. + /// + /// # Returns + /// + /// An optional tensor containing the gradient. + fn grad( + tensor: &FloatTensor, + grads: &Self::Gradients, + ) -> Option>; + + /// Pops the gradients of a tensor and returns them. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to pop the gradients from. + /// * `grads` - The gradients. + /// + /// # Returns + /// + /// An optional tensor containing the given gradients. + fn grad_remove( + tensor: &FloatTensor, + grads: &mut Self::Gradients, + ) -> Option>; + + /// Replace the gradients of a tensor with the one provided. + /// + /// If no gradient existed for the provided tensor, register it. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to pop the gradients from. + /// * `grads` - The gradients. + /// * `grad` - The updated grad tensor. + fn grad_replace( + tensor: &FloatTensor, + grads: &mut Self::Gradients, + grad: FloatTensor, + ); + + /// Returns the tensor with inner backend type. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the inner backend tensor for. + /// + /// # Returns + /// + /// The inner backend tensor. + fn inner(tensor: FloatTensor) -> FloatTensor; + + /// Converts the inner backend tensor to the autodiff backend tensor. + /// + /// # Arguments + /// + /// * `tensor` - The inner backend tensor to convert. + /// + /// + /// # Returns + /// + /// The autodiff backend tensor. + fn from_inner( + tensor: FloatTensor, + ) -> FloatTensor; } diff --git a/burn-tensor/src/tensor/container.rs b/burn-tensor/src/tensor/container.rs index 7432d4ee70..76bbde6151 100644 --- a/burn-tensor/src/tensor/container.rs +++ b/burn-tensor/src/tensor/container.rs @@ -12,80 +12,79 @@ use crate::{backend::Backend, Tensor}; /// Contains tensor of arbitrary dimension. #[derive(Debug)] pub struct TensorContainer { - tensors: HashMap>, + tensors: HashMap>, } impl Default for TensorContainer where - ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, + ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } type TensorPrimitive = ::TensorPrimitive; impl TensorContainer where - ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, + ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, { - /// Create an empty container. - pub fn new() -> Self { - Self { - tensors: HashMap::new(), + /// Create an empty container. + pub fn new() -> Self { + Self { + tensors: HashMap::new(), + } } - } - /// Get a tensor with the given ID. - pub fn get(&self, id: &ID) -> Option> - where - B: Backend, - { - let grad = match self.tensors.get(id) { - Some(grad) => grad, - None => return None, - }; + /// Get a tensor with the given ID. + pub fn get(&self, id: &ID) -> Option> + where + B: Backend, + { + let grad = match self.tensors.get(id) { + Some(grad) => grad, + None => return None, + }; - let tensor = grad - .downcast_ref::>() - .map(|primitive| Tensor::::from_primitive(primitive.clone())) - .unwrap(); + let tensor = grad + .downcast_ref::>() + .map(|primitive| Tensor::::from_primitive(primitive.clone())) + .unwrap(); - Some(tensor) - } + Some(tensor) + } - /// Register a new tensor for the given ID. - /// - /// # Notes - /// - /// If a tensor is already registered for the given ID, it will be replaced. - pub fn register(&mut self, id: ID, value: Tensor) - where - B: Backend, - { - self.tensors.insert(id, Box::new(value.into_primitive())); - } + /// Register a new tensor for the given ID. + /// + /// # Notes + /// + /// If a tensor is already registered for the given ID, it will be replaced. + pub fn register(&mut self, id: ID, value: Tensor) + where + B: Backend, + { + self.tensors.insert(id, Box::new(value.into_primitive())); + } - /// Remove a tensor for the given ID and returns it. - pub fn remove(&mut self, id: &ID) -> Option> - where - B: Backend, - { - self - .tensors - .remove(id) - .map(|item| item.downcast::>().unwrap()) - .map(|primitive| Tensor::from_primitive(*primitive)) - } + /// Remove a tensor for the given ID and returns it. + pub fn remove(&mut self, id: &ID) -> Option> + where + B: Backend, + { + self.tensors + .remove(id) + .map(|item| item.downcast::>().unwrap()) + .map(|primitive| Tensor::from_primitive(*primitive)) + } - /// The number of tensors registered. - pub fn len(&self) -> usize { - self.tensors.len() - } + /// The number of tensors registered. + pub fn len(&self) -> usize { + self.tensors.len() + } - /// If any tensor is contained. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } + /// If any tensor is contained. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } diff --git a/burn-tensor/src/tensor/data.rs b/burn-tensor/src/tensor/data.rs index 5cbf215199..4c75571788 100644 --- a/burn-tensor/src/tensor/data.rs +++ b/burn-tensor/src/tensor/data.rs @@ -9,514 +9,522 @@ use rand::{distributions::Standard, Rng, RngCore}; /// Data structure for serializing and deserializing tensor data. #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)] pub struct DataSerialize { - /// The values of the tensor. - pub value: Vec, - /// The shape of the tensor. - pub shape: Vec, + /// The values of the tensor. + pub value: Vec, + /// The shape of the tensor. + pub shape: Vec, } /// Data structure for tensors. #[derive(new, Debug, Clone, PartialEq, Eq)] pub struct Data { - /// The values of the tensor. - pub value: Vec, + /// The values of the tensor. + pub value: Vec, - /// The shape of the tensor. - pub shape: Shape, + /// The shape of the tensor. + pub shape: Shape, } /// Distribution for random value of a tensor. #[derive(Debug, Clone, Copy)] pub enum Distribution { - /// Uniform distribution from 0 (inclusive) to 1 (exclusive). - Default, + /// Uniform distribution from 0 (inclusive) to 1 (exclusive). + Default, - /// Bernoulli distribution with the given probability. - Bernoulli(f64), + /// Bernoulli distribution with the given probability. + Bernoulli(f64), - /// Uniform distribution. The range is inclusive. - Uniform(E, E), + /// Uniform distribution. The range is inclusive. + Uniform(E, E), - /// Normal distribution with the given mean and standard deviation. - Normal(f64, f64), + /// Normal distribution with the given mean and standard deviation. + Normal(f64, f64), } /// Distribution sampler for random value of a tensor. #[derive(new)] pub struct DistributionSampler<'a, E, R> where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, - R: RngCore, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, + R: RngCore, { - kind: DistributionSamplerKind, - rng: &'a mut R, + kind: DistributionSamplerKind, + rng: &'a mut R, } /// Distribution sampler kind for random value of a tensor. pub enum DistributionSamplerKind where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, { - /// Standard distribution. - Standard(rand::distributions::Standard), + /// Standard distribution. + Standard(rand::distributions::Standard), - /// Uniform distribution. - Uniform(rand::distributions::Uniform), + /// Uniform distribution. + Uniform(rand::distributions::Uniform), - /// Bernoulli distribution. - Bernoulli(rand::distributions::Bernoulli), + /// Bernoulli distribution. + Bernoulli(rand::distributions::Bernoulli), - /// Normal distribution. - Normal(rand_distr::Normal), + /// Normal distribution. + Normal(rand_distr::Normal), } impl<'a, E, R> DistributionSampler<'a, E, R> where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, - E: Element, - R: RngCore, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, + E: Element, + R: RngCore, { - /// Sames a random value from the distribution. - pub fn sample(&mut self) -> E { - match &self.kind { - DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution), - DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution), - DistributionSamplerKind::Bernoulli(distribution) => { - if self.rng.sample(distribution) { - 1.elem() - } else { - 0.elem() + /// Sames a random value from the distribution. + pub fn sample(&mut self) -> E { + match &self.kind { + DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution), + DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution), + DistributionSamplerKind::Bernoulli(distribution) => { + if self.rng.sample(distribution) { + 1.elem() + } else { + 0.elem() + } + } + DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(), } - } - DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(), } - } } impl Distribution where - Standard: rand::distributions::Distribution, - E: rand::distributions::uniform::SampleUniform, + Standard: rand::distributions::Distribution, + E: rand::distributions::uniform::SampleUniform, { - /// Creates a new distribution sampler. - /// - /// # Arguments - /// - /// * `rng` - The random number generator. - /// - /// # Returns - /// - /// The distribution sampler. - pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> { - let kind = match self { - Distribution::Default => DistributionSamplerKind::Standard(rand::distributions::Standard {}), - Distribution::Uniform(low, high) => { - DistributionSamplerKind::Uniform(rand::distributions::Uniform::new(low, high)) - } - Distribution::Bernoulli(prob) => { - DistributionSamplerKind::Bernoulli(rand::distributions::Bernoulli::new(prob).unwrap()) - } - Distribution::Normal(mean, std) => { - DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap()) - } - }; - - DistributionSampler::new(kind, rng) - } + /// Creates a new distribution sampler. + /// + /// # Arguments + /// + /// * `rng` - The random number generator. + /// + /// # Returns + /// + /// The distribution sampler. + pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> { + let kind = match self { + Distribution::Default => { + DistributionSamplerKind::Standard(rand::distributions::Standard {}) + } + Distribution::Uniform(low, high) => { + DistributionSamplerKind::Uniform(rand::distributions::Uniform::new(low, high)) + } + Distribution::Bernoulli(prob) => DistributionSamplerKind::Bernoulli( + rand::distributions::Bernoulli::new(prob).unwrap(), + ), + Distribution::Normal(mean, std) => { + DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap()) + } + }; + + DistributionSampler::new(kind, rng) + } } impl Distribution where - E: Element, + E: Element, { - /// Converts the distribution to a different element type. - /// - /// # Returns - /// - /// The converted distribution. - pub fn convert(self) -> Distribution { - match self { - Distribution::Default => Distribution::Default, - Distribution::Uniform(a, b) => { - Distribution::Uniform(EOther::from_elem(a), EOther::from_elem(b)) - } - Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob), - Distribution::Normal(mean, std) => Distribution::Normal(mean, std), + /// Converts the distribution to a different element type. + /// + /// # Returns + /// + /// The converted distribution. + pub fn convert(self) -> Distribution { + match self { + Distribution::Default => Distribution::Default, + Distribution::Uniform(a, b) => { + Distribution::Uniform(EOther::from_elem(a), EOther::from_elem(b)) + } + Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob), + Distribution::Normal(mean, std) => Distribution::Normal(mean, std), + } } - } } impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); + /// Converts the data to a different element type. + pub fn convert(self) -> Data { + let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - Data { - value, - shape: self.shape, + Data { + value, + shape: self.shape, + } } - } - - /// Asserts each value is within a given range. - /// - /// # Arguments - /// - /// * `range` - The range. - /// - /// # Panics - /// - /// If any value is not within the half-open range bounded inclusively below - /// and exclusively above (`start..end`). - pub fn assert_within_range(&self, range: core::ops::Range) { - let start = range.start.elem::(); - let end = range.end.elem::(); - - for elem in self.value.iter() { - let elem = elem.elem::(); - if elem < start || elem >= end { - panic!("Element ({elem:?}) is not within range {range:?}"); - } + + /// Asserts each value is within a given range. + /// + /// # Arguments + /// + /// * `range` - The range. + /// + /// # Panics + /// + /// If any value is not within the half-open range bounded inclusively below + /// and exclusively above (`start..end`). + pub fn assert_within_range(&self, range: core::ops::Range) { + let start = range.start.elem::(); + let end = range.end.elem::(); + + for elem in self.value.iter() { + let elem = elem.elem::(); + if elem < start || elem >= end { + panic!("Element ({elem:?}) is not within range {range:?}"); + } + } } - } } impl DataSerialize { - /// Converts the data to a different element type. - pub fn convert(self) -> DataSerialize { - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); + /// Converts the data to a different element type. + pub fn convert(self) -> DataSerialize { + let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - DataSerialize { - value, - shape: self.shape, + DataSerialize { + value, + shape: self.shape, + } } - } } impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| (a as i64).elem()).collect(); + /// Converts the data to a different element type. + pub fn convert(self) -> Data { + let value: Vec = self.value.into_iter().map(|a| (a as i64).elem()).collect(); - Data { - value, - shape: self.shape, + Data { + value, + shape: self.shape, + } } - } } impl Data { - /// Populates the data with random values. - pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); + /// Populates the data with random values. + pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); - for _ in 0..num_elements { - data.push(E::random(distribution, rng)); - } + for _ in 0..num_elements { + data.push(E::random(distribution, rng)); + } - Data::new(data, shape) - } + Data::new(data, shape) + } } impl Data where - E: Element, + E: Element, { - /// Populates the data with zeros. - pub fn zeros>>(shape: S) -> Data { - let shape = shape.into(); - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(0.elem()); - } + /// Populates the data with zeros. + pub fn zeros>>(shape: S) -> Data { + let shape = shape.into(); + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); + + for _ in 0..num_elements { + data.push(0.elem()); + } - Data::new(data, shape) - } + Data::new(data, shape) + } } impl Data where - E: Element, + E: Element, { - /// Populates the data with ones. - pub fn ones(shape: Shape) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); + /// Populates the data with ones. + pub fn ones(shape: Shape) -> Data { + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); - for _ in 0..num_elements { - data.push(1.elem()); - } + for _ in 0..num_elements { + data.push(1.elem()); + } - Data::new(data, shape) - } + Data::new(data, shape) + } } impl Data where - E: Element, + E: Element, { - /// Populates the data with the given value - pub fn full(shape: Shape, fill_value: E) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - for _ in 0..num_elements { - data.push(fill_value) - } + /// Populates the data with the given value + pub fn full(shape: Shape, fill_value: E) -> Data { + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); + for _ in 0..num_elements { + data.push(fill_value) + } - Data::new(data, shape) - } + Data::new(data, shape) + } } impl Data { - /// Serializes the data. - /// - /// # Returns - /// - /// The serialized data. - pub fn serialize(&self) -> DataSerialize { - DataSerialize { - value: self.value.clone(), - shape: self.shape.dims.to_vec(), + /// Serializes the data. + /// + /// # Returns + /// + /// The serialized data. + pub fn serialize(&self) -> DataSerialize { + DataSerialize { + value: self.value.clone(), + shape: self.shape.dims.to_vec(), + } } - } } impl + Clone + core::fmt::Debug + PartialEq, const D: usize> Data { - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `precision` - The precision of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq(&self, other: &Self, precision: usize) { - let tolerance = libm::pow(0.1, precision as f64); - - self.assert_approx_eq_diff(other, tolerance) - } - - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `tolerance` - The tolerance of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) { - let mut message = String::new(); - if self.shape != other.shape { - message += format!( - "\n => Shape is different: {:?} != {:?}", - self.shape.dims, other.shape.dims - ) - .as_str(); + /// Asserts the data is approximately equal to another data. + /// + /// # Arguments + /// + /// * `other` - The other data. + /// * `precision` - The precision of the comparison. + /// + /// # Panics + /// + /// Panics if the data is not approximately equal. + #[track_caller] + pub fn assert_approx_eq(&self, other: &Self, precision: usize) { + let tolerance = libm::pow(0.1, precision as f64); + + self.assert_approx_eq_diff(other, tolerance) } - let iter = self.value.clone().into_iter().zip(other.value.clone()); + /// Asserts the data is approximately equal to another data. + /// + /// # Arguments + /// + /// * `other` - The other data. + /// * `tolerance` - The tolerance of the comparison. + /// + /// # Panics + /// + /// Panics if the data is not approximately equal. + #[track_caller] + pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) { + let mut message = String::new(); + if self.shape != other.shape { + message += format!( + "\n => Shape is different: {:?} != {:?}", + self.shape.dims, other.shape.dims + ) + .as_str(); + } + + let iter = self.value.clone().into_iter().zip(other.value.clone()); - let mut num_diff = 0; - let max_num_diff = 5; + let mut num_diff = 0; + let max_num_diff = 5; - for (i, (a, b)) in iter.enumerate() { - let a: f64 = a.into(); - let b: f64 = b.into(); + for (i, (a, b)) in iter.enumerate() { + let a: f64 = a.into(); + let b: f64 = b.into(); - let err = libm::sqrt(libm::pow(a - b, 2.0)); + let err = libm::sqrt(libm::pow(a - b, 2.0)); - if err > tolerance { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += + if err > tolerance { + // Only print the first 5 different values. + if num_diff < max_num_diff { + message += format!("\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}") .as_str(); + } + num_diff += 1; + } } - num_diff += 1; - } - } - if num_diff >= max_num_diff { - message += format!("\n{} more errors...", num_diff - 5).as_str(); - } + if num_diff >= max_num_diff { + message += format!("\n{} more errors...", num_diff - 5).as_str(); + } - if !message.is_empty() { - panic!("Tensors are not approx eq:{}", message); + if !message.is_empty() { + panic!("Tensors are not approx eq:{}", message); + } } - } } impl Data { - /// Converts the usize data to a different element type. - pub fn from_usize(self) -> Data { - let value: Vec = self - .value - .into_iter() - .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap()) - .collect(); - - Data { - value, - shape: self.shape, + /// Converts the usize data to a different element type. + pub fn from_usize(self) -> Data { + let value: Vec = self + .value + .into_iter() + .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap()) + .collect(); + + Data { + value, + shape: self.shape, + } } - } } impl From<&DataSerialize> for Data { - fn from(data: &DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value.clone(), Shape::new(dims)) - } + fn from(data: &DataSerialize) -> Self { + let mut dims = [0; D]; + dims[..D].copy_from_slice(&data.shape[..D]); + Data::new(data.value.clone(), Shape::new(dims)) + } } impl From> for Data { - fn from(data: DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value, Shape::new(dims)) - } + fn from(data: DataSerialize) -> Self { + let mut dims = [0; D]; + dims[..D].copy_from_slice(&data.shape[..D]); + Data::new(data.value, Shape::new(dims)) + } } impl From<[E; A]> for Data { - fn from(elems: [E; A]) -> Self { - let mut data = Vec::with_capacity(2 * A); - for elem in elems.into_iter() { - data.push(elem); - } + fn from(elems: [E; A]) -> Self { + let mut data = Vec::with_capacity(2 * A); + for elem in elems.into_iter() { + data.push(elem); + } - Data::new(data, Shape::new([A])) - } + Data::new(data, Shape::new([A])) + } } impl From<&[E]> for Data { - fn from(elems: &[E]) -> Self { - let mut data = Vec::with_capacity(elems.len()); - for elem in elems.iter() { - data.push(*elem); - } + fn from(elems: &[E]) -> Self { + let mut data = Vec::with_capacity(elems.len()); + for elem in elems.iter() { + data.push(*elem); + } - Data::new(data, Shape::new([elems.len()])) - } + Data::new(data, Shape::new([elems.len()])) + } } impl From<[[E; B]; A]> for Data { - fn from(elems: [[E; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B); - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - data.push(elem); - } - } + fn from(elems: [[E; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B); + for elem in elems.into_iter().take(A) { + for elem in elem.into_iter().take(B) { + data.push(elem); + } + } - Data::new(data, Shape::new([A, B])) - } + Data::new(data, Shape::new([A, B])) + } } impl - From<[[[E; C]; B]; A]> for Data + From<[[[E; C]; B]; A]> for Data { - fn from(elems: [[[E; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - data.push(elem); + fn from(elems: [[[E; C]; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B * C); + + for elem in elems.into_iter().take(A) { + for elem in elem.into_iter().take(B) { + for elem in elem.into_iter().take(C) { + data.push(elem); + } + } } - } - } - Data::new(data, Shape::new([A, B, C])) - } + Data::new(data, Shape::new([A, B, C])) + } } -impl - From<[[[[E; D]; C]; B]; A]> for Data +impl< + E: core::fmt::Debug + Copy, + const A: usize, + const B: usize, + const C: usize, + const D: usize, + > From<[[[[E; D]; C]; B]; A]> for Data { - fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C * D); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - for elem in elem.into_iter().take(D) { - data.push(elem); - } + fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B * C * D); + + for elem in elems.into_iter().take(A) { + for elem in elem.into_iter().take(B) { + for elem in elem.into_iter().take(C) { + for elem in elem.into_iter().take(D) { + data.push(elem); + } + } + } } - } - } - Data::new(data, Shape::new([A, B, C, D])) - } + Data::new(data, Shape::new([A, B, C, D])) + } } impl core::fmt::Display for Data { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("{:?}", &self.value).as_str()) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(format!("{:?}", &self.value).as_str()) + } } #[cfg(test)] mod tests { - use super::*; - use rand::{rngs::StdRng, SeedableRng}; - - #[test] - fn should_have_right_num_elements() { - let shape = Shape::new([3, 5, 6]); - let num_elements = shape.num_elements(); - let data = Data::::random(shape, Distribution::Default, &mut StdRng::from_entropy()); - - assert_eq!(num_elements, data.value.len()); - } - - #[test] - fn should_have_right_shape() { - let data = Data::from([[3.0, 5.0, 6.0]]); - assert_eq!(data.shape, Shape::new([1, 3])); - - let data = Data::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]); - assert_eq!(data.shape, Shape::new([2, 3])); - - let data = Data::from([3.0, 5.0, 6.0]); - assert_eq!(data.shape, Shape::new([3])); - } - - #[test] - fn should_assert_appox_eq_limit() { - let data1 = Data::::from([[3.0, 5.0, 6.0]]); - let data2 = Data::::from([[3.01, 5.0, 6.0]]); - - data1.assert_approx_eq(&data2, 2); - } - - #[test] - #[should_panic] - fn should_assert_appox_eq_above_limit() { - let data1 = Data::::from([[3.0, 5.0, 6.0]]); - let data2 = Data::::from([[3.011, 5.0, 6.0]]); - - data1.assert_approx_eq(&data2, 2); - } - - #[test] - #[should_panic] - fn should_assert_appox_eq_check_shape() { - let data1 = Data::::from([[3.0, 5.0, 6.0, 7.0]]); - let data2 = Data::::from([[3.0, 5.0, 6.0]]); - - data1.assert_approx_eq(&data2, 2); - } + use super::*; + use rand::{rngs::StdRng, SeedableRng}; + + #[test] + fn should_have_right_num_elements() { + let shape = Shape::new([3, 5, 6]); + let num_elements = shape.num_elements(); + let data = + Data::::random(shape, Distribution::Default, &mut StdRng::from_entropy()); + + assert_eq!(num_elements, data.value.len()); + } + + #[test] + fn should_have_right_shape() { + let data = Data::from([[3.0, 5.0, 6.0]]); + assert_eq!(data.shape, Shape::new([1, 3])); + + let data = Data::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]); + assert_eq!(data.shape, Shape::new([2, 3])); + + let data = Data::from([3.0, 5.0, 6.0]); + assert_eq!(data.shape, Shape::new([3])); + } + + #[test] + fn should_assert_appox_eq_limit() { + let data1 = Data::::from([[3.0, 5.0, 6.0]]); + let data2 = Data::::from([[3.01, 5.0, 6.0]]); + + data1.assert_approx_eq(&data2, 2); + } + + #[test] + #[should_panic] + fn should_assert_appox_eq_above_limit() { + let data1 = Data::::from([[3.0, 5.0, 6.0]]); + let data2 = Data::::from([[3.011, 5.0, 6.0]]); + + data1.assert_approx_eq(&data2, 2); + } + + #[test] + #[should_panic] + fn should_assert_appox_eq_check_shape() { + let data1 = Data::::from([[3.0, 5.0, 6.0, 7.0]]); + let data2 = Data::::from([[3.0, 5.0, 6.0]]); + + data1.assert_approx_eq(&data2, 2); + } } diff --git a/burn-tensor/src/tensor/element.rs b/burn-tensor/src/tensor/element.rs index 5bcd40dcc5..2108a033b0 100644 --- a/burn-tensor/src/tensor/element.rs +++ b/burn-tensor/src/tensor/element.rs @@ -5,110 +5,110 @@ use rand::RngCore; /// Element trait for tensor. pub trait Element: - ToPrimitive - + ElementRandom - + ElementConversion - + ElementPrecision - + core::fmt::Debug - + core::fmt::Display - + Default - + Send - + Sync - + Copy - + 'static + ToPrimitive + + ElementRandom + + ElementConversion + + ElementPrecision + + core::fmt::Debug + + core::fmt::Display + + Default + + Send + + Sync + + Copy + + 'static { } /// Element conversion trait for tensor. pub trait ElementConversion { - /// Converts an element to another element. - /// - /// # Arguments - /// - /// * `elem` - The element to convert. - /// - /// # Returns - /// - /// The converted element. - fn from_elem(elem: E) -> Self; - - /// Converts and returns the converted element. - fn elem(self) -> E; + /// Converts an element to another element. + /// + /// # Arguments + /// + /// * `elem` - The element to convert. + /// + /// # Returns + /// + /// The converted element. + fn from_elem(elem: E) -> Self; + + /// Converts and returns the converted element. + fn elem(self) -> E; } /// Element trait for random value of a tensor. pub trait ElementRandom { - /// Returns a random value for the given distribution. - /// - /// # Arguments - /// - /// * `distribution` - The distribution to sample from. - /// * `rng` - The random number generator. - /// - /// # Returns - /// - /// The random value. - fn random(distribution: Distribution, rng: &mut R) -> Self - where - Self: Sized; + /// Returns a random value for the given distribution. + /// + /// # Arguments + /// + /// * `distribution` - The distribution to sample from. + /// * `rng` - The random number generator. + /// + /// # Returns + /// + /// The random value. + fn random(distribution: Distribution, rng: &mut R) -> Self + where + Self: Sized; } /// Element precision trait for tensor. #[derive(Clone, PartialEq, Eq, Copy, Debug)] pub enum Precision { - /// Double precision, e.g. f64. - Double, + /// Double precision, e.g. f64. + Double, - /// Full precision, e.g. f32. - Full, + /// Full precision, e.g. f32. + Full, - /// Half precision, e.g. f16. - Half, + /// Half precision, e.g. f16. + Half, - /// Other precision. - Other, + /// Other precision. + Other, } /// Element precision trait for tensor. pub trait ElementPrecision { - /// Returns the precision of the element. - fn precision() -> Precision; + /// Returns the precision of the element. + fn precision() -> Precision; } /// Macro to implement the element trait for a type. #[macro_export] macro_rules! make_element { - ( + ( ty $type:ident $precision:expr, convert $convert:expr, random $random:expr ) => { - impl Element for $type {} - - impl ElementConversion for $type { - fn from_elem(elem: E) -> Self { - #[allow(clippy::redundant_closure_call)] - $convert(&elem) - } - fn elem(self) -> E { - E::from_elem(self) - } - } - - impl ElementPrecision for $type { - fn precision() -> Precision { - $precision - } - } - - impl ElementRandom for $type { - fn random(distribution: Distribution, rng: &mut R) -> Self { - #[allow(clippy::redundant_closure_call)] - $random(distribution, rng) - } - } - }; + impl Element for $type {} + + impl ElementConversion for $type { + fn from_elem(elem: E) -> Self { + #[allow(clippy::redundant_closure_call)] + $convert(&elem) + } + fn elem(self) -> E { + E::from_elem(self) + } + } + + impl ElementPrecision for $type { + fn precision() -> Precision { + $precision + } + } + + impl ElementRandom for $type { + fn random(distribution: Distribution, rng: &mut R) -> Self { + #[allow(clippy::redundant_closure_call)] + $random(distribution, rng) + } + } + }; } make_element!( diff --git a/burn-tensor/src/tensor/loss/mod.rs b/burn-tensor/src/tensor/loss/mod.rs index 535427a601..339c4071f3 100644 --- a/burn-tensor/src/tensor/loss/mod.rs +++ b/burn-tensor/src/tensor/loss/mod.rs @@ -12,12 +12,12 @@ use crate::{activation, Tensor}; /// /// The log softmax cross entropy. pub fn cross_entropy_with_logits( - logits: Tensor, - target_probs: Tensor, + logits: Tensor, + target_probs: Tensor, ) -> Tensor { - let tensor = activation::log_softmax(logits, D - 1); - let tensor = tensor.mul(target_probs); - let tensor = tensor.sum_dim(D - 1); + let tensor = activation::log_softmax(logits, D - 1); + let tensor = tensor.mul(target_probs); + let tensor = tensor.sum_dim(D - 1); - tensor.mean().neg() + tensor.mean().neg() } diff --git a/burn-tensor/src/tensor/module.rs b/burn-tensor/src/tensor/module.rs index be8ec8c906..4ad956db56 100644 --- a/burn-tensor/src/tensor/module.rs +++ b/burn-tensor/src/tensor/module.rs @@ -1,221 +1,221 @@ use crate::{ - backend::Backend, - ops::{ConvOptions, ConvTransposeOptions, UnfoldOptions}, - Int, Tensor, + backend::Backend, + ops::{ConvOptions, ConvTransposeOptions, UnfoldOptions}, + Int, Tensor, }; /// Applies the [embedding module](crate::ops::ModuleOps::embedding). pub fn embedding(weights: Tensor, indices: Tensor) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::embedding(weights.primitive, indices.primitive)) + Tensor::new(B::embedding(weights.primitive, indices.primitive)) } /// Applies a [1D convolution](crate::ops::ModuleOps::conv2d). pub fn conv1d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvOptions<1>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvOptions<1>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv1d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv1d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [2D convolution](crate::ops::ModuleOps::conv2d). pub fn conv2d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvOptions<2>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvOptions<2>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv2d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv2d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d). pub fn conv_transpose1d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvTransposeOptions<1>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvTransposeOptions<1>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv_transpose1d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv_transpose1d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d). pub fn conv_transpose2d( - x: Tensor, - weight: Tensor, - bias: Option>, - options: ConvTransposeOptions<2>, + x: Tensor, + weight: Tensor, + bias: Option>, + options: ConvTransposeOptions<2>, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::conv_transpose2d( - x.primitive, - weight.primitive, - bias.map(|b| b.primitive), - options, - )) + Tensor::new(B::conv_transpose2d( + x.primitive, + weight.primitive, + bias.map(|b| b.primitive), + options, + )) } /// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d). pub fn unfold4d(x: Tensor, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::unfold4d(x.primitive, kernel_size, options)) + Tensor::new(B::unfold4d(x.primitive, kernel_size, options)) } /// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). pub fn max_pool1d( - x: Tensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::max_pool1d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )) + Tensor::new(B::max_pool1d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )) } /// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d). pub fn max_pool2d( - x: Tensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::max_pool2d( - x.primitive, - kernel_size, - stride, - padding, - dilation, - )) + Tensor::new(B::max_pool2d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + )) } /// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d). pub fn avg_pool2d( - x: Tensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::avg_pool2d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )) + Tensor::new(B::avg_pool2d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )) } /// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d). pub fn avg_pool1d( - x: Tensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, ) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::avg_pool1d( - x.primitive, - kernel_size, - stride, - padding, - count_include_pad, - )) + Tensor::new(B::avg_pool1d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + )) } /// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). pub fn max_pool1d_with_indices( - x: Tensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> (Tensor, Tensor) where - B: Backend, + B: Backend, { - let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - (Tensor::new(output.output), Tensor::new(output.indices)) + (Tensor::new(output.output), Tensor::new(output.indices)) } /// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices). pub fn max_pool2d_with_indices( - x: Tensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (Tensor, Tensor) where - B: Backend, + B: Backend, { - let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); + let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation); - (Tensor::new(output.output), Tensor::new(output.indices)) + (Tensor::new(output.output), Tensor::new(output.indices)) } /// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d). pub fn adaptive_avg_pool2d(x: Tensor, output_size: [usize; 2]) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::adaptive_avg_pool2d(x.primitive, output_size)) + Tensor::new(B::adaptive_avg_pool2d(x.primitive, output_size)) } /// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d). pub fn adaptive_avg_pool1d(x: Tensor, output_size: usize) -> Tensor where - B: Backend, + B: Backend, { - Tensor::new(B::adaptive_avg_pool1d(x.primitive, output_size)) + Tensor::new(B::adaptive_avg_pool1d(x.primitive, output_size)) } diff --git a/burn-tensor/src/tensor/named/base.rs b/burn-tensor/src/tensor/named/base.rs index a958050556..ab1f5aa9d8 100644 --- a/burn-tensor/src/tensor/named/base.rs +++ b/burn-tensor/src/tensor/named/base.rs @@ -6,76 +6,76 @@ use crate::{Distribution, NamedDims, Shape, Tensor}; /// A tensor with named dimensions. #[derive(Debug, Clone)] pub struct NamedTensor> { - pub(crate) tensor: D::Tensor, + pub(crate) tensor: D::Tensor, } impl>, const D: usize> From> - for Tensor + for Tensor { - fn from(nt: NamedTensor) -> Self { - nt.tensor - } + fn from(nt: NamedTensor) -> Self { + nt.tensor + } } impl>, const D: usize> From> - for NamedTensor + for NamedTensor { - fn from(tensor: Tensor) -> Self { - Self::from_tensor(tensor) - } + fn from(tensor: Tensor) -> Self { + Self::from_tensor(tensor) + } } impl> core::fmt::Display for NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(&format!( - "NamedTensor[shape={:?}, dims={}]", - self.shape().dims, - ND::to_string(), - )) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(&format!( + "NamedTensor[shape={:?}, dims={}]", + self.shape().dims, + ND::to_string(), + )) + } } impl NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - /// Create a named tensor from a tensor. - pub fn from_tensor(tensor: Tensor) -> Self { - Self { tensor } - } + /// Create a named tensor from a tensor. + pub fn from_tensor(tensor: Tensor) -> Self { + Self { tensor } + } - /// Create a random named tensor of the given shape where each element is sampled from - /// the given distribution. - pub fn random>>(shape: S, distribution: Distribution) -> Self { - Self::from_tensor(Tensor::random(shape, distribution)) - } + /// Create a random named tensor of the given shape where each element is sampled from + /// the given distribution. + pub fn random>>(shape: S, distribution: Distribution) -> Self { + Self::from_tensor(Tensor::random(shape, distribution)) + } - /// Returns the shape of the current tensor. - pub fn shape(&self) -> Shape { - self.tensor.shape() - } + /// Returns the shape of the current tensor. + pub fn shape(&self) -> Shape { + self.tensor.shape() + } - /// Applies element wise multiplication operation. - /// - /// `y = x2 * x1` - #[allow(clippy::should_implement_trait)] - pub fn mul(self, rhs: Self) -> Self { - Self::from_tensor(self.tensor.mul(rhs.tensor)) - } + /// Applies element wise multiplication operation. + /// + /// `y = x2 * x1` + #[allow(clippy::should_implement_trait)] + pub fn mul(self, rhs: Self) -> Self { + Self::from_tensor(self.tensor.mul(rhs.tensor)) + } - /// Reshape the tensor to have the given shape. - /// - /// # Panics - /// - /// If the tensor can not be reshape to the given shape. - pub fn reshape(self, shape: S, _: ND2) -> NamedTensor - where - S: Into>, - ND2: NamedDims>, - { - NamedTensor::from_tensor(self.tensor.reshape(shape.into())) - } + /// Reshape the tensor to have the given shape. + /// + /// # Panics + /// + /// If the tensor can not be reshape to the given shape. + pub fn reshape(self, shape: S, _: ND2) -> NamedTensor + where + S: Into>, + ND2: NamedDims>, + { + NamedTensor::from_tensor(self.tensor.reshape(shape.into())) + } } diff --git a/burn-tensor/src/tensor/named/dims.rs b/burn-tensor/src/tensor/named/dims.rs index 5b4c37908d..6f5631db77 100644 --- a/burn-tensor/src/tensor/named/dims.rs +++ b/burn-tensor/src/tensor/named/dims.rs @@ -6,90 +6,90 @@ use crate::Tensor; /// Dimension trait. pub trait Dim: core::fmt::Debug { - /// Converts the dimension to a string. - fn to_string() -> String; + /// Converts the dimension to a string. + fn to_string() -> String; } /// Named dimensions trait. pub trait NamedDims: core::fmt::Debug { - /// Tensor type. - type Tensor; + /// Tensor type. + type Tensor; - /// Converts the named dimensions to a string. - fn to_string() -> String; + /// Converts the named dimensions to a string. + fn to_string() -> String; } /// Named dimension macro. #[macro_export] macro_rules! NamedDim { - ($name:ident) => { - #[derive(Debug, Clone)] - pub struct $name; - impl Dim for $name { - fn to_string() -> String { - stringify!($name).to_string() - } - } - }; + ($name:ident) => { + #[derive(Debug, Clone)] + pub struct $name; + impl Dim for $name { + fn to_string() -> String { + stringify!($name).to_string() + } + } + }; } impl NamedDims for (D1,) where - B: Backend, - D1: Dim, + B: Backend, + D1: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!("[{}]", D1::to_string()) - } + type Tensor = Tensor; + fn to_string() -> String { + format!("[{}]", D1::to_string()) + } } impl NamedDims for (D1, D2) where - B: Backend, - D1: Dim, - D2: Dim, + B: Backend, + D1: Dim, + D2: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!("[{}, {}]", D1::to_string(), D2::to_string()) - } + type Tensor = Tensor; + fn to_string() -> String { + format!("[{}, {}]", D1::to_string(), D2::to_string()) + } } impl NamedDims for (D1, D2, D3) where - B: Backend, - D1: Dim, - D2: Dim, - D3: Dim, + B: Backend, + D1: Dim, + D2: Dim, + D3: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!( - "[{}, {}, {}]", - D1::to_string(), - D2::to_string(), - D3::to_string() - ) - } + type Tensor = Tensor; + fn to_string() -> String { + format!( + "[{}, {}, {}]", + D1::to_string(), + D2::to_string(), + D3::to_string() + ) + } } impl NamedDims for (D1, D2, D3, D4) where - B: Backend, - D1: Dim, - D2: Dim, - D3: Dim, - D4: Dim, + B: Backend, + D1: Dim, + D2: Dim, + D3: Dim, + D4: Dim, { - type Tensor = Tensor; - fn to_string() -> String { - format!( - "[{}, {}, {}, {}]", - D1::to_string(), - D2::to_string(), - D3::to_string(), - D4::to_string() - ) - } + type Tensor = Tensor; + fn to_string() -> String { + format!( + "[{}, {}, {}, {}]", + D1::to_string(), + D2::to_string(), + D3::to_string(), + D4::to_string() + ) + } } diff --git a/burn-tensor/src/tensor/named/matmul.rs b/burn-tensor/src/tensor/named/matmul.rs index 0a7df3534e..ef8e9849d0 100644 --- a/burn-tensor/src/tensor/named/matmul.rs +++ b/burn-tensor/src/tensor/named/matmul.rs @@ -2,58 +2,58 @@ use crate::backend::Backend; use crate::{Dim, NamedDims, NamedTensor, Tensor}; pub trait Matmul { - fn matmul(self, rhs: Rhs) -> Out; + fn matmul(self, rhs: Rhs) -> Out; } impl NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - /// Applies the matrix multiplication operation. - /// - /// `C = AB` - /// - /// # Panics - /// - /// If the two tensors dont' have a compatible shape. - pub fn matmul( - self, - rhs: NamedTensor, - ) -> NamedTensor - where - NamedDimsRhs: NamedDims>, - NamedDimsOut: NamedDims>, - Self: Matmul, NamedTensor>, - { - Matmul::matmul(self, rhs) - } + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors dont' have a compatible shape. + pub fn matmul( + self, + rhs: NamedTensor, + ) -> NamedTensor + where + NamedDimsRhs: NamedDims>, + NamedDimsOut: NamedDims>, + Self: Matmul, NamedTensor>, + { + Matmul::matmul(self, rhs) + } } impl Matmul, NamedTensor> - for NamedTensor + for NamedTensor { - fn matmul(self, rhs: NamedTensor) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) - } + fn matmul(self, rhs: NamedTensor) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) + } } impl - Matmul, NamedTensor> - for NamedTensor + Matmul, NamedTensor> + for NamedTensor { - fn matmul(self, rhs: NamedTensor) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) - } + fn matmul(self, rhs: NamedTensor) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) + } } impl - Matmul, NamedTensor> - for NamedTensor + Matmul, NamedTensor> + for NamedTensor { - fn matmul( - self, - rhs: NamedTensor, - ) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) - } + fn matmul( + self, + rhs: NamedTensor, + ) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor)) + } } diff --git a/burn-tensor/src/tensor/named/swap_dims.rs b/burn-tensor/src/tensor/named/swap_dims.rs index 1a2f4f18f5..51d0b3b11b 100644 --- a/burn-tensor/src/tensor/named/swap_dims.rs +++ b/burn-tensor/src/tensor/named/swap_dims.rs @@ -2,53 +2,53 @@ use crate::backend::Backend; use crate::{Dim, NamedDims, NamedTensor, Tensor}; pub trait SwapDims { - fn swap_dims(self) -> N; + fn swap_dims(self) -> N; } impl NamedTensor where - ND: NamedDims>, + ND: NamedDims>, { - /// Swap two dimensions. - pub fn swap_dims(self) -> NamedTensor - where - ND2: NamedDims>, - Self: SwapDims, D1, D2>, - { - SwapDims::swap_dims(self) - } + /// Swap two dimensions. + pub fn swap_dims(self) -> NamedTensor + where + ND2: NamedDims>, + Self: SwapDims, D1, D2>, + { + SwapDims::swap_dims(self) + } } macro_rules! generate_permut { - (2 => $output:ty, ($dim1:expr, $dim2:expr)) => { - impl SwapDims, $dim1, $dim2> - for NamedTensor - { - fn swap_dims(self) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) - } - } - }; + (2 => $output:ty, ($dim1:expr, $dim2:expr)) => { + impl SwapDims, $dim1, $dim2> + for NamedTensor + { + fn swap_dims(self) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) + } + } + }; - (3 => $output:ty, ($dim1:expr, $dim2:expr)) => { - impl SwapDims, $dim1, $dim2> - for NamedTensor - { - fn swap_dims(self) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) - } - } - }; + (3 => $output:ty, ($dim1:expr, $dim2:expr)) => { + impl SwapDims, $dim1, $dim2> + for NamedTensor + { + fn swap_dims(self) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) + } + } + }; - (4 => $output:ty, ($dim1:expr, $dim2:expr)) => { - impl - SwapDims, $dim1, $dim2> for NamedTensor - { - fn swap_dims(self) -> NamedTensor { - NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) - } - } - }; + (4 => $output:ty, ($dim1:expr, $dim2:expr)) => { + impl + SwapDims, $dim1, $dim2> for NamedTensor + { + fn swap_dims(self) -> NamedTensor { + NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2)) + } + } + }; } generate_permut!(2 => (D2, D1), (0, 1)); diff --git a/burn-tensor/src/tensor/ops/activation.rs b/burn-tensor/src/tensor/ops/activation.rs index a2f53f7a5a..b0aef546d5 100644 --- a/burn-tensor/src/tensor/ops/activation.rs +++ b/burn-tensor/src/tensor/ops/activation.rs @@ -7,99 +7,99 @@ use super::FloatTensor; /// /// This trait let backend implementations override activation functions for better performance. pub trait ActivationOps { - /// Applies the ReLU activation function. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The output tensor. - fn relu(tensor: FloatTensor) -> FloatTensor { - let mask = B::lower_equal_elem(tensor.clone(), 0.elem()); - - B::mask_fill(tensor, mask, 0.elem()) - } - - /// Applies the ReLU activation function backward. - /// - /// # Arguments - /// - /// * `output` - The output tensor. - /// - /// # Returns - /// - /// The gradient. - fn relu_backward( - output: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - let mask = B::lower_equal_elem(output, 0.elem()); - - B::mask_fill(grad, mask, 0.elem()) - } - - /// Applies the Gelu activation function. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The output tensor. - fn gelu(tensor: FloatTensor) -> FloatTensor { - let x = B::div_scalar(tensor.clone(), SQRT_2.elem()); - let x = B::erf(x); - let x = B::add_scalar(x, 1i32.elem()); - let x = B::mul(tensor, x); - - B::div_scalar(x, 2i32.elem()) - } - - /// Applies the Gelu activation function backward. - /// - /// # Arguments - /// - /// * `x` - The tensor. - /// * `grad` - The gradient. - /// - /// # Returns - /// - /// The output tensor. - fn gelu_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - // Derivative of the approximate gelu implementation based on tanh. - - let constant_1 = 0.0356774; - let constant_2 = 0.797885; - let constant_3 = 0.0535161; - let constant_4 = 0.398942; - - let x3 = B::powf(x.clone(), 3.0); - - let c1 = B::mul_scalar(x3.clone(), constant_1.elem()); - let c2 = B::mul_scalar(x.clone(), constant_2.elem()); - let c3 = B::mul_scalar(x3, constant_3.elem()); - let c4 = B::mul_scalar(x, constant_4.elem()); - - let inner1 = B::add(c1, c2); - let inner2 = B::add(c3, c4); - - let tanh = B::tanh(inner1); - - let sech = B::powf(tanh.clone(), 2.0); - let sech = B::neg(sech); - let sech = B::add_scalar(sech, 1.elem()); - - let y1 = B::mul_scalar(tanh, 0.5.elem()); - let y2 = B::mul(inner2, sech); - let y2 = B::add_scalar(y2, 0.5.elem()); - let y = B::add(y1, y2); - - B::mul(y, grad) - } + /// Applies the ReLU activation function. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The output tensor. + fn relu(tensor: FloatTensor) -> FloatTensor { + let mask = B::lower_equal_elem(tensor.clone(), 0.elem()); + + B::mask_fill(tensor, mask, 0.elem()) + } + + /// Applies the ReLU activation function backward. + /// + /// # Arguments + /// + /// * `output` - The output tensor. + /// + /// # Returns + /// + /// The gradient. + fn relu_backward( + output: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + let mask = B::lower_equal_elem(output, 0.elem()); + + B::mask_fill(grad, mask, 0.elem()) + } + + /// Applies the Gelu activation function. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The output tensor. + fn gelu(tensor: FloatTensor) -> FloatTensor { + let x = B::div_scalar(tensor.clone(), SQRT_2.elem()); + let x = B::erf(x); + let x = B::add_scalar(x, 1i32.elem()); + let x = B::mul(tensor, x); + + B::div_scalar(x, 2i32.elem()) + } + + /// Applies the Gelu activation function backward. + /// + /// # Arguments + /// + /// * `x` - The tensor. + /// * `grad` - The gradient. + /// + /// # Returns + /// + /// The output tensor. + fn gelu_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + // Derivative of the approximate gelu implementation based on tanh. + + let constant_1 = 0.0356774; + let constant_2 = 0.797885; + let constant_3 = 0.0535161; + let constant_4 = 0.398942; + + let x3 = B::powf(x.clone(), 3.0); + + let c1 = B::mul_scalar(x3.clone(), constant_1.elem()); + let c2 = B::mul_scalar(x.clone(), constant_2.elem()); + let c3 = B::mul_scalar(x3, constant_3.elem()); + let c4 = B::mul_scalar(x, constant_4.elem()); + + let inner1 = B::add(c1, c2); + let inner2 = B::add(c3, c4); + + let tanh = B::tanh(inner1); + + let sech = B::powf(tanh.clone(), 2.0); + let sech = B::neg(sech); + let sech = B::add_scalar(sech, 1.elem()); + + let y1 = B::mul_scalar(tanh, 0.5.elem()); + let y2 = B::mul(inner2, sech); + let y2 = B::add_scalar(y2, 0.5.elem()); + let y = B::add(y1, y2); + + B::mul(y, grad) + } } diff --git a/burn-tensor/src/tensor/ops/bool_tensor.rs b/burn-tensor/src/tensor/ops/bool_tensor.rs index fcf2a55d6a..cf478f972a 100644 --- a/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -7,254 +7,255 @@ use core::ops::Range; /// Bool Tensor API for basic operations, see [tensor](crate::Tensor) /// for documentation on each function. pub trait BoolTensorOps { - /// Creates a new bool tensor. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The boolean tensor with the given shape. - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor; + /// Creates a new bool tensor. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The boolean tensor with the given shape. + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor; - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn bool_shape(tensor: &BoolTensor) -> Shape; + /// Returns the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn bool_shape(tensor: &BoolTensor) -> Shape; - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn bool_into_data(tensor: BoolTensor) -> Reader>; + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn bool_into_data(tensor: BoolTensor) -> Reader>; - /// Gets the data from the tensor. - /// - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// - /// # Returns - /// - /// The data cloned from the data structure. - fn bool_to_data(tensor: &BoolTensor) -> Reader> { - Self::bool_into_data(tensor.clone()) - } + /// Gets the data from the tensor. + /// + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// + /// # Returns + /// + /// The data cloned from the data structure. + fn bool_to_data(tensor: &BoolTensor) -> Reader> { + Self::bool_into_data(tensor.clone()) + } - /// Creates a tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the data. - fn bool_from_data(data: Data, device: &Device) -> BoolTensor; + /// Creates a tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the data. + fn bool_from_data(data: Data, device: &Device) -> BoolTensor; - /// Converts bool tensor to int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The int tensor with the same data as the bool tensor. - fn bool_into_int(tensor: BoolTensor) -> IntTensor; + /// Converts bool tensor to int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The int tensor with the same data as the bool tensor. + fn bool_into_int(tensor: BoolTensor) -> IntTensor; - /// Converts bool tensor to float tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The float tensor with the same data as the bool tensor. - fn bool_into_float(tensor: BoolTensor) -> FloatTensor; + /// Converts bool tensor to float tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The float tensor with the same data as the bool tensor. + fn bool_into_float(tensor: BoolTensor) -> FloatTensor; - /// Gets the device of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device of the tensor. - fn bool_device(tensor: &BoolTensor) -> Device; + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn bool_device(tensor: &BoolTensor) -> Device; - /// Moves the tensor to the device. - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor; + /// Moves the tensor to the device. + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor; - /// Reshapes the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `shape` - The new shape. - /// - /// # Returns - /// - /// The tensor with the new shape. - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor; + /// Reshapes the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `shape` - The new shape. + /// + /// # Returns + /// + /// The tensor with the new shape. + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor; - /// Gets the values from the tensor for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges to get the values from. - /// - /// # Returns - /// - /// The tensor with the values for the given ranges. - fn bool_slice( - tensor: BoolTensor, - ranges: [Range; D2], - ) -> BoolTensor; + /// Gets the values from the tensor for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges to get the values from. + /// + /// # Returns + /// + /// The tensor with the values for the given ranges. + fn bool_slice( + tensor: BoolTensor, + ranges: [Range; D2], + ) -> BoolTensor; - /// Sets the values in the tensor for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `ranges` - The ranges to set the values for. - /// * `value` - The values to set. - /// - /// # Returns - /// - /// The tensor with the values set for the given ranges. - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [Range; D2], - value: BoolTensor, - ) -> BoolTensor; + /// Sets the values in the tensor for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `ranges` - The ranges to set the values for. + /// * `value` - The values to set. + /// + /// # Returns + /// + /// The tensor with the values set for the given ranges. + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [Range; D2], + value: BoolTensor, + ) -> BoolTensor; - /// Repeats one dimension of the tensor a given number of times along that dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to repeat. - /// * `times` - The number of times to repeat the dimension. - /// - /// # Returns - /// - /// The tensor with the dimension repeated. - fn bool_repeat( - tensor: BoolTensor, - dim: usize, - times: usize, - ) -> BoolTensor { - let mut shape = Self::bool_shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } - shape.dims[dim] = times; + /// Repeats one dimension of the tensor a given number of times along that dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to repeat. + /// * `times` - The number of times to repeat the dimension. + /// + /// # Returns + /// + /// The tensor with the dimension repeated. + fn bool_repeat( + tensor: BoolTensor, + dim: usize, + times: usize, + ) -> BoolTensor { + let mut shape = Self::bool_shape(&tensor); + if shape.dims[dim] != 1 { + panic!("Can only repeat dimension with dim=1"); + } + shape.dims[dim] = times; - let mut i = 0; - let ranges_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); + let mut i = 0; + let ranges_select_all = [0; D].map(|_| { + let start = 0; + let end = shape.dims[i]; + i += 1; + start..end + }); - let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor)); - for i in 0..times { - let mut ranges = ranges_select_all.clone(); - ranges[dim] = i..i + 1; - tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone()); - } + let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor)); + for i in 0..times { + let mut ranges = ranges_select_all.clone(); + ranges[dim] = i..i + 1; + tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone()); + } - tensor_output - } + tensor_output + } - /// Concatenates the tensors along the given dimension. - /// - /// # Arguments - /// - /// * `tensors` - The tensors to concatenate. - /// * `dim` - The dimension to concatenate along. - /// - /// # Returns - /// - /// The tensor with the tensors concatenated along the given dimension. - fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor; + /// Concatenates the tensors along the given dimension. + /// + /// # Arguments + /// + /// * `tensors` - The tensors to concatenate. + /// * `dim` - The dimension to concatenate along. + /// + /// # Returns + /// + /// The tensor with the tensors concatenated along the given dimension. + fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor; - /// Equates the two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The tensor with the result of the equate. - fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor; + /// Equates the two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The tensor with the result of the equate. + fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) + -> BoolTensor; - /// Inverses boolean values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The tensor with the result of the negation. - fn bool_not(tensor: BoolTensor) -> BoolTensor; + /// Inverses boolean values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The tensor with the result of the negation. + fn bool_not(tensor: BoolTensor) -> BoolTensor; - /// Transposes a bool tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn bool_transpose(tensor: BoolTensor) -> BoolTensor { - Self::bool_swap_dims(tensor, D - 2, D - 1) - } + /// Transposes a bool tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn bool_transpose(tensor: BoolTensor) -> BoolTensor { + Self::bool_swap_dims(tensor, D - 2, D - 1) + } - /// Swaps two dimensions of a bool tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn bool_swap_dims( - tensor: BoolTensor, - dim1: usize, - dim2: usize, - ) -> BoolTensor; + /// Swaps two dimensions of a bool tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn bool_swap_dims( + tensor: BoolTensor, + dim1: usize, + dim2: usize, + ) -> BoolTensor; } diff --git a/burn-tensor/src/tensor/ops/int_tensor.rs b/burn-tensor/src/tensor/ops/int_tensor.rs index e61da7622b..bf0b324c00 100644 --- a/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/burn-tensor/src/tensor/ops/int_tensor.rs @@ -7,844 +7,847 @@ use core::ops::Range; /// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor) /// for documentation on each function. pub trait IntTensorOps { - /// Creates a new int tensor. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The integer tensor with the given shape. - fn int_empty(shape: Shape, device: &Device) -> IntTensor; - - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn int_shape(tensor: &IntTensor) -> Shape; - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn int_into_data(tensor: IntTensor) -> Reader, D>>; - - /// Gets the data from the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data cloned from the data structure. - fn int_to_data(tensor: &IntTensor) -> Reader, D>> { - Self::int_into_data(tensor.clone()) - } - - /// Creates a tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the data. - fn int_from_data( - data: Data, D>, - device: &Device, - ) -> IntTensor; - - /// Gets the device of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device of the tensor. - fn int_device(tensor: &IntTensor) -> Device; - - /// Moves the tensor to the given device. - fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor; - - /// Reshapes the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `shape` - The new shape. - /// - /// # Returns - /// - /// The tensor with the new shape. - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor; - - /// Gets the element at the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `indices` - The indices. - /// - /// # Returns - /// - /// The elements at the given indices. - fn int_slice( - tensor: IntTensor, - indices: [Range; D2], - ) -> IntTensor; - - /// Sets the element at the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `indices` - The indices. - /// - /// # Returns - /// - /// The tensor with the element at the given indices set. - fn int_slice_assign( - tensor: IntTensor, - indices: [Range; D2], - value: IntTensor, - ) -> IntTensor; - - /// Converts int tensor to float tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The int tensor with the same data as the float tensor. - fn int_into_float(tensor: IntTensor) -> FloatTensor; - - /// Fills the tensor with values from the source tensor if the mask is true at the given - /// indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `mask` - The mask. - /// * `source` - The source tensor. - /// - /// # Returns - /// - /// The tensor with the values filled. - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - source: IntTensor, - ) -> IntTensor; - - /// Fills the tensor with the given value if the mask is true at the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `mask` - The mask. - /// * `value` - The value. - /// - /// # Returns - /// - /// The tensor with the values filled. - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor; - - /// Gather elements from the tensor at the given indices. - /// - /// # Arguments - /// - /// * `dim` - The dimension to gather from. - /// * `tensor` - The tensor. - /// * `indices` - The indices. - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor; - - /// Scatter a given value to the tensor at the given indices. - /// - /// # Arguments - /// - /// * `dim` - The dimension to scatter to. - /// * `tensor` - The tensor. - /// * `indices` - The indices. - /// * `value` - The value. - /// - /// # Returns - /// - /// The tensor with the values scattered. - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor; - - /// Select tensor elements along the given dimension corresponding to the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices. - /// - /// # Returns - /// - /// The tensor with the selected elements. - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor; - - /// Assign the selected elements along the given dimension corresponding to the given indices - /// to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices. - /// * `value` - The value. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor; - - /// Repeats the tensor along the given dimension the given number of times. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to repeat. - /// * `times` - The number of times to repeat. - /// - /// # Returns - /// - /// The tensor with the given dimension repeated the given number of times. - fn int_repeat( - tensor: IntTensor, - dim: usize, - times: usize, - ) -> IntTensor { - let mut shape = Self::int_shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); + /// Creates a new int tensor. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The integer tensor with the given shape. + fn int_empty(shape: Shape, device: &Device) -> IntTensor; + + /// Returns the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn int_shape(tensor: &IntTensor) -> Shape; + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn int_into_data(tensor: IntTensor) -> Reader, D>>; + + /// Gets the data from the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data cloned from the data structure. + fn int_to_data(tensor: &IntTensor) -> Reader, D>> { + Self::int_into_data(tensor.clone()) } - shape.dims[dim] = times; - - let mut i = 0; - let indices_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); - - let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor)); - for i in 0..times { - let mut indices = indices_select_all.clone(); - indices[dim] = i..i + 1; - tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone()); + + /// Creates a tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the data. + fn int_from_data( + data: Data, D>, + device: &Device, + ) -> IntTensor; + + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn int_device(tensor: &IntTensor) -> Device; + + /// Moves the tensor to the given device. + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor; + + /// Reshapes the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `shape` - The new shape. + /// + /// # Returns + /// + /// The tensor with the new shape. + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor; + + /// Gets the element at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The elements at the given indices. + fn int_slice( + tensor: IntTensor, + indices: [Range; D2], + ) -> IntTensor; + + /// Sets the element at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The tensor with the element at the given indices set. + fn int_slice_assign( + tensor: IntTensor, + indices: [Range; D2], + value: IntTensor, + ) -> IntTensor; + + /// Converts int tensor to float tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The int tensor with the same data as the float tensor. + fn int_into_float(tensor: IntTensor) -> FloatTensor; + + /// Fills the tensor with values from the source tensor if the mask is true at the given + /// indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `mask` - The mask. + /// * `source` - The source tensor. + /// + /// # Returns + /// + /// The tensor with the values filled. + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + source: IntTensor, + ) -> IntTensor; + + /// Fills the tensor with the given value if the mask is true at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `mask` - The mask. + /// * `value` - The value. + /// + /// # Returns + /// + /// The tensor with the values filled. + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor; + + /// Gather elements from the tensor at the given indices. + /// + /// # Arguments + /// + /// * `dim` - The dimension to gather from. + /// * `tensor` - The tensor. + /// * `indices` - The indices. + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor; + + /// Scatter a given value to the tensor at the given indices. + /// + /// # Arguments + /// + /// * `dim` - The dimension to scatter to. + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// * `value` - The value. + /// + /// # Returns + /// + /// The tensor with the values scattered. + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor; + + /// Select tensor elements along the given dimension corresponding to the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The tensor with the selected elements. + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor; + + /// Assign the selected elements along the given dimension corresponding to the given indices + /// to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices. + /// * `value` - The value. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor; + + /// Repeats the tensor along the given dimension the given number of times. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to repeat. + /// * `times` - The number of times to repeat. + /// + /// # Returns + /// + /// The tensor with the given dimension repeated the given number of times. + fn int_repeat( + tensor: IntTensor, + dim: usize, + times: usize, + ) -> IntTensor { + let mut shape = Self::int_shape(&tensor); + if shape.dims[dim] != 1 { + panic!("Can only repeat dimension with dim=1"); + } + shape.dims[dim] = times; + + let mut i = 0; + let indices_select_all = [0; D].map(|_| { + let start = 0; + let end = shape.dims[i]; + i += 1; + start..end + }); + + let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor)); + for i in 0..times { + let mut indices = indices_select_all.clone(); + indices[dim] = i..i + 1; + tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone()); + } + + tensor_output + } + + /// Concatenates the given tensors along the given dimension. + /// + /// # Arguments + /// + /// * `tensors` - The tensors. + /// * `dim` - The dimension to concatenate along. + /// + /// # Returns + /// + /// The concatenated tensor. + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor; + + /// Elementwise equality comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; + + /// Elementwise equality comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; + + /// Elementwise greater than comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; + + /// Elementwise greater than comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; + + /// Elementwise greater than or equal comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor; + + /// Elementwise greater than or equal comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor; + + /// Elementwise less than comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; + + /// Elementwise less than comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; + + /// Elementwise less than or equal comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor; + + /// Elementwise less than or equal comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The boolean tensor with the result of the comparison. + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor; + + // ==== NUMERIC ==== // + + /// Elementwise addition. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the addition. + fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise addition with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the addition. + fn int_add_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn int_clamp_min(tensor: IntTensor, min: IntElem) -> IntTensor { + let mask = Self::int_lower_elem(tensor.clone(), min); + Self::int_mask_fill(tensor, mask, min) + } + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn int_clamp_max(tensor: IntTensor, max: IntElem) -> IntTensor { + let mask = Self::int_greater_elem(tensor.clone(), max); + Self::int_mask_fill(tensor, mask, max) + } + + /// Clamps a tensor between a minimum and maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn int_clamp( + tensor: IntTensor, + min: IntElem, + max: IntElem, + ) -> IntTensor { + Self::int_clamp_min(Self::int_clamp_max(tensor, max), min) + } + + /// Elementwise subtraction. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the subtraction. + fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise subtraction with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the subtraction. + fn int_sub_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Elementwise multiplication. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the multiplication. + fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise multiplication with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the multiplication. + fn int_mul_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Elementwise division. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the division. + fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Elementwise division with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of the division. + fn int_div_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Elementwise negation. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to negate. + /// + /// # Returns + /// + /// The negated tensor. + fn int_neg(tensor: IntTensor) -> IntTensor { + Self::int_mul_scalar(tensor, (-1.0).elem::>()) + } + + /// Creates a tensor of zeros. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor of zeros. + fn int_zeros(shape: Shape, device: &Device) -> IntTensor; + + /// Creates a tensor of ones. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor of ones. + fn int_ones(shape: Shape, device: &Device) -> IntTensor; + + /// Creates a tensor filled with given value. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `fill_value` - The value with which to fill the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor filled with given value + fn int_full( + shape: Shape, + fill_value: IntElem, + device: &Device, + ) -> IntTensor { + Self::int_add_scalar(Self::int_zeros(shape, device), fill_value) + } + + /// Sums all elements in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// + /// # Returns + /// + /// The sum of all elements in the tensor. + fn int_sum(tensor: IntTensor) -> IntTensor; + + /// Sums all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension to sum along. + /// + /// # Returns + /// + /// The sum of all elements in the tensor along the dimension. + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Computes the mean of all elements in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// + /// # Returns + /// + /// The mean of all elements in the tensor. + fn int_mean(tensor: IntTensor) -> IntTensor { + let num_elems = B::int_shape(&tensor).num_elements(); + B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem()) + } + + /// Computes the mean of all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the mean of. + /// + /// # Returns + /// + /// The mean of all elements in the tensor along the dimension. + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Gets the indices of the maximum elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum indices of. + /// * `dim` - The dimension to get the maximum indices along. + /// + /// # Returns + /// + /// The indices of the maximum elements along the dimension. + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Gets the indices of the minimum elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum indices of. + /// * `dim` - The dimension to get the minimum indices along. + /// + /// # Returns + /// + /// The indices of the minimum elements along the dimension. + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor; + + /// Gets the maximum element in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum element of. + /// + /// # Returns + /// + /// The maximum element in the tensor. + fn int_max(tensor: IntTensor) -> IntTensor { + let shape = B::int_shape(&tensor); + let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); + + B::int_max_dim(tensor, 0) + } + + /// Gets the maximum element in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum element of. + /// * `dim` - The dimension to get the maximum element along. + /// + /// # Returns + /// + /// The maximum element in the tensor along the dimension. + fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let index = B::int_argmax(tensor.clone(), dim); + + B::int_gather(D - 1, tensor, index) + } + + /// Gets the maximum elements and corresponding indices along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements and indices of. + /// * `dim` - The dimension to get the maximum elements and indices along. + /// + /// # Returns + /// + /// The maximum elements and corresponding indices along the dimension. + fn int_max_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + let index = B::int_argmax(tensor.clone(), dim); + let values = B::int_gather(D - 1, tensor, index.clone()); + + (values, index) + } + + /// Gets the minimum element in the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum element of. + /// + /// # Returns + /// + /// The minimum element in the tensor. + fn int_min(tensor: IntTensor) -> IntTensor { + let shape = B::int_shape(&tensor); + let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); + + B::int_min_dim(tensor, 0) + } + + /// Gets the minimum elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum element of. + /// * `dim` - The dimension to get the minimum element along. + /// + /// # Returns + /// + /// The minimum element in the tensor along the dimension. + fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let index = B::int_argmin(tensor.clone(), dim); + + B::int_gather(D - 1, tensor, index) + } + + /// Gets the minimum elements and corresponding indices along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements and indices of. + /// * `dim` - The dimension to get the minimum elements and indices along. + /// + /// # Returns + /// + /// The minimum elements and corresponding indices along the dimension. + fn int_min_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + let indices = B::int_argmin(tensor.clone(), dim); + let values = B::int_gather(D - 1, tensor, indices.clone()); + + (values, indices) + } + + /// Returns a new tensor with absolute values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take absolute value of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with absolute values. + fn int_abs(tensor: IntTensor) -> IntTensor; + + /// Transposes an int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn int_transpose(tensor: IntTensor) -> IntTensor { + Self::int_swap_dims(tensor, D - 2, D - 1) } - tensor_output - } - - /// Concatenates the given tensors along the given dimension. - /// - /// # Arguments - /// - /// * `tensors` - The tensors. - /// * `dim` - The dimension to concatenate along. - /// - /// # Returns - /// - /// The concatenated tensor. - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor; - - /// Elementwise equality comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; - - /// Elementwise equality comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; - - /// Elementwise greater than comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; - - /// Elementwise greater than comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; - - /// Elementwise greater than or equal comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor; - - /// Elementwise greater than or equal comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor; - - /// Elementwise less than comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; - - /// Elementwise less than comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor; - - /// Elementwise less than or equal comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor; - - /// Elementwise less than or equal comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The boolean tensor with the result of the comparison. - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor; - - // ==== NUMERIC ==== // - - /// Elementwise addition. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the addition. - fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise addition with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the addition. - fn int_add_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn int_clamp_min(tensor: IntTensor, min: IntElem) -> IntTensor { - let mask = Self::int_lower_elem(tensor.clone(), min); - Self::int_mask_fill(tensor, mask, min) - } - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn int_clamp_max(tensor: IntTensor, max: IntElem) -> IntTensor { - let mask = Self::int_greater_elem(tensor.clone(), max); - Self::int_mask_fill(tensor, mask, max) - } - - /// Clamps a tensor between a minimum and maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn int_clamp( - tensor: IntTensor, - min: IntElem, - max: IntElem, - ) -> IntTensor { - Self::int_clamp_min(Self::int_clamp_max(tensor, max), min) - } - - /// Elementwise subtraction. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the subtraction. - fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise subtraction with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the subtraction. - fn int_sub_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Elementwise multiplication. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the multiplication. - fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise multiplication with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the multiplication. - fn int_mul_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Elementwise division. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of the division. - fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor; - - /// Elementwise division with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of the division. - fn int_div_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; - - /// Elementwise negation. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to negate. - /// - /// # Returns - /// - /// The negated tensor. - fn int_neg(tensor: IntTensor) -> IntTensor { - Self::int_mul_scalar(tensor, (-1.0).elem::>()) - } - - /// Creates a tensor of zeros. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor of zeros. - fn int_zeros(shape: Shape, device: &Device) -> IntTensor; - - /// Creates a tensor of ones. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor of ones. - fn int_ones(shape: Shape, device: &Device) -> IntTensor; - - /// Creates a tensor filled with given value. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `fill_value` - The value with which to fill the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor filled with given value - fn int_full( - shape: Shape, - fill_value: IntElem, - device: &Device, - ) -> IntTensor { - Self::int_add_scalar(Self::int_zeros(shape, device), fill_value) - } - - /// Sums all elements in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// - /// # Returns - /// - /// The sum of all elements in the tensor. - fn int_sum(tensor: IntTensor) -> IntTensor; - - /// Sums all elements in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// * `dim` - The dimension to sum along. - /// - /// # Returns - /// - /// The sum of all elements in the tensor along the dimension. - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Computes the mean of all elements in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// - /// # Returns - /// - /// The mean of all elements in the tensor. - fn int_mean(tensor: IntTensor) -> IntTensor { - let num_elems = B::int_shape(&tensor).num_elements(); - B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem()) - } - - /// Computes the mean of all elements in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the mean of. - /// - /// # Returns - /// - /// The mean of all elements in the tensor along the dimension. - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Gets the indices of the maximum elements along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum indices of. - /// * `dim` - The dimension to get the maximum indices along. - /// - /// # Returns - /// - /// The indices of the maximum elements along the dimension. - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Gets the indices of the minimum elements along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum indices of. - /// * `dim` - The dimension to get the minimum indices along. - /// - /// # Returns - /// - /// The indices of the minimum elements along the dimension. - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor; - - /// Gets the maximum element in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum element of. - /// - /// # Returns - /// - /// The maximum element in the tensor. - fn int_max(tensor: IntTensor) -> IntTensor { - let shape = B::int_shape(&tensor); - let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); - - B::int_max_dim(tensor, 0) - } - - /// Gets the maximum element in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum element of. - /// * `dim` - The dimension to get the maximum element along. - /// - /// # Returns - /// - /// The maximum element in the tensor along the dimension. - fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let index = B::int_argmax(tensor.clone(), dim); - - B::int_gather(D - 1, tensor, index) - } - - /// Gets the maximum elements and corresponding indices along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements and indices of. - /// * `dim` - The dimension to get the maximum elements and indices along. - /// - /// # Returns - /// - /// The maximum elements and corresponding indices along the dimension. - fn int_max_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - let index = B::int_argmax(tensor.clone(), dim); - let values = B::int_gather(D - 1, tensor, index.clone()); - - (values, index) - } - - /// Gets the minimum element in the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum element of. - /// - /// # Returns - /// - /// The minimum element in the tensor. - fn int_min(tensor: IntTensor) -> IntTensor { - let shape = B::int_shape(&tensor); - let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); - - B::int_min_dim(tensor, 0) - } - - /// Gets the minimum elements in the tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum element of. - /// * `dim` - The dimension to get the minimum element along. - /// - /// # Returns - /// - /// The minimum element in the tensor along the dimension. - fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let index = B::int_argmin(tensor.clone(), dim); - - B::int_gather(D - 1, tensor, index) - } - - /// Gets the minimum elements and corresponding indices along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements and indices of. - /// * `dim` - The dimension to get the minimum elements and indices along. - /// - /// # Returns - /// - /// The minimum elements and corresponding indices along the dimension. - fn int_min_dim_with_indices( - tensor: IntTensor, - dim: usize, - ) -> (IntTensor, IntTensor) { - let indices = B::int_argmin(tensor.clone(), dim); - let values = B::int_gather(D - 1, tensor, indices.clone()); - - (values, indices) - } - - /// Returns a new tensor with absolute values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take absolute value of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with absolute values. - fn int_abs(tensor: IntTensor) -> IntTensor; - - /// Transposes an int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn int_transpose(tensor: IntTensor) -> IntTensor { - Self::int_swap_dims(tensor, D - 2, D - 1) - } - - /// Swaps two dimensions of an int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn int_swap_dims( - tensor: IntTensor, - dim1: usize, - dim2: usize, - ) -> IntTensor; + /// Swaps two dimensions of an int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn int_swap_dims( + tensor: IntTensor, + dim1: usize, + dim2: usize, + ) -> IntTensor; } diff --git a/burn-tensor/src/tensor/ops/modules/base.rs b/burn-tensor/src/tensor/ops/modules/base.rs index 636c073bc4..4290fbca23 100644 --- a/burn-tensor/src/tensor/ops/modules/base.rs +++ b/burn-tensor/src/tensor/ops/modules/base.rs @@ -1,434 +1,441 @@ use super::{conv, pool, unfold::unfold4d_using_conv2d}; use crate::{ - backend::Backend, - ops::{FloatTensor, IntTensor}, - Shape, + backend::Backend, + ops::{FloatTensor, IntTensor}, + Shape, }; /// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). #[derive(new)] pub struct Conv2dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, - /// Weights gradient. - pub weights_grad: FloatTensor, + /// Weights gradient. + pub weights_grad: FloatTensor, - /// Bias gradient. - pub bias_grad: Option>, + /// Bias gradient. + pub bias_grad: Option>, } /// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). #[derive(new)] pub struct MaxPool1dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, } /// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices). #[derive(new)] pub struct MaxPool1dWithIndices { - /// The output tensor. - pub output: FloatTensor, + /// The output tensor. + pub output: FloatTensor, - /// The indices tensor. - pub indices: IntTensor, + /// The indices tensor. + pub indices: IntTensor, } /// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d). #[derive(new)] pub struct MaxPool2dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, } /// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices). #[derive(new)] pub struct MaxPool2dWithIndices { - /// The output tensor. - pub output: FloatTensor, + /// The output tensor. + pub output: FloatTensor, - /// The indices tensor. - pub indices: IntTensor, + /// The indices tensor. + pub indices: IntTensor, } /// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d). #[derive(new)] pub struct Conv1dBackward { - /// Gradient. - pub x_grad: FloatTensor, + /// Gradient. + pub x_grad: FloatTensor, - /// Weights gradient. - pub weights_grad: FloatTensor, + /// Weights gradient. + pub weights_grad: FloatTensor, - /// Bias gradient. - pub bias_grad: Option>, + /// Bias gradient. + pub bias_grad: Option>, } /// Convolution options. #[derive(new, Debug, Clone, Hash)] pub struct ConvOptions { - /// Stride. - pub stride: [usize; N], + /// Stride. + pub stride: [usize; N], - /// Padding. - pub padding: [usize; N], + /// Padding. + pub padding: [usize; N], - /// Dilation. - pub dilation: [usize; N], + /// Dilation. + pub dilation: [usize; N], - /// Groups. - pub groups: usize, + /// Groups. + pub groups: usize, } /// Transposed convolution options. #[derive(new, Debug, Clone, Hash)] pub struct ConvTransposeOptions { - /// Stride. - pub stride: [usize; N], + /// Stride. + pub stride: [usize; N], - /// Padding. - pub padding: [usize; N], + /// Padding. + pub padding: [usize; N], - /// Padding out. - pub padding_out: [usize; N], + /// Padding out. + pub padding_out: [usize; N], - /// Dilation. - pub dilation: [usize; N], + /// Dilation. + pub dilation: [usize; N], - /// Groups. - pub groups: usize, + /// Groups. + pub groups: usize, } /// Unfold operation options. #[derive(new, Debug, Clone)] pub struct UnfoldOptions { - /// The number of positions to slide over the input tensor in each dimension. - /// A stride of `[1, 1]` will slide the kernel one pixel at a time. - pub stride: [usize; 2], + /// The number of positions to slide over the input tensor in each dimension. + /// A stride of `[1, 1]` will slide the kernel one pixel at a time. + pub stride: [usize; 2], - /// The number of zero-padding pixels added to each side of the input tensor in each dimension. - pub padding: [usize; 2], + /// The number of zero-padding pixels added to each side of the input tensor in each dimension. + pub padding: [usize; 2], - /// The spacing between the blocks (patches) in the original input tensor. - pub dilation: [usize; 2], + /// The spacing between the blocks (patches) in the original input tensor. + pub dilation: [usize; 2], } /// Module operations trait. pub trait ModuleOps { - /// Embedding operation. - /// - /// # Arguments - /// - /// * `weights` - The embedding weights. - /// * `indices` - The indices tensor. - /// - /// # Returns - /// - /// The output tensor. - fn embedding(weights: FloatTensor, indices: IntTensor) -> FloatTensor { - let [batch_size, seq_length] = B::int_shape(&indices).dims; - let [_, d_model] = B::shape(&weights).dims; - - let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); - let output = B::select(weights, 0, indices); - - B::reshape(output, Shape::new([batch_size, seq_length, d_model])) - } - - /// Embedding backward operation. - /// - /// # Arguments - /// - /// * `weights` - The embedding weights. - /// * `output_grad` - The output gradient. - /// * `indices` - The indices tensor. - /// - /// # Returns - /// - /// The gradient. - fn embedding_backward( - weights: FloatTensor, - output_grad: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - let [batch_size, seq_length] = B::int_shape(&indices).dims; - let [n_embeddings, d_model] = B::shape(&weights).dims; - let device = B::device(&weights); - - let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); - let output_grad = B::reshape(output_grad, Shape::new([batch_size * seq_length, d_model])); - let grad = B::zeros(Shape::new([n_embeddings, d_model]), &device); - - B::select_assign(grad, 0, indices, output_grad) - } - /// One dimensional convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, length]`, - /// weight: `[channels_out, channels_in, kernel_size]`, - /// bias: `[channels_out]`, - fn conv1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, - ) -> FloatTensor { - conv::conv1d_from_conv2d::(x, weight, bias, options) - } - /// Backward pass for the [conv1d](ModuleOps::conv1d) operation. - fn conv1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<1>, - ) -> Conv1dBackward { - conv::conv1d_backward(x, weight, bias, output_grad, options) - } - /// Two dimensional convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, height, width]`, - /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, - /// bias: `[channels_out]`, - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor; - /// Backward pass for the [conv2d](ModuleOps::conv2d) operation. - fn conv2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<2>, - ) -> Conv2dBackward { - conv::conv2d_backward(x, weight, bias, output_grad, options) - } - /// One dimensional transposed convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, length]`, - /// weight: `[channels_in, channels_out, length]`, - /// bias: `[channels_out]`, - fn conv_transpose1d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, - ) -> FloatTensor { - conv::conv_transpose1d_from_conv_transpose2d::(x, weight, bias, options) - } - /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation. - fn conv_transpose1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<1>, - ) -> Conv1dBackward { - conv::conv_transpose1d_backward(x, weight, bias, output_grad, options) - } - /// Two dimensional transposed convolution. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, height, width]`, - /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`, - /// bias: `[channels_out]`, - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor; - - /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation. - fn conv_transpose2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<2>, - ) -> Conv2dBackward { - conv::conv_transpose2d_backward(x, weight, bias, output_grad, options) - } - - /// Four-dimensional unfolding. - /// - /// # Shapes - /// - /// x: `[batch_size, channels_in, height, width]`, - /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, - fn unfold4d( - x: FloatTensor, - kernel_size: [usize; 2], - options: UnfoldOptions, - ) -> FloatTensor { - unfold4d_using_conv2d::(x, kernel_size, options) - } - - /// One dimensional avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, length], - fn avg_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - pool::avg_pool1d_from_2d::(x, kernel_size, stride, padding, count_include_pad) - } - /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation. - fn avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, - ) -> FloatTensor { - pool::avg_pool1d_backward_from_2d::(x, grad, kernel_size, stride, padding, count_include_pad) - } - /// Two dimensional avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor; - /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation. - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor; - /// Two dimensional adaptive avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor; - /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation. - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor; - /// One dimensional adaptive avg pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, length], - fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { - pool::adaptive_avg_pool1d_from_2d::(x, output_size) - } - /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation. - fn adaptive_avg_pool1d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - pool::adaptive_avg_pool1d_backward_from_2d::(x, grad) - } - /// One dimensional max pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, length], - fn max_pool1d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> FloatTensor { - pool::max_pool1d_from_2d::(x, kernel_size, stride, padding, dilation) - } - - /// One dimensional max pooling with indices. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn max_pool1d_with_indices( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - ) -> MaxPool1dWithIndices { - pool::max_pool1d_with_indices_from_2d::(x, kernel_size, stride, padding, dilation) - } - /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation. - fn max_pool1d_with_indices_backward( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool1dBackward { - pool::max_pool1d_with_indices_backward_from_2d::( - x, - kernel_size, - stride, - padding, - dilation, - output_grad, - indices, - ) - } - - /// Two dimensional max pooling. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor; - - /// Two dimensional max pooling with indices. - /// - /// # Shapes - /// - /// x: [batch_size, channels, height, width], - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices; - /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation. - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward; + /// Embedding operation. + /// + /// # Arguments + /// + /// * `weights` - The embedding weights. + /// * `indices` - The indices tensor. + /// + /// # Returns + /// + /// The output tensor. + fn embedding(weights: FloatTensor, indices: IntTensor) -> FloatTensor { + let [batch_size, seq_length] = B::int_shape(&indices).dims; + let [_, d_model] = B::shape(&weights).dims; + + let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); + let output = B::select(weights, 0, indices); + + B::reshape(output, Shape::new([batch_size, seq_length, d_model])) + } + + /// Embedding backward operation. + /// + /// # Arguments + /// + /// * `weights` - The embedding weights. + /// * `output_grad` - The output gradient. + /// * `indices` - The indices tensor. + /// + /// # Returns + /// + /// The gradient. + fn embedding_backward( + weights: FloatTensor, + output_grad: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + let [batch_size, seq_length] = B::int_shape(&indices).dims; + let [n_embeddings, d_model] = B::shape(&weights).dims; + let device = B::device(&weights); + + let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); + let output_grad = B::reshape(output_grad, Shape::new([batch_size * seq_length, d_model])); + let grad = B::zeros(Shape::new([n_embeddings, d_model]), &device); + + B::select_assign(grad, 0, indices, output_grad) + } + /// One dimensional convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, length]`, + /// weight: `[channels_out, channels_in, kernel_size]`, + /// bias: `[channels_out]`, + fn conv1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> FloatTensor { + conv::conv1d_from_conv2d::(x, weight, bias, options) + } + /// Backward pass for the [conv1d](ModuleOps::conv1d) operation. + fn conv1d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<1>, + ) -> Conv1dBackward { + conv::conv1d_backward(x, weight, bias, output_grad, options) + } + /// Two dimensional convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, height, width]`, + /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, + /// bias: `[channels_out]`, + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor; + /// Backward pass for the [conv2d](ModuleOps::conv2d) operation. + fn conv2d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<2>, + ) -> Conv2dBackward { + conv::conv2d_backward(x, weight, bias, output_grad, options) + } + /// One dimensional transposed convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, length]`, + /// weight: `[channels_in, channels_out, length]`, + /// bias: `[channels_out]`, + fn conv_transpose1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + conv::conv_transpose1d_from_conv_transpose2d::(x, weight, bias, options) + } + /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation. + fn conv_transpose1d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, + ) -> Conv1dBackward { + conv::conv_transpose1d_backward(x, weight, bias, output_grad, options) + } + /// Two dimensional transposed convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, height, width]`, + /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`, + /// bias: `[channels_out]`, + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor; + + /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation. + fn conv_transpose2d_backward( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, + ) -> Conv2dBackward { + conv::conv_transpose2d_backward(x, weight, bias, output_grad, options) + } + + /// Four-dimensional unfolding. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, height, width]`, + /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, + fn unfold4d( + x: FloatTensor, + kernel_size: [usize; 2], + options: UnfoldOptions, + ) -> FloatTensor { + unfold4d_using_conv2d::(x, kernel_size, options) + } + + /// One dimensional avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, length], + fn avg_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + pool::avg_pool1d_from_2d::(x, kernel_size, stride, padding, count_include_pad) + } + /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation. + fn avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + pool::avg_pool1d_backward_from_2d::( + x, + grad, + kernel_size, + stride, + padding, + count_include_pad, + ) + } + /// Two dimensional avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor; + /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation. + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor; + /// Two dimensional adaptive avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor; + /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation. + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor; + /// One dimensional adaptive avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, length], + fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { + pool::adaptive_avg_pool1d_from_2d::(x, output_size) + } + /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation. + fn adaptive_avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + pool::adaptive_avg_pool1d_backward_from_2d::(x, grad) + } + /// One dimensional max pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, length], + fn max_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> FloatTensor { + pool::max_pool1d_from_2d::(x, kernel_size, stride, padding, dilation) + } + + /// One dimensional max pooling with indices. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool1d_with_indices( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices { + pool::max_pool1d_with_indices_from_2d::(x, kernel_size, stride, padding, dilation) + } + /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation. + fn max_pool1d_with_indices_backward( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool1dBackward { + pool::max_pool1d_with_indices_backward_from_2d::( + x, + kernel_size, + stride, + padding, + dilation, + output_grad, + indices, + ) + } + + /// Two dimensional max pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor; + + /// Two dimensional max pooling with indices. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices; + /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation. + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward; } diff --git a/burn-tensor/src/tensor/ops/modules/conv.rs b/burn-tensor/src/tensor/ops/modules/conv.rs index d97969a9b9..666068f566 100644 --- a/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/burn-tensor/src/tensor/ops/modules/conv.rs @@ -5,757 +5,763 @@ use libm::ceilf; /// Calculate the expected padding size required when applying a convolution. pub fn calculate_conv_padding( - kernel_size: usize, - stride: usize, - size_in: usize, - size_out: usize, + kernel_size: usize, + stride: usize, + size_in: usize, + size_out: usize, ) -> usize { - let kernel_size = kernel_size as f32; - let stride = stride as f32; - let size_in = size_in as f32; - let size_out = size_out as f32; + let kernel_size = kernel_size as f32; + let stride = stride as f32; + let size_in = size_in as f32; + let size_out = size_out as f32; - let padding = stride * (size_out - 1.) - size_in + kernel_size; - let padding = ceilf(padding / 2.); + let padding = stride * (size_out - 1.) - size_in + kernel_size; + let padding = ceilf(padding / 2.); - padding as usize + padding as usize } /// Calculate the expected output size when doing a convolution operation. pub fn calculate_conv_output_size( - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - size_in: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + size_in: usize, ) -> usize { - (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 } /// Calculate the expected output size when doing a transposed convolution operation. pub fn calculate_conv_transpose_output_size( - kernel_size: usize, - stride: usize, - padding: usize, - padding_out: usize, - dilation: usize, - size_in: usize, + kernel_size: usize, + stride: usize, + padding: usize, + padding_out: usize, + dilation: usize, + size_in: usize, ) -> usize { - (size_in - 1) * stride + dilation * (kernel_size - 1) + padding_out - 2 * padding + 1 + (size_in - 1) * stride + dilation * (kernel_size - 1) + padding_out - 2 * padding + 1 } /// Calculate the expected output size when doing a pooling operation. pub fn calculate_pool_output_size( - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - size_in: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + size_in: usize, ) -> usize { - ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 + ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 } /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions. pub(crate) fn conv1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<1>, ) -> Conv1dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _, length_in] = B::shape(&x).dims; - let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims; - let [_, _, kernel_size] = weight_shape.dims; - - let padding_out = calculate_padding_out( - kernel_size, - options.stride[0], - options.padding[0], - options.dilation[0], - length_in, - length_out, - ); - - let x_grad = B::conv_transpose1d( - output_grad.clone(), - weight, - None, - ConvTransposeOptions::new( - options.stride, - options.padding, - [padding_out], - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => conv1d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), - false => conv1d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv1dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _, length_in] = B::shape(&x).dims; + let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims; + let [_, _, kernel_size] = weight_shape.dims; + + let padding_out = calculate_padding_out( + kernel_size, + options.stride[0], + options.padding[0], + options.dilation[0], + length_in, + length_out, + ); + + let x_grad = B::conv_transpose1d( + output_grad.clone(), + weight, + None, + ConvTransposeOptions::new( + options.stride, + options.padding, + [padding_out], + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => conv1d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), + false => conv1d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv1dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass using convolutions. pub(crate) fn conv2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvOptions<2>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvOptions<2>, ) -> Conv2dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims; - let [_, _, height_out, width_out] = B::shape(&output_grad).dims; - let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims; - - let padding_1_out = calculate_padding_out( - kernel_size_1, - options.stride[0], - options.padding[0], - options.dilation[0], - height_in, - height_out, - ); - let padding_2_out = calculate_padding_out( - kernel_size_2, - options.stride[1], - options.padding[1], - options.dilation[1], - width_in, - width_out, - ); - - let x_grad = B::conv_transpose2d( - output_grad.clone(), - weight, - None, - ConvTransposeOptions::new( - options.stride, - options.padding, - [padding_1_out, padding_2_out], - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => conv2d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), - false => conv2d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv2dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape( - grad, - Shape::new([channels_out, batch_size * height_out * width_out]), - ); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims; + let [_, _, height_out, width_out] = B::shape(&output_grad).dims; + let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims; + + let padding_1_out = calculate_padding_out( + kernel_size_1, + options.stride[0], + options.padding[0], + options.dilation[0], + height_in, + height_out, + ); + let padding_2_out = calculate_padding_out( + kernel_size_2, + options.stride[1], + options.padding[1], + options.dilation[1], + width_in, + width_out, + ); + + let x_grad = B::conv_transpose2d( + output_grad.clone(), + weight, + None, + ConvTransposeOptions::new( + options.stride, + options.padding, + [padding_1_out, padding_2_out], + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => conv2d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options), + false => conv2d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv2dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape( + grad, + Shape::new([channels_out, batch_size * height_out * width_out]), + ); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass using convolutions. pub(crate) fn conv_transpose2d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<2>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, ) -> Conv2dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _channels_in, _, _] = B::shape(&x).dims; - let [_, channels_out, height_out, width_out] = B::shape(&output_grad).dims; - - let x_grad = B::conv2d( - output_grad.clone(), - weight, - None, - ConvOptions::new( - options.stride, - options.padding, - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => { - conv_transpose2d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options) - } - false => conv_transpose2d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv2dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape( - grad, - Shape::new([channels_out, batch_size * height_out * width_out]), - ); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _channels_in, _, _] = B::shape(&x).dims; + let [_, channels_out, height_out, width_out] = B::shape(&output_grad).dims; + + let x_grad = B::conv2d( + output_grad.clone(), + weight, + None, + ConvOptions::new( + options.stride, + options.padding, + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => conv_transpose2d_weight_grad_no_groups::( + x, + output_grad.clone(), + weight_shape, + options, + ), + false => conv_transpose2d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv2dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape( + grad, + Shape::new([channels_out, batch_size * height_out * width_out]), + ); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass using convolutions. pub(crate) fn conv_transpose1d_backward( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - output_grad: FloatTensor, - options: ConvTransposeOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, ) -> Conv1dBackward { - let weight_shape = B::shape(&weight); - let weight_device = B::device(&weight); - - let [batch_size, _channels_in, _] = B::shape(&x).dims; - let [_, channels_out, length_out] = B::shape(&output_grad).dims; - - let x_grad = B::conv1d( - output_grad.clone(), - weight, - None, - ConvOptions::new( - options.stride, - options.padding, - options.dilation, - options.groups, - ), - ); - - let weight_grad = match options.groups == 1 { - true => { - conv_transpose1d_weight_grad_no_groups::(x, output_grad.clone(), weight_shape, options) - } - false => conv_transpose1d_weight_grad_groups::( - x, - B::zeros(weight_shape, &weight_device), - output_grad.clone(), - options, - ), - }; - - Conv1dBackward::new( - x_grad, - weight_grad, - bias.map(|b| { - let grad = B::swap_dims(output_grad, 0, 1); - let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); - let grad = B::sum_dim(grad, 1); - - B::reshape(grad, B::shape(&b)) - }), - ) + let weight_shape = B::shape(&weight); + let weight_device = B::device(&weight); + + let [batch_size, _channels_in, _] = B::shape(&x).dims; + let [_, channels_out, length_out] = B::shape(&output_grad).dims; + + let x_grad = B::conv1d( + output_grad.clone(), + weight, + None, + ConvOptions::new( + options.stride, + options.padding, + options.dilation, + options.groups, + ), + ); + + let weight_grad = match options.groups == 1 { + true => conv_transpose1d_weight_grad_no_groups::( + x, + output_grad.clone(), + weight_shape, + options, + ), + false => conv_transpose1d_weight_grad_groups::( + x, + B::zeros(weight_shape, &weight_device), + output_grad.clone(), + options, + ), + }; + + Conv1dBackward::new( + x_grad, + weight_grad, + bias.map(|b| { + let grad = B::swap_dims(output_grad, 0, 1); + let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out])); + let grad = B::sum_dim(grad, 1); + + B::reshape(grad, B::shape(&b)) + }), + ) } /// Execute a 1D convolution using a 2D convolution. pub(crate) fn conv1d_from_conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, ) -> FloatTensor { - let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims; - let [batch_size, channels_in, length_in] = B::shape(&x).dims; - - let weight = B::reshape( - weight, - Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]), - ); - let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); - - let tensor = B::conv2d( - x, - weight, - bias, - ConvOptions::new( - [options.stride[0], 1], - [options.padding[0], 0], - [options.dilation[0], 1], - options.groups, - ), - ); - let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; - B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) + let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims; + let [batch_size, channels_in, length_in] = B::shape(&x).dims; + + let weight = B::reshape( + weight, + Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]), + ); + let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); + + let tensor = B::conv2d( + x, + weight, + bias, + ConvOptions::new( + [options.stride[0], 1], + [options.padding[0], 0], + [options.dilation[0], 1], + options.groups, + ), + ); + let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; + B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } /// Execute a 1D transposed convolution using a 2D transposed convolution. pub(crate) fn conv_transpose1d_from_conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<1>, + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, ) -> FloatTensor { - let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims; - let [batch_size, _channels_in, length_in] = B::shape(&x).dims; - - let weight = B::reshape( - weight, - Shape::new([channels_in, channels_out, kernel_size, 1]), - ); - let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); - - let tensor = B::conv_transpose2d( - x, - weight, - bias, - ConvTransposeOptions::new( - [options.stride[0], 1], - [options.padding[0], 0], - [options.padding_out[0], 0], - [options.dilation[0], 1], - options.groups, - ), - ); - let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; - B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) + let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims; + let [batch_size, _channels_in, length_in] = B::shape(&x).dims; + + let weight = B::reshape( + weight, + Shape::new([channels_in, channels_out, kernel_size, 1]), + ); + let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); + + let tensor = B::conv_transpose2d( + x, + weight, + bias, + ConvTransposeOptions::new( + [options.stride[0], 1], + [options.padding[0], 0], + [options.padding_out[0], 0], + [options.dilation[0], 1], + options.groups, + ), + ); + let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; + B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } fn conv1d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvOptions<1>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<1>, ) -> FloatTensor { - let [channels_out, increment_ci, kernel_size] = B::shape(&weight_grad).dims; - let increment_co = channels_out / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv1d( - x, - grad, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - weight_grad = B::slice_assign( - weight_grad, - [start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size], - weight_grad_tmp, - ); - } + let [channels_out, increment_ci, kernel_size] = B::shape(&weight_grad).dims; + let increment_co = channels_out / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv1d( + x, + grad, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + weight_grad = B::slice_assign( + weight_grad, + [start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size], + weight_grad_tmp, + ); + } - weight_grad + weight_grad } fn conv2d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvOptions<2>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvOptions<2>, ) -> FloatTensor { - let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; - let increment_co = channels_out / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv2d( - x, - grad, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - weight_grad = B::slice_assign( - weight_grad, - [ - start_idx_co..end_idx_co, - 0..increment_ci, - 0..kernel_size_1, - 0..kernel_size_2, - ], - weight_grad_tmp, - ); - } + let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; + let increment_co = channels_out / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv2d( + x, + grad, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + weight_grad = B::slice_assign( + weight_grad, + [ + start_idx_co..end_idx_co, + 0..increment_ci, + 0..kernel_size_1, + 0..kernel_size_2, + ], + weight_grad_tmp, + ); + } - weight_grad + weight_grad } fn conv_transpose2d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvTransposeOptions<2>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<2>, ) -> FloatTensor { - let [channels_in, increment_co, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; - let increment_ci = channels_in / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv2d( - grad, - x, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::shape(&weight_grad_tmp).dims; - - if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { - weight_grad_tmp = B::slice( - weight_grad_tmp, - [ - 0..increment_ci, - 0..increment_co, - 0..kernel_size_1, - 0..kernel_size_2, - ], - ); + let [channels_in, increment_co, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims; + let increment_ci = channels_in / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv2d( + grad, + x, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::shape(&weight_grad_tmp).dims; + + if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { + weight_grad_tmp = B::slice( + weight_grad_tmp, + [ + 0..increment_ci, + 0..increment_co, + 0..kernel_size_1, + 0..kernel_size_2, + ], + ); + } + + weight_grad = B::slice_assign( + weight_grad, + [ + start_idx_ci..end_idx_ci, + 0..increment_co, + 0..kernel_size_1, + 0..kernel_size_2, + ], + weight_grad_tmp, + ); } - weight_grad = B::slice_assign( - weight_grad, - [ - start_idx_ci..end_idx_ci, - 0..increment_co, - 0..kernel_size_1, - 0..kernel_size_2, - ], - weight_grad_tmp, - ); - } - - weight_grad + weight_grad } fn conv1d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<3>, - options: ConvOptions<1>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<3>, + options: ConvOptions<1>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv1d( - x_swapped, - output_grad_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - if B::shape(&weight_grad) != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - ], + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv1d( + x_swapped, + output_grad_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); - } - weight_grad + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + if B::shape(&weight_grad) != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + ], + ); + } + weight_grad } fn conv_transpose1d_weight_grad_groups( - x: FloatTensor, - mut weight_grad: FloatTensor, - output_grad: FloatTensor, - options: ConvTransposeOptions<1>, + x: FloatTensor, + mut weight_grad: FloatTensor, + output_grad: FloatTensor, + options: ConvTransposeOptions<1>, ) -> FloatTensor { - let [channels_in, increment_co, kernel_size] = B::shape(&weight_grad).dims; - let increment_ci = channels_in / options.groups; - - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - - for g in 0..options.groups { - let start_idx_ci = g * increment_ci; - let end_idx_ci = (g + 1) * increment_ci; - let start_idx_co = g * increment_co; - let end_idx_co = (g + 1) * increment_co; - - let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); - let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); - let mut weight_grad_tmp = B::conv1d( - grad, - x, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); - let [_, _, kernel_size_tmp] = B::shape(&weight_grad_tmp).dims; - - if kernel_size_tmp != kernel_size { - weight_grad_tmp = B::slice( - weight_grad_tmp, - [0..increment_ci, 0..increment_co, 0..kernel_size], - ); + let [channels_in, increment_co, kernel_size] = B::shape(&weight_grad).dims; + let increment_ci = channels_in / options.groups; + + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + + for g in 0..options.groups { + let start_idx_ci = g * increment_ci; + let end_idx_ci = (g + 1) * increment_ci; + let start_idx_co = g * increment_co; + let end_idx_co = (g + 1) * increment_co; + + let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]); + let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]); + let mut weight_grad_tmp = B::conv1d( + grad, + x, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), + ); + weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1); + let [_, _, kernel_size_tmp] = B::shape(&weight_grad_tmp).dims; + + if kernel_size_tmp != kernel_size { + weight_grad_tmp = B::slice( + weight_grad_tmp, + [0..increment_ci, 0..increment_co, 0..kernel_size], + ); + } + + weight_grad = B::slice_assign( + weight_grad, + [start_idx_ci..end_idx_ci, 0..increment_co, 0..kernel_size], + weight_grad_tmp, + ); } - weight_grad = B::slice_assign( - weight_grad, - [start_idx_ci..end_idx_ci, 0..increment_co, 0..kernel_size], - weight_grad_tmp, - ); - } - - weight_grad + weight_grad } fn conv2d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<4>, - options: ConvOptions<2>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<4>, + options: ConvOptions<2>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv2d( - x_swapped, - output_grad_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - if B::shape(&weight_grad) != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - 0..weight_shape.dims[3], - ], + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv2d( + x_swapped, + output_grad_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); - } - weight_grad + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + if B::shape(&weight_grad) != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + 0..weight_shape.dims[3], + ], + ); + } + weight_grad } fn conv_transpose1d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<3>, - options: ConvTransposeOptions<1>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<3>, + options: ConvTransposeOptions<1>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv1d( - output_grad_swapped, - x_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - let grad_shape = B::shape(&weight_grad); - - if grad_shape != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - ], + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv1d( + output_grad_swapped, + x_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); - } - weight_grad + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + let grad_shape = B::shape(&weight_grad); + + if grad_shape != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + ], + ); + } + weight_grad } fn conv_transpose2d_weight_grad_no_groups( - x: FloatTensor, - output_grad: FloatTensor, - weight_shape: Shape<4>, - options: ConvTransposeOptions<2>, + x: FloatTensor, + output_grad: FloatTensor, + weight_shape: Shape<4>, + options: ConvTransposeOptions<2>, ) -> FloatTensor { - let x_swapped = B::swap_dims(x, 0, 1); - let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv2d( - output_grad_swapped, - x_swapped, - None, - ConvOptions::new(options.dilation, options.padding, options.stride, 1), - ); - let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); - - let grad_shape = B::shape(&weight_grad); - - if grad_shape != weight_shape { - weight_grad = B::slice( - weight_grad, - [ - 0..weight_shape.dims[0], - 0..weight_shape.dims[1], - 0..weight_shape.dims[2], - 0..weight_shape.dims[3], - ], + let x_swapped = B::swap_dims(x, 0, 1); + let output_grad_swapped = B::swap_dims(output_grad, 0, 1); + let weight_grad_swapped = B::conv2d( + output_grad_swapped, + x_swapped, + None, + ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); - } - weight_grad + let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); + + let grad_shape = B::shape(&weight_grad); + + if grad_shape != weight_shape { + weight_grad = B::slice( + weight_grad, + [ + 0..weight_shape.dims[0], + 0..weight_shape.dims[1], + 0..weight_shape.dims[2], + 0..weight_shape.dims[3], + ], + ); + } + weight_grad } fn calculate_padding_out( - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - size_in: usize, - size_out: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + size_in: usize, + size_out: usize, ) -> usize { - if stride <= 1 { - return 0; - } - - let out = 1 - + libm::ceil((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64) - as usize; - i64::max(0, out as i64 - size_out as i64) as usize + if stride <= 1 { + return 0; + } + + let out = 1 + libm::ceil( + (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64, + ) as usize; + i64::max(0, out as i64 - size_out as i64) as usize } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_calculate_output_size_1() { - let kernel_size = 3; - let stride = 1; - let padding = 1; - let size_in = 3; - let dilation = 1; - - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, 3); - } - - #[test] - fn test_calculate_output_size_2() { - let kernel_size = 5; - let stride = 2; - let padding = 3; - let size_in = 27; - let dilation = 1; - - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, 15); - } - - #[test] - fn test_calculate_output_size_3() { - let kernel_size = 5; - let stride = 2; - let padding = 3; - let size_in = 27; - let dilation = 2; - - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, 13); - } - - #[test] - fn test_calculate_same_padding_1() { - let kernel_size = 3; - let stride = 1; - let size_in = 3; - let dilation = 1; - - let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_in, size_out, "Expected size"); - } - - #[test] - fn test_calculate_same_padding_2() { - let kernel_size = 3; - let stride = 2; - let size_in = 7; - let dilation = 1; - - let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); - let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_in, size_out, "Expected size"); - } - - #[test] - fn test_calculate_output_padding_1() { - let kernel_size = 3; - let stride = 2; - let size_in = 7; - let size_out = 10; - let dilation = 1; - - let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out); - let size_out_expected = - calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); - - assert_eq!(size_out, size_out_expected, "Expected size"); - } + use super::*; + + #[test] + fn test_calculate_output_size_1() { + let kernel_size = 3; + let stride = 1; + let padding = 1; + let size_in = 3; + let dilation = 1; + + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, 3); + } + + #[test] + fn test_calculate_output_size_2() { + let kernel_size = 5; + let stride = 2; + let padding = 3; + let size_in = 27; + let dilation = 1; + + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, 15); + } + + #[test] + fn test_calculate_output_size_3() { + let kernel_size = 5; + let stride = 2; + let padding = 3; + let size_in = 27; + let dilation = 2; + + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, 13); + } + + #[test] + fn test_calculate_same_padding_1() { + let kernel_size = 3; + let stride = 1; + let size_in = 3; + let dilation = 1; + + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_in, size_out, "Expected size"); + } + + #[test] + fn test_calculate_same_padding_2() { + let kernel_size = 3; + let stride = 2; + let size_in = 7; + let dilation = 1; + + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_in, size_out, "Expected size"); + } + + #[test] + fn test_calculate_output_padding_1() { + let kernel_size = 3; + let stride = 2; + let size_in = 7; + let size_out = 10; + let dilation = 1; + + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out); + let size_out_expected = + calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, size_out_expected, "Expected size"); + } } diff --git a/burn-tensor/src/tensor/ops/modules/pool.rs b/burn-tensor/src/tensor/ops/modules/pool.rs index 0b3c830687..4d096ae4a5 100644 --- a/burn-tensor/src/tensor/ops/modules/pool.rs +++ b/burn-tensor/src/tensor/ops/modules/pool.rs @@ -1,167 +1,167 @@ use crate::{ - backend::Backend, - ops::{FloatTensor, IntTensor}, - Shape, + backend::Backend, + ops::{FloatTensor, IntTensor}, + Shape, }; use super::{MaxPool1dBackward, MaxPool1dWithIndices}; pub(crate) fn avg_pool1d_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, ) -> FloatTensor { - let [batch_size, channels, length] = B::shape(&x).dims; + let [batch_size, channels, length] = B::shape(&x).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); - let x = B::avg_pool2d( - x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - count_include_pad, - ); + let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); + let x = B::avg_pool2d( + x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + count_include_pad, + ); - let [batch_size, channels, length, _] = B::shape(&x).dims; + let [batch_size, channels, length, _] = B::shape(&x).dims; - B::reshape(x, Shape::from([batch_size, channels, length])) + B::reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn avg_pool1d_backward_from_2d( - x: FloatTensor, - grad: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - count_include_pad: bool, + x: FloatTensor, + grad: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, ) -> FloatTensor { - let [batch_size, channels, length_in] = B::shape(&x).dims; - let [_, _, length_out] = B::shape(&grad).dims; - - let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); - let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); - - let grad_x = B::avg_pool2d_backward( - x, - grad_x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - count_include_pad, - ); - - B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) + let [batch_size, channels, length_in] = B::shape(&x).dims; + let [_, _, length_out] = B::shape(&grad).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); + let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); + + let grad_x = B::avg_pool2d_backward( + x, + grad_x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + count_include_pad, + ); + + B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) } pub(crate) fn adaptive_avg_pool1d_from_2d( - x: FloatTensor, - output_size: usize, + x: FloatTensor, + output_size: usize, ) -> FloatTensor { - let [batch_size, channels, length] = B::shape(&x).dims; + let [batch_size, channels, length] = B::shape(&x).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); - let x = B::adaptive_avg_pool2d(x, [output_size, 1]); + let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); + let x = B::adaptive_avg_pool2d(x, [output_size, 1]); - let [batch_size, channels, length, _] = B::shape(&x).dims; + let [batch_size, channels, length, _] = B::shape(&x).dims; - B::reshape(x, Shape::from([batch_size, channels, length])) + B::reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn adaptive_avg_pool1d_backward_from_2d( - x: FloatTensor, - grad: FloatTensor, + x: FloatTensor, + grad: FloatTensor, ) -> FloatTensor { - let [batch_size, channels, length_in] = B::shape(&x).dims; - let [_, _, length_out] = B::shape(&grad).dims; + let [batch_size, channels, length_in] = B::shape(&x).dims; + let [_, _, length_out] = B::shape(&grad).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); - let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); + let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); + let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1])); - let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x); + let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x); - B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) + B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) } pub(crate) fn max_pool1d_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> FloatTensor { - let [batch_size, channels, length] = B::shape(&x).dims; + let [batch_size, channels, length] = B::shape(&x).dims; - let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); - let x = B::max_pool2d( - x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - [dilation, 1], - ); + let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); + let x = B::max_pool2d( + x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + [dilation, 1], + ); - let [batch_size, channels, length, _] = B::shape(&x).dims; + let [batch_size, channels, length, _] = B::shape(&x).dims; - B::reshape(x, Shape::from([batch_size, channels, length])) + B::reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn max_pool1d_with_indices_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, ) -> MaxPool1dWithIndices { - let [batch_size, channels, length] = B::shape(&x).dims; - - let x = B::reshape(x, Shape::from([batch_size, channels, 1, length])); - let x = B::max_pool2d_with_indices( - x, - [1, kernel_size], - [1, stride], - [0, padding], - [1, dilation], - ); - let [batch_size, channels, _, length] = B::shape(&x.output).dims; - let output = B::reshape(x.output, Shape::from([batch_size, channels, length])); - let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); - MaxPool1dWithIndices::new(output, indices) + let [batch_size, channels, length] = B::shape(&x).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, 1, length])); + let x = B::max_pool2d_with_indices( + x, + [1, kernel_size], + [1, stride], + [0, padding], + [1, dilation], + ); + let [batch_size, channels, _, length] = B::shape(&x.output).dims; + let output = B::reshape(x.output, Shape::from([batch_size, channels, length])); + let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); + MaxPool1dWithIndices::new(output, indices) } pub(crate) fn max_pool1d_with_indices_backward_from_2d( - x: FloatTensor, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - output_grad: FloatTensor, - indices: IntTensor, + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: FloatTensor, + indices: IntTensor, ) -> MaxPool1dBackward { - let [batch_size, channels, length_in] = B::shape(&x).dims; - let [_, _, length_out] = B::shape(&output_grad).dims; - - let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); - let grad_x = B::reshape( - output_grad, - Shape::from([batch_size, channels, length_out, 1]), - ); - let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1])); - - let grad_x = B::max_pool2d_with_indices_backward( - x, - [kernel_size, 1], - [stride, 1], - [padding, 0], - [dilation, 1], - grad_x, - indices, - ) - .x_grad; - - MaxPool1dBackward::new(B::reshape( - grad_x, - Shape::from([batch_size, channels, length_in]), - )) + let [batch_size, channels, length_in] = B::shape(&x).dims; + let [_, _, length_out] = B::shape(&output_grad).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); + let grad_x = B::reshape( + output_grad, + Shape::from([batch_size, channels, length_out, 1]), + ); + let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1])); + + let grad_x = B::max_pool2d_with_indices_backward( + x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + [dilation, 1], + grad_x, + indices, + ) + .x_grad; + + MaxPool1dBackward::new(B::reshape( + grad_x, + Shape::from([batch_size, channels, length_in]), + )) } diff --git a/burn-tensor/src/tensor/ops/modules/unfold.rs b/burn-tensor/src/tensor/ops/modules/unfold.rs index b542bfe1a5..65f47562e8 100644 --- a/burn-tensor/src/tensor/ops/modules/unfold.rs +++ b/burn-tensor/src/tensor/ops/modules/unfold.rs @@ -15,71 +15,72 @@ use super::{ConvOptions, UnfoldOptions}; /// the convolution operation's mechanism as it moves across the input tensor, picking up the desired /// values in the pattern of the unfolding operation. pub(crate) fn create_unfolding_weight( - in_channels: usize, - kernel_size: [usize; 2], - device: &B::Device, + in_channels: usize, + kernel_size: [usize; 2], + device: &B::Device, ) -> FloatTensor { - let shape = Shape::new([ - in_channels * kernel_size[0] * kernel_size[1], - in_channels, - kernel_size[0], - kernel_size[1], - ]); + let shape = Shape::new([ + in_channels * kernel_size[0] * kernel_size[1], + in_channels, + kernel_size[0], + kernel_size[1], + ]); - let mut strides = [0; 4]; - let mut current = 1; - shape - .dims - .iter() - .enumerate() - .rev() - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - }); + let mut strides = [0; 4]; + let mut current = 1; + shape + .dims + .iter() + .enumerate() + .rev() + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); - let num_elements = shape.num_elements(); + let num_elements = shape.num_elements(); - let mut weight: Vec = vec![0.0.elem(); num_elements]; + let mut weight: Vec = vec![0.0.elem(); num_elements]; - for k in 0..in_channels { - for i in 0..kernel_size[0] { - for j in 0..kernel_size[1] { - let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j; - let index = output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3]; + for k in 0..in_channels { + for i in 0..kernel_size[0] { + for j in 0..kernel_size[1] { + let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j; + let index = + output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3]; - weight[index] = 1.elem(); - } + weight[index] = 1.elem(); + } + } } - } - B::from_data(Data::new(weight, shape), device) + B::from_data(Data::new(weight, shape), device) } /// Compute the unfold4d operation using the conv2d operations. pub(crate) fn unfold4d_using_conv2d( - x: FloatTensor, - kernel_size: [usize; 2], - options: UnfoldOptions, + x: FloatTensor, + kernel_size: [usize; 2], + options: UnfoldOptions, ) -> FloatTensor { - let [_batch_size, in_channels, _in_height, _in_width] = B::shape(&x).dims; - let weight = create_unfolding_weight::(in_channels, kernel_size, &B::device(&x)); - let unfolded = B::conv2d( - x, - weight, - None, - ConvOptions { - stride: options.stride, - padding: options.padding, - dilation: options.dilation, - groups: 1, - }, - ); + let [_batch_size, in_channels, _in_height, _in_width] = B::shape(&x).dims; + let weight = create_unfolding_weight::(in_channels, kernel_size, &B::device(&x)); + let unfolded = B::conv2d( + x, + weight, + None, + ConvOptions { + stride: options.stride, + padding: options.padding, + dilation: options.dilation, + groups: 1, + }, + ); - let [batch_size, channels_out, out_height, out_width] = B::shape(&unfolded).dims; + let [batch_size, channels_out, out_height, out_width] = B::shape(&unfolded).dims; - B::reshape( - unfolded, - Shape::new([batch_size, channels_out, out_height * out_width]), - ) + B::reshape( + unfolded, + Shape::new([batch_size, channels_out, out_height * out_width]), + ) } diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 056c6d92d9..63c081df40 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -6,1064 +6,1073 @@ use core::ops::Range; /// Operations on float tensors. pub trait TensorOps { - /// Creates a new tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given data. - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor; - - /// Creates a new tensor with random values. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `distribution` - The distribution to sample from. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given shape and random values. - fn random( - shape: Shape, - distribution: Distribution>, - device: &Device, - ) -> FloatTensor; - - /// Creates a new tensor with zeros. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given shape and zeros. - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::zeros(shape), device) - } - - /// Creates a new tensor with ones. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given shape and ones. - fn ones(shape: Shape, device: &Device) -> FloatTensor { - Self::from_data(Data::ones(shape), device) - } - - /// Creates a tensor filled with given value. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `fill_value` - The value with which to fill the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor filled with given value - fn full( - shape: Shape, - fill_value: FloatElem, - device: &Device, - ) -> FloatTensor { - Self::add_scalar(Self::zeros(shape, device), fill_value) - } - - /// Gets the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn shape(tensor: &FloatTensor) -> Shape; - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn to_data(tensor: &FloatTensor) -> Reader, D>> { - Self::into_data(tensor.clone()) - } - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn into_data(tensor: FloatTensor) -> Reader, D>>; - - /// Gets the device of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device of the tensor. - fn device(tensor: &FloatTensor) -> Device; - - /// Moves the tensor to the given device. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `device` - The device to move the tensor to. - /// - /// # Returns - /// - /// The tensor on the given device. - fn to_device(tensor: FloatTensor, device: &Device) -> FloatTensor; - - /// Creates a new tensor with values from the given range. - /// - /// # Arguments - /// - /// * `range` - The range of values. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given values. - /// - /// # Remarks - /// - /// Uses `arange_step` with a step size of 1 under the hood. - fn arange(range: Range, device: &Device) -> IntTensor { - Self::arange_step(range, 1, device) - } - - /// Converts float tensor to int tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The int tensor with the same data as the float tensor. - fn into_int(tensor: FloatTensor) -> IntTensor; - - /// Creates a new tensor with values from the given range with the given step size. - /// - /// # Arguments - /// - /// * `range` - The range of values. - /// * `step` - The step size. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the given values. - fn arange_step(range: Range, step: usize, device: &Device) -> IntTensor { - let value = range - .step_by(step) - .map(|i| (i as i64).elem()) - .collect::>>(); - let shape = Shape::new([value.len()]); - let data = Data::new(value, shape); - B::int_from_data(data, device) - } - - /// Creates an empty tensor with the given shape. - /// - /// # Arguments - /// - /// * `shape` - The shape of the tensor. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The empty tensor with the given shape. - fn empty(shape: Shape, device: &Device) -> FloatTensor; - - /// Repeat the tensor along the given dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `dim` - The dimension to repeat. - /// * `times` - The number of times to repeat the dimension. - /// - /// # Returns - /// - /// The tensor with the given dimension repeated. - fn repeat( - tensor: FloatTensor, - dim: usize, - times: usize, - ) -> FloatTensor { - let mut shape = B::shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); + /// Creates a new tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given data. + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor; + + /// Creates a new tensor with random values. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `distribution` - The distribution to sample from. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given shape and random values. + fn random( + shape: Shape, + distribution: Distribution>, + device: &Device, + ) -> FloatTensor; + + /// Creates a new tensor with zeros. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given shape and zeros. + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::zeros(shape), device) } - shape.dims[dim] = times; - - let mut i = 0; - let indices_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); - - let mut tensor_output = B::empty(shape, &B::device(&tensor)); - for i in 0..times { - let mut indices = indices_select_all.clone(); - indices[dim] = i..i + 1; - tensor_output = B::slice_assign(tensor_output, indices, tensor.clone()); + + /// Creates a new tensor with ones. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given shape and ones. + fn ones(shape: Shape, device: &Device) -> FloatTensor { + Self::from_data(Data::ones(shape), device) + } + + /// Creates a tensor filled with given value. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `fill_value` - The value with which to fill the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor filled with given value + fn full( + shape: Shape, + fill_value: FloatElem, + device: &Device, + ) -> FloatTensor { + Self::add_scalar(Self::zeros(shape, device), fill_value) + } + + /// Gets the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn shape(tensor: &FloatTensor) -> Shape; + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn to_data(tensor: &FloatTensor) -> Reader, D>> { + Self::into_data(tensor.clone()) + } + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn into_data(tensor: FloatTensor) -> Reader, D>>; + + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn device(tensor: &FloatTensor) -> Device; + + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device to move the tensor to. + /// + /// # Returns + /// + /// The tensor on the given device. + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor; + + /// Creates a new tensor with values from the given range. + /// + /// # Arguments + /// + /// * `range` - The range of values. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given values. + /// + /// # Remarks + /// + /// Uses `arange_step` with a step size of 1 under the hood. + fn arange(range: Range, device: &Device) -> IntTensor { + Self::arange_step(range, 1, device) + } + + /// Converts float tensor to int tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The int tensor with the same data as the float tensor. + fn into_int(tensor: FloatTensor) -> IntTensor; + + /// Creates a new tensor with values from the given range with the given step size. + /// + /// # Arguments + /// + /// * `range` - The range of values. + /// * `step` - The step size. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given values. + fn arange_step(range: Range, step: usize, device: &Device) -> IntTensor { + let value = range + .step_by(step) + .map(|i| (i as i64).elem()) + .collect::>>(); + let shape = Shape::new([value.len()]); + let data = Data::new(value, shape); + B::int_from_data(data, device) + } + + /// Creates an empty tensor with the given shape. + /// + /// # Arguments + /// + /// * `shape` - The shape of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The empty tensor with the given shape. + fn empty(shape: Shape, device: &Device) -> FloatTensor; + + /// Repeat the tensor along the given dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to repeat. + /// * `times` - The number of times to repeat the dimension. + /// + /// # Returns + /// + /// The tensor with the given dimension repeated. + fn repeat( + tensor: FloatTensor, + dim: usize, + times: usize, + ) -> FloatTensor { + let mut shape = B::shape(&tensor); + if shape.dims[dim] != 1 { + panic!("Can only repeat dimension with dim=1"); + } + shape.dims[dim] = times; + + let mut i = 0; + let indices_select_all = [0; D].map(|_| { + let start = 0; + let end = shape.dims[i]; + i += 1; + start..end + }); + + let mut tensor_output = B::empty(shape, &B::device(&tensor)); + for i in 0..times { + let mut indices = indices_select_all.clone(); + indices[dim] = i..i + 1; + tensor_output = B::slice_assign(tensor_output, indices, tensor.clone()); + } + + tensor_output + } + + /// Adds two tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Adds a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of adding the scalar to the tensor. + fn add_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn clamp_min( + tensor: FloatTensor, + min: FloatElem, + ) -> FloatTensor { + // Default implementation + let mask = Self::lower_elem(tensor.clone(), min); + B::mask_fill(tensor, mask, min) + } + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn clamp_max( + tensor: FloatTensor, + max: FloatElem, + ) -> FloatTensor { + // Default implementation + let mask = Self::greater_elem(tensor.clone(), max); + B::mask_fill(tensor, mask, max) + } + + /// Clamps a tensor between a minimum and maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + // Default implementation + Self::clamp_min(Self::clamp_max(tensor, max), min) + } + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Subtracts a scalar from a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of subtracting the scalar from the tensor. + fn sub_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Multiplies two tensors together element-wise. + fn mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Multiplies a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of multiplying the tensor by the scalar. + fn mul_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Divides two tensors element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Divides a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of dividing the tensor by the scalar. + fn div_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; + + /// Multiplies two tensors together using matrix multiplication. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together using matrix multiplication. + fn matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; + + /// Negates a tensor element-wise. + fn neg(tensor: FloatTensor) -> FloatTensor { + Self::mul_scalar(tensor, (-1.0_f32).elem::>()) + } + + /// Calculates the reciprocals elementwise + fn recip(tensor: FloatTensor) -> FloatTensor; + + /// Transposes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn transpose(tensor: FloatTensor) -> FloatTensor { + Self::swap_dims(tensor, D - 2, D - 1) + } + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor; + + /// Reshapes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reshape. + /// * `shape` - The new shape of the tensor. + /// + /// # Returns + /// + /// The tensor with the new shape. + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor; + + /// Gather elements from a tensor. + /// + /// # Arguments + /// + /// * `dim` - The dimension to gather from. + /// * `tensor` - The tensor to gather from. + /// * `indices` - The indices to gather. + /// + /// # Returns + /// + /// The gathered elements. + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor; + + /// Scatter elements into a tensor. + /// + /// # Arguments + /// + /// * `dim` - The dimension to scatter into. + /// * `tensor` - The tensor to scatter into. + /// * `indices` - The indices to scatter into. + /// * `value` - The value to scatter. + /// + /// # Returns + /// + /// The tensor with the scattered elements. + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor; + + /// Select tensor elements along the given dimension corresponding for the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices to select. + /// + /// # Returns + /// + /// The selected elements. + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor; + + /// Assign the selected elements along the given dimension corresponding for the given indices + /// to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor; + + /// Select tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `ranges` - The ranges to select. + /// + /// # Returns + /// + /// The selected elements in a new tensor. + fn slice( + tensor: FloatTensor, + ranges: [Range; D2], + ) -> FloatTensor; + + /// Assign the selected elements corresponding for the given ranges to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `ranges` - The ranges to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn slice_assign( + tensor: FloatTensor, + ranges: [Range; D2], + value: FloatTensor, + ) -> FloatTensor; + + /// Update the given tensor with the value tensor where the mask is true. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `mask` - The boolean mask to select with. + /// * `value` - The value to assign to the selected elements from the value tensor. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor; + + /// Update the given tensor with the value where the mask is true. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `mask` - The boolean mask to select with. + /// * `value` - The value to assign to the selected elements. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor; + + /// Equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; + + /// Equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; + + /// Greater than comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; + + /// Greater than comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; + + /// Greater than or equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor; + + /// Greater than or equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor; + + /// Less than comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; + + /// Less than comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; + + /// Less than or equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor; + + /// Less than or equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor; + + /// Detaches a tensor from the computation graph. + fn detach(tensor: FloatTensor) -> FloatTensor { + // Should only be overridden by autodiff backends. + tensor + } + + /// Sets the `require_grad` flag of a tensor. + fn set_require_grad( + tensor: FloatTensor, + _require_grad: bool, + ) -> FloatTensor { + // Should only be overridden by autodiff backends. + tensor + } + + /// Returns the `require_grad` flag of a tensor. + fn is_require_grad(_tensor: &FloatTensor) -> bool { + // Should only be overridden by autodiff backends. + false + } + + /// Sum of all elements in a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// + /// # Returns + /// + /// A scalar tensor with the sum of all elements in `tensor`. + fn sum(tensor: FloatTensor) -> FloatTensor; + + /// Sum of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// A tensor with the sum of all elements in `tensor` along `dim`. + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; + + /// Mean of all elements in a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to mean. + /// + /// # Returns + /// + /// A scalar tensor with the mean of all elements in `tensor`. + fn mean(tensor: FloatTensor) -> FloatTensor { + let num_elems = B::shape(&tensor).num_elements(); + B::div_scalar(B::sum(tensor), (num_elems as i64).elem()) + } + + /// Mean of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to mean. + /// * `dim` - The dimension along which to mean. + /// + /// # Returns + /// + /// A tensor with the mean of all elements in `tensor` along `dim`. + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; + + /// Converts a tensor to full precision. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to convert. + /// + /// # Returns + /// + /// A tensor with the same values as `tensor` but with full precision. + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D>; + + /// Converts a tensor from full precision. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to convert. + /// + /// # Returns + /// + /// A tensor with the same values as `tensor` but with the precision of the backend. + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor; + + /// Returns a new tensor with exponential values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to exponentiate. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with exponential values. + fn exp(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with natural logarithm values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the logarithm of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with natural logarithm values. + fn log(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with logarithm values of (1 + Xi). + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the logarithm of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). + fn log1p(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with values raised to the power of `value`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to exponentiate. + /// * `value` - The exponent. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with values raised to the power of `value`. + fn powf(tensor: FloatTensor, value: f32) -> FloatTensor; + + /// Returns a new tensor with square root values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the square root of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with square root values. + fn sqrt(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with absolute values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take absolute value of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with absolute values. + fn abs(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with cosine values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the cosine of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with cosine values. + fn cos(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with sine values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the sine of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with sine values. + fn sin(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with tangent values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the tangent of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with tangent values. + fn tanh(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with the error function values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the error function of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with error function values. + fn erf(tensor: FloatTensor) -> FloatTensor; + + /// Catcatenates tensors along a dimension. + /// + /// # Arguments + /// + /// * `tensors` - The tensors to catcatenate. + /// * `dim` - The dimension along which to catcatenate. + /// + /// # Returns + /// + /// A tensor with the catcatenated tensors along `dim`. + fn cat(tensors: Vec>, dim: usize) -> FloatTensor; + + /// Gets the indices of the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the indices of the maximum elements of `tensor` along `dim`. + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor; + + /// Gets the indices of the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the indices of the minimum elements of `tensor` along `dim`. + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor; + + /// Gets the maximum element of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// + /// # Returns + /// + /// A tensor with the maximum element of `tensor`. + fn max(tensor: FloatTensor) -> FloatTensor { + let shape = B::shape(&tensor); + let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); + + B::max_dim(tensor, 0) + } + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the maximum elements of `tensor` along `dim`. + fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let index = B::argmax(tensor.clone(), dim); + + B::gather(D - 1, tensor, index) + } + + /// Gets the maximum elements of a tensor along an axis and their indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tuple with the maximum elements of `tensor` along `dim` and their indices. + fn max_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + let index = B::argmax(tensor.clone(), dim); + let values = B::gather(D - 1, tensor, index.clone()); + + (values, index) } - tensor_output - } - - /// Adds two tensors together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of adding the two tensors together. - fn add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Adds a scalar to a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of adding the scalar to the tensor. - fn add_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn clamp_min(tensor: FloatTensor, min: FloatElem) -> FloatTensor { - // Default implementation - let mask = Self::lower_elem(tensor.clone(), min); - B::mask_fill(tensor, mask, min) - } - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn clamp_max(tensor: FloatTensor, max: FloatElem) -> FloatTensor { - // Default implementation - let mask = Self::greater_elem(tensor.clone(), max); - B::mask_fill(tensor, mask, max) - } - - /// Clamps a tensor between a minimum and maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// The clamped tensor. - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - // Default implementation - Self::clamp_min(Self::clamp_max(tensor, max), min) - } - - /// Subtracts two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of subtracting the two tensors. - fn sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Subtracts a scalar from a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of subtracting the scalar from the tensor. - fn sub_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Multiplies two tensors together element-wise. - fn mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Multiplies a tensor by a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of multiplying the tensor by the scalar. - fn mul_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Divides two tensors element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of dividing the two tensors. - fn div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Divides a tensor by a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of dividing the tensor by the scalar. - fn div_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor; - - /// Multiplies two tensors together using matrix multiplication. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of multiplying the two tensors together using matrix multiplication. - fn matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; - - /// Negates a tensor element-wise. - fn neg(tensor: FloatTensor) -> FloatTensor { - Self::mul_scalar(tensor, (-1.0_f32).elem::>()) - } - - /// Calculates the reciprocals elementwise - fn recip(tensor: FloatTensor) -> FloatTensor; - - /// Transposes a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn transpose(tensor: FloatTensor) -> FloatTensor { - Self::swap_dims(tensor, D - 2, D - 1) - } - - /// Swaps two dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor; - - /// Reshapes a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to reshape. - /// * `shape` - The new shape of the tensor. - /// - /// # Returns - /// - /// The tensor with the new shape. - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor; - - /// Gather elements from a tensor. - /// - /// # Arguments - /// - /// * `dim` - The dimension to gather from. - /// * `tensor` - The tensor to gather from. - /// * `indices` - The indices to gather. - /// - /// # Returns - /// - /// The gathered elements. - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor; - - /// Scatter elements into a tensor. - /// - /// # Arguments - /// - /// * `dim` - The dimension to scatter into. - /// * `tensor` - The tensor to scatter into. - /// * `indices` - The indices to scatter into. - /// * `value` - The value to scatter. - /// - /// # Returns - /// - /// The tensor with the scattered elements. - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor; - - /// Select tensor elements along the given dimension corresponding for the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices to select. - /// - /// # Returns - /// - /// The selected elements. - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor; - - /// Assign the selected elements along the given dimension corresponding for the given indices - /// to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor; - - /// Select tensor elements corresponding for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `ranges` - The ranges to select. - /// - /// # Returns - /// - /// The selected elements in a new tensor. - fn slice( - tensor: FloatTensor, - ranges: [Range; D2], - ) -> FloatTensor; - - /// Assign the selected elements corresponding for the given ranges to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `ranges` - The ranges to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn slice_assign( - tensor: FloatTensor, - ranges: [Range; D2], - value: FloatTensor, - ) -> FloatTensor; - - /// Update the given tensor with the value tensor where the mask is true. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `mask` - The boolean mask to select with. - /// * `value` - The value to assign to the selected elements from the value tensor. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor; - - /// Update the given tensor with the value where the mask is true. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `mask` - The boolean mask to select with. - /// * `value` - The value to assign to the selected elements. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor; - - /// Equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; - - /// Equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; - - /// Greater than comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; - - /// Greater than comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; - - /// Greater than or equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor; - - /// Greater than or equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor; - - /// Less than comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; - - /// Less than comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor; - - /// Less than or equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor; - - /// Less than or equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor; - - /// Detaches a tensor from the computation graph. - fn detach(tensor: FloatTensor) -> FloatTensor { - // Should only be overridden by autodiff backends. - tensor - } - - /// Sets the `require_grad` flag of a tensor. - fn set_require_grad( - tensor: FloatTensor, - _require_grad: bool, - ) -> FloatTensor { - // Should only be overridden by autodiff backends. - tensor - } - - /// Returns the `require_grad` flag of a tensor. - fn is_require_grad(_tensor: &FloatTensor) -> bool { - // Should only be overridden by autodiff backends. - false - } - - /// Sum of all elements in a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// - /// # Returns - /// - /// A scalar tensor with the sum of all elements in `tensor`. - fn sum(tensor: FloatTensor) -> FloatTensor; - - /// Sum of all elements in a tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to sum. - /// * `dim` - The dimension along which to sum. - /// - /// # Returns - /// - /// A tensor with the sum of all elements in `tensor` along `dim`. - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; - - /// Mean of all elements in a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to mean. - /// - /// # Returns - /// - /// A scalar tensor with the mean of all elements in `tensor`. - fn mean(tensor: FloatTensor) -> FloatTensor { - let num_elems = B::shape(&tensor).num_elements(); - B::div_scalar(B::sum(tensor), (num_elems as i64).elem()) - } - - /// Mean of all elements in a tensor along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to mean. - /// * `dim` - The dimension along which to mean. - /// - /// # Returns - /// - /// A tensor with the mean of all elements in `tensor` along `dim`. - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; - - /// Converts a tensor to full precision. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to convert. - /// - /// # Returns - /// - /// A tensor with the same values as `tensor` but with full precision. - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D>; - - /// Converts a tensor from full precision. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to convert. - /// - /// # Returns - /// - /// A tensor with the same values as `tensor` but with the precision of the backend. - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor; - - /// Returns a new tensor with exponential values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to exponentiate. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with exponential values. - fn exp(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with natural logarithm values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the logarithm of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with natural logarithm values. - fn log(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with logarithm values of (1 + Xi). - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the logarithm of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). - fn log1p(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with values raised to the power of `value`. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to exponentiate. - /// * `value` - The exponent. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with values raised to the power of `value`. - fn powf(tensor: FloatTensor, value: f32) -> FloatTensor; - - /// Returns a new tensor with square root values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the square root of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with square root values. - fn sqrt(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with absolute values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take absolute value of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with absolute values. - fn abs(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with cosine values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the cosine of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with cosine values. - fn cos(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with sine values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the sine of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with sine values. - fn sin(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with tangent values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the tangent of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with tangent values. - fn tanh(tensor: FloatTensor) -> FloatTensor; - - /// Returns a new tensor with the error function values. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to take the error function of. - /// - /// # Returns - /// - /// A tensor with the same shape as `tensor` with error function values. - fn erf(tensor: FloatTensor) -> FloatTensor; - - /// Catcatenates tensors along a dimension. - /// - /// # Arguments - /// - /// * `tensors` - The tensors to catcatenate. - /// * `dim` - The dimension along which to catcatenate. - /// - /// # Returns - /// - /// A tensor with the catcatenated tensors along `dim`. - fn cat(tensors: Vec>, dim: usize) -> FloatTensor; - - /// Gets the indices of the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// * `dim` - The dimension along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the indices of the maximum elements of `tensor` along `dim`. - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor; - - /// Gets the indices of the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// * `dim` - The dimension along which to get the minimum elements. - /// - /// # Returns - /// - /// A tensor with the indices of the minimum elements of `tensor` along `dim`. - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor; - - /// Gets the maximum element of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// - /// # Returns - /// - /// A tensor with the maximum element of `tensor`. - fn max(tensor: FloatTensor) -> FloatTensor { - let shape = B::shape(&tensor); - let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); - - B::max_dim(tensor, 0) - } - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// * `dim` - The dimension along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the maximum elements of `tensor` along `dim`. - fn max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - let index = B::argmax(tensor.clone(), dim); - - B::gather(D - 1, tensor, index) - } - - /// Gets the maximum elements of a tensor along an axis and their indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements of. - /// * `dim` - The dimension along which to get the maximum elements. - /// - /// # Returns - /// - /// A tuple with the maximum elements of `tensor` along `dim` and their indices. - fn max_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - let index = B::argmax(tensor.clone(), dim); - let values = B::gather(D - 1, tensor, index.clone()); - - (values, index) - } - - /// Gets the minimum element of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// - /// # Returns - /// - /// A tensor with the minimum element of `tensor`. - fn min(tensor: FloatTensor) -> FloatTensor { - let shape = B::shape(&tensor); - let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); - - B::min_dim(tensor, 0) - } - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// * `dim` - The dimension along which to get the minimum elements. - /// - /// # Returns - /// - /// A tensor with the minimum elements of `tensor` along `dim`. - fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - let index = B::argmin(tensor.clone(), dim); - - B::gather(D - 1, tensor, index) - } - - /// Gets the minimum elements of a tensor along an axis and their indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements of. - /// * `dim` - The dimension along which to get the minimum elements. - /// - /// # Returns - /// - /// A tuple with the minimum elements of `tensor` along `dim` and their indices. - fn min_dim_with_indices( - tensor: FloatTensor, - dim: usize, - ) -> (FloatTensor, IntTensor) { - let index = B::argmin(tensor.clone(), dim); - let values = B::gather(D - 1, tensor, index.clone()); - - (values, index) - } + /// Gets the minimum element of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// + /// # Returns + /// + /// A tensor with the minimum element of `tensor`. + fn min(tensor: FloatTensor) -> FloatTensor { + let shape = B::shape(&tensor); + let tensor = B::reshape(tensor, Shape::new([shape.num_elements()])); + + B::min_dim(tensor, 0) + } + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the minimum elements of `tensor` along `dim`. + fn min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let index = B::argmin(tensor.clone(), dim); + + B::gather(D - 1, tensor, index) + } + + /// Gets the minimum elements of a tensor along an axis and their indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tuple with the minimum elements of `tensor` along `dim` and their indices. + fn min_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + let index = B::argmin(tensor.clone(), dim); + let values = B::gather(D - 1, tensor, index.clone()); + + (values, index) + } } diff --git a/burn-tensor/src/tensor/shape.rs b/burn-tensor/src/tensor/shape.rs index b2bf9744c9..00c69501b6 100644 --- a/burn-tensor/src/tensor/shape.rs +++ b/burn-tensor/src/tensor/shape.rs @@ -3,76 +3,76 @@ use alloc::vec::Vec; /// Shape of a tensor. #[derive(new, Debug, Clone, PartialEq, Eq)] pub struct Shape { - /// The dimensions of the tensor. - pub dims: [usize; D], + /// The dimensions of the tensor. + pub dims: [usize; D], } impl Shape { - /// Returns the total number of elements of a tensor having this shape - pub fn num_elements(&self) -> usize { - let mut num_elements = 1; - for i in 0..D { - num_elements *= self.dims[i]; - } + /// Returns the total number of elements of a tensor having this shape + pub fn num_elements(&self) -> usize { + let mut num_elements = 1; + for i in 0..D { + num_elements *= self.dims[i]; + } - num_elements - } + num_elements + } } impl From<[usize; D]> for Shape { - fn from(dims: [usize; D]) -> Self { - Shape::new(dims) - } + fn from(dims: [usize; D]) -> Self { + Shape::new(dims) + } } impl From> for Shape { - fn from(shape: Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.into_iter().enumerate() { - dims[i] = dim as usize; + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim as usize; + } + Self::new(dims) } - Self::new(dims) - } } impl From> for Shape { - fn from(shape: Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.into_iter().enumerate() { - dims[i] = dim as usize; + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim as usize; + } + Self::new(dims) } - Self::new(dims) - } } impl From> for Shape { - fn from(shape: Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.into_iter().enumerate() { - dims[i] = dim; + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim; + } + Self::new(dims) } - Self::new(dims) - } } impl From<&Vec> for Shape { - fn from(shape: &Vec) -> Self { - let mut dims = [1; D]; - for (i, dim) in shape.iter().enumerate() { - dims[i] = *dim; + fn from(shape: &Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.iter().enumerate() { + dims[i] = *dim; + } + Self::new(dims) } - Self::new(dims) - } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn num_elements() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - assert_eq!(120, shape.num_elements()); - } + #[test] + fn num_elements() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + assert_eq!(120, shape.num_elements()); + } } diff --git a/burn-tensor/src/tensor/stats/mod.rs b/burn-tensor/src/tensor/stats/mod.rs index 70dda6288a..0ad39dc69a 100644 --- a/burn-tensor/src/tensor/stats/mod.rs +++ b/burn-tensor/src/tensor/stats/mod.rs @@ -1,38 +1,38 @@ use crate::{backend::Backend, Tensor}; pub fn var(tensor: Tensor, dim: usize) -> Tensor { - let mean = tensor.clone().mean_dim(dim); - var_with_mean(tensor, mean, dim) + let mean = tensor.clone().mean_dim(dim); + var_with_mean(tensor, mean, dim) } pub fn var_with_mean( - tensor: Tensor, - mean: Tensor, - dim: usize, + tensor: Tensor, + mean: Tensor, + dim: usize, ) -> Tensor { - let n = tensor.shape().dims[dim] - 1; - var_with_mean_n(tensor, mean, dim, n) + let n = tensor.shape().dims[dim] - 1; + var_with_mean_n(tensor, mean, dim, n) } pub fn var_bias(tensor: Tensor, dim: usize) -> Tensor { - let mean = tensor.clone().mean_dim(dim); - var_with_mean_bias(tensor, mean, dim) + let mean = tensor.clone().mean_dim(dim); + var_with_mean_bias(tensor, mean, dim) } pub fn var_with_mean_bias( - tensor: Tensor, - mean: Tensor, - dim: usize, + tensor: Tensor, + mean: Tensor, + dim: usize, ) -> Tensor { - let n = tensor.shape().dims[dim]; - var_with_mean_n(tensor, mean, dim, n) + let n = tensor.shape().dims[dim]; + var_with_mean_n(tensor, mean, dim, n) } pub fn var_with_mean_n( - tensor: Tensor, - mean: Tensor, - dim: usize, - n: usize, + tensor: Tensor, + mean: Tensor, + dim: usize, + n: usize, ) -> Tensor { - tensor.sub(mean).powf(2.0).sum_dim(dim).div_scalar(n as f32) + tensor.sub(mean).powf(2.0).sum_dim(dim).div_scalar(n as f32) } diff --git a/burn-tensor/src/tests/activation/gelu.rs b/burn-tensor/src/tests/activation/gelu.rs index aad5288645..a6dc2a617d 100644 --- a/burn-tensor/src/tests/activation/gelu.rs +++ b/burn-tensor/src/tests/activation/gelu.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(gelu)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_gelu() { - let data = Data::from([[ - 0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737, - ]]); - let tensor = Tensor::::from_data(data).clone().clone(); + #[test] + fn test_gelu() { + let data = Data::from([[ + 0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737, + ]]); + let tensor = Tensor::::from_data(data).clone().clone(); - let data_actual = activation::gelu(tensor).to_data(); + let data_actual = activation::gelu(tensor).to_data(); - let data_expected = Data::from([[ - 0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051, - ]]); - data_expected.assert_approx_eq(&data_actual, 2); // Low precision to allow approximation - // implementation using tanh - } + let data_expected = Data::from([[ + 0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051, + ]]); + data_expected.assert_approx_eq(&data_actual, 2); // Low precision to allow approximation + // implementation using tanh + } } diff --git a/burn-tensor/src/tests/activation/relu.rs b/burn-tensor/src/tests/activation/relu.rs index cbd23cccb4..b9e5ff6623 100644 --- a/burn-tensor/src/tests/activation/relu.rs +++ b/burn-tensor/src/tests/activation/relu.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(relu)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_relu_d2() { - let data = Data::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_relu_d2() { + let data = Data::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::relu(tensor).to_data(); + let data_actual = activation::relu(tensor).to_data(); - let data_expected = Data::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/activation/sigmoid.rs b/burn-tensor/src/tests/activation/sigmoid.rs index 54b889a49b..17d55ba7ca 100644 --- a/burn-tensor/src/tests/activation/sigmoid.rs +++ b/burn-tensor/src/tests/activation/sigmoid.rs @@ -1,27 +1,27 @@ #[burn_tensor_testgen::testgen(sigmoid)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_sigmoid() { - let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_sigmoid() { + let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::sigmoid(tensor).to_data(); + let data_actual = activation::sigmoid(tensor).to_data(); - let data_expected = Data::from([[0.7311, 0.9991], [1.0, 0.0474]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[0.7311, 0.9991], [1.0, 0.0474]]); + data_actual.assert_approx_eq(&data_expected, 4); + } - #[test] - fn test_sigmoid_overflow() { - let data = Data::from([f32::MAX, f32::MIN]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_sigmoid_overflow() { + let data = Data::from([f32::MAX, f32::MIN]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::sigmoid(tensor).to_data(); + let data_actual = activation::sigmoid(tensor).to_data(); - let data_expected = Data::from([1.0, 0.0]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([1.0, 0.0]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/activation/silu.rs b/burn-tensor/src/tests/activation/silu.rs index 32728a6427..f207bc6145 100644 --- a/burn-tensor/src/tests/activation/silu.rs +++ b/burn-tensor/src/tests/activation/silu.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(silu)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_silu() { - let data = Data::from([[1.0, 2.0], [3.0, 4.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_silu() { + let data = Data::from([[1.0, 2.0], [3.0, 4.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::silu(tensor).to_data(); + let data_actual = activation::silu(tensor).to_data(); - let data_expected = Data::from([[0.7311, 1.7616], [2.8577, 3.9281]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[0.7311, 1.7616], [2.8577, 3.9281]]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/activation/softmax.rs b/burn-tensor/src/tests/activation/softmax.rs index f3a761de90..7e04168bff 100644 --- a/burn-tensor/src/tests/activation/softmax.rs +++ b/burn-tensor/src/tests/activation/softmax.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(softmax)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_softmax_d2() { - let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_softmax_d2() { + let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::softmax(tensor, 1).to_data(); + let data_actual = activation::softmax(tensor, 1).to_data(); - let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/activation/tanh_activation.rs b/burn-tensor/src/tests/activation/tanh_activation.rs index a3012b1ac4..1aaa9e3d0d 100644 --- a/burn-tensor/src/tests/activation/tanh_activation.rs +++ b/burn-tensor/src/tests/activation/tanh_activation.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(tanh_activation)] mod tests { - use super::*; - use burn_tensor::{activation, Data, Tensor}; + use super::*; + use burn_tensor::{activation, Data, Tensor}; - #[test] - fn test_tanh() { - let data = Data::from([[1., 2.], [3., 4.]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_tanh() { + let data = Data::from([[1., 2.], [3., 4.]]); + let tensor = Tensor::::from_data(data); - let data_actual = activation::tanh(tensor).to_data(); + let data_actual = activation::tanh(tensor).to_data(); - let data_expected = Data::from([[0.7616, 0.9640], [0.9951, 0.9993]]); - data_actual.assert_approx_eq(&data_expected, 4); - } + let data_expected = Data::from([[0.7616, 0.9640], [0.9951, 0.9993]]); + data_actual.assert_approx_eq(&data_expected, 4); + } } diff --git a/burn-tensor/src/tests/clone_invariance.rs b/burn-tensor/src/tests/clone_invariance.rs index 879b07a55f..a70973e4b3 100644 --- a/burn-tensor/src/tests/clone_invariance.rs +++ b/burn-tensor/src/tests/clone_invariance.rs @@ -6,713 +6,713 @@ /// and use different kernels in such cases. We ensure that the results are consistent regardless /// of the approach and that the input tensors are not modified when cloned. mod tests { - use super::*; - use burn_tensor::activation::{ - gelu, log_sigmoid, log_softmax, relu, sigmoid, silu, softmax, tanh, - }; - use burn_tensor::{Data, Distribution, Tensor}; + use super::*; + use burn_tensor::activation::{ + gelu, log_sigmoid, log_softmax, relu, sigmoid, silu, softmax, tanh, + }; + use burn_tensor::{Data, Distribution, Tensor}; - pub trait CloneInvarianceTest { - type Args; + pub trait CloneInvarianceTest { + type Args; - fn args(&self) -> Self::Args; + fn args(&self) -> Self::Args; - fn run(&self, args: &Self::Args, inplace: bool) -> Data; + fn run(&self, args: &Self::Args, inplace: bool) -> Data; - fn check(&self) { - let args = self.args(); - let out = self.run(&args, false); - let out_inplace = self.run(&args, true); + fn check(&self) { + let args = self.args(); + let out = self.run(&args, false); + let out_inplace = self.run(&args, true); - out.assert_approx_eq(&out_inplace, 4); - } - } - - macro_rules! clone_invariance_test { - (unary: $name:ident, ops_float: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = Data; - - fn args(&self) -> Self::Args { - TestTensor::random([32, 32], Distribution::Default) - .into_data() - .convert() - } - - fn run(&self, args: &Self::Args, inplace: bool) -> Data { - let lhs = TestTensor::from_data(args.clone().convert()); - - if inplace { - $ops(lhs).into_data().convert() - } else { - let out = $ops(lhs.clone()).into_data().convert(); - lhs.into_data().assert_approx_eq(args, 4); - out - } - } + out.assert_approx_eq(&out_inplace, 4); } + } - CloneInvarianceTest::<2>::check(&$name); - } - }; - - (binary: $name:ident, ops_float: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = (Data, Data); - - fn args(&self) -> Self::Args { - ( - TestTensor::random([32, 32], Distribution::Default) - .into_data() - .convert(), - // Avoid div by zero. - TestTensor::random([32, 32], Distribution::Uniform(1., 3.)) - .into_data() - .convert(), - ) - } - - fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { - let lhs = TestTensor::from_data(lhs_arg.clone().convert()); - let rhs = TestTensor::from_data(rhs_arg.clone().convert()); - - if inplace { - $ops(lhs, rhs).into_data().convert() - } else { - let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); - - lhs.into_data().assert_approx_eq(lhs_arg, 4); - rhs.into_data().assert_approx_eq(rhs_arg, 4); - - out + macro_rules! clone_invariance_test { + (unary: $name:ident, ops_float: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = Data; + + fn args(&self) -> Self::Args { + TestTensor::random([32, 32], Distribution::Default) + .into_data() + .convert() + } + + fn run(&self, args: &Self::Args, inplace: bool) -> Data { + let lhs = TestTensor::from_data(args.clone().convert()); + + if inplace { + $ops(lhs).into_data().convert() + } else { + let out = $ops(lhs.clone()).into_data().convert(); + lhs.into_data().assert_approx_eq(args, 4); + out + } + } + } + + CloneInvarianceTest::<2>::check(&$name); } - } - } - - CloneInvarianceTest::<2>::check(&$name); - } - }; - - (unary: $name:ident, ops_int: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = Data; - - fn args(&self) -> Self::Args { - TestTensor::random([32, 32], Distribution::Uniform(0.0, 50.0)) - .into_data() - .convert() - } - - fn run(&self, args: &Self::Args, inplace: bool) -> Data { - let lhs = TestTensorInt::from_data(args.clone().convert()); - - if inplace { - $ops(lhs).into_data().convert() - } else { - let out = $ops(lhs.clone()).into_data().convert(); - lhs.into_data().convert().assert_approx_eq(args, 4); - out + }; + + (binary: $name:ident, ops_float: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = (Data, Data); + + fn args(&self) -> Self::Args { + ( + TestTensor::random([32, 32], Distribution::Default) + .into_data() + .convert(), + // Avoid div by zero. + TestTensor::random([32, 32], Distribution::Uniform(1., 3.)) + .into_data() + .convert(), + ) + } + + fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { + let lhs = TestTensor::from_data(lhs_arg.clone().convert()); + let rhs = TestTensor::from_data(rhs_arg.clone().convert()); + + if inplace { + $ops(lhs, rhs).into_data().convert() + } else { + let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); + + lhs.into_data().assert_approx_eq(lhs_arg, 4); + rhs.into_data().assert_approx_eq(rhs_arg, 4); + + out + } + } + } + + CloneInvarianceTest::<2>::check(&$name); } - } - } - - CloneInvarianceTest::<2>::check(&$name); - } - }; - - (binary: $name:ident, ops_int: $ops:expr) => { - #[test] - #[allow(non_snake_case)] - fn $name() { - struct $name; - - impl CloneInvarianceTest<2> for $name { - type Args = (Data, Data); - - fn args(&self) -> Self::Args { - ( - TestTensor::random([32, 32], Distribution::Uniform(0., 50.)) - .into_data() - .convert(), - // Avoid div by zero. - TestTensor::random([32, 32], Distribution::Uniform(1., 51.)) - .into_data() - .convert(), - ) - } - - fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { - let lhs = TestTensorInt::from_data(lhs_arg.clone().convert()); - let rhs = TestTensorInt::from_data(rhs_arg.clone().convert()); - - if inplace { - $ops(lhs, rhs).into_data().convert() - } else { - let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); - - lhs.into_data().convert().assert_approx_eq(lhs_arg, 4); - rhs.into_data().convert().assert_approx_eq(rhs_arg, 4); - - out + }; + + (unary: $name:ident, ops_int: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = Data; + + fn args(&self) -> Self::Args { + TestTensor::random([32, 32], Distribution::Uniform(0.0, 50.0)) + .into_data() + .convert() + } + + fn run(&self, args: &Self::Args, inplace: bool) -> Data { + let lhs = TestTensorInt::from_data(args.clone().convert()); + + if inplace { + $ops(lhs).into_data().convert() + } else { + let out = $ops(lhs.clone()).into_data().convert(); + lhs.into_data().convert().assert_approx_eq(args, 4); + out + } + } + } + + CloneInvarianceTest::<2>::check(&$name); } - } - } - - CloneInvarianceTest::<2>::check(&$name); - } - }; - } - - mod float { - use super::*; - - // Unary ops - clone_invariance_test!( - unary: AddScalar, - ops_float: |tensor: TestTensor<2>| tensor.add_scalar(2.0) - ); - clone_invariance_test!( - unary: SubScalar, - ops_float: |tensor: TestTensor<2>| tensor.sub_scalar(2.0) - ); - clone_invariance_test!( - unary: DivScalar, - ops_float: |tensor: TestTensor<2>| tensor.div_scalar(2.0) - ); - clone_invariance_test!( - unary: MulScalar, - ops_float: |tensor: TestTensor<2>| tensor.mul_scalar(2.0) - ); - clone_invariance_test!( - unary: PowScalar, - ops_float: |tensor: TestTensor<2>| tensor.powf(2.0) - ); - clone_invariance_test!( - unary: Sqrt, - ops_float: |tensor: TestTensor<2>| tensor.sqrt() - ); - clone_invariance_test!( - unary: Exp, - ops_float: |tensor: TestTensor<2>| tensor.exp() - ); - clone_invariance_test!( - unary: Neg, - ops_float: |tensor: TestTensor<2>| tensor.neg() - ); - clone_invariance_test!( - unary: MeanDim, - ops_float: |tensor: TestTensor<2>| tensor.mean_dim(1) - ); - clone_invariance_test!( - unary: SumDim, - ops_float: |tensor: TestTensor<2>| tensor.sum_dim(1) - ); - clone_invariance_test!( - unary: Sum, - ops_float: |tensor: TestTensor<2>| tensor.sum().unsqueeze() - ); - clone_invariance_test!( - unary: Mean, - ops_float: |tensor: TestTensor<2>| tensor.mean().unsqueeze() - ); - clone_invariance_test!( - unary: Clamp, - ops_float: |tensor: TestTensor<2>| tensor.clamp(-2., 2.) - ); - clone_invariance_test!( - unary: ClampMin, - ops_float: |tensor: TestTensor<2>| tensor.clamp_min(-2.) - ); - clone_invariance_test!( - unary: ClampMax, - ops_float: |tensor: TestTensor<2>| tensor.clamp_max(2.) - ); - clone_invariance_test!( - unary: Abs, - ops_float: |tensor: TestTensor<2>| tensor.abs() - ); - clone_invariance_test!( - unary: Cos, - ops_float: |tensor: TestTensor<2>| tensor.cos() - ); - clone_invariance_test!( - unary: Sin, - ops_float: |tensor: TestTensor<2>| tensor.sin() - ); - clone_invariance_test!( - unary: Log, - ops_float: |tensor: TestTensor<2>| tensor.log() - ); - clone_invariance_test!( - unary: Log1P, - ops_float: |tensor: TestTensor<2>| tensor.log1p() - ); - clone_invariance_test!( - unary: SwapDims, - ops_float: |tensor: TestTensor<2>| tensor.swap_dims(0, 1) - ); - clone_invariance_test!( - unary: Transpose, - ops_float: |tensor: TestTensor<2>| tensor.transpose() - ); - clone_invariance_test!( - unary: Slice, - ops_float: |tensor: TestTensor<2>| tensor.slice([0..12, 12..24]) - ); - clone_invariance_test!( - unary: Erf, - ops_float: |tensor: TestTensor<2>| tensor.erf() - ); - clone_invariance_test!( - unary: EqualElem, - ops_float: |tensor: TestTensor<2>| tensor.equal_elem(0.5) - ); - clone_invariance_test!( - unary: GreaterElem, - ops_float: |tensor: TestTensor<2>| tensor.greater_elem(0.5) - ); - clone_invariance_test!( - unary: GreaterEqualElem, - ops_float: |tensor: TestTensor<2>| tensor.greater_equal_elem(0.5) - ); - clone_invariance_test!( - unary: LowerElem, - ops_float: |tensor: TestTensor<2>| tensor.lower_elem(0.5) - ); - clone_invariance_test!( - unary: LowerEqualElem, - ops_float: |tensor: TestTensor<2>| tensor.lower_equal_elem(0.5) - ); - clone_invariance_test!( - unary: Argmax, - ops_float: |tensor: TestTensor<2>| tensor.argmax(0) - ); - clone_invariance_test!( - unary: Argmin, - ops_float: |tensor: TestTensor<2>| tensor.argmin(0) - ); - clone_invariance_test!( - unary: Max, - ops_float: |tensor: TestTensor<2>| tensor.max().unsqueeze() - ); - clone_invariance_test!( - unary: Min, - ops_float: |tensor: TestTensor<2>| tensor.min().unsqueeze() - ); - clone_invariance_test!( - unary: MaxDim, - ops_float: |tensor: TestTensor<2>| tensor.max_dim(1) - ); - clone_invariance_test!( - unary: MaxDimWithIndices, - ops_float: |tensor: TestTensor<2>| tensor.max_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDimWithIndices, - ops_float: |tensor: TestTensor<2>| tensor.min_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDim, - ops_float: |tensor: TestTensor<2>| tensor.min_dim(1) - ); - clone_invariance_test!( - unary: Repeat, - ops_float: |tensor: TestTensor<2>| { - tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) - } - ); - clone_invariance_test!( - unary: Reshape, - ops_float: |tensor: TestTensor<2>| { - let shape = tensor.shape(); - let new_shape = [shape.num_elements(), 1]; - tensor.reshape(new_shape) - } - ); - clone_invariance_test!( - unary: Gatter, - ops_float: |tensor: TestTensor<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.gather(0, indices) - } - ); - clone_invariance_test!( - unary: Select, - ops_float: |tensor: TestTensor<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - tensor.select(0, indices) - } - ); - clone_invariance_test!( - unary: MaskFill, - ops_float: |tensor: TestTensor<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_fill(mask, 77.0) - } - ); - - // Activation - clone_invariance_test!( - unary: Softmax, - ops_float: |tensor: TestTensor<2>| softmax(tensor, 1) - ); - clone_invariance_test!( - unary: LogSoftmax, - ops_float: |tensor: TestTensor<2>| log_softmax(tensor, 1) - ); - clone_invariance_test!( - unary: Sigmoid, - ops_float: |tensor: TestTensor<2>| sigmoid(tensor) - ); - clone_invariance_test!( - unary: LogSigmoid, - ops_float: |tensor: TestTensor<2>| log_sigmoid(tensor) - ); - clone_invariance_test!( - unary: Relu, - ops_float: |tensor: TestTensor<2>| relu(tensor) - ); - clone_invariance_test!( - unary: Gelu, - ops_float: |tensor: TestTensor<2>| gelu(tensor) - ); - clone_invariance_test!( - unary: Silu, - ops_float: |tensor: TestTensor<2>| silu(tensor) - ); - clone_invariance_test!( - unary: Tanh, - ops_float: |tensor: TestTensor<2>| tanh(tensor) - ); - - // Binary ops - clone_invariance_test!( - binary: Add, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.add(rhs) - ); - clone_invariance_test!( - binary: Sub, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.sub(rhs) - ); - clone_invariance_test!( - binary: Div, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.div(rhs) - ); - clone_invariance_test!( - binary: Mul, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.mul(rhs) - ); - clone_invariance_test!( - binary: Matmul, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.matmul(rhs) - ); - clone_invariance_test!( - binary: Equal, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.equal(rhs) - ); - clone_invariance_test!( - binary: Greater, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater(rhs) - ); - clone_invariance_test!( - binary: GreaterEqual, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater_equal(rhs) - ); - clone_invariance_test!( - binary: Lower, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower(rhs) - ); - clone_invariance_test!( - binary: LowerEqual, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower_equal(rhs) - ); - clone_invariance_test!( - binary: Cat, - ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| { - let lhs = lhs.reshape([1usize, 32, 32]); - let rhs = rhs.reshape([1usize, 32, 32]); - - TestTensor::cat(vec![lhs, rhs], 0).reshape([64, 32]) - } - ); - clone_invariance_test!( - binary: Scatter, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.scatter(0, indices, values) - } - ); - clone_invariance_test!( - binary: SliceAssign, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) - } - ); - clone_invariance_test!( - binary: MaskWhere, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_where(mask, values) - } - ); - clone_invariance_test!( - binary: SelectAssign, - ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - let values = values.select(0, indices.clone()); - tensor.select_assign(0, indices, values) - } - ); - } + }; + + (binary: $name:ident, ops_int: $ops:expr) => { + #[test] + #[allow(non_snake_case)] + fn $name() { + struct $name; + + impl CloneInvarianceTest<2> for $name { + type Args = (Data, Data); + + fn args(&self) -> Self::Args { + ( + TestTensor::random([32, 32], Distribution::Uniform(0., 50.)) + .into_data() + .convert(), + // Avoid div by zero. + TestTensor::random([32, 32], Distribution::Uniform(1., 51.)) + .into_data() + .convert(), + ) + } + + fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> Data { + let lhs = TestTensorInt::from_data(lhs_arg.clone().convert()); + let rhs = TestTensorInt::from_data(rhs_arg.clone().convert()); + + if inplace { + $ops(lhs, rhs).into_data().convert() + } else { + let out = $ops(lhs.clone(), rhs.clone()).into_data().convert(); + + lhs.into_data().convert().assert_approx_eq(lhs_arg, 4); + rhs.into_data().convert().assert_approx_eq(rhs_arg, 4); + + out + } + } + } + + CloneInvarianceTest::<2>::check(&$name); + } + }; + } - mod int { - use super::*; + mod float { + use super::*; + + // Unary ops + clone_invariance_test!( + unary: AddScalar, + ops_float: |tensor: TestTensor<2>| tensor.add_scalar(2.0) + ); + clone_invariance_test!( + unary: SubScalar, + ops_float: |tensor: TestTensor<2>| tensor.sub_scalar(2.0) + ); + clone_invariance_test!( + unary: DivScalar, + ops_float: |tensor: TestTensor<2>| tensor.div_scalar(2.0) + ); + clone_invariance_test!( + unary: MulScalar, + ops_float: |tensor: TestTensor<2>| tensor.mul_scalar(2.0) + ); + clone_invariance_test!( + unary: PowScalar, + ops_float: |tensor: TestTensor<2>| tensor.powf(2.0) + ); + clone_invariance_test!( + unary: Sqrt, + ops_float: |tensor: TestTensor<2>| tensor.sqrt() + ); + clone_invariance_test!( + unary: Exp, + ops_float: |tensor: TestTensor<2>| tensor.exp() + ); + clone_invariance_test!( + unary: Neg, + ops_float: |tensor: TestTensor<2>| tensor.neg() + ); + clone_invariance_test!( + unary: MeanDim, + ops_float: |tensor: TestTensor<2>| tensor.mean_dim(1) + ); + clone_invariance_test!( + unary: SumDim, + ops_float: |tensor: TestTensor<2>| tensor.sum_dim(1) + ); + clone_invariance_test!( + unary: Sum, + ops_float: |tensor: TestTensor<2>| tensor.sum().unsqueeze() + ); + clone_invariance_test!( + unary: Mean, + ops_float: |tensor: TestTensor<2>| tensor.mean().unsqueeze() + ); + clone_invariance_test!( + unary: Clamp, + ops_float: |tensor: TestTensor<2>| tensor.clamp(-2., 2.) + ); + clone_invariance_test!( + unary: ClampMin, + ops_float: |tensor: TestTensor<2>| tensor.clamp_min(-2.) + ); + clone_invariance_test!( + unary: ClampMax, + ops_float: |tensor: TestTensor<2>| tensor.clamp_max(2.) + ); + clone_invariance_test!( + unary: Abs, + ops_float: |tensor: TestTensor<2>| tensor.abs() + ); + clone_invariance_test!( + unary: Cos, + ops_float: |tensor: TestTensor<2>| tensor.cos() + ); + clone_invariance_test!( + unary: Sin, + ops_float: |tensor: TestTensor<2>| tensor.sin() + ); + clone_invariance_test!( + unary: Log, + ops_float: |tensor: TestTensor<2>| tensor.log() + ); + clone_invariance_test!( + unary: Log1P, + ops_float: |tensor: TestTensor<2>| tensor.log1p() + ); + clone_invariance_test!( + unary: SwapDims, + ops_float: |tensor: TestTensor<2>| tensor.swap_dims(0, 1) + ); + clone_invariance_test!( + unary: Transpose, + ops_float: |tensor: TestTensor<2>| tensor.transpose() + ); + clone_invariance_test!( + unary: Slice, + ops_float: |tensor: TestTensor<2>| tensor.slice([0..12, 12..24]) + ); + clone_invariance_test!( + unary: Erf, + ops_float: |tensor: TestTensor<2>| tensor.erf() + ); + clone_invariance_test!( + unary: EqualElem, + ops_float: |tensor: TestTensor<2>| tensor.equal_elem(0.5) + ); + clone_invariance_test!( + unary: GreaterElem, + ops_float: |tensor: TestTensor<2>| tensor.greater_elem(0.5) + ); + clone_invariance_test!( + unary: GreaterEqualElem, + ops_float: |tensor: TestTensor<2>| tensor.greater_equal_elem(0.5) + ); + clone_invariance_test!( + unary: LowerElem, + ops_float: |tensor: TestTensor<2>| tensor.lower_elem(0.5) + ); + clone_invariance_test!( + unary: LowerEqualElem, + ops_float: |tensor: TestTensor<2>| tensor.lower_equal_elem(0.5) + ); + clone_invariance_test!( + unary: Argmax, + ops_float: |tensor: TestTensor<2>| tensor.argmax(0) + ); + clone_invariance_test!( + unary: Argmin, + ops_float: |tensor: TestTensor<2>| tensor.argmin(0) + ); + clone_invariance_test!( + unary: Max, + ops_float: |tensor: TestTensor<2>| tensor.max().unsqueeze() + ); + clone_invariance_test!( + unary: Min, + ops_float: |tensor: TestTensor<2>| tensor.min().unsqueeze() + ); + clone_invariance_test!( + unary: MaxDim, + ops_float: |tensor: TestTensor<2>| tensor.max_dim(1) + ); + clone_invariance_test!( + unary: MaxDimWithIndices, + ops_float: |tensor: TestTensor<2>| tensor.max_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDimWithIndices, + ops_float: |tensor: TestTensor<2>| tensor.min_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDim, + ops_float: |tensor: TestTensor<2>| tensor.min_dim(1) + ); + clone_invariance_test!( + unary: Repeat, + ops_float: |tensor: TestTensor<2>| { + tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) + } + ); + clone_invariance_test!( + unary: Reshape, + ops_float: |tensor: TestTensor<2>| { + let shape = tensor.shape(); + let new_shape = [shape.num_elements(), 1]; + tensor.reshape(new_shape) + } + ); + clone_invariance_test!( + unary: Gatter, + ops_float: |tensor: TestTensor<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.gather(0, indices) + } + ); + clone_invariance_test!( + unary: Select, + ops_float: |tensor: TestTensor<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + tensor.select(0, indices) + } + ); + clone_invariance_test!( + unary: MaskFill, + ops_float: |tensor: TestTensor<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_fill(mask, 77.0) + } + ); + + // Activation + clone_invariance_test!( + unary: Softmax, + ops_float: |tensor: TestTensor<2>| softmax(tensor, 1) + ); + clone_invariance_test!( + unary: LogSoftmax, + ops_float: |tensor: TestTensor<2>| log_softmax(tensor, 1) + ); + clone_invariance_test!( + unary: Sigmoid, + ops_float: |tensor: TestTensor<2>| sigmoid(tensor) + ); + clone_invariance_test!( + unary: LogSigmoid, + ops_float: |tensor: TestTensor<2>| log_sigmoid(tensor) + ); + clone_invariance_test!( + unary: Relu, + ops_float: |tensor: TestTensor<2>| relu(tensor) + ); + clone_invariance_test!( + unary: Gelu, + ops_float: |tensor: TestTensor<2>| gelu(tensor) + ); + clone_invariance_test!( + unary: Silu, + ops_float: |tensor: TestTensor<2>| silu(tensor) + ); + clone_invariance_test!( + unary: Tanh, + ops_float: |tensor: TestTensor<2>| tanh(tensor) + ); + + // Binary ops + clone_invariance_test!( + binary: Add, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.add(rhs) + ); + clone_invariance_test!( + binary: Sub, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.sub(rhs) + ); + clone_invariance_test!( + binary: Div, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.div(rhs) + ); + clone_invariance_test!( + binary: Mul, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.mul(rhs) + ); + clone_invariance_test!( + binary: Matmul, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.matmul(rhs) + ); + clone_invariance_test!( + binary: Equal, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.equal(rhs) + ); + clone_invariance_test!( + binary: Greater, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater(rhs) + ); + clone_invariance_test!( + binary: GreaterEqual, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater_equal(rhs) + ); + clone_invariance_test!( + binary: Lower, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower(rhs) + ); + clone_invariance_test!( + binary: LowerEqual, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower_equal(rhs) + ); + clone_invariance_test!( + binary: Cat, + ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| { + let lhs = lhs.reshape([1usize, 32, 32]); + let rhs = rhs.reshape([1usize, 32, 32]); + + TestTensor::cat(vec![lhs, rhs], 0).reshape([64, 32]) + } + ); + clone_invariance_test!( + binary: Scatter, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.scatter(0, indices, values) + } + ); + clone_invariance_test!( + binary: SliceAssign, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) + } + ); + clone_invariance_test!( + binary: MaskWhere, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_where(mask, values) + } + ); + clone_invariance_test!( + binary: SelectAssign, + ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + let values = values.select(0, indices.clone()); + tensor.select_assign(0, indices, values) + } + ); + } - // Unary ops - clone_invariance_test!( - unary: AddScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.add_scalar(2.0) - ); - clone_invariance_test!( - unary: SubScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.sub_scalar(2.0) - ); - clone_invariance_test!( - unary: DivScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.div_scalar(2.0) - ); - clone_invariance_test!( - unary: MulScalar, - ops_int: |tensor: TestTensorInt<2>| tensor.mul_scalar(2.0) - ); - clone_invariance_test!( - unary: Neg, - ops_int: |tensor: TestTensorInt<2>| tensor.neg() - ); - clone_invariance_test!( - unary: MeanDim, - ops_int: |tensor: TestTensorInt<2>| tensor.mean_dim(1) - ); - clone_invariance_test!( - unary: SumDim, - ops_int: |tensor: TestTensorInt<2>| tensor.sum_dim(1) - ); - clone_invariance_test!( - unary: Sum, - ops_int: |tensor: TestTensorInt<2>| tensor.sum().unsqueeze() - ); - clone_invariance_test!( - unary: Mean, - ops_int: |tensor: TestTensorInt<2>| tensor.mean().unsqueeze() - ); - clone_invariance_test!( - unary: Clamp, - ops_int: |tensor: TestTensorInt<2>| tensor.clamp(-2., 2.) - ); - clone_invariance_test!( - unary: ClampMin, - ops_int: |tensor: TestTensorInt<2>| tensor.clamp_min(-2.) - ); - clone_invariance_test!( - unary: ClampMax, - ops_int: |tensor: TestTensorInt<2>| tensor.clamp_max(2.) - ); - clone_invariance_test!( - unary: Abs, - ops_int: |tensor: TestTensorInt<2>| tensor.abs() - ); - clone_invariance_test!( - unary: SwapDims, - ops_int: |tensor: TestTensorInt<2>| tensor.swap_dims(0, 1) - ); - clone_invariance_test!( - unary: Transpose, - ops_int: |tensor: TestTensorInt<2>| tensor.transpose() - ); - clone_invariance_test!( - unary: Slice, - ops_int: |tensor: TestTensorInt<2>| tensor.slice([0..12, 12..24]) - ); - clone_invariance_test!( - unary: EqualElem, - ops_int: |tensor: TestTensorInt<2>| tensor.equal_elem(25) - ); - clone_invariance_test!( - unary: GreaterElem, - ops_int: |tensor: TestTensorInt<2>| tensor.greater_elem(25) - ); - clone_invariance_test!( - unary: GreaterEqualElem, - ops_int: |tensor: TestTensorInt<2>| tensor.greater_equal_elem(25) - ); - clone_invariance_test!( - unary: LowerElem, - ops_int: |tensor: TestTensorInt<2>| tensor.lower_elem(25) - ); - clone_invariance_test!( - unary: LowerEqualElem, - ops_int: |tensor: TestTensorInt<2>| tensor.lower_equal_elem(25) - ); - clone_invariance_test!( - unary: Argmax, - ops_int: |tensor: TestTensorInt<2>| tensor.argmax(0) - ); - clone_invariance_test!( - unary: Argmin, - ops_int: |tensor: TestTensorInt<2>| tensor.argmin(0) - ); - clone_invariance_test!( - unary: Max, - ops_int: |tensor: TestTensorInt<2>| tensor.max().unsqueeze() - ); - clone_invariance_test!( - unary: Min, - ops_int: |tensor: TestTensorInt<2>| tensor.min().unsqueeze() - ); - clone_invariance_test!( - unary: MaxDim, - ops_int: |tensor: TestTensorInt<2>| tensor.max_dim(1) - ); - clone_invariance_test!( - unary: MaxDimWithIndices, - ops_int: |tensor: TestTensorInt<2>| tensor.max_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDimWithIndices, - ops_int: |tensor: TestTensorInt<2>| tensor.min_dim_with_indices(1).0 - ); - clone_invariance_test!( - unary: MinDim, - ops_int: |tensor: TestTensorInt<2>| tensor.min_dim(1) - ); - clone_invariance_test!( - unary: Repeat, - ops_int: |tensor: TestTensorInt<2>| { - tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) - } - ); - clone_invariance_test!( - unary: Reshape, - ops_int: |tensor: TestTensorInt<2>| { - let shape = tensor.shape(); - let new_shape = [shape.num_elements(), 1]; - tensor.reshape(new_shape) - } - ); - clone_invariance_test!( - unary: Gatter, - ops_int: |tensor: TestTensorInt<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.gather(0, indices) - } - ); - clone_invariance_test!( - unary: Select, - ops_int: |tensor: TestTensorInt<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - tensor.select(0, indices) - } - ); - clone_invariance_test!( - unary: MaskFill, - ops_int: |tensor: TestTensorInt<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_fill(mask, 77.0) - } - ); - - // Binary ops - clone_invariance_test!( - binary: Add, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.add(rhs) - ); - clone_invariance_test!( - binary: Sub, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.sub(rhs) - ); - clone_invariance_test!( - binary: Div, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.div(rhs) - ); - clone_invariance_test!( - binary: Mul, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.mul(rhs) - ); - clone_invariance_test!( - binary: Equal, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.equal(rhs) - ); - clone_invariance_test!( - binary: Greater, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater(rhs) - ); - clone_invariance_test!( - binary: GreaterEqual, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater_equal(rhs) - ); - clone_invariance_test!( - binary: Lower, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower(rhs) - ); - clone_invariance_test!( - binary: LowerEqual, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower_equal(rhs) - ); - clone_invariance_test!( - binary: Cat, - ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| { - let lhs = lhs.reshape([1usize, 32, 32]); - let rhs = rhs.reshape([1usize, 32, 32]); - - TestTensorInt::cat(vec![lhs, rhs], 0).reshape([64, 32]) - } - ); - clone_invariance_test!( - binary: Scatter, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - let shape = tensor.shape(); - let indices = TestTensorInt::ones(shape); - tensor.scatter(0, indices, values) - } - ); - clone_invariance_test!( - binary: SliceAssign, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) - } - ); - clone_invariance_test!( - binary: MaskWhere, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - let mask = tensor.clone().greater_elem(0.5); - tensor.mask_where(mask, values) - } - ); - clone_invariance_test!( - binary: SelectAssign, - ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { - let indices = TestTensorInt::from_ints([1, 2, 0, 5]); - let values = values.select(0, indices.clone()); - tensor.select_assign(0, indices, values) - } - ); - } + mod int { + use super::*; + + // Unary ops + clone_invariance_test!( + unary: AddScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.add_scalar(2.0) + ); + clone_invariance_test!( + unary: SubScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.sub_scalar(2.0) + ); + clone_invariance_test!( + unary: DivScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.div_scalar(2.0) + ); + clone_invariance_test!( + unary: MulScalar, + ops_int: |tensor: TestTensorInt<2>| tensor.mul_scalar(2.0) + ); + clone_invariance_test!( + unary: Neg, + ops_int: |tensor: TestTensorInt<2>| tensor.neg() + ); + clone_invariance_test!( + unary: MeanDim, + ops_int: |tensor: TestTensorInt<2>| tensor.mean_dim(1) + ); + clone_invariance_test!( + unary: SumDim, + ops_int: |tensor: TestTensorInt<2>| tensor.sum_dim(1) + ); + clone_invariance_test!( + unary: Sum, + ops_int: |tensor: TestTensorInt<2>| tensor.sum().unsqueeze() + ); + clone_invariance_test!( + unary: Mean, + ops_int: |tensor: TestTensorInt<2>| tensor.mean().unsqueeze() + ); + clone_invariance_test!( + unary: Clamp, + ops_int: |tensor: TestTensorInt<2>| tensor.clamp(-2., 2.) + ); + clone_invariance_test!( + unary: ClampMin, + ops_int: |tensor: TestTensorInt<2>| tensor.clamp_min(-2.) + ); + clone_invariance_test!( + unary: ClampMax, + ops_int: |tensor: TestTensorInt<2>| tensor.clamp_max(2.) + ); + clone_invariance_test!( + unary: Abs, + ops_int: |tensor: TestTensorInt<2>| tensor.abs() + ); + clone_invariance_test!( + unary: SwapDims, + ops_int: |tensor: TestTensorInt<2>| tensor.swap_dims(0, 1) + ); + clone_invariance_test!( + unary: Transpose, + ops_int: |tensor: TestTensorInt<2>| tensor.transpose() + ); + clone_invariance_test!( + unary: Slice, + ops_int: |tensor: TestTensorInt<2>| tensor.slice([0..12, 12..24]) + ); + clone_invariance_test!( + unary: EqualElem, + ops_int: |tensor: TestTensorInt<2>| tensor.equal_elem(25) + ); + clone_invariance_test!( + unary: GreaterElem, + ops_int: |tensor: TestTensorInt<2>| tensor.greater_elem(25) + ); + clone_invariance_test!( + unary: GreaterEqualElem, + ops_int: |tensor: TestTensorInt<2>| tensor.greater_equal_elem(25) + ); + clone_invariance_test!( + unary: LowerElem, + ops_int: |tensor: TestTensorInt<2>| tensor.lower_elem(25) + ); + clone_invariance_test!( + unary: LowerEqualElem, + ops_int: |tensor: TestTensorInt<2>| tensor.lower_equal_elem(25) + ); + clone_invariance_test!( + unary: Argmax, + ops_int: |tensor: TestTensorInt<2>| tensor.argmax(0) + ); + clone_invariance_test!( + unary: Argmin, + ops_int: |tensor: TestTensorInt<2>| tensor.argmin(0) + ); + clone_invariance_test!( + unary: Max, + ops_int: |tensor: TestTensorInt<2>| tensor.max().unsqueeze() + ); + clone_invariance_test!( + unary: Min, + ops_int: |tensor: TestTensorInt<2>| tensor.min().unsqueeze() + ); + clone_invariance_test!( + unary: MaxDim, + ops_int: |tensor: TestTensorInt<2>| tensor.max_dim(1) + ); + clone_invariance_test!( + unary: MaxDimWithIndices, + ops_int: |tensor: TestTensorInt<2>| tensor.max_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDimWithIndices, + ops_int: |tensor: TestTensorInt<2>| tensor.min_dim_with_indices(1).0 + ); + clone_invariance_test!( + unary: MinDim, + ops_int: |tensor: TestTensorInt<2>| tensor.min_dim(1) + ); + clone_invariance_test!( + unary: Repeat, + ops_int: |tensor: TestTensorInt<2>| { + tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) + } + ); + clone_invariance_test!( + unary: Reshape, + ops_int: |tensor: TestTensorInt<2>| { + let shape = tensor.shape(); + let new_shape = [shape.num_elements(), 1]; + tensor.reshape(new_shape) + } + ); + clone_invariance_test!( + unary: Gatter, + ops_int: |tensor: TestTensorInt<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.gather(0, indices) + } + ); + clone_invariance_test!( + unary: Select, + ops_int: |tensor: TestTensorInt<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + tensor.select(0, indices) + } + ); + clone_invariance_test!( + unary: MaskFill, + ops_int: |tensor: TestTensorInt<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_fill(mask, 77.0) + } + ); + + // Binary ops + clone_invariance_test!( + binary: Add, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.add(rhs) + ); + clone_invariance_test!( + binary: Sub, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.sub(rhs) + ); + clone_invariance_test!( + binary: Div, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.div(rhs) + ); + clone_invariance_test!( + binary: Mul, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.mul(rhs) + ); + clone_invariance_test!( + binary: Equal, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.equal(rhs) + ); + clone_invariance_test!( + binary: Greater, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater(rhs) + ); + clone_invariance_test!( + binary: GreaterEqual, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater_equal(rhs) + ); + clone_invariance_test!( + binary: Lower, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower(rhs) + ); + clone_invariance_test!( + binary: LowerEqual, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower_equal(rhs) + ); + clone_invariance_test!( + binary: Cat, + ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| { + let lhs = lhs.reshape([1usize, 32, 32]); + let rhs = rhs.reshape([1usize, 32, 32]); + + TestTensorInt::cat(vec![lhs, rhs], 0).reshape([64, 32]) + } + ); + clone_invariance_test!( + binary: Scatter, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + let shape = tensor.shape(); + let indices = TestTensorInt::ones(shape); + tensor.scatter(0, indices, values) + } + ); + clone_invariance_test!( + binary: SliceAssign, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) + } + ); + clone_invariance_test!( + binary: MaskWhere, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + let mask = tensor.clone().greater_elem(0.5); + tensor.mask_where(mask, values) + } + ); + clone_invariance_test!( + binary: SelectAssign, + ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { + let indices = TestTensorInt::from_ints([1, 2, 0, 5]); + let values = values.select(0, indices.clone()); + tensor.select_assign(0, indices, values) + } + ); + } } diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index 3108659006..1d0cd93807 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -7,79 +7,79 @@ mod stats; #[allow(missing_docs)] #[macro_export] macro_rules! testgen_all { - () => { - // test activation - burn_tensor::testgen_gelu!(); - burn_tensor::testgen_relu!(); - burn_tensor::testgen_softmax!(); - burn_tensor::testgen_sigmoid!(); - burn_tensor::testgen_silu!(); - burn_tensor::testgen_tanh_activation!(); + () => { + // test activation + burn_tensor::testgen_gelu!(); + burn_tensor::testgen_relu!(); + burn_tensor::testgen_softmax!(); + burn_tensor::testgen_sigmoid!(); + burn_tensor::testgen_silu!(); + burn_tensor::testgen_tanh_activation!(); - // test module - burn_tensor::testgen_module_forward!(); - burn_tensor::testgen_module_conv1d!(); - burn_tensor::testgen_module_conv2d!(); - burn_tensor::testgen_module_conv_transpose1d!(); - burn_tensor::testgen_module_conv_transpose2d!(); - burn_tensor::testgen_module_unfold4d!(); - burn_tensor::testgen_module_max_pool1d!(); - burn_tensor::testgen_module_max_pool2d!(); - burn_tensor::testgen_module_avg_pool1d!(); - burn_tensor::testgen_module_avg_pool2d!(); - burn_tensor::testgen_module_adaptive_avg_pool1d!(); - burn_tensor::testgen_module_adaptive_avg_pool2d!(); + // test module + burn_tensor::testgen_module_forward!(); + burn_tensor::testgen_module_conv1d!(); + burn_tensor::testgen_module_conv2d!(); + burn_tensor::testgen_module_conv_transpose1d!(); + burn_tensor::testgen_module_conv_transpose2d!(); + burn_tensor::testgen_module_unfold4d!(); + burn_tensor::testgen_module_max_pool1d!(); + burn_tensor::testgen_module_max_pool2d!(); + burn_tensor::testgen_module_avg_pool1d!(); + burn_tensor::testgen_module_avg_pool2d!(); + burn_tensor::testgen_module_adaptive_avg_pool1d!(); + burn_tensor::testgen_module_adaptive_avg_pool2d!(); - // test ops - burn_tensor::testgen_add!(); - burn_tensor::testgen_aggregation!(); - burn_tensor::testgen_arange!(); - burn_tensor::testgen_arange_step!(); - burn_tensor::testgen_arg!(); - burn_tensor::testgen_cast!(); - burn_tensor::testgen_cat!(); - burn_tensor::testgen_clamp!(); - burn_tensor::testgen_cos!(); - burn_tensor::testgen_create_like!(); - burn_tensor::testgen_div!(); - burn_tensor::testgen_erf!(); - burn_tensor::testgen_exp!(); - burn_tensor::testgen_flatten!(); - burn_tensor::testgen_full!(); - burn_tensor::testgen_gather_scatter!(); - burn_tensor::testgen_init!(); - burn_tensor::testgen_iter_dim!(); - burn_tensor::testgen_log!(); - burn_tensor::testgen_log1p!(); - burn_tensor::testgen_map_comparison!(); - burn_tensor::testgen_mask!(); - burn_tensor::testgen_matmul!(); - burn_tensor::testgen_maxmin!(); - burn_tensor::testgen_mul!(); - burn_tensor::testgen_neg!(); - burn_tensor::testgen_one_hot!(); - burn_tensor::testgen_powf!(); - burn_tensor::testgen_random!(); - burn_tensor::testgen_recip!(); - burn_tensor::testgen_repeat!(); - burn_tensor::testgen_reshape!(); - burn_tensor::testgen_select!(); - burn_tensor::testgen_sin!(); - burn_tensor::testgen_slice!(); - burn_tensor::testgen_sqrt!(); - burn_tensor::testgen_abs!(); - burn_tensor::testgen_squeeze!(); - burn_tensor::testgen_sub!(); - burn_tensor::testgen_tanh!(); - burn_tensor::testgen_transpose!(); + // test ops + burn_tensor::testgen_add!(); + burn_tensor::testgen_aggregation!(); + burn_tensor::testgen_arange!(); + burn_tensor::testgen_arange_step!(); + burn_tensor::testgen_arg!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_clamp!(); + burn_tensor::testgen_cos!(); + burn_tensor::testgen_create_like!(); + burn_tensor::testgen_div!(); + burn_tensor::testgen_erf!(); + burn_tensor::testgen_exp!(); + burn_tensor::testgen_flatten!(); + burn_tensor::testgen_full!(); + burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_init!(); + burn_tensor::testgen_iter_dim!(); + burn_tensor::testgen_log!(); + burn_tensor::testgen_log1p!(); + burn_tensor::testgen_map_comparison!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_matmul!(); + burn_tensor::testgen_maxmin!(); + burn_tensor::testgen_mul!(); + burn_tensor::testgen_neg!(); + burn_tensor::testgen_one_hot!(); + burn_tensor::testgen_powf!(); + burn_tensor::testgen_random!(); + burn_tensor::testgen_recip!(); + burn_tensor::testgen_repeat!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_select!(); + burn_tensor::testgen_sin!(); + burn_tensor::testgen_slice!(); + burn_tensor::testgen_sqrt!(); + burn_tensor::testgen_abs!(); + burn_tensor::testgen_squeeze!(); + burn_tensor::testgen_sub!(); + burn_tensor::testgen_tanh!(); + burn_tensor::testgen_transpose!(); - // test stats - burn_tensor::testgen_var!(); - burn_tensor::testgen_cov!(); - burn_tensor::testgen_diagonal!(); - burn_tensor::testgen_display!(); + // test stats + burn_tensor::testgen_var!(); + burn_tensor::testgen_cov!(); + burn_tensor::testgen_diagonal!(); + burn_tensor::testgen_display!(); - // test clone invariance - burn_tensor::testgen_clone_invariance!(); - }; + // test clone invariance + burn_tensor::testgen_clone_invariance!(); + }; } diff --git a/burn-tensor/src/tests/module/adaptive_avgpool1d.rs b/burn-tensor/src/tests/module/adaptive_avgpool1d.rs index 3173613d21..f655ed8496 100644 --- a/burn-tensor/src/tests/module/adaptive_avgpool1d.rs +++ b/burn-tensor/src/tests/module/adaptive_avgpool1d.rs @@ -1,73 +1,73 @@ #[burn_tensor_testgen::testgen(module_adaptive_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_adaptive_avg_pool1d_simple() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 8, - length_out: 4, - }; + #[test] + fn test_adaptive_avg_pool1d_simple() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 8, + length_out: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5, 2.5, 4.5, 6.5], - [8.5, 10.5, 12.5, 14.5], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5, 2.5, 4.5, 6.5], + [8.5, 10.5, 12.5, 14.5], + ]])); + } - #[test] - fn test_adaptive_avg_pool1d_dyn_filter_size() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 7, - length_out: 3, - }; + #[test] + fn test_adaptive_avg_pool1d_dyn_filter_size() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 7, + length_out: 3, + }; - test.assert_output(TestTensor::from_floats([[ - [1.0, 3.0, 5.0], - [8.0, 10.0, 12.0], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [1.0, 3.0, 5.0], + [8.0, 10.0, 12.0], + ]])); + } - #[test] - fn test_adaptive_avg_pool1d_bigger_output() { - let test = AdaptiveAvgPool1dTestCase { - batch_size: 1, - channels: 2, - length: 4, - length_out: 8, - }; + #[test] + fn test_adaptive_avg_pool1d_bigger_output() { + let test = AdaptiveAvgPool1dTestCase { + batch_size: 1, + channels: 2, + length: 4, + length_out: 8, + }; - test.assert_output(TestTensor::from_floats([[ - [0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0], - [4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0], + [4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0], + ]])); + } - struct AdaptiveAvgPool1dTestCase { - batch_size: usize, - channels: usize, - length: usize, - length_out: usize, - } + struct AdaptiveAvgPool1dTestCase { + batch_size: usize, + channels: usize, + length: usize, + length_out: usize, + } - impl AdaptiveAvgPool1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = adaptive_avg_pool1d(x, self.length_out); + impl AdaptiveAvgPool1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = adaptive_avg_pool1d(x, self.length_out); - y.into_data().assert_approx_eq(&output.into_data(), 3); + y.into_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/adaptive_avgpool2d.rs b/burn-tensor/src/tests/module/adaptive_avgpool2d.rs index a484cdc9fa..2711948388 100644 --- a/burn-tensor/src/tests/module/adaptive_avgpool2d.rs +++ b/burn-tensor/src/tests/module/adaptive_avgpool2d.rs @@ -1,103 +1,103 @@ #[burn_tensor_testgen::testgen(module_adaptive_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::adaptive_avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::adaptive_avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_adaptive_avg_pool2d_simple() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 8, - width: 6, - height_out: 4, - width_out: 4, - }; + #[test] + fn test_adaptive_avg_pool2d_simple() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 8, + width: 6, + height_out: 4, + width_out: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [3.5000, 4.5000, 6.5000, 7.5000], - [15.5000, 16.5000, 18.5000, 19.5000], - [27.5000, 28.5000, 30.5000, 31.5000], - [39.5000, 40.5000, 42.5000, 43.5000], - ], - [ - [51.5000, 52.5000, 54.5000, 55.5000], - [63.5000, 64.5000, 66.5000, 67.5000], - [75.5000, 76.5000, 78.5000, 79.5000], - [87.5000, 88.5000, 90.5000, 91.5000], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [3.5000, 4.5000, 6.5000, 7.5000], + [15.5000, 16.5000, 18.5000, 19.5000], + [27.5000, 28.5000, 30.5000, 31.5000], + [39.5000, 40.5000, 42.5000, 43.5000], + ], + [ + [51.5000, 52.5000, 54.5000, 55.5000], + [63.5000, 64.5000, 66.5000, 67.5000], + [75.5000, 76.5000, 78.5000, 79.5000], + [87.5000, 88.5000, 90.5000, 91.5000], + ], + ]])); + } - #[test] - fn test_adaptive_avg_pool2d_dyn_filter_size() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 5, - width: 7, - height_out: 3, - width_out: 2, - }; + #[test] + fn test_adaptive_avg_pool2d_dyn_filter_size() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 5, + width: 7, + height_out: 3, + width_out: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]], - [[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]], + [[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]], + ]])); + } - #[test] - fn test_adaptive_avg_pool2d_bigger_output() { - let test = AdaptiveAvgPool2dTestCase { - batch_size: 1, - channels: 2, - height: 4, - width: 3, - height_out: 5, - width_out: 4, - }; + #[test] + fn test_adaptive_avg_pool2d_bigger_output() { + let test = AdaptiveAvgPool2dTestCase { + batch_size: 1, + channels: 2, + height: 4, + width: 3, + height_out: 5, + width_out: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [0.0000, 0.5000, 1.5000, 2.0000], - [1.5000, 2.0000, 3.0000, 3.5000], - [4.5000, 5.0000, 6.0000, 6.5000], - [7.5000, 8.0000, 9.0000, 9.5000], - [9.0000, 9.5000, 10.5000, 11.0000], - ], - [ - [12.0000, 12.5000, 13.5000, 14.0000], - [13.5000, 14.0000, 15.0000, 15.5000], - [16.5000, 17.0000, 18.0000, 18.5000], - [19.5000, 20.0000, 21.0000, 21.5000], - [21.0000, 21.5000, 22.5000, 23.0000], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [0.0000, 0.5000, 1.5000, 2.0000], + [1.5000, 2.0000, 3.0000, 3.5000], + [4.5000, 5.0000, 6.0000, 6.5000], + [7.5000, 8.0000, 9.0000, 9.5000], + [9.0000, 9.5000, 10.5000, 11.0000], + ], + [ + [12.0000, 12.5000, 13.5000, 14.0000], + [13.5000, 14.0000, 15.0000, 15.5000], + [16.5000, 17.0000, 18.0000, 18.5000], + [19.5000, 20.0000, 21.0000, 21.5000], + [21.0000, 21.5000, 22.5000, 23.0000], + ], + ]])); + } - struct AdaptiveAvgPool2dTestCase { - batch_size: usize, - channels: usize, - height: usize, - width: usize, - height_out: usize, - width_out: usize, - } + struct AdaptiveAvgPool2dTestCase { + batch_size: usize, + channels: usize, + height: usize, + width: usize, + height_out: usize, + width_out: usize, + } - impl AdaptiveAvgPool2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]); + impl AdaptiveAvgPool2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/avgpool1d.rs b/burn-tensor/src/tests/module/avgpool1d.rs index 0706bfc8fc..ce1a95fd55 100644 --- a/burn-tensor/src/tests/module/avgpool1d.rs +++ b/burn-tensor/src/tests/module/avgpool1d.rs @@ -1,88 +1,88 @@ #[burn_tensor_testgen::testgen(module_avg_pool1d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool1d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool1d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool1d_simple() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 1, - kernel_size: 3, - padding: 0, - stride: 1, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_simple() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 1, + kernel_size: 3, + padding: 0, + stride: 1, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[1., 2., 3., 4.]]])); - } + test.assert_output(TestTensor::from_floats([[[1., 2., 3., 4.]]])); + } - #[test] - fn test_avg_pool1d_complex() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool1d_complex() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[ - [0.3333, 2.0000, 4.0000], - [4.3333, 8.0000, 10.0000], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.3333, 2.0000, 4.0000], + [4.3333, 8.0000, 10.0000], + ]])); + } - #[test] - fn test_avg_pool1d_complex_dont_count_pad() { - let test = AvgPool1dTestCase { - batch_size: 1, - channels: 2, - kernel_size: 3, - padding: 1, - stride: 2, - length: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool1d_complex_dont_count_pad() { + let test = AvgPool1dTestCase { + batch_size: 1, + channels: 2, + kernel_size: 3, + padding: 1, + stride: 2, + length: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[ - [0.5000, 2.0000, 4.0000], - [6.5000, 8.0000, 10.0000], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0.5000, 2.0000, 4.0000], + [6.5000, 8.0000, 10.0000], + ]])); + } - struct AvgPool1dTestCase { - batch_size: usize, - channels: usize, - kernel_size: usize, - padding: usize, - stride: usize, - length: usize, - count_include_pad: bool, - } + struct AvgPool1dTestCase { + batch_size: usize, + channels: usize, + kernel_size: usize, + padding: usize, + stride: usize, + length: usize, + count_include_pad: bool, + } - impl AvgPool1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.length]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = avg_pool1d( - x, - self.kernel_size, - self.stride, - self.padding, - self.count_include_pad, - ); + impl AvgPool1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.length]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = avg_pool1d( + x, + self.kernel_size, + self.stride, + self.padding, + self.count_include_pad, + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/avgpool2d.rs b/burn-tensor/src/tests/module/avgpool2d.rs index ca9ffcf321..0207014326 100644 --- a/burn-tensor/src/tests/module/avgpool2d.rs +++ b/burn-tensor/src/tests/module/avgpool2d.rs @@ -1,113 +1,113 @@ #[burn_tensor_testgen::testgen(module_avg_pool2d)] mod tests { - use super::*; - use burn_tensor::module::avg_pool2d; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::avg_pool2d; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_avg_pool2d_simple() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - height: 6, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_simple() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + height: 6, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [7., 8., 9., 10.], - [13., 14., 15., 16.], - [19., 20., 21., 22.], - [25., 26., 27., 28.], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [7., 8., 9., 10.], + [13., 14., 15., 16.], + [19., 20., 21., 22.], + [25., 26., 27., 28.], + ]]])); + } - #[test] - fn test_avg_pool2d_complex() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: true, - }; + #[test] + fn test_avg_pool2d_complex() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: true, + }; - test.assert_output(TestTensor::from_floats([[[ - [1.1667, 3.0000, 4.3333, 2.5000], - [3.2500, 7.5000, 9.5000, 5.2500], - [6.2500, 13.5000, 15.5000, 8.2500], - [5.1667, 11.0000, 12.3333, 6.5000], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [1.1667, 3.0000, 4.3333, 2.5000], + [3.2500, 7.5000, 9.5000, 5.2500], + [6.2500, 13.5000, 15.5000, 8.2500], + [5.1667, 11.0000, 12.3333, 6.5000], + ]]])); + } - #[test] - fn test_avg_pool2d_complex_dont_include_pad() { - let test = AvgPool2dTestCase { - batch_size: 1, - channels: 1, - kernel_size_1: 3, - kernel_size_2: 4, - padding_1: 1, - padding_2: 2, - stride_1: 1, - stride_2: 2, - height: 4, - width: 6, - count_include_pad: false, - }; + #[test] + fn test_avg_pool2d_complex_dont_include_pad() { + let test = AvgPool2dTestCase { + batch_size: 1, + channels: 1, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + height: 4, + width: 6, + count_include_pad: false, + }; - test.assert_output(TestTensor::from_floats([[[ - [3.5000, 4.5000, 6.5000, 7.5000], - [6.5000, 7.5000, 9.5000, 10.5000], - [12.5000, 13.5000, 15.5000, 16.5000], - [15.5000, 16.5000, 18.5000, 19.5000], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [3.5000, 4.5000, 6.5000, 7.5000], + [6.5000, 7.5000, 9.5000, 10.5000], + [12.5000, 13.5000, 15.5000, 16.5000], + [15.5000, 16.5000, 18.5000, 19.5000], + ]]])); + } - struct AvgPool2dTestCase { - batch_size: usize, - channels: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - height: usize, - width: usize, - count_include_pad: bool, - } + struct AvgPool2dTestCase { + batch_size: usize, + channels: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + height: usize, + width: usize, + count_include_pad: bool, + } - impl AvgPool2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = avg_pool2d( - x, - [self.kernel_size_1, self.kernel_size_2], - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - self.count_include_pad, - ); + impl AvgPool2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = avg_pool2d( + x, + [self.kernel_size_1, self.kernel_size_2], + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + self.count_include_pad, + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/conv1d.rs b/burn-tensor/src/tests/module/conv1d.rs index 77bd82c10d..662ba8a4bd 100644 --- a/burn-tensor/src/tests/module/conv1d.rs +++ b/burn-tensor/src/tests/module/conv1d.rs @@ -1,135 +1,135 @@ #[burn_tensor_testgen::testgen(module_conv1d)] mod tests { - use super::*; - use burn_tensor::module::conv1d; - use burn_tensor::ops::ConvOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv1d; + use burn_tensor::ops::ConvOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv1d_simple() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv1d_simple() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[43., 67., 82., 49.], [104., 176., 227., 158.]], - [[139., 187., 202., 113.], [392., 584., 635., 414.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[43., 67., 82., 49.], [104., 176., 227., 158.]], + [[139., 187., 202., 113.], [392., 584., 635., 414.]], + ])); + } - #[test] - fn test_conv1d_dilation() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 2, - groups: 1, - length: 4, - }; + #[test] + fn test_conv1d_dilation() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 2, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[62., 38.], [159., 111.]], - [[158., 102.], [447., 367.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[62., 38.], [159., 111.]], + [[158., 102.], [447., 367.]], + ])); + } - #[test] - fn test_conv1d_groups() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - stride: 1, - dilation: 1, - groups: 2, - length: 4, - }; + #[test] + fn test_conv1d_groups() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 1, + groups: 2, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[2., 5., 8., 3.], [42., 63., 75., 47.]], - [[26., 29., 32., 11.], [114., 159., 171., 103.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[2., 5., 8., 3.], [42., 63., 75., 47.]], + [[26., 29., 32., 11.], [114., 159., 171., 103.]], + ])); + } - #[test] - fn test_conv1d_complex() { - let test = Conv1dTestCase { - batch_size: 2, - channels_in: 3, - channels_out: 4, - kernel_size: 3, - padding: 1, - stride: 2, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv1d_complex() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 3, + channels_out: 4, + kernel_size: 3, + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([ - [[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]], - [[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]], - ])); - } + test.assert_output(TestTensor::from_floats([ + [[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]], + [[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]], + ])); + } - struct Conv1dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size: usize, - padding: usize, - stride: usize, - dilation: usize, - groups: usize, - length: usize, - } + struct Conv1dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size: usize, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + length: usize, + } - impl Conv1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size, - ]); - let weight = TestTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv1d( - x, - weight, - Some(bias), - ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), - ); + impl Conv1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size, + ]); + let weight = TestTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv1d( + x, + weight, + Some(bias), + ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/conv2d.rs b/burn-tensor/src/tests/module/conv2d.rs index 7d92170fdc..ba7292ea39 100644 --- a/burn-tensor/src/tests/module/conv2d.rs +++ b/burn-tensor/src/tests/module/conv2d.rs @@ -1,165 +1,165 @@ #[burn_tensor_testgen::testgen(module_conv2d)] mod tests { - use super::*; - use burn_tensor::module::conv2d; - use burn_tensor::ops::ConvOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv2d; + use burn_tensor::ops::ConvOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv2d_simple() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; + #[test] + fn test_conv2d_simple() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [1196., 1796., 1916., 1264.], - [1881., 2793., 2946., 1923.], - [2313., 3405., 3558., 2307.], - [1424., 2072., 2156., 1380.], - ], - [ - [2709., 4173., 4509., 3065.], - [4582., 7006., 7483., 5056.], - [5878., 8914., 9391., 6304.], - [4089., 6177., 6477., 4333.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [1196., 1796., 1916., 1264.], + [1881., 2793., 2946., 1923.], + [2313., 3405., 3558., 2307.], + [1424., 2072., 2156., 1380.], + ], + [ + [2709., 4173., 4509., 3065.], + [4582., 7006., 7483., 5056.], + [5878., 8914., 9391., 6304.], + [4089., 6177., 6477., 4333.], + ], + ]])); + } - #[test] - fn test_conv2d_groups() { - let test = Conv2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 5, - width: 5, - }; + #[test] + fn test_conv2d_groups() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 5, + width: 5, + }; - test.assert_output(TestTensor::from_floats([[ - [[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]], - [ - [3724., 3841., 3958.], - [4309., 4426., 4543.], - [4894., 5011., 5128.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]], + [ + [3724., 3841., 3958.], + [4309., 4426., 4543.], + [4894., 5011., 5128.], + ], + ]])); + } - #[test] - fn test_conv2d_complex() { - let test = Conv2dTestCase { - batch_size: 2, - channels_in: 3, - channels_out: 4, - kernel_size_1: 3, - kernel_size_2: 2, - padding_1: 1, - padding_2: 2, - stride_1: 2, - stride_2: 3, - dilation_1: 1, - dilation_2: 2, - groups: 1, - height: 4, - width: 5, - }; + #[test] + fn test_conv2d_complex() { + let test = Conv2dTestCase { + batch_size: 2, + channels_in: 3, + channels_out: 4, + kernel_size_1: 3, + kernel_size_2: 2, + padding_1: 1, + padding_2: 2, + stride_1: 2, + stride_2: 3, + dilation_1: 1, + dilation_2: 2, + groups: 1, + height: 4, + width: 5, + }; - test.assert_output(TestTensor::from_floats([ - [ - [[1845., 3789., 1926.], [3210., 6465., 3228.]], - [[4276., 9082., 4789.], [8071., 16834., 8737.]], - [[6707., 14375., 7652.], [12932., 27203., 14246.]], - [[9138., 19668., 10515.], [17793., 37572., 19755.]], - ], - [ - [[5445., 10629., 5166.], [8070., 15645., 7548.]], - [[14356., 28882., 14509.], [22651., 45454., 22777.]], - [[23267., 47135., 23852.], [37232., 75263., 38006.]], - [[32178., 65388., 33195.], [51813., 105072., 53235.]], - ], - ])); - } + test.assert_output(TestTensor::from_floats([ + [ + [[1845., 3789., 1926.], [3210., 6465., 3228.]], + [[4276., 9082., 4789.], [8071., 16834., 8737.]], + [[6707., 14375., 7652.], [12932., 27203., 14246.]], + [[9138., 19668., 10515.], [17793., 37572., 19755.]], + ], + [ + [[5445., 10629., 5166.], [8070., 15645., 7548.]], + [[14356., 28882., 14509.], [22651., 45454., 22777.]], + [[23267., 47135., 23852.], [37232., 75263., 38006.]], + [[32178., 65388., 33195.], [51813., 105072., 53235.]], + ], + ])); + } - struct Conv2dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - stride_1: usize, - stride_2: usize, - dilation_1: usize, - dilation_2: usize, - groups: usize, - height: usize, - width: usize, - } + struct Conv2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + groups: usize, + height: usize, + width: usize, + } - impl Conv2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let shape_weight = Shape::new([ - self.channels_out, - self.channels_in / self.groups, - self.kernel_size_1, - self.kernel_size_2, - ]); - let weight = TestTensor::from_data( - TestTensorInt::arange(0..shape_weight.num_elements()) - .reshape(shape_weight) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv2d( - x, - weight, - Some(bias), - ConvOptions::new( - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - [self.dilation_1, self.dilation_2], - self.groups, - ), - ); + impl Conv2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let weight = TestTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements()) + .reshape(shape_weight) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv2d( + x, + weight, + Some(bias), + ConvOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], + self.groups, + ), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/conv_transpose1d.rs b/burn-tensor/src/tests/module/conv_transpose1d.rs index d7b487869d..349e1f1cfd 100644 --- a/burn-tensor/src/tests/module/conv_transpose1d.rs +++ b/burn-tensor/src/tests/module/conv_transpose1d.rs @@ -1,146 +1,146 @@ #[burn_tensor_testgen::testgen(module_conv_transpose1d)] mod tests { - use super::*; - use burn_tensor::module::conv_transpose1d; - use burn_tensor::ops::ConvTransposeOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv_transpose1d; + use burn_tensor::ops::ConvTransposeOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv_transpose1d_diff_channels() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 3, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv_transpose1d_diff_channels() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [270., 453., 516., 387.], - [352., 589., 679., 505.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [270., 453., 516., 387.], + [352., 589., 679., 505.], + ]])); + } - #[test] - fn test_conv_transpose1d_stride() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 1, - stride: 2, - dilation: 1, - groups: 1, - length: 4, - }; + #[test] + fn test_conv_transpose1d_stride() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 1, + stride: 2, + dilation: 1, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [28., 62., 36., 78., 44., 94., 52., 62.], - [41., 93., 55., 121., 69., 149., 83., 93.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [28., 62., 36., 78., 44., 94., 52., 62.], + [41., 93., 55., 121., 69., 149., 83., 93.], + ]])); + } - #[test] - fn test_conv_transpose1d_dilation() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 0, - stride: 1, - dilation: 2, - groups: 1, - length: 4, - }; + #[test] + fn test_conv_transpose1d_dilation() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 0, + stride: 1, + dilation: 2, + groups: 1, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [30., 64., 78., 76., 94., 52.], - [49., 101., 127., 113., 143., 77.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [30., 64., 78., 76., 94., 52.], + [49., 101., 127., 113., 143., 77.], + ]])); + } - #[test] - fn test_conv_transpose1d_groups() { - let test = ConvTranspose1dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size: 3, - padding: 1, - padding_out: 0, - stride: 1, - dilation: 1, - groups: 2, - length: 4, - }; + #[test] + fn test_conv_transpose1d_groups() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 0, + stride: 1, + dilation: 1, + groups: 2, + length: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [0., 1., 4., 7.], - [32., 59., 71., 59.], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [0., 1., 4., 7.], + [32., 59., 71., 59.], + ]])); + } - struct ConvTranspose1dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size: usize, - padding: usize, - padding_out: usize, - stride: usize, - dilation: usize, - groups: usize, - length: usize, - } + struct ConvTranspose1dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size: usize, + padding: usize, + padding_out: usize, + stride: usize, + dilation: usize, + groups: usize, + length: usize, + } - impl ConvTranspose1dTestCase { - fn assert_output(self, y: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); - let shape_weights = Shape::new([ - self.channels_in, - self.channels_out / self.groups, - self.kernel_size, - ]); - let weights = TestTensor::from_data( - TestTensorInt::arange(0..shape_weights.num_elements()) - .reshape(shape_weights) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv_transpose1d( - x, - weights, - Some(bias), - ConvTransposeOptions::new( - [self.stride], - [self.padding], - [self.padding_out], - [self.dilation], - self.groups, - ), - ); + impl ConvTranspose1dTestCase { + fn assert_output(self, y: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); + let shape_weights = Shape::new([ + self.channels_in, + self.channels_out / self.groups, + self.kernel_size, + ]); + let weights = TestTensor::from_data( + TestTensorInt::arange(0..shape_weights.num_elements()) + .reshape(shape_weights) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv_transpose1d( + x, + weights, + Some(bias), + ConvTransposeOptions::new( + [self.stride], + [self.padding], + [self.padding_out], + [self.dilation], + self.groups, + ), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/conv_transpose2d.rs b/burn-tensor/src/tests/module/conv_transpose2d.rs index 4fb76daf44..d8b3a3e05d 100644 --- a/burn-tensor/src/tests/module/conv_transpose2d.rs +++ b/burn-tensor/src/tests/module/conv_transpose2d.rs @@ -1,336 +1,336 @@ #[burn_tensor_testgen::testgen(module_conv_transpose2d)] mod tests { - use super::*; - use burn_tensor::module::conv_transpose2d; - use burn_tensor::ops::ConvTransposeOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::conv_transpose2d; + use burn_tensor::ops::ConvTransposeOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_conv_transpose2d_simple_1() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 1, - channels_out: 1, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_simple_1() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 1, + channels_out: 1, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[[[5.0, 11.0], [23.0, 29.0]]]])); - } - #[test] - fn test_conv_transpose2d_simple_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 3, - channels_out: 3, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; + test.assert_output(TestTensor::from_floats([[[[5.0, 11.0], [23.0, 29.0]]]])); + } + #[test] + fn test_conv_transpose2d_simple_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [9855., 15207., 15738., 10797.], - [16290., 25119., 25956., 17793.], - [18486., 28467., 29304., 20061.], - [13593., 20913., 21498., 14703.], - ], - [ - [11854., 18286., 18979., 13012.], - [19612., 30223., 31303., 21439.], - [22456., 34543., 35623., 24355.], - [16456., 25288., 26035., 17782.], - ], - [ - [13853., 21365., 22220., 15227.], - [22934., 35327., 36650., 25085.], - [26426., 40619., 41942., 28649.], - [19319., 29663., 30572., 20861.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [9855., 15207., 15738., 10797.], + [16290., 25119., 25956., 17793.], + [18486., 28467., 29304., 20061.], + [13593., 20913., 21498., 14703.], + ], + [ + [11854., 18286., 18979., 13012.], + [19612., 30223., 31303., 21439.], + [22456., 34543., 35623., 24355.], + [16456., 25288., 26035., 17782.], + ], + [ + [13853., 21365., 22220., 15227.], + [22934., 35327., 36650., 25085.], + [26426., 40619., 41942., 28649.], + [19319., 29663., 30572., 20861.], + ], + ]])); + } - #[test] - fn test_conv_transpose2d_stride_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 1, - channels_out: 1, - kernel_size_1: 2, - kernel_size_2: 2, - padding_1: 0, - padding_2: 0, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 2, - stride_2: 2, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_stride_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 1, + channels_out: 1, + kernel_size_1: 2, + kernel_size_2: 2, + padding_1: 0, + padding_2: 0, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 2, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[[ - [0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 2.0, 3.0], - [0.0, 2.0, 0.0, 3.0], - [4.0, 6.0, 6.0, 9.0], - ]]])); - } + test.assert_output(TestTensor::from_floats([[[ + [0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 2.0, 3.0], + [0.0, 2.0, 0.0, 3.0], + [4.0, 6.0, 6.0, 9.0], + ]]])); + } - #[test] - fn test_conv_transpose2d_dilation_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 1, - padding_out_2: 1, - stride_1: 1, - stride_2: 1, - dilation_1: 2, - dilation_2: 2, - groups: 1, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_dilation_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 1, + padding_out_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 2, + dilation_2: 2, + groups: 1, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [126., 116., 136., 124., 146.], - [108., 88., 114., 92., 120.], - [156., 140., 166., 148., 176.], - [126., 100., 132., 104., 138.], - [186., 164., 196., 172., 206.], - ], - [ - [217., 189., 227., 197., 237.], - [163., 125., 169., 129., 175.], - [247., 213., 257., 221., 267.], - [181., 137., 187., 141., 193.], - [277., 237., 287., 245., 297.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [126., 116., 136., 124., 146.], + [108., 88., 114., 92., 120.], + [156., 140., 166., 148., 176.], + [126., 100., 132., 104., 138.], + [186., 164., 196., 172., 206.], + ], + [ + [217., 189., 227., 197., 237.], + [163., 125., 169., 129., 175.], + [247., 213., 257., 221., 267.], + [181., 137., 187., 141., 193.], + [277., 237., 287., 245., 297.], + ], + ]])); + } - #[test] - fn test_conv_transpose2d_stride2_out_padding() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 1, - padding_out_2: 1, - stride_1: 2, - stride_2: 2, - dilation_1: 1, - dilation_2: 1, - groups: 1, - height: 4, - width: 4, - }; + #[test] + fn test_conv_transpose2d_stride2_out_padding() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 1, + padding_out_2: 1, + stride_1: 2, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [352., 728., 378., 780., 404., 832., 430., 452.], - [784., 1616., 836., 1720., 888., 1824., 940., 992.], - [456., 936., 482., 988., 508., 1040., 534., 564.], - [992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.], - [560., 1144., 586., 1196., 612., 1248., 638., 676.], - [1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.], - [664., 1352., 690., 1404., 716., 1456., 742., 788.], - [784., 1598., 816., 1662., 848., 1726., 880., 926.], - ], - [ - [497., 1035., 541., 1123., 585., 1211., 629., 651.], - [1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.], - [673., 1387., 717., 1475., 761., 1563., 805., 835.], - [1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.], - [849., 1739., 893., 1827., 937., 1915., 981., 1019.], - [1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.], - [1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.], - [1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [352., 728., 378., 780., 404., 832., 430., 452.], + [784., 1616., 836., 1720., 888., 1824., 940., 992.], + [456., 936., 482., 988., 508., 1040., 534., 564.], + [992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.], + [560., 1144., 586., 1196., 612., 1248., 638., 676.], + [1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.], + [664., 1352., 690., 1404., 716., 1456., 742., 788.], + [784., 1598., 816., 1662., 848., 1726., 880., 926.], + ], + [ + [497., 1035., 541., 1123., 585., 1211., 629., 651.], + [1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.], + [673., 1387., 717., 1475., 761., 1563., 805., 835.], + [1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.], + [849., 1739., 893., 1827., 937., 1915., 981., 1019.], + [1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.], + [1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.], + [1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.], + ], + ]])); + } - #[test] - fn test_conv_transpose2d_groups_2() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 2, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 1, - padding_2: 1, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_groups_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [[5., 11.], [23., 29.]], - [[236., 258.], [302., 324.]], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [[5., 11.], [23., 29.]], + [[236., 258.], [302., 324.]], + ]])); + } - #[test] - fn test_conv_transpose2d_groups_different_channels() { - let test = ConvTranspose2dTestCase { - batch_size: 1, - channels_in: 2, - channels_out: 6, - kernel_size_1: 3, - kernel_size_2: 3, - padding_1: 0, - padding_2: 0, - padding_out_1: 0, - padding_out_2: 0, - stride_1: 1, - stride_2: 1, - dilation_1: 1, - dilation_2: 1, - groups: 2, - height: 2, - width: 2, - }; + #[test] + fn test_conv_transpose2d_groups_different_channels() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 6, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 2, + width: 2, + }; - test.assert_output(TestTensor::from_floats([[ - [ - [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00], - [0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01], - [6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01], - [1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01], - ], - [ - [1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01], - [1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01], - [2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01], - [3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01], - ], - [ - [2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01], - [3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01], - [4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01], - [5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01], - ], - [ - [1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02], - [2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02], - [3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02], - [2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02], - ], - [ - [1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02], - [3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02], - [4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02], - [2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02], - ], - [ - [1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02], - [4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02], - [4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02], - [3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02], - ], - ]])); - } + test.assert_output(TestTensor::from_floats([[ + [ + [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00], + [0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01], + [6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01], + [1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01], + ], + [ + [1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01], + [1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01], + [2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01], + [3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01], + ], + [ + [2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01], + [3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01], + [4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01], + [5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01], + ], + [ + [1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02], + [2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02], + [3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02], + [2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02], + ], + [ + [1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02], + [3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02], + [4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02], + [2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02], + ], + [ + [1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02], + [4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02], + [4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02], + [3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02], + ], + ]])); + } - struct ConvTranspose2dTestCase { - batch_size: usize, - channels_in: usize, - channels_out: usize, - kernel_size_1: usize, - kernel_size_2: usize, - padding_1: usize, - padding_2: usize, - padding_out_1: usize, - padding_out_2: usize, - stride_1: usize, - stride_2: usize, - dilation_1: usize, - dilation_2: usize, - groups: usize, - height: usize, - width: usize, - } + struct ConvTranspose2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + padding_out_1: usize, + padding_out_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + groups: usize, + height: usize, + width: usize, + } - impl ConvTranspose2dTestCase { - fn assert_output(self, y: TestTensor<4>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let shape_weights = Shape::new([ - self.channels_in, - self.channels_out / self.groups, - self.kernel_size_1, - self.kernel_size_2, - ]); - let weights = TestTensor::from_data( - TestTensorInt::arange(0..shape_weights.num_elements()) - .reshape(shape_weights) - .into_data() - .convert(), - ); - let bias = TestTensor::from_data( - TestTensorInt::arange(0..self.channels_out) - .into_data() - .convert(), - ); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); - let output = conv_transpose2d( - x, - weights, - Some(bias), - ConvTransposeOptions::new( - [self.stride_1, self.stride_2], - [self.padding_1, self.padding_2], - [self.padding_out_1, self.padding_out_2], - [self.dilation_1, self.dilation_2], - self.groups, - ), - ); + impl ConvTranspose2dTestCase { + fn assert_output(self, y: TestTensor<4>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_weights = Shape::new([ + self.channels_in, + self.channels_out / self.groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let weights = TestTensor::from_data( + TestTensorInt::arange(0..shape_weights.num_elements()) + .reshape(shape_weights) + .into_data() + .convert(), + ); + let bias = TestTensor::from_data( + TestTensorInt::arange(0..self.channels_out) + .into_data() + .convert(), + ); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); + let output = conv_transpose2d( + x, + weights, + Some(bias), + ConvTransposeOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.padding_out_1, self.padding_out_2], + [self.dilation_1, self.dilation_2], + self.groups, + ), + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/module/forward.rs b/burn-tensor/src/tests/module/forward.rs index 2ea81da5e4..7ff629140a 100644 --- a/burn-tensor/src/tests/module/forward.rs +++ b/burn-tensor/src/tests/module/forward.rs @@ -1,20 +1,20 @@ #[burn_tensor_testgen::testgen(module_forward)] mod tests { - use super::*; - use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; + use super::*; + use burn_tensor::{backend::Backend, module::embedding, Data, Int, Tensor}; - #[test] - fn test_embedding_forward() { - let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = Data::from([[0, 1], [1, 1]]); - let weights = Tensor::::from_data(weights); - let indices = Tensor::::from_data(indices); + #[test] + fn test_embedding_forward() { + let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = Data::from([[0, 1], [1, 1]]); + let weights = Tensor::::from_data(weights); + let indices = Tensor::::from_data(indices); - let output = embedding(weights, indices); - let expected = Data::from([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], - ]); - assert_eq!(output.to_data(), expected); - } + let output = embedding(weights, indices); + let expected = Data::from([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], + ]); + assert_eq!(output.to_data(), expected); + } } diff --git a/burn-tensor/src/tests/module/maxpool1d.rs b/burn-tensor/src/tests/module/maxpool1d.rs index 89cff90ec4..97c26129da 100644 --- a/burn-tensor/src/tests/module/maxpool1d.rs +++ b/burn-tensor/src/tests/module/maxpool1d.rs @@ -1,116 +1,116 @@ #[burn_tensor_testgen::testgen(module_max_pool1d)] mod tests { - use super::*; - use burn_tensor::module::{max_pool1d, max_pool1d_with_indices}; - use burn_tensor::{backend::Backend, Data, Tensor}; - - type IntElem = ::IntElem; - - #[test] - fn test_max_pool1d_simple() { - let kernel_size = 3; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[ - [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], - [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], - ]]); - let y = TestTensor::from_floats([[ - [0.9861, 0.5474, 0.4477, 0.8221], - [0.949, 0.949, 0.949, 0.789], - ]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_different_padding_stride_kernel() { - let kernel_size = 3; - let padding = 1; - let stride = 2; - let dilation = 1; - - let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]); - let y = TestTensor::from_floats([[[0.6309, 0.6998]]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_with_neg() { - let kernel_size = 3; - let padding = 1; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]); - let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_with_dilation() { - let kernel_size = 2; - let padding = 1; - let stride = 1; - let dilation = 2; - - let x = TestTensor::from_floats([[ - [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], - [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], - ]]); - let y = TestTensor::from_floats([[ - [0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548], - [0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537], - ]]); - - let output = max_pool1d(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - } - - #[test] - fn test_max_pool1d_with_indices() { - let kernel_size = 2; - let padding = 0; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]); - let indices = Data::::from([[[1, 1, 3]]]); - let y = TestTensor::from_floats([[[0.6386, 0.6386, 0.5742]]]); - - let (output, output_indices) = - max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } - - #[test] - fn test_max_pool1d_complex() { - let kernel_size = 4; - let padding = 2; - let stride = 1; - let dilation = 1; - - let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]); - let indices = Data::::from([[[0, 2, 3, 3, 3, 3]]]); - let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]); - - let (output, output_indices) = - max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); - - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } + use super::*; + use burn_tensor::module::{max_pool1d, max_pool1d_with_indices}; + use burn_tensor::{backend::Backend, Data, Tensor}; + + type IntElem = ::IntElem; + + #[test] + fn test_max_pool1d_simple() { + let kernel_size = 3; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + ]]); + let y = TestTensor::from_floats([[ + [0.9861, 0.5474, 0.4477, 0.8221], + [0.949, 0.949, 0.949, 0.789], + ]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_different_padding_stride_kernel() { + let kernel_size = 3; + let padding = 1; + let stride = 2; + let dilation = 1; + + let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]); + let y = TestTensor::from_floats([[[0.6309, 0.6998]]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_neg() { + let kernel_size = 3; + let padding = 1; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]); + let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_dilation() { + let kernel_size = 2; + let padding = 1; + let stride = 1; + let dilation = 2; + + let x = TestTensor::from_floats([[ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + ]]); + let y = TestTensor::from_floats([[ + [0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548], + [0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537], + ]]); + + let output = max_pool1d(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_indices() { + let kernel_size = 2; + let padding = 0; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]); + let indices = Data::::from([[[1, 1, 3]]]); + let y = TestTensor::from_floats([[[0.6386, 0.6386, 0.5742]]]); + + let (output, output_indices) = + max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } + + #[test] + fn test_max_pool1d_complex() { + let kernel_size = 4; + let padding = 2; + let stride = 1; + let dilation = 1; + + let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]); + let indices = Data::::from([[[0, 2, 3, 3, 3, 3]]]); + let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]); + + let (output, output_indices) = + max_pool1d_with_indices(x, kernel_size, stride, padding, dilation); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } } diff --git a/burn-tensor/src/tests/module/maxpool2d.rs b/burn-tensor/src/tests/module/maxpool2d.rs index d70f194cea..47a843fa93 100644 --- a/burn-tensor/src/tests/module/maxpool2d.rs +++ b/burn-tensor/src/tests/module/maxpool2d.rs @@ -1,324 +1,324 @@ #[burn_tensor_testgen::testgen(module_max_pool2d)] mod tests { - use super::*; - use burn_tensor::module::{max_pool2d, max_pool2d_with_indices}; - use burn_tensor::{backend::Backend, Data, Tensor}; + use super::*; + use burn_tensor::module::{max_pool2d, max_pool2d_with_indices}; + use burn_tensor::{backend::Backend, Data, Tensor}; - type IntElem = ::IntElem; + type IntElem = ::IntElem; - #[test] - fn test_max_pool2d_simple() { - let batch_size = 2; - let channels_in = 2; - let kernel_size_1 = 3; - let kernel_size_2 = 3; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_simple() { + let batch_size = 2; + let channels_in = 2; + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([ - [ - [ - [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], - [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], - [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182], - [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392], - [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605], - [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068], - ], - [ - [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473], - [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947], - [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149], - [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426], - [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544], - [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572], - ], - ], - [ - [ - [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910], - [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546], - [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584], - [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441], - [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881], - [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467], - ], - [ - [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952], - [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454], - [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628], - [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527], - [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827], - [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421], - ], - ], - ]); - let y = TestTensor::from_floats([ - [ - [ - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], - [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], - [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], - [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], - ], - [ - [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947], - [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149], - [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149], - [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149], - [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025], - [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775], - ], - ], - [ - [ - [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], - [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], - [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546], - [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], - [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], - [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037], - ], - [ - [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378], - [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378], - [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445], - [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128], - [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128], - [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128], - ], - ], - ]); + let x = TestTensor::from_floats([ + [ + [ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182], + [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392], + [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605], + [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068], + ], + [ + [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473], + [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947], + [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149], + [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426], + [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544], + [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572], + ], + ], + [ + [ + [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910], + [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546], + [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584], + [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441], + [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881], + [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467], + ], + [ + [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952], + [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454], + [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628], + [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527], + [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827], + [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421], + ], + ], + ]); + let y = TestTensor::from_floats([ + [ + [ + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], + [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], + [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], + [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], + ], + [ + [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947], + [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149], + [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149], + [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149], + [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025], + [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775], + ], + ], + [ + [ + [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], + [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], + [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546], + [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], + [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], + [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037], + ], + [ + [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378], + [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378], + [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445], + [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128], + [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128], + [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128], + ], + ], + ]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - #[test] - fn test_max_pool2d_different_padding_stride_kernel() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 3; - let kernel_size_2 = 1; - let padding_1 = 1; - let padding_2 = 0; - let stride_1 = 1; - let stride_2 = 2; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_different_padding_stride_kernel() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 3; + let kernel_size_2 = 1; + let padding_1 = 1; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 2; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.6309, 0.6112, 0.6998], - [0.4708, 0.9161, 0.5402], - [0.4577, 0.7397, 0.9870], - [0.6380, 0.4352, 0.5884], - [0.6277, 0.5139, 0.4525], - [0.9333, 0.9846, 0.5006], - ]]]); - let y = TestTensor::from_floats([[[ - [0.6309, 0.6998], - [0.6309, 0.9870], - [0.6380, 0.9870], - [0.6380, 0.9870], - [0.9333, 0.5884], - [0.9333, 0.5006], - ]]]); + let x = TestTensor::from_floats([[[ + [0.6309, 0.6112, 0.6998], + [0.4708, 0.9161, 0.5402], + [0.4577, 0.7397, 0.9870], + [0.6380, 0.4352, 0.5884], + [0.6277, 0.5139, 0.4525], + [0.9333, 0.9846, 0.5006], + ]]]); + let y = TestTensor::from_floats([[[ + [0.6309, 0.6998], + [0.6309, 0.9870], + [0.6380, 0.9870], + [0.6380, 0.9870], + [0.9333, 0.5884], + [0.9333, 0.5006], + ]]]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - #[test] - fn test_max_pool2d_with_neg() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 3; - let kernel_size_2 = 3; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_with_neg() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.6309, 0.6112, 0.6998], - [0.4708, 0.9161, 0.5402], - [0.4577, 0.7397, 0.9870], - [0.6380, 0.4352, 0.5884], - [0.6277, 0.5139, 0.4525], - [0.9333, 0.9846, 0.5006], - ]]]) - .neg(); - let y = TestTensor::from_floats([[[ - [-0.4708, -0.4708, -0.5402], - [-0.4577, -0.4577, -0.5402], - [-0.4352, -0.4352, -0.4352], - [-0.4352, -0.4352, -0.4352], - [-0.4352, -0.4352, -0.4352], - [-0.5139, -0.4525, -0.4525], - ]]]); + let x = TestTensor::from_floats([[[ + [0.6309, 0.6112, 0.6998], + [0.4708, 0.9161, 0.5402], + [0.4577, 0.7397, 0.9870], + [0.6380, 0.4352, 0.5884], + [0.6277, 0.5139, 0.4525], + [0.9333, 0.9846, 0.5006], + ]]]) + .neg(); + let y = TestTensor::from_floats([[[ + [-0.4708, -0.4708, -0.5402], + [-0.4577, -0.4577, -0.5402], + [-0.4352, -0.4352, -0.4352], + [-0.4352, -0.4352, -0.4352], + [-0.4352, -0.4352, -0.4352], + [-0.5139, -0.4525, -0.4525], + ]]]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - #[test] - fn test_max_pool2d_with_dilation() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 0; - let padding_2 = 0; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 2; - let dilation_2 = 2; + #[test] + fn test_max_pool2d_with_dilation() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 0; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 2; + let dilation_2 = 2; - let x = TestTensor::from_floats([[[ - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], - [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], - [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], - [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], - [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], - ]]]); - let y = TestTensor::from_floats([[[ - [0.9861, 0.9861, 0.9540, 0.9490], - [0.9861, 0.9861, 0.9540, 0.9490], - [0.9540, 0.9540, 0.9540, 0.9490], - [0.9540, 0.9540, 0.9540, 0.9432], - ]]]); + let x = TestTensor::from_floats([[[ + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], + [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], + [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], + [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], + ]]]); + let y = TestTensor::from_floats([[[ + [0.9861, 0.9861, 0.9540, 0.9490], + [0.9861, 0.9861, 0.9540, 0.9490], + [0.9540, 0.9540, 0.9540, 0.9490], + [0.9540, 0.9540, 0.9540, 0.9432], + ]]]); - let output = max_pool2d( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let output = max_pool2d( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + } - fn test_max_pool2d_with_indices() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 2; - let kernel_size_2 = 2; - let padding_1 = 1; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 1; - let dilation_1 = 1; - let dilation_2 = 1; + fn test_max_pool2d_with_indices() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.2479, 0.6386, 0.3166, 0.5742], - [0.7065, 0.1940, 0.6305, 0.8959], - [0.5416, 0.8602, 0.8129, 0.1662], - [0.3358, 0.3059, 0.8293, 0.0990], - ]]]); - let indices = Data::::from([[[ - [0, 1, 1, 3, 3], - [4, 4, 1, 7, 7], - [4, 9, 9, 7, 7], - [8, 9, 9, 14, 11], - [12, 12, 14, 14, 15], - ]]]); - let y = TestTensor::from_floats([[[ - [0.2479, 0.6386, 0.6386, 0.5742, 0.5742], - [0.7065, 0.7065, 0.6386, 0.8959, 0.8959], - [0.7065, 0.8602, 0.8602, 0.8959, 0.8959], - [0.5416, 0.8602, 0.8602, 0.8293, 0.1662], - [0.3358, 0.3358, 0.8293, 0.8293, 0.0990], - ]]]); + let x = TestTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]); + let indices = Data::::from([[[ + [0, 1, 1, 3, 3], + [4, 4, 1, 7, 7], + [4, 9, 9, 7, 7], + [8, 9, 9, 14, 11], + [12, 12, 14, 14, 15], + ]]]); + let y = TestTensor::from_floats([[[ + [0.2479, 0.6386, 0.6386, 0.5742, 0.5742], + [0.7065, 0.7065, 0.6386, 0.8959, 0.8959], + [0.7065, 0.8602, 0.8602, 0.8959, 0.8959], + [0.5416, 0.8602, 0.8602, 0.8293, 0.1662], + [0.3358, 0.3358, 0.8293, 0.8293, 0.0990], + ]]]); - let (output, output_indices) = max_pool2d_with_indices( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let (output, output_indices) = max_pool2d_with_indices( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } - #[test] - fn test_max_pool2d_complex() { - let batch_size = 1; - let channels_in = 1; - let kernel_size_1 = 4; - let kernel_size_2 = 2; - let padding_1 = 2; - let padding_2 = 1; - let stride_1 = 1; - let stride_2 = 2; - let dilation_1 = 1; - let dilation_2 = 1; + #[test] + fn test_max_pool2d_complex() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 4; + let kernel_size_2 = 2; + let padding_1 = 2; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 2; + let dilation_1 = 1; + let dilation_2 = 1; - let x = TestTensor::from_floats([[[ - [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], - [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], - [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], - [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], - [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], - ]]]); - let indices = Data::::from([[[ - [5, 7, 3], - [5, 7, 3], - [5, 16, 3], - [5, 16, 8], - [15, 16, 24], - [15, 16, 24], - ]]]); - let y = TestTensor::from_floats([[[ - [0.9154, 0.9089, 0.8316], - [0.9154, 0.9089, 0.8316], - [0.9154, 0.9963, 0.8316], - [0.9154, 0.9963, 0.8016], - [0.4384, 0.9963, 0.688], - [0.4384, 0.9963, 0.688], - ]]]); - let (output, output_indices) = max_pool2d_with_indices( - x, - [kernel_size_1, kernel_size_2], - [stride_1, stride_2], - [padding_1, padding_2], - [dilation_1, dilation_2], - ); + let x = TestTensor::from_floats([[[ + [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], + [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], + [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], + [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], + [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], + ]]]); + let indices = Data::::from([[[ + [5, 7, 3], + [5, 7, 3], + [5, 16, 3], + [5, 16, 8], + [15, 16, 24], + [15, 16, 24], + ]]]); + let y = TestTensor::from_floats([[[ + [0.9154, 0.9089, 0.8316], + [0.9154, 0.9089, 0.8316], + [0.9154, 0.9963, 0.8316], + [0.9154, 0.9963, 0.8016], + [0.4384, 0.9963, 0.688], + [0.4384, 0.9963, 0.688], + ]]]); + let (output, output_indices) = max_pool2d_with_indices( + x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + [dilation_1, dilation_2], + ); - y.to_data().assert_approx_eq(&output.into_data(), 3); - assert_eq!(indices.value, output_indices.into_data().value); - } + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } } diff --git a/burn-tensor/src/tests/module/unfold4d.rs b/burn-tensor/src/tests/module/unfold4d.rs index 0ead03612d..afb1df24b8 100644 --- a/burn-tensor/src/tests/module/unfold4d.rs +++ b/burn-tensor/src/tests/module/unfold4d.rs @@ -1,132 +1,132 @@ #[burn_tensor_testgen::testgen(module_unfold4d)] mod tests { - use super::*; - use burn_tensor::module::unfold4d; - use burn_tensor::ops::UnfoldOptions; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::module::unfold4d; + use burn_tensor::ops::UnfoldOptions; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_unfold4d_shape() { - let test = Unfold4dTestCase { - batch_size: 2, - channels_in: 5, - kernel_size: [2, 3], - padding: [0, 0], - stride: [1, 1], - dilation: [1, 1], - height: 3, - width: 4, - }; + #[test] + fn test_unfold4d_shape() { + let test = Unfold4dTestCase { + batch_size: 2, + channels_in: 5, + kernel_size: [2, 3], + padding: [0, 0], + stride: [1, 1], + dilation: [1, 1], + height: 3, + width: 4, + }; - test.assert_shape([2, 30, 4]); - } + test.assert_shape([2, 30, 4]); + } - #[test] - fn test_unfold4d_simple() { - let test = Unfold4dTestCase { - batch_size: 1, - channels_in: 2, - kernel_size: [2, 2], - padding: [0, 0], - stride: [1, 1], - dilation: [1, 1], - height: 4, - width: 4, - }; + #[test] + fn test_unfold4d_simple() { + let test = Unfold4dTestCase { + batch_size: 1, + channels_in: 2, + kernel_size: [2, 2], + padding: [0, 0], + stride: [1, 1], + dilation: [1, 1], + height: 4, + width: 4, + }; - test.assert_output(TestTensor::from_data([[ - [0., 1., 2., 4., 5., 6., 8., 9., 10.], - [1., 2., 3., 5., 6., 7., 9., 10., 11.], - [4., 5., 6., 8., 9., 10., 12., 13., 14.], - [5., 6., 7., 9., 10., 11., 13., 14., 15.], - [16., 17., 18., 20., 21., 22., 24., 25., 26.], - [17., 18., 19., 21., 22., 23., 25., 26., 27.], - [20., 21., 22., 24., 25., 26., 28., 29., 30.], - [21., 22., 23., 25., 26., 27., 29., 30., 31.], - ]])); - } + test.assert_output(TestTensor::from_data([[ + [0., 1., 2., 4., 5., 6., 8., 9., 10.], + [1., 2., 3., 5., 6., 7., 9., 10., 11.], + [4., 5., 6., 8., 9., 10., 12., 13., 14.], + [5., 6., 7., 9., 10., 11., 13., 14., 15.], + [16., 17., 18., 20., 21., 22., 24., 25., 26.], + [17., 18., 19., 21., 22., 23., 25., 26., 27.], + [20., 21., 22., 24., 25., 26., 28., 29., 30.], + [21., 22., 23., 25., 26., 27., 29., 30., 31.], + ]])); + } - #[test] - fn test_unfold4d_complex() { - let test = Unfold4dTestCase { - batch_size: 1, - channels_in: 2, - kernel_size: [2, 3], - padding: [0, 1], - stride: [1, 2], - dilation: [1, 2], - height: 3, - width: 4, - }; + #[test] + fn test_unfold4d_complex() { + let test = Unfold4dTestCase { + batch_size: 1, + channels_in: 2, + kernel_size: [2, 3], + padding: [0, 1], + stride: [1, 2], + dilation: [1, 2], + height: 3, + width: 4, + }; - test.assert_output(TestTensor::from_data([[ - [0., 0.], - [1., 5.], - [3., 7.], - [0., 0.], - [5., 9.], - [7., 11.], - [0., 0.], - [13., 17.], - [15., 19.], - [0., 0.], - [17., 21.], - [19., 23.], - ]])); - } + test.assert_output(TestTensor::from_data([[ + [0., 0.], + [1., 5.], + [3., 7.], + [0., 0.], + [5., 9.], + [7., 11.], + [0., 0.], + [13., 17.], + [15., 19.], + [0., 0.], + [17., 21.], + [19., 23.], + ]])); + } - struct Unfold4dTestCase { - batch_size: usize, - channels_in: usize, - kernel_size: [usize; 2], - padding: [usize; 2], - stride: [usize; 2], - dilation: [usize; 2], - height: usize, - width: usize, - } + struct Unfold4dTestCase { + batch_size: usize, + channels_in: usize, + kernel_size: [usize; 2], + padding: [usize; 2], + stride: [usize; 2], + dilation: [usize; 2], + height: usize, + width: usize, + } - impl Unfold4dTestCase { - fn assert_shape(self, expected_shape: [usize; 3]) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); + impl Unfold4dTestCase { + fn assert_shape(self, expected_shape: [usize; 3]) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); - let output = unfold4d( - x, - self.kernel_size, - UnfoldOptions::new(self.stride, self.padding, self.dilation), - ); + let output = unfold4d( + x, + self.kernel_size, + UnfoldOptions::new(self.stride, self.padding, self.dilation), + ); - assert_eq!( - output.shape().dims, - expected_shape, - "Expected shape doesn't match the actual shape" - ); - } + assert_eq!( + output.shape().dims, + expected_shape, + "Expected shape doesn't match the actual shape" + ); + } - fn assert_output(self, expected: TestTensor<3>) { - let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); - let x = TestTensor::from_data( - TestTensorInt::arange(0..shape_x.num_elements()) - .reshape(shape_x) - .into_data() - .convert(), - ); + fn assert_output(self, expected: TestTensor<3>) { + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let x = TestTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements()) + .reshape(shape_x) + .into_data() + .convert(), + ); - let output = unfold4d( - x, - self.kernel_size, - UnfoldOptions::new(self.stride, self.padding, self.dilation), - ); + let output = unfold4d( + x, + self.kernel_size, + UnfoldOptions::new(self.stride, self.padding, self.dilation), + ); - output - .into_data() - .assert_approx_eq(&expected.into_data(), 3); + output + .into_data() + .assert_approx_eq(&expected.into_data(), 3); + } } - } } diff --git a/burn-tensor/src/tests/ops/abs.rs b/burn-tensor/src/tests/ops/abs.rs index ad34a4581a..f87b87a6fb 100644 --- a/burn-tensor/src/tests/ops/abs.rs +++ b/burn-tensor/src/tests/ops/abs.rs @@ -1,27 +1,27 @@ #[burn_tensor_testgen::testgen(abs)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn should_support_abs_ops_float() { - let data = Data::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_abs_ops_float() { + let data = Data::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.abs().into_data(); + let data_actual = tensor.abs().into_data(); - let data_expected = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_abs_ops_int() { - let data = Data::from([[0, -1, 2], [3, 4, -5]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_abs_ops_int() { + let data = Data::from([[0, -1, 2], [3, 4, -5]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.abs().into_data(); + let data_actual = tensor.abs().into_data(); - let data_expected = Data::from([[0, 1, 2], [3, 4, 5]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0, 1, 2], [3, 4, 5]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/add.rs b/burn-tensor/src/tests/ops/add.rs index 08a013f297..bd45b4376d 100644 --- a/burn-tensor/src/tests/ops/add.rs +++ b/burn-tensor/src/tests/ops/add.rs @@ -1,83 +1,83 @@ #[burn_tensor_testgen::testgen(add)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_add_d2() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_add_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_add_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor + scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_add_d2_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[6, 8, 10], [12, 14, 16]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_add_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 + tensor_2).into_data(); - - let data_expected = Data::from([[3, 5, 7], [6, 8, 10]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_add_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor + scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[2, 3, 4], [5, 6, 7]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_add_d2() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_add_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_add_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor + scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_add_d2_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[6, 8, 10], [12, 14, 16]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_add_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[3, 5, 7], [6, 8, 10]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_add_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor + scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[2, 3, 4], [5, 6, 7]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/aggregation.rs b/burn-tensor/src/tests/ops/aggregation.rs index e234b795b2..45b94a89bb 100644 --- a/burn-tensor/src/tests/ops/aggregation.rs +++ b/burn-tensor/src/tests/ops/aggregation.rs @@ -1,125 +1,125 @@ #[burn_tensor_testgen::testgen(aggregation)] mod tests { - use super::*; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Shape, Tensor}; - #[test] - fn test_should_mean() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_mean() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.mean().to_data(); + let data_actual = tensor.mean().to_data(); - data_actual.assert_approx_eq(&Data::from([15.0 / 6.0]), 3); - } + data_actual.assert_approx_eq(&Data::from([15.0 / 6.0]), 3); + } - #[test] - fn test_should_mean_int() { - let tensor = TestTensorInt::from_data([[2, 2, 2], [3, 4, 5]]); + #[test] + fn test_should_mean_int() { + let tensor = TestTensorInt::from_data([[2, 2, 2], [3, 4, 5]]); - let data_actual = tensor.mean().to_data(); + let data_actual = tensor.mean().to_data(); - assert_eq!(data_actual, Data::from([3])); - } + assert_eq!(data_actual, Data::from([3])); + } - #[test] - fn test_should_sum() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_sum() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.sum().to_data(); + let data_actual = tensor.sum().to_data(); - assert_eq!(data_actual, Data::from([15.0])); - } + assert_eq!(data_actual, Data::from([15.0])); + } - #[test] - fn test_should_sum_int() { - let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + #[test] + fn test_should_sum_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); - let data_actual = tensor.sum().to_data(); + let data_actual = tensor.sum().to_data(); - assert_eq!(data_actual, Data::from([15])); - } + assert_eq!(data_actual, Data::from([15])); + } - #[test] - fn test_should_mean_last_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_mean_last_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.mean_dim(1).to_data(); + let data_actual = tensor.mean_dim(1).to_data(); - data_actual.assert_approx_eq(&Data::from([[3.0 / 3.0], [12.0 / 3.0]]), 3); - } + data_actual.assert_approx_eq(&Data::from([[3.0 / 3.0], [12.0 / 3.0]]), 3); + } - #[test] - fn test_should_sum_last_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_should_sum_last_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_actual = tensor.sum_dim(1).to_data(); + let data_actual = tensor.sum_dim(1).to_data(); - assert_eq!(data_actual, Data::from([[3.0], [12.0]])); - } + assert_eq!(data_actual, Data::from([[3.0], [12.0]])); + } - #[test] - fn test_should_mean_last_dim_int() { - let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + #[test] + fn test_should_mean_last_dim_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); - let data_actual = tensor.mean_dim(1).to_data(); + let data_actual = tensor.mean_dim(1).to_data(); - assert_eq!(data_actual, Data::from([[1], [4]])); - } + assert_eq!(data_actual, Data::from([[1], [4]])); + } - #[test] - fn test_should_sum_last_dim_int() { - let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + #[test] + fn test_should_sum_last_dim_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); - let data_actual = tensor.sum_dim(1).to_data(); + let data_actual = tensor.sum_dim(1).to_data(); - assert_eq!(data_actual, Data::from([[3], [12]])); - } + assert_eq!(data_actual, Data::from([[3], [12]])); + } - #[test] - fn test_should_sum_first_dim() { - let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); + #[test] + fn test_should_sum_first_dim() { + let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); - let data_actual = tensor.sum_dim(0).to_data(); + let data_actual = tensor.sum_dim(0).to_data(); - assert_eq!(data_actual, Data::from([[7.0, 3.0, 5.0]])); - } + assert_eq!(data_actual, Data::from([[7.0, 3.0, 5.0]])); + } - #[test] - fn test_should_mean_first_dim() { - let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); + #[test] + fn test_should_mean_first_dim() { + let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); - let data_actual = tensor.mean_dim(0).to_data(); + let data_actual = tensor.mean_dim(0).to_data(); - assert_eq!(data_actual, Data::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]])); - } + assert_eq!(data_actual, Data::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]])); + } - #[test] - fn test_should_sum_mid_dim_3d_non_contiguous_1() { - let tensor = TestTensor::from_data([ - [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], - [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], - ]); + #[test] + fn test_should_sum_mid_dim_3d_non_contiguous_1() { + let tensor = TestTensor::from_data([ + [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], + [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], + ]); - let data_actual = tensor.swap_dims(0, 2).sum_dim(1).into_data(); + let data_actual = tensor.swap_dims(0, 2).sum_dim(1).into_data(); - assert_eq!( - data_actual, - Data::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], Shape::new([3, 1, 2])) - ); - } + assert_eq!( + data_actual, + Data::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], Shape::new([3, 1, 2])) + ); + } - #[test] - fn test_should_sum_mid_dim_3d_non_contiguous_2() { - let tensor = TestTensor::from_data([ - [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], - [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], - ]); + #[test] + fn test_should_sum_mid_dim_3d_non_contiguous_2() { + let tensor = TestTensor::from_data([ + [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], + [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], + ]); - let data_actual = tensor.swap_dims(0, 1).sum_dim(1).into_data(); + let data_actual = tensor.swap_dims(0, 1).sum_dim(1).into_data(); - assert_eq!( - data_actual, - Data::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], Shape::new([2, 1, 3])) - ); - } + assert_eq!( + data_actual, + Data::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], Shape::new([2, 1, 3])) + ); + } } diff --git a/burn-tensor/src/tests/ops/arange.rs b/burn-tensor/src/tests/ops/arange.rs index 4552943183..5abe68aff1 100644 --- a/burn-tensor/src/tests/ops/arange.rs +++ b/burn-tensor/src/tests/ops/arange.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(arange)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn test_arange() { - let tensor = Tensor::::arange(2..5); - assert_eq!(tensor.into_data(), Data::from([2, 3, 4])); - } + #[test] + fn test_arange() { + let tensor = Tensor::::arange(2..5); + assert_eq!(tensor.into_data(), Data::from([2, 3, 4])); + } - #[test] - fn test_arange_device() { - let device = ::Device::default(); + #[test] + fn test_arange_device() { + let device = ::Device::default(); - let tensor = Tensor::::arange_device(2..5, &device); - assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4])); - assert_eq!(tensor.device(), device); - } + let tensor = Tensor::::arange_device(2..5, &device); + assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4])); + assert_eq!(tensor.device(), device); + } } diff --git a/burn-tensor/src/tests/ops/arange_step.rs b/burn-tensor/src/tests/ops/arange_step.rs index 0922ac5620..127f234eca 100644 --- a/burn-tensor/src/tests/ops/arange_step.rs +++ b/burn-tensor/src/tests/ops/arange_step.rs @@ -1,46 +1,46 @@ #[burn_tensor_testgen::testgen(arange_step)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn test_arange_step() { - // Test correct sequence of numbers when the range is 0..9 and the step is 1 - let tensor = Tensor::::arange_step(0..9, 1); - assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); - - // Test correct sequence of numbers when the range is 0..3 and the step is 2 - let tensor = Tensor::::arange_step(0..3, 2); - assert_eq!(tensor.into_data(), Data::from([0, 2])); - - // Test correct sequence of numbers when the range is 0..2 and the step is 5 - let tensor = Tensor::::arange_step(0..2, 5); - assert_eq!(tensor.into_data(), Data::from([0])); - } - - #[test] - fn test_arange_step_device() { - let device = ::Device::default(); - - // Test correct sequence of numbers when the range is 0..9 and the step is 1 - let tensor = Tensor::::arange_step_device(0..9, 1, &device); - assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); - - // Test correct sequence of numbers when the range is 0..3 and the step is 2 - let tensor = Tensor::::arange_step_device(0..3, 2, &device); - assert_eq!(tensor.into_data(), Data::from([0, 2])); - - // Test correct sequence of numbers when the range is 0..2 and the step is 5 - let tensor = Tensor::::arange_step_device(0..2, 5, &device); - assert_eq!(tensor.clone().into_data(), Data::from([0])); - assert_eq!(tensor.device(), device); - } - - #[test] - #[should_panic] - fn should_panic_when_step_is_zero() { - // Test that arange_step panics when the step is 0 - let _tensor = Tensor::::arange_step(0..3, 0); - } + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_arange_step() { + // Test correct sequence of numbers when the range is 0..9 and the step is 1 + let tensor = Tensor::::arange_step(0..9, 1); + assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); + + // Test correct sequence of numbers when the range is 0..3 and the step is 2 + let tensor = Tensor::::arange_step(0..3, 2); + assert_eq!(tensor.into_data(), Data::from([0, 2])); + + // Test correct sequence of numbers when the range is 0..2 and the step is 5 + let tensor = Tensor::::arange_step(0..2, 5); + assert_eq!(tensor.into_data(), Data::from([0])); + } + + #[test] + fn test_arange_step_device() { + let device = ::Device::default(); + + // Test correct sequence of numbers when the range is 0..9 and the step is 1 + let tensor = Tensor::::arange_step_device(0..9, 1, &device); + assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); + + // Test correct sequence of numbers when the range is 0..3 and the step is 2 + let tensor = Tensor::::arange_step_device(0..3, 2, &device); + assert_eq!(tensor.into_data(), Data::from([0, 2])); + + // Test correct sequence of numbers when the range is 0..2 and the step is 5 + let tensor = Tensor::::arange_step_device(0..2, 5, &device); + assert_eq!(tensor.clone().into_data(), Data::from([0])); + assert_eq!(tensor.device(), device); + } + + #[test] + #[should_panic] + fn should_panic_when_step_is_zero() { + // Test that arange_step panics when the step is 0 + let _tensor = Tensor::::arange_step(0..3, 0); + } } diff --git a/burn-tensor/src/tests/ops/arg.rs b/burn-tensor/src/tests/ops/arg.rs index c954ec5739..fd6f282b76 100644 --- a/burn-tensor/src/tests/ops/arg.rs +++ b/burn-tensor/src/tests/ops/arg.rs @@ -1,71 +1,71 @@ #[burn_tensor_testgen::testgen(arg)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn test_argmax_2d_dim0() { - let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmax_2d_dim0() { + let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmax(0); + let data_actual = tensor.argmax(0); - let data_expected = Data::from([[0, 0, 1]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 0, 1]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmin_2d_dim0() { - let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmin_2d_dim0() { + let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmin(0); + let data_actual = tensor.argmin(0); - let data_expected = Data::from([[0, 1, 0]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 1, 0]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmax_2d_dim0_int() { - let data = Data::from([[10, 11, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmax_2d_dim0_int() { + let data = Data::from([[10, 11, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmax(0); + let data_actual = tensor.argmax(0); - let data_expected = Data::from([[0, 0, 1]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 0, 1]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmin_2d_dim0_int() { - let data = Data::from([[10, 11, 2], [30, 4, 5]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmin_2d_dim0_int() { + let data = Data::from([[10, 11, 2], [30, 4, 5]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmin(0); + let data_actual = tensor.argmin(0); - let data_expected = Data::from([[0, 1, 0]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[0, 1, 0]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmax_2d_dim1() { - let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmax_2d_dim1() { + let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmax(1); + let data_actual = tensor.argmax(1); - let data_expected = Data::from([[1], [2]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[1], [2]]); + assert_eq!(data_expected, data_actual.to_data()); + } - #[test] - fn test_argmin_2d_dim1() { - let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn test_argmin_2d_dim1() { + let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.argmin(1); + let data_actual = tensor.argmin(1); - let data_expected = Data::from([[2], [1]]); - assert_eq!(data_expected, data_actual.to_data()); - } + let data_expected = Data::from([[2], [1]]); + assert_eq!(data_expected, data_actual.to_data()); + } } diff --git a/burn-tensor/src/tests/ops/cast.rs b/burn-tensor/src/tests/ops/cast.rs index 057901273f..e6ac40ae74 100644 --- a/burn-tensor/src/tests/ops/cast.rs +++ b/burn-tensor/src/tests/ops/cast.rs @@ -1,43 +1,43 @@ #[burn_tensor_testgen::testgen(cast)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; - - #[test] - fn cast_float_to_int() { - let tensor = Tensor::::from_data([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]); - - let actual = tensor.int().into_data(); - let expected = Data::from([[1, 2, 3], [4, 5, 6]]); - assert_eq!(expected, actual); - } - - #[test] - fn cast_int_to_float_tensor() { - let tensor = Tensor::::from_data([[1, 2, 3], [4, 5, 6]]); - - let actual = tensor.float().into_data(); - let expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - assert_eq!(expected, actual); - } - - #[test] - fn cast_bool_to_int_tensor() { - let tensor = - Tensor::::from_data([[true, false, true], [false, false, true]]); - - let actual = tensor.int().into_data(); - let expected = Data::from([[1, 0, 1], [0, 0, 1]]); - assert_eq!(expected, actual); - } - - #[test] - fn cast_bool_to_float_tensor() { - let tensor = - Tensor::::from_data([[true, false, true], [false, false, true]]); - - let actual = tensor.float().into_data(); - let expected = Data::from([[1., 0., 1.], [0., 0., 1.]]); - assert_eq!(expected, actual); - } + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; + + #[test] + fn cast_float_to_int() { + let tensor = Tensor::::from_data([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]); + + let actual = tensor.int().into_data(); + let expected = Data::from([[1, 2, 3], [4, 5, 6]]); + assert_eq!(expected, actual); + } + + #[test] + fn cast_int_to_float_tensor() { + let tensor = Tensor::::from_data([[1, 2, 3], [4, 5, 6]]); + + let actual = tensor.float().into_data(); + let expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + assert_eq!(expected, actual); + } + + #[test] + fn cast_bool_to_int_tensor() { + let tensor = + Tensor::::from_data([[true, false, true], [false, false, true]]); + + let actual = tensor.int().into_data(); + let expected = Data::from([[1, 0, 1], [0, 0, 1]]); + assert_eq!(expected, actual); + } + + #[test] + fn cast_bool_to_float_tensor() { + let tensor = + Tensor::::from_data([[true, false, true], [false, false, true]]); + + let actual = tensor.float().into_data(); + let expected = Data::from([[1., 0., 1.], [0., 0., 1.]]); + assert_eq!(expected, actual); + } } diff --git a/burn-tensor/src/tests/ops/cat.rs b/burn-tensor/src/tests/ops/cat.rs index 519600b619..f01311b2d6 100644 --- a/burn-tensor/src/tests/ops/cat.rs +++ b/burn-tensor/src/tests/ops/cat.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(cat)] mod tests { - use super::*; - use alloc::vec::Vec; - use burn_tensor::{Bool, Data, Int, Tensor}; - #[test] - fn should_support_cat_ops_2d_dim0() { - let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); - let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); - - let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_cat_ops_int() { - let tensor_1 = Tensor::::from_data([[1, 2, 3]]); - let tensor_2 = Tensor::::from_data([[4, 5, 6]]); - - let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[1, 2, 3], [4, 5, 6]]); - assert_eq!(&data_actual, &data_expected); - } - - #[test] - fn should_support_cat_ops_bool() { - let tensor_1 = Tensor::::from_data([[false, true, true]]); - let tensor_2 = Tensor::::from_data([[true, true, false]]); - - let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[false, true, true], [true, true, false]]); - assert_eq!(&data_actual, &data_expected); - } - - #[test] - fn should_support_cat_ops_2d_dim1() { - let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); - let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); - - let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 1).into_data(); - - let data_expected = Data::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_cat_ops_3d() { - let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); - let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); - - let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - - let data_expected = Data::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - #[should_panic] - fn should_panic_when_dimensions_are_not_the_same() { - let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]); - let tensor_2 = TestTensor::from_data([[4.0, 5.0]]); - - TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); - } - - #[test] - #[should_panic] - fn should_panic_when_list_of_vectors_is_empty() { - let tensor: Vec> = vec![]; - TestTensor::cat(tensor, 0).into_data(); - } - - #[test] - #[should_panic] - fn should_panic_when_cat_exceeds_dimension() { - let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); - let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); - - TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data(); - } + use super::*; + use alloc::vec::Vec; + use burn_tensor::{Bool, Data, Int, Tensor}; + #[test] + fn should_support_cat_ops_2d_dim0() { + let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); + let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); + + let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_cat_ops_int() { + let tensor_1 = Tensor::::from_data([[1, 2, 3]]); + let tensor_2 = Tensor::::from_data([[4, 5, 6]]); + + let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[1, 2, 3], [4, 5, 6]]); + assert_eq!(&data_actual, &data_expected); + } + + #[test] + fn should_support_cat_ops_bool() { + let tensor_1 = Tensor::::from_data([[false, true, true]]); + let tensor_2 = Tensor::::from_data([[true, true, false]]); + + let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[false, true, true], [true, true, false]]); + assert_eq!(&data_actual, &data_expected); + } + + #[test] + fn should_support_cat_ops_2d_dim1() { + let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]); + let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]]); + + let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 1).into_data(); + + let data_expected = Data::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_cat_ops_3d() { + let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); + let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); + + let data_actual = TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + + let data_expected = Data::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + #[should_panic] + fn should_panic_when_dimensions_are_not_the_same() { + let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]); + let tensor_2 = TestTensor::from_data([[4.0, 5.0]]); + + TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); + } + + #[test] + #[should_panic] + fn should_panic_when_list_of_vectors_is_empty() { + let tensor: Vec> = vec![]; + TestTensor::cat(tensor, 0).into_data(); + } + + #[test] + #[should_panic] + fn should_panic_when_cat_exceeds_dimension() { + let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); + let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]); + + TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data(); + } } diff --git a/burn-tensor/src/tests/ops/clamp.rs b/burn-tensor/src/tests/ops/clamp.rs index 6c8ddd7b85..2acda4c0dd 100644 --- a/burn-tensor/src/tests/ops/clamp.rs +++ b/burn-tensor/src/tests/ops/clamp.rs @@ -1,60 +1,60 @@ #[burn_tensor_testgen::testgen(clamp)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn clamp_min() { - // test float tensor - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.clamp_min(2.0).into_data(); - - let data_expected = Data::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]); - assert_eq!(data_expected, data_actual); - - // test int tensor - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp_min(2).into_data(); - let data_expected = Data::from([[2, 2, 2], [3, 4, 5]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn clamp_max() { - // test float tensor - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.clamp_max(2.0).into_data(); - - let data_expected = Data::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]); - assert_eq!(data_expected, data_actual); - - // test int tensor - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp_max(4).into_data(); - let data_expected = Data::from([[0, 1, 2], [3, 4, 4]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn clamp_min_max() { - // test float tensor - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp(1.0, 4.0).into_data(); - let data_expected = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]); - assert_eq!(data_expected, data_actual); - - // test int tensor - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.clamp(1, 4).into_data(); - let data_expected = Data::from([[1, 1, 2], [3, 4, 4]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn clamp_min() { + // test float tensor + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.clamp_min(2.0).into_data(); + + let data_expected = Data::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]); + assert_eq!(data_expected, data_actual); + + // test int tensor + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp_min(2).into_data(); + let data_expected = Data::from([[2, 2, 2], [3, 4, 5]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn clamp_max() { + // test float tensor + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.clamp_max(2.0).into_data(); + + let data_expected = Data::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]); + assert_eq!(data_expected, data_actual); + + // test int tensor + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp_max(4).into_data(); + let data_expected = Data::from([[0, 1, 2], [3, 4, 4]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn clamp_min_max() { + // test float tensor + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp(1.0, 4.0).into_data(); + let data_expected = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]); + assert_eq!(data_expected, data_actual); + + // test int tensor + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.clamp(1, 4).into_data(); + let data_expected = Data::from([[1, 1, 2], [3, 4, 4]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/cos.rs b/burn-tensor/src/tests/ops/cos.rs index b193cd991a..c099b1e4e9 100644 --- a/burn-tensor/src/tests/ops/cos.rs +++ b/burn-tensor/src/tests/ops/cos.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(cos)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_cos_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_cos_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.cos().into_data(); + let data_actual = tensor.cos().into_data(); - let data_expected = Data::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/create_like.rs b/burn-tensor/src/tests/ops/create_like.rs index 80374043d6..ea54aeabc6 100644 --- a/burn-tensor/src/tests/ops/create_like.rs +++ b/burn-tensor/src/tests/ops/create_like.rs @@ -1,49 +1,52 @@ #[burn_tensor_testgen::testgen(create_like)] mod tests { - use super::*; - use burn_tensor::{Data, Distribution, Tensor}; + use super::*; + use burn_tensor::{Data, Distribution, Tensor}; - #[test] - fn should_support_zeros_like() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_zeros_like() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.zeros_like().into_data(); + let data_actual = tensor.zeros_like().into_data(); - let data_expected = Data::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]); + let data_expected = + Data::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_ones_like() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_ones_like() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.ones_like().into_data(); + let data_actual = tensor.ones_like().into_data(); - let data_expected = Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); + let data_expected = + Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_randoms_like() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_randoms_like() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor - .random_like(Distribution::Uniform(0.99999, 1.)) - .into_data(); + let data_actual = tensor + .random_like(Distribution::Uniform(0.99999, 1.)) + .into_data(); - let data_expected = Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); + let data_expected = + Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/div.rs b/burn-tensor/src/tests/ops/div.rs index a7f17fe19b..fb3e91f070 100644 --- a/burn-tensor/src/tests/ops/div.rs +++ b/burn-tensor/src/tests/ops/div.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(div)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn should_support_div_ops() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 / tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_div_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 / tensor_2).into_data(); - - let data_expected = Data::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_div_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor / scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_div_ops_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[1, 1, 2], [1, 1, 2]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 / tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 1, 1], [3, 4, 2]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_div_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[1, 1, 2], [3, 4, 5]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 / tensor_2).into_data(); - - let data_expected = Data::from([[0, 1, 1], [0, 0, 0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_div_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor / scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 0, 1], [1, 2, 2]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn should_support_div_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 / tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_div_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 / tensor_2).into_data(); + + let data_expected = Data::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_div_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor / scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_div_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[1, 1, 2], [1, 1, 2]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 / tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 1, 1], [3, 4, 2]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_div_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[1, 1, 2], [3, 4, 5]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 / tensor_2).into_data(); + + let data_expected = Data::from([[0, 1, 1], [0, 0, 0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_div_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor / scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 0, 1], [1, 2, 2]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/erf.rs b/burn-tensor/src/tests/ops/erf.rs index aeac9ac496..14afb899d9 100644 --- a/burn-tensor/src/tests/ops/erf.rs +++ b/burn-tensor/src/tests/ops/erf.rs @@ -1,30 +1,30 @@ #[burn_tensor_testgen::testgen(erf)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_erf_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_erf_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.erf().into_data(); + let data_actual = tensor.erf().into_data(); - let data_expected = Data::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_erf_ops_with_negative_number() { - let data = Data::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_erf_ops_with_negative_number() { + let data = Data::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.erf().into_data(); + let data_actual = tensor.erf().into_data(); - let data_expected = Data::from([ - [-0.06312324, -0.048490416, -0.10016122], - [1.0000, 1.0000, 1.0000], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([ + [-0.06312324, -0.048490416, -0.10016122], + [1.0000, 1.0000, 1.0000], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/exp.rs b/burn-tensor/src/tests/ops/exp.rs index f0c1d203dc..278b7b6a44 100644 --- a/burn-tensor/src/tests/ops/exp.rs +++ b/burn-tensor/src/tests/ops/exp.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(exp)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_exp_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_exp_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.exp().into_data(); + let data_actual = tensor.exp().into_data(); - let data_expected = Data::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/flatten.rs b/burn-tensor/src/tests/ops/flatten.rs index 65f05477ac..7876bfac3d 100644 --- a/burn-tensor/src/tests/ops/flatten.rs +++ b/burn-tensor/src/tests/ops/flatten.rs @@ -1,58 +1,58 @@ #[burn_tensor_testgen::testgen(flatten)] mod tests { - use super::*; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Shape, Tensor}; - /// Test if the function can successfully flatten a 4D tensor to a 1D tensor. - #[test] - fn should_flatten_to_1d() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(0, 3); - let expected_shape = Shape::new([120]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten a 4D tensor to a 1D tensor. + #[test] + fn should_flatten_to_1d() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(0, 3); + let expected_shape = Shape::new([120]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function can successfully flatten the middle dimensions of a 4D tensor. - #[test] - fn should_flatten_middle() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(1, 2); - let expected_shape = Shape::new([2, 12, 5]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten the middle dimensions of a 4D tensor. + #[test] + fn should_flatten_middle() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(1, 2); + let expected_shape = Shape::new([2, 12, 5]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function can successfully flatten the first dimensions of a 4D tensor. - #[test] - fn should_flatten_begin() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(0, 2); - let expected_shape = Shape::new([24, 5]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten the first dimensions of a 4D tensor. + #[test] + fn should_flatten_begin() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(0, 2); + let expected_shape = Shape::new([24, 5]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function can successfully flatten the last dimensions of a 4D tensor. - #[test] - fn should_flatten_end() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(1, 3); - let expected_shape = Shape::new([2, 60]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + /// Test if the function can successfully flatten the last dimensions of a 4D tensor. + #[test] + fn should_flatten_end() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(1, 3); + let expected_shape = Shape::new([2, 60]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } - /// Test if the function panics when the start dimension is greater than the end dimension. - #[test] - #[should_panic] - fn should_flatten_panic() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let flattened_tensor: Tensor = tensor.flatten(2, 0); - } + /// Test if the function panics when the start dimension is greater than the end dimension. + #[test] + #[should_panic] + fn should_flatten_panic() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let flattened_tensor: Tensor = tensor.flatten(2, 0); + } - #[test] - #[should_panic] - fn not_enough_destination_dimension() { - let tensor = Tensor::::ones(Shape::new([1, 5, 15])); - let flattened_tensor: Tensor = tensor.flatten(1, 2); - let expected_shape = Shape::new([75]); - assert_eq!(flattened_tensor.shape(), expected_shape); - } + #[test] + #[should_panic] + fn not_enough_destination_dimension() { + let tensor = Tensor::::ones(Shape::new([1, 5, 15])); + let flattened_tensor: Tensor = tensor.flatten(1, 2); + let expected_shape = Shape::new([75]); + assert_eq!(flattened_tensor.shape(), expected_shape); + } } diff --git a/burn-tensor/src/tests/ops/full.rs b/burn-tensor/src/tests/ops/full.rs index d2e4a7abc2..c1de8e8592 100644 --- a/burn-tensor/src/tests/ops/full.rs +++ b/burn-tensor/src/tests/ops/full.rs @@ -1,25 +1,25 @@ #[burn_tensor_testgen::testgen(full)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Shape, Tensor}; - #[test] - fn test_data_full() { - let data_actual = Data::full([2, 3].into(), 2.0); - let data_expected = Data::from([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]); - assert_eq!(data_expected, data_actual); - } + #[test] + fn test_data_full() { + let data_actual = Data::full([2, 3].into(), 2.0); + let data_expected = Data::from([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn test_tensor_full() { - // Test full with f32 - let tensor = Tensor::::full([2, 3], 2.1); - let data_expected = Data::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]); - assert_eq!(data_expected, tensor.into_data()); + #[test] + fn test_tensor_full() { + // Test full with f32 + let tensor = Tensor::::full([2, 3], 2.1); + let data_expected = Data::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]); + assert_eq!(data_expected, tensor.into_data()); - // Test full with Int - let int_tensor = Tensor::::full([2, 2], 2); - let data_expected = Data::from([[2, 2], [2, 2]]); - assert_eq!(data_expected, int_tensor.into_data()); - } + // Test full with Int + let int_tensor = Tensor::::full([2, 2], 2); + let data_expected = Data::from([[2, 2], [2, 2]]); + assert_eq!(data_expected, int_tensor.into_data()); + } } diff --git a/burn-tensor/src/tests/ops/gather_scatter.rs b/burn-tensor/src/tests/ops/gather_scatter.rs index 38f0710d34..7b9abc7820 100644 --- a/burn-tensor/src/tests/ops/gather_scatter.rs +++ b/burn-tensor/src/tests/ops/gather_scatter.rs @@ -1,177 +1,177 @@ #[burn_tensor_testgen::testgen(gather_scatter)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_gather_1d_dim0() { - let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]); - let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); + #[test] + fn should_gather_1d_dim0() { + let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]); + let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); - let output = tensor.gather(0, indices); + let output = tensor.gather(0, indices); - assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); - } - - #[test] - fn should_gather_1d_dim0_int() { - let tensor = TestTensorInt::from_ints([5, 6, 7]); - let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); - - let output = tensor.gather(0, indices); - - assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); - } + assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); + } + + #[test] + fn should_gather_1d_dim0_int() { + let tensor = TestTensorInt::from_ints([5, 6, 7]); + let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); + + let output = tensor.gather(0, indices); + + assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); + } - #[test] - fn should_gather_2d_dim0() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]); - - let output = tensor.gather(0, indices); - - assert_eq!( - output.into_data(), - Data::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]) - ); - } - - #[test] - fn should_gather_2d_dim1() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]); - - let output = tensor.gather(1, indices); - - assert_eq!( - output.into_data(), - Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]) - ); - } - - #[test] - fn should_gather_3d_dim1() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); - - let output = tensor.gather(1, indices); - - assert_eq!( - output.into_data(), - Data::from([ - [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], - [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]] - ]) - ); - } - - #[test] - fn should_gather_2d_only_1dim() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]); - - let output = tensor.gather(1, indices); - - assert_eq!(output.into_data(), Data::from([[1.0], [5.0]])); - } - - #[test] - fn should_scatter_1d() { - let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); - let values = TestTensor::from_floats([5.0, 4.0, 3.0]); - let indices = TestTensorInt::from_ints([1, 0, 2]); - - let output = tensor.scatter(0, indices, values); - - assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); - } - - #[test] - fn should_scatter_1d_int() { - let tensor = TestTensorInt::from_ints([0, 0, 0]); - let values = TestTensorInt::from_ints([5, 4, 3]); - let indices = TestTensorInt::from_ints([1, 0, 2]); - - let output = tensor.scatter(0, indices, values); - - assert_eq!(output.into_data(), Data::from([4, 5, 3])); - } - - #[test] - fn should_scatter_2d_dim0() { - let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); - let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]); - - let output = tensor.scatter(0, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]) - ); - } - - #[test] - fn should_scatter_2d_dim1() { - let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); - let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]); - - let output = tensor.scatter(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]) - ); - } - - #[test] - fn should_scatter_3d_dim1() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - let values = TestTensor::from_floats([ - [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], - ]); - let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); - - let output = tensor.scatter(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([ - [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], - [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]] - ]) - ); - } - - #[test] - fn should_scatter_2d_dim1_diff_shape() { - let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); - let values = TestTensor::from_floats([[1.0], [4.0]]); - let indices = TestTensorInt::from_ints([[1], [2]]); - - let output = tensor.scatter(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]) - ); - } - - #[test] - #[should_panic] - fn scatter_should_panic_on_mismatch_of_shapes() { - let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); - let values = TestTensor::from_floats([5.0, 4.0]); - let indices = TestTensorInt::from_ints([1, 0, 2]); - - tensor.scatter(0, indices, values); - } + #[test] + fn should_gather_2d_dim0() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]); + + let output = tensor.gather(0, indices); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]) + ); + } + + #[test] + fn should_gather_2d_dim1() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]); + + let output = tensor.gather(1, indices); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_gather_3d_dim1() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); + + let output = tensor.gather(1, indices); + + assert_eq!( + output.into_data(), + Data::from([ + [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], + [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]] + ]) + ); + } + + #[test] + fn should_gather_2d_only_1dim() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]); + + let output = tensor.gather(1, indices); + + assert_eq!(output.into_data(), Data::from([[1.0], [5.0]])); + } + + #[test] + fn should_scatter_1d() { + let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); + let values = TestTensor::from_floats([5.0, 4.0, 3.0]); + let indices = TestTensorInt::from_ints([1, 0, 2]); + + let output = tensor.scatter(0, indices, values); + + assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); + } + + #[test] + fn should_scatter_1d_int() { + let tensor = TestTensorInt::from_ints([0, 0, 0]); + let values = TestTensorInt::from_ints([5, 4, 3]); + let indices = TestTensorInt::from_ints([1, 0, 2]); + + let output = tensor.scatter(0, indices, values); + + assert_eq!(output.into_data(), Data::from([4, 5, 3])); + } + + #[test] + fn should_scatter_2d_dim0() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]); + + let output = tensor.scatter(0, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]) + ); + } + + #[test] + fn should_scatter_2d_dim1() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]); + + let output = tensor.scatter(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_scatter_3d_dim1() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + let values = TestTensor::from_floats([ + [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ]); + let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); + + let output = tensor.scatter(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([ + [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], + [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]] + ]) + ); + } + + #[test] + fn should_scatter_2d_dim1_diff_shape() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0], [4.0]]); + let indices = TestTensorInt::from_ints([[1], [2]]); + + let output = tensor.scatter(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]) + ); + } + + #[test] + #[should_panic] + fn scatter_should_panic_on_mismatch_of_shapes() { + let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); + let values = TestTensor::from_floats([5.0, 4.0]); + let indices = TestTensorInt::from_ints([1, 0, 2]); + + tensor.scatter(0, indices, values); + } } diff --git a/burn-tensor/src/tests/ops/init.rs b/burn-tensor/src/tests/ops/init.rs index 599d0ef95a..7a89527cd3 100644 --- a/burn-tensor/src/tests/ops/init.rs +++ b/burn-tensor/src/tests/ops/init.rs @@ -1,58 +1,58 @@ #[burn_tensor_testgen::testgen(init)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; - #[test] - fn should_support_float_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } + #[test] + fn should_support_float_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } - #[test] - fn should_support_int_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } + #[test] + fn should_support_int_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } - #[test] - fn should_support_float_zeros() { - let shape = [2, 2]; - let tensor = Tensor::::zeros(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[0., 0.], [0., 0.]])) - } + #[test] + fn should_support_float_zeros() { + let shape = [2, 2]; + let tensor = Tensor::::zeros(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[0., 0.], [0., 0.]])) + } - #[test] - fn should_support_int_zeros() { - let shape = [2, 2]; - let tensor = Tensor::::zeros(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[0, 0], [0, 0]])) - } + #[test] + fn should_support_int_zeros() { + let shape = [2, 2]; + let tensor = Tensor::::zeros(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[0, 0], [0, 0]])) + } - #[test] - fn should_support_float_ones() { - let shape = [2, 2]; - let tensor = Tensor::::ones(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[1., 1.], [1., 1.]])) - } + #[test] + fn should_support_float_ones() { + let shape = [2, 2]; + let tensor = Tensor::::ones(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[1., 1.], [1., 1.]])) + } - #[test] - fn should_support_int_ones() { - let shape = [2, 2]; - let tensor = Tensor::::ones(shape); - assert_eq!(tensor.shape(), shape.into()); - assert_eq!(tensor.to_data(), Data::from([[1, 1], [1, 1]])) - } + #[test] + fn should_support_int_ones() { + let shape = [2, 2]; + let tensor = Tensor::::ones(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[1, 1], [1, 1]])) + } - #[test] - fn should_support_bool_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } + #[test] + fn should_support_bool_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } } diff --git a/burn-tensor/src/tests/ops/iter_dim.rs b/burn-tensor/src/tests/ops/iter_dim.rs index d12c9ed787..08d581a339 100644 --- a/burn-tensor/src/tests/ops/iter_dim.rs +++ b/burn-tensor/src/tests/ops/iter_dim.rs @@ -1,46 +1,46 @@ #[burn_tensor_testgen::testgen(iter_dim)] mod test { - use super::*; - use burn_tensor::{Data, Int, Tensor}; + use super::*; + use burn_tensor::{Data, Int, Tensor}; - #[test] - fn test_1d_iter_last_item() { - let data = [1, 2, 3, 4]; - let tensor = Tensor::::from_ints(data); - assert_eq!( - Tensor::::from_ints([4]).into_data(), - tensor.iter_dim(0).last().unwrap().into_data() - ) - } + #[test] + fn test_1d_iter_last_item() { + let data = [1, 2, 3, 4]; + let tensor = Tensor::::from_ints(data); + assert_eq!( + Tensor::::from_ints([4]).into_data(), + tensor.iter_dim(0).last().unwrap().into_data() + ) + } - #[test] - #[should_panic] - fn test_too_high_dimension() { - Tensor::::zeros([10]).iter_dim(1); - } + #[test] + #[should_panic] + fn test_too_high_dimension() { + Tensor::::zeros([10]).iter_dim(1); + } - #[test] - fn test_transposed() { - let data = [ - [1., 2., 3., 1., 2.], - [4., 5., 6., 1., 2.], - [7., 8., 9., 1., 2.], - ]; - let tensor = Tensor::::from_floats(data); - let lhs = tensor.clone().slice([1..2, 0..5]); - let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap(); - assert_eq!(lhs.into_data().value, rhs.into_data().value); - } + #[test] + fn test_transposed() { + let data = [ + [1., 2., 3., 1., 2.], + [4., 5., 6., 1., 2.], + [7., 8., 9., 1., 2.], + ]; + let tensor = Tensor::::from_floats(data); + let lhs = tensor.clone().slice([1..2, 0..5]); + let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap(); + assert_eq!(lhs.into_data().value, rhs.into_data().value); + } - fn test_iteration_over_low_dim() { - let data = [[ - [1., 2., 3., 1., 2.], - [4., 5., 6., 1., 2.], - [7., 8., 9., 1., 2.], - ]; 5]; - let tensor = Tensor::::from_floats(data); - let lhs = tensor.iter_dim(2).nth(1).unwrap(); - let rhs = Data::from([2., 5., 8.]); - assert_eq!(lhs.into_data().value, rhs.value); - } + fn test_iteration_over_low_dim() { + let data = [[ + [1., 2., 3., 1., 2.], + [4., 5., 6., 1., 2.], + [7., 8., 9., 1., 2.], + ]; 5]; + let tensor = Tensor::::from_floats(data); + let lhs = tensor.iter_dim(2).nth(1).unwrap(); + let rhs = Data::from([2., 5., 8.]); + assert_eq!(lhs.into_data().value, rhs.value); + } } diff --git a/burn-tensor/src/tests/ops/log.rs b/burn-tensor/src/tests/ops/log.rs index 4532643487..f71387317a 100644 --- a/burn-tensor/src/tests/ops/log.rs +++ b/burn-tensor/src/tests/ops/log.rs @@ -1,19 +1,19 @@ #[burn_tensor_testgen::testgen(log)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_log_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_log_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.log().into_data(); + let data_actual = tensor.log().into_data(); - let data_expected = Data::from([ - [-f32::INFINITY, 0.0, core::f32::consts::LN_2], - [1.0986, 1.3862, 1.6094], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([ + [-f32::INFINITY, 0.0, core::f32::consts::LN_2], + [1.0986, 1.3862, 1.6094], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/log1p.rs b/burn-tensor/src/tests/ops/log1p.rs index fe6f2b5b4e..346e01b2e3 100644 --- a/burn-tensor/src/tests/ops/log1p.rs +++ b/burn-tensor/src/tests/ops/log1p.rs @@ -1,19 +1,19 @@ #[burn_tensor_testgen::testgen(log1p)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_exp_log1p() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_exp_log1p() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.log1p().into_data(); + let data_actual = tensor.log1p().into_data(); - let data_expected = Data::from([ - [0.0, core::f32::consts::LN_2, 1.0986], - [1.3862, 1.6094, 1.7917], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([ + [0.0, core::f32::consts::LN_2, 1.0986], + [1.3862, 1.6094, 1.7917], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/map_comparison.rs b/burn-tensor/src/tests/ops/map_comparison.rs index 0906c2bc5b..76eee5dc02 100644 --- a/burn-tensor/src/tests/ops/map_comparison.rs +++ b/burn-tensor/src/tests/ops/map_comparison.rs @@ -1,308 +1,308 @@ #[burn_tensor_testgen::testgen(map_comparison)] mod tests { - use super::*; - use burn_tensor::{ - backend::Backend, BasicOps, Bool, Data, Element, Float, Int, Numeric, Tensor, TensorKind, - }; - - type IntElem = ::IntElem; - type FloatElem = ::FloatElem; - - #[test] - fn test_equal() { - equal::() - } - - #[test] - fn test_int_equal() { - equal::() - } - - #[test] - fn test_equal_elem() { - equal_elem::() - } - - #[test] - fn test_int_equal_elem() { - equal_elem::() - } - - #[test] - fn test_greater_elem() { - greater_elem::() - } - - #[test] - fn test_int_greater_elem() { - greater_elem::() - } - - #[test] - fn test_greater_equal_elem() { - greater_equal_elem::() - } - - #[test] - fn test_int_greater_equal_elem() { - greater_equal_elem::() - } - - #[test] - fn test_greater() { - greater::() - } - - #[test] - fn test_int_greater() { - greater::() - } - - #[test] - fn test_greater_equal() { - greater_equal::() - } - - #[test] - fn test_int_greater_equal() { - greater_equal::() - } - - #[test] - fn test_lower_elem() { - lower_elem::() - } - - #[test] - fn test_int_lower_elem() { - lower_elem::() - } - - #[test] - fn test_lower_equal_elem() { - lower_equal_elem::() - } - - #[test] - fn test_int_lower_equal_elem() { - lower_equal_elem::() - } - - #[test] - fn test_lower() { - lower::() - } - - #[test] - fn test_int_lower() { - lower::() - } - - #[test] - fn test_lower_equal() { - lower_equal::() - } - - #[test] - fn test_int_lower_equal() { - lower_equal::() - } - - fn equal() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.equal(tensor_2); - - let data_expected = Data::from([[false, true, false], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn equal_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().equal_elem(2); - let data_actual_inplace = tensor_1.equal_elem(2); - - let data_expected = Data::from([[false, false, true], [false, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().greater_elem(4); - let data_actual_inplace = tensor_1.greater_elem(4); - - let data_expected = Data::from([[false, false, false], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater_equal_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0); - let data_actual_inplace = tensor_1.greater_equal_elem(4.0); - - let data_expected = Data::from([[false, false, false], [false, true, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); - let data_actual_inplace = tensor_1.greater(tensor_2); - - let data_expected = Data::from([[false, false, true], [false, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn greater_equal() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.greater_equal(tensor_2); - - let data_expected = Data::from([[false, true, true], [false, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().lower_elem(4.0); - let data_actual_inplace = tensor_1.lower_elem(4.0); - - let data_expected = Data::from([[true, true, true], [true, false, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower_equal_elem() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0); - let data_actual_inplace = tensor_1.lower_equal_elem(4.0); - - let data_expected = Data::from([[true, true, true], [true, true, false]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); - let data_actual_inplace = tensor_1.lower(tensor_2); - - let data_expected = Data::from([[true, false, false], [true, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - fn lower_equal() - where - K: Numeric + BasicOps, - E: Element, - { - let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); - let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.lower_equal(tensor_2); - - let data_expected = Data::from([[true, true, false], [true, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - #[test] - fn should_support_bool_equal() { - let data_1 = Data::from([[false, true, true], [true, false, true]]); - let data_2 = Data::from([[false, false, true], [false, true, true]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); - let data_actual_inplace = tensor_1.equal(tensor_2); - - let data_expected = Data::from([[true, false, true], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } - - #[test] - fn should_support_bool_not() { - let data_1 = Data::from([[false, true, true], [true, true, false]]); - let tensor_1 = Tensor::::from_data(data_1); - - let data_actual_cloned = tensor_1.clone().bool_not(); - let data_actual_inplace = tensor_1.bool_not(); - - let data_expected = Data::from([[true, false, false], [false, false, true]]); - assert_eq!(data_expected, data_actual_cloned.into_data()); - assert_eq!(data_expected, data_actual_inplace.into_data()); - } + use super::*; + use burn_tensor::{ + backend::Backend, BasicOps, Bool, Data, Element, Float, Int, Numeric, Tensor, TensorKind, + }; + + type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + + #[test] + fn test_equal() { + equal::() + } + + #[test] + fn test_int_equal() { + equal::() + } + + #[test] + fn test_equal_elem() { + equal_elem::() + } + + #[test] + fn test_int_equal_elem() { + equal_elem::() + } + + #[test] + fn test_greater_elem() { + greater_elem::() + } + + #[test] + fn test_int_greater_elem() { + greater_elem::() + } + + #[test] + fn test_greater_equal_elem() { + greater_equal_elem::() + } + + #[test] + fn test_int_greater_equal_elem() { + greater_equal_elem::() + } + + #[test] + fn test_greater() { + greater::() + } + + #[test] + fn test_int_greater() { + greater::() + } + + #[test] + fn test_greater_equal() { + greater_equal::() + } + + #[test] + fn test_int_greater_equal() { + greater_equal::() + } + + #[test] + fn test_lower_elem() { + lower_elem::() + } + + #[test] + fn test_int_lower_elem() { + lower_elem::() + } + + #[test] + fn test_lower_equal_elem() { + lower_equal_elem::() + } + + #[test] + fn test_int_lower_equal_elem() { + lower_equal_elem::() + } + + #[test] + fn test_lower() { + lower::() + } + + #[test] + fn test_int_lower() { + lower::() + } + + #[test] + fn test_lower_equal() { + lower_equal::() + } + + #[test] + fn test_int_lower_equal() { + lower_equal::() + } + + fn equal() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.equal(tensor_2); + + let data_expected = Data::from([[false, true, false], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn equal_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().equal_elem(2); + let data_actual_inplace = tensor_1.equal_elem(2); + + let data_expected = Data::from([[false, false, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().greater_elem(4); + let data_actual_inplace = tensor_1.greater_elem(4); + + let data_expected = Data::from([[false, false, false], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater_equal_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0); + let data_actual_inplace = tensor_1.greater_equal_elem(4.0); + + let data_expected = Data::from([[false, false, false], [false, true, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); + let data_actual_inplace = tensor_1.greater(tensor_2); + + let data_expected = Data::from([[false, false, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn greater_equal() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.greater_equal(tensor_2); + + let data_expected = Data::from([[false, true, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().lower_elem(4.0); + let data_actual_inplace = tensor_1.lower_elem(4.0); + + let data_expected = Data::from([[true, true, true], [true, false, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower_equal_elem() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0); + let data_actual_inplace = tensor_1.lower_equal_elem(4.0); + + let data_expected = Data::from([[true, true, true], [true, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); + let data_actual_inplace = tensor_1.lower(tensor_2); + + let data_expected = Data::from([[true, false, false], [true, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + fn lower_equal() + where + K: Numeric + BasicOps, + E: Element, + { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert(); + let data_2 = Data::::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]).convert(); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.lower_equal(tensor_2); + + let data_expected = Data::from([[true, true, false], [true, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn should_support_bool_equal() { + let data_1 = Data::from([[false, true, true], [true, false, true]]); + let data_2 = Data::from([[false, false, true], [false, true, true]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.equal(tensor_2); + + let data_expected = Data::from([[true, false, true], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn should_support_bool_not() { + let data_1 = Data::from([[false, true, true], [true, true, false]]); + let tensor_1 = Tensor::::from_data(data_1); + + let data_actual_cloned = tensor_1.clone().bool_not(); + let data_actual_inplace = tensor_1.bool_not(); + + let data_expected = Data::from([[true, false, false], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } } diff --git a/burn-tensor/src/tests/ops/mask.rs b/burn-tensor/src/tests/ops/mask.rs index 6a815ef2ff..735c31127e 100644 --- a/burn-tensor/src/tests/ops/mask.rs +++ b/burn-tensor/src/tests/ops/mask.rs @@ -1,55 +1,55 @@ #[burn_tensor_testgen::testgen(mask)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; - - #[test] - fn should_support_mask_where_ops() { - let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - let value = Tensor::::from_data(Data::from([[1.8, 2.8], [3.8, 4.8]])); - - let data_actual = tensor.mask_where(mask, value).into_data(); - - let data_expected = Data::from([[1.8, 7.0], [2.0, 4.8]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mask_fill_ops() { - let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - - let data_actual = tensor.mask_fill(mask, 2.0).to_data(); - - let data_expected = Data::from([[2.0, 7.0], [2.0, 2.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_int_mask_where_ops() { - let tensor = Tensor::::from_data([[1, 7], [2, 3]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - let value = Tensor::::from_data(Data::from([[8, 9], [10, 11]])); - - let data_actual = tensor.mask_where(mask, value).into_data(); - - let data_expected = Data::from([[8, 7], [2, 11]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_int_mask_fill_ops() { - let tensor = Tensor::::from_data([[1, 7], [2, 3]]); - let mask = - Tensor::::from_bool(Data::from([[true, false], [false, true]])); - - let data_actual = tensor.mask_fill(mask, 9).to_data(); - - let data_expected = Data::from([[9, 7], [2, 9]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; + + #[test] + fn should_support_mask_where_ops() { + let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + let value = Tensor::::from_data(Data::from([[1.8, 2.8], [3.8, 4.8]])); + + let data_actual = tensor.mask_where(mask, value).into_data(); + + let data_expected = Data::from([[1.8, 7.0], [2.0, 4.8]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mask_fill_ops() { + let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + + let data_actual = tensor.mask_fill(mask, 2.0).to_data(); + + let data_expected = Data::from([[2.0, 7.0], [2.0, 2.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_int_mask_where_ops() { + let tensor = Tensor::::from_data([[1, 7], [2, 3]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + let value = Tensor::::from_data(Data::from([[8, 9], [10, 11]])); + + let data_actual = tensor.mask_where(mask, value).into_data(); + + let data_expected = Data::from([[8, 7], [2, 11]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_int_mask_fill_ops() { + let tensor = Tensor::::from_data([[1, 7], [2, 3]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + + let data_actual = tensor.mask_fill(mask, 9).to_data(); + + let data_expected = Data::from([[9, 7], [2, 9]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/matmul.rs b/burn-tensor/src/tests/ops/matmul.rs index adc0090c5f..6edf59067c 100644 --- a/burn-tensor/src/tests/ops/matmul.rs +++ b/burn-tensor/src/tests/ops/matmul.rs @@ -1,105 +1,108 @@ #[burn_tensor_testgen::testgen(matmul)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; - - #[test] - fn test_matmul_d2() { - let tensor_1 = TestTensor::from_floats([[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]]); - let tensor_2 = TestTensor::from_floats([[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]) - ); - } - - #[test] - fn test_matmul_d3() { - let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); - let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[[18.0, 28.0], [14.0, 23.0]]]) - ); - } - - #[test] - fn test_matmul_broadcast_1() { - let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); - let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]) - ); - } - - #[test] - fn test_matmul_simple_1() { - let tensor_1 = TestTensor::from_floats([[5.0, 14.0], [14.0, 50.0]]); - let tensor_2 = TestTensor::from_floats([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]]) - ); - } - - #[test] - fn test_matmul_simple_2() { - let tensor_1 = TestTensor::from_floats([[1.0, 2.0, 3.0, 4.0]]); - let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!(tensor_3.into_data(), Data::from([[50.0]])); - } - - #[test] - fn test_matmul_simple_3() { - let tensor_1 = - TestTensor::from_floats([[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]); - let tensor_2 = TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([ - [9., 18., 27., 36.], - [12., 24., 36., 48.], - [15., 30., 45., 60.], - [18., 36., 54., 72.] - ]) - ); - } - - #[test] - #[should_panic] - fn should_panic_when_inner_dimensions_are_not_equal() { - let tensor_1 = TestTensor::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]); - let tensor_2 = TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); - - let tensor_3 = tensor_1.matmul(tensor_2); - - assert_eq!( - tensor_3.into_data(), - Data::from([ - [9., 18., 27., 36.], - [12., 24., 36., 48.], - [15., 30., 45., 60.], - [18., 36., 54., 72.] - ]) - ); - } + use super::*; + use burn_tensor::{Data, Tensor}; + + #[test] + fn test_matmul_d2() { + let tensor_1 = TestTensor::from_floats([[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]]); + let tensor_2 = TestTensor::from_floats([[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]) + ); + } + + #[test] + fn test_matmul_d3() { + let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); + let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[[18.0, 28.0], [14.0, 23.0]]]) + ); + } + + #[test] + fn test_matmul_broadcast_1() { + let tensor_1 = TestTensor::from_floats([[[1.0, 7.0], [2.0, 3.0]]]); + let tensor_2 = + TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]) + ); + } + + #[test] + fn test_matmul_simple_1() { + let tensor_1 = TestTensor::from_floats([[5.0, 14.0], [14.0, 50.0]]); + let tensor_2 = TestTensor::from_floats([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]]) + ); + } + + #[test] + fn test_matmul_simple_2() { + let tensor_1 = TestTensor::from_floats([[1.0, 2.0, 3.0, 4.0]]); + let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!(tensor_3.into_data(), Data::from([[50.0]])); + } + + #[test] + fn test_matmul_simple_3() { + let tensor_1 = + TestTensor::from_floats([[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]); + let tensor_2 = + TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([ + [9., 18., 27., 36.], + [12., 24., 36., 48.], + [15., 30., 45., 60.], + [18., 36., 54., 72.] + ]) + ); + } + + #[test] + #[should_panic] + fn should_panic_when_inner_dimensions_are_not_equal() { + let tensor_1 = TestTensor::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]); + let tensor_2 = + TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); + + let tensor_3 = tensor_1.matmul(tensor_2); + + assert_eq!( + tensor_3.into_data(), + Data::from([ + [9., 18., 27., 36.], + [12., 24., 36., 48.], + [15., 30., 45., 60.], + [18., 36., 54., 72.] + ]) + ); + } } diff --git a/burn-tensor/src/tests/ops/maxmin.rs b/burn-tensor/src/tests/ops/maxmin.rs index 3cbda1fd0b..0dea58522d 100644 --- a/burn-tensor/src/tests/ops/maxmin.rs +++ b/burn-tensor/src/tests/ops/maxmin.rs @@ -1,51 +1,51 @@ #[burn_tensor_testgen::testgen(maxmin)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn test_max_dim_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_max_dim_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let output_actual = tensor.max_dim(1); + let output_actual = tensor.max_dim(1); - let output_expected = Data::from([[2.], [5.]]); - assert_eq!(output_expected, output_actual.into_data()); - } + let output_expected = Data::from([[2.], [5.]]); + assert_eq!(output_expected, output_actual.into_data()); + } - #[test] - fn test_max_dim_with_indices_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_max_dim_with_indices_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let (output_actual, index_actual) = tensor.max_dim_with_indices(1); + let (output_actual, index_actual) = tensor.max_dim_with_indices(1); - let output_expected = Data::from([[2.], [5.]]); - let index_expected = Data::from([[2], [2]]); + let output_expected = Data::from([[2.], [5.]]); + let index_expected = Data::from([[2], [2]]); - assert_eq!(output_expected, output_actual.into_data()); - assert_eq!(index_expected, index_actual.into_data()); - } + assert_eq!(output_expected, output_actual.into_data()); + assert_eq!(index_expected, index_actual.into_data()); + } - #[test] - fn test_min_dim_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_min_dim_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let output_actual = tensor.min_dim(1); + let output_actual = tensor.min_dim(1); - let output_expected = Data::from([[0.], [3.]]); - assert_eq!(output_expected, output_actual.into_data()); - } + let output_expected = Data::from([[0.], [3.]]); + assert_eq!(output_expected, output_actual.into_data()); + } - #[test] - fn test_min_dim_with_indices_2d() { - let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + #[test] + fn test_min_dim_with_indices_2d() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let (output_actual, index_actual) = tensor.min_dim_with_indices(1); + let (output_actual, index_actual) = tensor.min_dim_with_indices(1); - let output_expected = Data::from([[0.], [3.]]); - let index_expected = Data::from([[0], [0]]); + let output_expected = Data::from([[0.], [3.]]); + let index_expected = Data::from([[0], [0]]); - assert_eq!(output_expected, output_actual.into_data()); - assert_eq!(index_expected, index_actual.into_data()); - } + assert_eq!(output_expected, output_actual.into_data()); + assert_eq!(index_expected, index_actual.into_data()); + } } diff --git a/burn-tensor/src/tests/ops/mul.rs b/burn-tensor/src/tests/ops/mul.rs index f3983a4005..81337b808f 100644 --- a/burn-tensor/src/tests/ops/mul.rs +++ b/burn-tensor/src/tests/ops/mul.rs @@ -1,85 +1,85 @@ #[burn_tensor_testgen::testgen(mul)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn should_support_mul_ops() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 * tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_mul_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 * tensor_2).into_data(); - - let data_expected = Data::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mul_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor * scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mul_ops_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[0, 1, 2], [3, 4, 5]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let output = tensor_1 * tensor_2; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 1, 4], [9, 16, 25]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_mul_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 * tensor_2).into_data(); - - let data_expected = Data::from([[0, 4, 10], [0, 7, 16]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_mul_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor * scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[0, 2, 4], [6, 8, 10]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn should_support_mul_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 * tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_mul_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 * tensor_2).into_data(); + + let data_expected = Data::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor * scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 * tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 1, 4], [9, 16, 25]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_mul_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 * tensor_2).into_data(); + + let data_expected = Data::from([[0, 4, 10], [0, 7, 16]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor * scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 2, 4], [6, 8, 10]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/neg.rs b/burn-tensor/src/tests/ops/neg.rs index 3393418c41..bcea87b40c 100644 --- a/burn-tensor/src/tests/ops/neg.rs +++ b/burn-tensor/src/tests/ops/neg.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(neg)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_neg_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.neg().into_data(); + let data_actual = tensor.neg().into_data(); - let data_expected = Data::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/one_hot.rs b/burn-tensor/src/tests/ops/one_hot.rs index 7dc96e7547..1dd09382d9 100644 --- a/burn-tensor/src/tests/ops/one_hot.rs +++ b/burn-tensor/src/tests/ops/one_hot.rs @@ -1,32 +1,32 @@ #[burn_tensor_testgen::testgen(one_hot)] mod tests { - use super::*; - use burn_tensor::{Data, Int}; + use super::*; + use burn_tensor::{Data, Int}; - #[test] - fn should_support_one_hot() { - let tensor = TestTensor::<1>::one_hot(0, 5); - assert_eq!(tensor.to_data(), Data::from([1., 0., 0., 0., 0.])); + #[test] + fn should_support_one_hot() { + let tensor = TestTensor::<1>::one_hot(0, 5); + assert_eq!(tensor.to_data(), Data::from([1., 0., 0., 0., 0.])); - let tensor = TestTensor::<1>::one_hot(1, 5); - assert_eq!(tensor.to_data(), Data::from([0., 1., 0., 0., 0.])); + let tensor = TestTensor::<1>::one_hot(1, 5); + assert_eq!(tensor.to_data(), Data::from([0., 1., 0., 0., 0.])); - let tensor = TestTensor::<1>::one_hot(4, 5); - assert_eq!(tensor.to_data(), Data::from([0., 0., 0., 0., 1.])); + let tensor = TestTensor::<1>::one_hot(4, 5); + assert_eq!(tensor.to_data(), Data::from([0., 0., 0., 0., 1.])); - let tensor = TestTensor::<1>::one_hot(1, 2); - assert_eq!(tensor.to_data(), Data::from([0., 1.])); - } + let tensor = TestTensor::<1>::one_hot(1, 2); + assert_eq!(tensor.to_data(), Data::from([0., 1.])); + } - #[test] - #[should_panic] - fn should_panic_when_index_exceeds_number_of_classes() { - let tensor = TestTensor::<1>::one_hot(1, 1); - } + #[test] + #[should_panic] + fn should_panic_when_index_exceeds_number_of_classes() { + let tensor = TestTensor::<1>::one_hot(1, 1); + } - #[test] - #[should_panic] - fn should_panic_when_number_of_classes_is_zero() { - let tensor = TestTensor::<1>::one_hot(0, 0); - } + #[test] + #[should_panic] + fn should_panic_when_number_of_classes_is_zero() { + let tensor = TestTensor::<1>::one_hot(0, 0); + } } diff --git a/burn-tensor/src/tests/ops/powf.rs b/burn-tensor/src/tests/ops/powf.rs index b98c56578f..59f9abca5d 100644 --- a/burn-tensor/src/tests/ops/powf.rs +++ b/burn-tensor/src/tests/ops/powf.rs @@ -1,49 +1,49 @@ #[burn_tensor_testgen::testgen(powf)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_powf_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_powf_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(0.71).into_data(); + let data_actual = tensor.powf(0.71).into_data(); - let data_expected = Data::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_neg_power() { - let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_power() { + let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(-0.33).into_data(); + let data_actual = tensor.powf(-0.33).into_data(); - let data_expected = Data::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_neg_values_with_even_power() { - let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_values_with_even_power() { + let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(4.0).into_data(); + let data_actual = tensor.powf(4.0).into_data(); - let data_expected = Data::from([[0.0, 1.0, 16.0], [81.0, 256.0, 625.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 1.0, 16.0], [81.0, 256.0, 625.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn should_support_neg_values_with_odd_power() { - let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_neg_values_with_odd_power() { + let data = Data::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.powf(3.0).into_data(); + let data_actual = tensor.powf(3.0).into_data(); - let data_expected = Data::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/random.rs b/burn-tensor/src/tests/ops/random.rs index 1bc15a1c3a..5aeccfa7a4 100644 --- a/burn-tensor/src/tests/ops/random.rs +++ b/burn-tensor/src/tests/ops/random.rs @@ -1,27 +1,27 @@ #[burn_tensor_testgen::testgen(random)] mod tests { - use super::*; - use burn_tensor::{Distribution, Tensor}; + use super::*; + use burn_tensor::{Distribution, Tensor}; - #[test] - fn rand_default() { - let tensor = Tensor::::random([20], Distribution::Default); + #[test] + fn rand_default() { + let tensor = Tensor::::random([20], Distribution::Default); - // check that the tensor is within the range of [0..1) (1 is exclusive) - tensor.into_data().assert_within_range(0.0..1.0); - } + // check that the tensor is within the range of [0..1) (1 is exclusive) + tensor.into_data().assert_within_range(0.0..1.0); + } - #[test] - fn rand_uniform() { - let tensor = Tensor::::random([20], Distribution::Uniform(4., 5.)); + #[test] + fn rand_uniform() { + let tensor = Tensor::::random([20], Distribution::Uniform(4., 5.)); - tensor.into_data().assert_within_range(4.0..5.0); - } + tensor.into_data().assert_within_range(4.0..5.0); + } - #[test] - fn rand_bernoulli() { - let tensor = Tensor::::random([20], Distribution::Bernoulli(1.)); + #[test] + fn rand_bernoulli() { + let tensor = Tensor::::random([20], Distribution::Bernoulli(1.)); - assert_eq!(tensor.into_data(), [1.; 20].into()); - } + assert_eq!(tensor.into_data(), [1.; 20].into()); + } } diff --git a/burn-tensor/src/tests/ops/recip.rs b/burn-tensor/src/tests/ops/recip.rs index 9e700d67fa..70395fd60b 100644 --- a/burn-tensor/src/tests/ops/recip.rs +++ b/burn-tensor/src/tests/ops/recip.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(recip)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_recip_ops() { - let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_recip_ops() { + let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.recip().into_data(); + let data_actual = tensor.recip().into_data(); - let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/repeat.rs b/burn-tensor/src/tests/ops/repeat.rs index d436b11330..8725decb26 100644 --- a/burn-tensor/src/tests/ops/repeat.rs +++ b/burn-tensor/src/tests/ops/repeat.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(repeat)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_repeat_ops() { - let data = Data::from([[0.0, 1.0, 2.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_repeat_ops() { + let data = Data::from([[0.0, 1.0, 2.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.repeat(0, 4).into_data(); + let data_actual = tensor.repeat(0, 4).into_data(); - let data_expected = Data::from([ - [0.0, 1.0, 2.0], - [0.0, 1.0, 2.0], - [0.0, 1.0, 2.0], - [0.0, 1.0, 2.0], - ]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([ + [0.0, 1.0, 2.0], + [0.0, 1.0, 2.0], + [0.0, 1.0, 2.0], + [0.0, 1.0, 2.0], + ]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/reshape.rs b/burn-tensor/src/tests/ops/reshape.rs index 6de805dc7c..9c02dd5132 100644 --- a/burn-tensor/src/tests/ops/reshape.rs +++ b/burn-tensor/src/tests/ops/reshape.rs @@ -1,88 +1,88 @@ #[burn_tensor_testgen::testgen(reshape)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; - #[test] - fn should_support_reshape_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([1, 3]).into_data(); - let data_expected = Data::from([[0.0, 1.0, 2.0]]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([1, 3]).into_data(); + let data_expected = Data::from([[0.0, 1.0, 2.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_reshape_int() { - let data = Data::from([0, 1, 2]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_int() { + let data = Data::from([0, 1, 2]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([1, 3]).into_data(); - let data_expected = Data::from([[0, 1, 2]]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([1, 3]).into_data(); + let data_expected = Data::from([[0, 1, 2]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_reshape_bool() { - let data = Data::from([false, true, false]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_bool() { + let data = Data::from([false, true, false]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([1, 3]).into_data(); - let data_expected = Data::from([[false, true, false]]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([1, 3]).into_data(); + let data_expected = Data::from([[false, true, false]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_reshape_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_reshape_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.clone().reshape([6]).into_data(); - let data_expected = Data::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); - assert_eq!(data_expected, data_actual); - } + let data_actual = tensor.clone().reshape([6]).into_data(); + let data_expected = Data::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_dim_infererence() { - let data = Data::from([ - [0.0, 1.0, 2.0], - [3.0, 4.0, 5.0], - [6.0, 7.0, 8.0], - [9.0, 10.0, 11.0], - ]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_dim_infererence() { + let data = Data::from([ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + ]); + let tensor = Tensor::::from_data(data); - // Infer the dimension via -1 - let reshaped = tensor.clone().reshape([2, -1]); - assert_eq!(reshaped.shape(), [2, 6].into()); + // Infer the dimension via -1 + let reshaped = tensor.clone().reshape([2, -1]); + assert_eq!(reshaped.shape(), [2, 6].into()); - // Infer the dimension via 0 (keep from the source) and -1 (infer) - let reshaped = reshaped.reshape([0, 2, -1]); - assert_eq!(reshaped.shape(), [2, 2, 3].into()); + // Infer the dimension via 0 (keep from the source) and -1 (infer) + let reshaped = reshaped.reshape([0, 2, -1]); + assert_eq!(reshaped.shape(), [2, 2, 3].into()); - // This is effectively as if we did a flatten - let reshaped = tensor.clone().reshape([-1]); - assert_eq!(reshaped.shape(), [12].into()); + // This is effectively as if we did a flatten + let reshaped = tensor.clone().reshape([-1]); + assert_eq!(reshaped.shape(), [12].into()); - // Keeping the first dimension the same (using 0) - let reshaped = tensor.clone().reshape([0, 3]); - assert_eq!(reshaped.shape(), [4, 3].into()); - } + // Keeping the first dimension the same (using 0) + let reshaped = tensor.clone().reshape([0, 3]); + assert_eq!(reshaped.shape(), [4, 3].into()); + } - #[test] - #[should_panic] - fn multiple_neg_ones() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.reshape([-1, -1]).into_data(); - } + #[test] + #[should_panic] + fn multiple_neg_ones() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.reshape([-1, -1]).into_data(); + } - #[test] - #[should_panic] - fn neg_value() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.reshape([-2, -1]).into_data(); - } + #[test] + #[should_panic] + fn neg_value() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.reshape([-2, -1]).into_data(); + } } diff --git a/burn-tensor/src/tests/ops/select.rs b/burn-tensor/src/tests/ops/select.rs index 823168618a..d62fdaeae7 100644 --- a/burn-tensor/src/tests/ops/select.rs +++ b/burn-tensor/src/tests/ops/select.rs @@ -1,128 +1,128 @@ #[burn_tensor_testgen::testgen(select)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_select_1d() { - let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - let output = tensor.select(0, indices); - - assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); - } - - #[test] - fn should_select_1d_int() { - let tensor = TestTensorInt::from_data([5, 6, 7]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - let output = tensor.select(0, indices); - - assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); - } - - #[test] - fn should_select_2d_dim0_same_num_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data(([1, 0])); - - let output = tensor.select(0, indices); - - assert_eq!( - output.into_data(), - Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]) - ); - } - - #[test] - fn should_select_2d_dim0_more_num_dim() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data([1, 0, 1, 1]); - - let output = tensor.select(0, indices); - - assert_eq!( - output.into_data(), - Data::from([ - [3.0, 4.0, 5.0], - [0.0, 1.0, 2.0], - [3.0, 4.0, 5.0], - [3.0, 4.0, 5.0] - ]) - ); - } - - #[test] - fn should_select_2d_dim1() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - let output = tensor.select(1, indices); - - assert_eq!( - output.into_data(), - Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]) - ); - } - - #[test] - fn should_select_assign_1d() { - let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); - let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0]); - let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - - let output = tensor.select_assign(0, indices, values); - - assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0])); - } - - #[test] - fn should_select_assign_1d_int() { - let tensor = TestTensorInt::from_data([7, 8, 9]); - let values = TestTensorInt::from_data([5, 4, 3, 2, 1]); - let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - - let output = tensor.select_assign(0, indices, values); - - assert_eq!(output.into_data(), Data::from([10, 19, 10])); - } - - #[test] - fn should_select_assign_2d_dim0() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_data(Data::from([1, 0])); - - let output = tensor.select_assign(0, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]) - ); - } - - #[test] - fn should_select_assign_2d_dim1() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); - let indices = TestTensorInt::from_data(Data::from([1, 0, 2])); - - let output = tensor.select_assign(1, indices, values); - - assert_eq!( - output.into_data(), - Data::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]]) - ); - } - - #[test] - #[should_panic] - fn should_select_panic_invalid_dimension() { - let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); - - tensor.select(10, indices); - } + #[test] + fn should_select_1d() { + let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + let output = tensor.select(0, indices); + + assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); + } + + #[test] + fn should_select_1d_int() { + let tensor = TestTensorInt::from_data([5, 6, 7]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + let output = tensor.select(0, indices); + + assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); + } + + #[test] + fn should_select_2d_dim0_same_num_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data(([1, 0])); + + let output = tensor.select(0, indices); + + assert_eq!( + output.into_data(), + Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]) + ); + } + + #[test] + fn should_select_2d_dim0_more_num_dim() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data([1, 0, 1, 1]); + + let output = tensor.select(0, indices); + + assert_eq!( + output.into_data(), + Data::from([ + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [3.0, 4.0, 5.0] + ]) + ); + } + + #[test] + fn should_select_2d_dim1() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + let output = tensor.select(1, indices); + + assert_eq!( + output.into_data(), + Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_select_assign_1d() { + let tensor = TestTensor::from_data([0.0, 1.0, 2.0]); + let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0]); + let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); + + let output = tensor.select_assign(0, indices, values); + + assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0])); + } + + #[test] + fn should_select_assign_1d_int() { + let tensor = TestTensorInt::from_data([7, 8, 9]); + let values = TestTensorInt::from_data([5, 4, 3, 2, 1]); + let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); + + let output = tensor.select_assign(0, indices, values); + + assert_eq!(output.into_data(), Data::from([10, 19, 10])); + } + + #[test] + fn should_select_assign_2d_dim0() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_data(Data::from([1, 0])); + + let output = tensor.select_assign(0, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]) + ); + } + + #[test] + fn should_select_assign_2d_dim1() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indices = TestTensorInt::from_data(Data::from([1, 0, 2])); + + let output = tensor.select_assign(1, indices, values); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]]) + ); + } + + #[test] + #[should_panic] + fn should_select_panic_invalid_dimension() { + let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + tensor.select(10, indices); + } } diff --git a/burn-tensor/src/tests/ops/sin.rs b/burn-tensor/src/tests/ops/sin.rs index 24518d0f96..c7f9685947 100644 --- a/burn-tensor/src/tests/ops/sin.rs +++ b/burn-tensor/src/tests/ops/sin.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(sin)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_sin_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_sin_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.sin().into_data(); + let data_actual = tensor.sin().into_data(); - let data_expected = Data::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/slice.rs b/burn-tensor/src/tests/ops/slice.rs index 84ccd02d6e..7db32a6722 100644 --- a/burn-tensor/src/tests/ops/slice.rs +++ b/burn-tensor/src/tests/ops/slice.rs @@ -1,150 +1,150 @@ #[burn_tensor_testgen::testgen(slice)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_full_sliceing_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + fn should_support_full_sliceing_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([0..3]).into_data(); + let data_actual = tensor.slice([0..3]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - fn should_support_partial_sliceing_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_partial_sliceing_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.slice([1..3]).into_data(); + let data_actual = tensor.slice([1..3]).into_data(); - let data_expected = Data::from([1.0, 2.0]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([1.0, 2.0]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_full_sliceing_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + fn should_support_full_sliceing_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual_1 = tensor.clone().slice([0..2]).into_data(); - let data_actual_2 = tensor.slice([0..2, 0..3]).into_data(); + let data_actual_1 = tensor.clone().slice([0..2]).into_data(); + let data_actual_2 = tensor.slice([0..2, 0..3]).into_data(); - assert_eq!(data, data_actual_1); - assert_eq!(data, data_actual_2); - } + assert_eq!(data, data_actual_1); + assert_eq!(data, data_actual_2); + } - #[test] - fn should_support_partial_sliceing_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_partial_sliceing_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.slice([0..2, 0..2]).into_data(); + let data_actual = tensor.slice([0..2, 0..2]).into_data(); - let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_partial_sliceing_3d() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_partial_sliceing_3d() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.slice([1..2, 1..2, 0..2]).into_data(); + let data_actual = tensor.slice([1..2, 1..2, 0..2]).into_data(); - let data_expected = Data::from([[[9.0, 10.0]]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[[9.0, 10.0]]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_partial_sliceing_3d_non_contiguous() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); + #[test] + fn should_support_partial_sliceing_3d_non_contiguous() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); - let data_actual = tensor.transpose().slice([1..2, 1..2, 0..2]).into_data(); + let data_actual = tensor.transpose().slice([1..2, 1..2, 0..2]).into_data(); - let data_expected = Data::from([[[7.0, 10.0]]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[[7.0, 10.0]]]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_slice_assign_1d() { - let data = Data::from([0.0, 1.0, 2.0]); - let data_assigned = Data::from([10.0, 5.0]); + #[test] + fn should_support_slice_assign_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let data_assigned = Data::from([10.0, 5.0]); - let tensor = Tensor::::from_data(data); - let tensor_assigned = Tensor::::from_data(data_assigned); + let tensor = Tensor::::from_data(data); + let tensor_assigned = Tensor::::from_data(data_assigned); - let data_actual = tensor.slice_assign([0..2], tensor_assigned).into_data(); + let data_actual = tensor.slice_assign([0..2], tensor_assigned).into_data(); - let data_expected = Data::from([10.0, 5.0, 2.0]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([10.0, 5.0, 2.0]); + assert_eq!(data_expected, data_actual); + } - #[test] - fn should_support_slice_assign_2d() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_assigned = Data::from([[10.0, 5.0]]); + #[test] + fn should_support_slice_assign_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_assigned = Data::from([[10.0, 5.0]]); - let tensor = Tensor::::from_data(data); - let tensor_assigned = Tensor::::from_data(data_assigned); + let tensor = Tensor::::from_data(data); + let tensor_assigned = Tensor::::from_data(data_assigned); - let data_actual = tensor - .slice_assign([1..2, 0..2], tensor_assigned) - .into_data(); + let data_actual = tensor + .slice_assign([1..2, 0..2], tensor_assigned) + .into_data(); - let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); - assert_eq!(data_expected, data_actual); - } + let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); + assert_eq!(data_expected, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_exceeds_dimension() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_exceeds_dimension() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([0..4]).into_data(); + let data_actual = tensor.slice([0..4]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_with_too_many_dimensions() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_with_too_many_dimensions() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([0..1, 0..1]).into_data(); + let data_actual = tensor.slice([0..1, 0..1]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_is_desc() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_is_desc() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - #[allow(clippy::reversed_empty_ranges)] - let data_actual = tensor.slice([2..1]).into_data(); + #[allow(clippy::reversed_empty_ranges)] + let data_actual = tensor.slice([2..1]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } - #[test] - #[should_panic] - fn should_panic_when_slice_is_equal() { - let data = Data::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone()); + #[test] + #[should_panic] + fn should_panic_when_slice_is_equal() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone()); - let data_actual = tensor.slice([1..1]).into_data(); + let data_actual = tensor.slice([1..1]).into_data(); - assert_eq!(data, data_actual); - } + assert_eq!(data, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/sqrt.rs b/burn-tensor/src/tests/ops/sqrt.rs index 593e399c1a..3044b1d9bc 100644 --- a/burn-tensor/src/tests/ops/sqrt.rs +++ b/burn-tensor/src/tests/ops/sqrt.rs @@ -1,17 +1,17 @@ #[burn_tensor_testgen::testgen(sqrt)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; - use core::f32::consts::SQRT_2; + use super::*; + use burn_tensor::{Data, Tensor}; + use core::f32::consts::SQRT_2; - #[test] - fn should_support_sqrt_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_sqrt_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.sqrt().into_data(); + let data_actual = tensor.sqrt().into_data(); - let data_expected = Data::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/squeeze.rs b/burn-tensor/src/tests/ops/squeeze.rs index 8d1e9fb3cf..d8de064bdd 100644 --- a/burn-tensor/src/tests/ops/squeeze.rs +++ b/burn-tensor/src/tests/ops/squeeze.rs @@ -1,37 +1,37 @@ #[burn_tensor_testgen::testgen(squeeze)] mod tests { - use super::*; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::{Data, Shape, Tensor}; - /// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor. - #[test] - fn should_squeeze() { - let tensor = Tensor::::ones(Shape::new([2, 1, 4])); - let squeezed_tensor: Tensor = tensor.squeeze(1); - let expected_shape = Shape::new([2, 4]); - assert_eq!(squeezed_tensor.shape(), expected_shape); - } - /// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor. - #[test] - fn should_squeeze_first() { - let tensor = Tensor::::ones(Shape::new([1, 3, 4, 5])); - let squeezed_tensor: Tensor = tensor.squeeze(0); - let expected_shape = Shape::new([3, 4, 5]); - assert_eq!(squeezed_tensor.shape(), expected_shape); - } - /// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor. - #[test] - fn should_squeeze_last() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 1])); - let squeezed_tensor: Tensor = tensor.squeeze(3); - let expected_shape = Shape::new([2, 3, 4]); - assert_eq!(squeezed_tensor.shape(), expected_shape); - } - /// Test if the function panics when the squeezed dimension is not of size 1. - #[test] - #[should_panic] - fn should_squeeze_panic() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); - let squeezed_tensor: Tensor = tensor.squeeze(2); - } + /// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor. + #[test] + fn should_squeeze() { + let tensor = Tensor::::ones(Shape::new([2, 1, 4])); + let squeezed_tensor: Tensor = tensor.squeeze(1); + let expected_shape = Shape::new([2, 4]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor. + #[test] + fn should_squeeze_first() { + let tensor = Tensor::::ones(Shape::new([1, 3, 4, 5])); + let squeezed_tensor: Tensor = tensor.squeeze(0); + let expected_shape = Shape::new([3, 4, 5]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor. + #[test] + fn should_squeeze_last() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 1])); + let squeezed_tensor: Tensor = tensor.squeeze(3); + let expected_shape = Shape::new([2, 3, 4]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function panics when the squeezed dimension is not of size 1. + #[test] + #[should_panic] + fn should_squeeze_panic() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let squeezed_tensor: Tensor = tensor.squeeze(2); + } } diff --git a/burn-tensor/src/tests/ops/sub.rs b/burn-tensor/src/tests/ops/sub.rs index d2d1adba54..3293379abd 100644 --- a/burn-tensor/src/tests/ops/sub.rs +++ b/burn-tensor/src/tests/ops/sub.rs @@ -1,83 +1,83 @@ #[burn_tensor_testgen::testgen(sub)] mod tests { - use super::*; - use burn_tensor::{Data, Int, Tensor}; - - #[test] - fn should_support_sub_ops() { - let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); - let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_sub_broadcast() { - let data_1 = Data::from([[0.0, 1.0, 2.0]]); - let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - let data_expected = Data::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_sub_scalar_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let scalar = 2.0; - let tensor = Tensor::::from_data(data); - - let output = tensor - scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_sub_ops_int() { - let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); - let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); - let data_expected = Data::from([[-6, -6, -6], [-6, -6, -6]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn test_sub_broadcast_int() { - let data_1 = Data::from([[0, 1, 2]]); - let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - - let data_actual = (tensor_1 - tensor_2).into_data(); - - let data_expected = Data::from([[-3, -3, -3], [-6, -6, -6]]); - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_sub_scalar_ops_int() { - let data = Data::from([[0, 1, 2], [3, 4, 5]]); - let scalar = 2; - let tensor = Tensor::::from_data(data); - - let output = tensor - scalar; - - let data_actual = output.into_data(); - let data_expected = Data::from([[-2, -1, 0], [1, 2, 3]]); - assert_eq!(data_expected, data_actual); - } + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn should_support_sub_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); + let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_sub_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + let data_expected = Data::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_sub_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = Tensor::::from_data(data); + + let output = tensor - scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_sub_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); + let data_expected = Data::from([[-6, -6, -6], [-6, -6, -6]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_sub_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + let data_expected = Data::from([[-3, -3, -3], [-6, -6, -6]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_sub_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor - scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[-2, -1, 0], [1, 2, 3]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/tanh.rs b/burn-tensor/src/tests/ops/tanh.rs index 5d95b948f7..b65b1be7ab 100644 --- a/burn-tensor/src/tests/ops/tanh.rs +++ b/burn-tensor/src/tests/ops/tanh.rs @@ -1,16 +1,16 @@ #[burn_tensor_testgen::testgen(tanh)] mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::{Data, Tensor}; - #[test] - fn should_support_tanh_ops() { - let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data); + #[test] + fn should_support_tanh_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = Tensor::::from_data(data); - let data_actual = tensor.tanh().into_data(); + let data_actual = tensor.tanh().into_data(); - let data_expected = Data::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/ops/transpose.rs b/burn-tensor/src/tests/ops/transpose.rs index fcd5566aff..83ae67b045 100644 --- a/burn-tensor/src/tests/ops/transpose.rs +++ b/burn-tensor/src/tests/ops/transpose.rs @@ -1,93 +1,97 @@ #[burn_tensor_testgen::testgen(transpose)] mod tests { - use super::*; - use burn_tensor::{Bool, Data, Int, Tensor}; - - #[test] - fn should_support_transpose_ops() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - - let data_actual = tensor.transpose().into_data(); - - let data_expected = Data::from([ - [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], - [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_swap_dims() { - let tensor = TestTensor::from_floats([ - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], - ]); - - let data_actual = tensor.swap_dims(0, 2).into_data(); - - let data_expected = Data::from([ - [[0.0, 6.0], [3.0, 9.0]], - [[1.0, 7.0], [4.0, 10.0]], - [[2.0, 8.0], [5.0, 11.0]], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn should_support_transpose_ops_int() { - let tensor = - Tensor::::from_data([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]); - - let data_actual = tensor.transpose().into_data(); - - let data_expected = Data::from([[[0, 3], [1, 4], [2, 5]], [[6, 9], [7, 10], [8, 11]]]); - assert_eq!(&data_expected, &data_actual); - } - - #[test] - fn should_support_swap_dims_int() { - let tensor = - Tensor::::from_data([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]); - - let data_actual = tensor.swap_dims(0, 2).into_data(); - - let data_expected = Data::from([[[0, 6], [3, 9]], [[1, 7], [4, 10]], [[2, 8], [5, 11]]]); - assert_eq!(&data_expected, &data_actual); - } - - #[test] - fn should_support_transpose_bool() { - let tensor = Tensor::::from_data([ - [[false, true, false], [false, false, false]], - [[false, false, true], [false, false, true]], - ]); - - let data_actual = tensor.transpose().into_data(); - - let data_expected = Data::from([ - [[false, false], [true, false], [false, false]], - [[false, false], [false, false], [true, true]], - ]); - assert_eq!(&data_expected, &data_actual); - } - - #[test] - fn should_support_swap_dims_bool() { - let tensor = Tensor::::from_data([ - [[false, true, false], [false, false, false]], - [[false, false, true], [false, false, true]], - ]); - - let data_actual = tensor.swap_dims(0, 2).into_data(); - - let data_expected = Data::from([ - [[false, false], [false, false]], - [[true, false], [false, false]], - [[false, true], [false, true]], - ]); - assert_eq!(&data_expected, &data_actual); - } + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; + + #[test] + fn should_support_transpose_ops() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + + let data_actual = tensor.transpose().into_data(); + + let data_expected = Data::from([ + [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], + [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_swap_dims() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + + let data_actual = tensor.swap_dims(0, 2).into_data(); + + let data_expected = Data::from([ + [[0.0, 6.0], [3.0, 9.0]], + [[1.0, 7.0], [4.0, 10.0]], + [[2.0, 8.0], [5.0, 11.0]], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn should_support_transpose_ops_int() { + let tensor = Tensor::::from_data([ + [[0, 1, 2], [3, 4, 5]], + [[6, 7, 8], [9, 10, 11]], + ]); + + let data_actual = tensor.transpose().into_data(); + + let data_expected = Data::from([[[0, 3], [1, 4], [2, 5]], [[6, 9], [7, 10], [8, 11]]]); + assert_eq!(&data_expected, &data_actual); + } + + #[test] + fn should_support_swap_dims_int() { + let tensor = Tensor::::from_data([ + [[0, 1, 2], [3, 4, 5]], + [[6, 7, 8], [9, 10, 11]], + ]); + + let data_actual = tensor.swap_dims(0, 2).into_data(); + + let data_expected = Data::from([[[0, 6], [3, 9]], [[1, 7], [4, 10]], [[2, 8], [5, 11]]]); + assert_eq!(&data_expected, &data_actual); + } + + #[test] + fn should_support_transpose_bool() { + let tensor = Tensor::::from_data([ + [[false, true, false], [false, false, false]], + [[false, false, true], [false, false, true]], + ]); + + let data_actual = tensor.transpose().into_data(); + + let data_expected = Data::from([ + [[false, false], [true, false], [false, false]], + [[false, false], [false, false], [true, true]], + ]); + assert_eq!(&data_expected, &data_actual); + } + + #[test] + fn should_support_swap_dims_bool() { + let tensor = Tensor::::from_data([ + [[false, true, false], [false, false, false]], + [[false, false, true], [false, false, true]], + ]); + + let data_actual = tensor.swap_dims(0, 2).into_data(); + + let data_expected = Data::from([ + [[false, false], [false, false]], + [[true, false], [false, false]], + [[false, true], [false, true]], + ]); + assert_eq!(&data_expected, &data_actual); + } } diff --git a/burn-tensor/src/tests/stats/cov.rs b/burn-tensor/src/tests/stats/cov.rs index e93ecd5c29..f6fe8a91c5 100644 --- a/burn-tensor/src/tests/stats/cov.rs +++ b/burn-tensor/src/tests/stats/cov.rs @@ -1,61 +1,61 @@ #[burn_tensor_testgen::testgen(cov)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Tensor}; - - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; - - #[test] - fn test_cov_1() { - let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.cov(1, 1).into_data(); - - let data_expected = Data::from([[2.4892, -1.7333], [-1.7333, 15.3333]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_cov_4() { - let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.cov(1, 0).into_data(); - - let data_expected = Data::from([[1.8668, -1.2999], [-1.2999, 11.5]]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_cov_2() { - let data = Data::from([[0.5, 1.8], [0.2, -2.0], [3.0, -4.0], [5.0, 0.0]]); - let tensor = Tensor::::from_data(data); - - let data_actual = tensor.cov(1, 1).into_data(); - - let data_expected = Data::from([ - [0.845, -1.43, -4.55, -3.25], - [-1.43, 2.42, 7.7, 5.5], - [-4.55, 7.7, 24.5, 17.5], - [-3.25, 5.5, 17.5, 12.5], - ]); - data_expected.assert_approx_eq(&data_actual, 3); - } - - #[test] - fn test_cov_3() { - let data = Data::from([ - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], - ]); - let tensor = Tensor::::from_data(data); - let data_actual = tensor.cov(0, 1).into_data(); - let data_expected = Tensor::::zeros([4, 4, 4]).to_data(); - data_expected.assert_approx_eq(&data_actual, 3); - } + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Tensor}; + + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; + + #[test] + fn test_cov_1() { + let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.cov(1, 1).into_data(); + + let data_expected = Data::from([[2.4892, -1.7333], [-1.7333, 15.3333]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_cov_4() { + let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.cov(1, 0).into_data(); + + let data_expected = Data::from([[1.8668, -1.2999], [-1.2999, 11.5]]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_cov_2() { + let data = Data::from([[0.5, 1.8], [0.2, -2.0], [3.0, -4.0], [5.0, 0.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.cov(1, 1).into_data(); + + let data_expected = Data::from([ + [0.845, -1.43, -4.55, -3.25], + [-1.43, 2.42, 7.7, 5.5], + [-4.55, 7.7, 24.5, 17.5], + [-3.25, 5.5, 17.5, 12.5], + ]); + data_expected.assert_approx_eq(&data_actual, 3); + } + + #[test] + fn test_cov_3() { + let data = Data::from([ + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], + ]); + let tensor = Tensor::::from_data(data); + let data_actual = tensor.cov(0, 1).into_data(); + let data_expected = Tensor::::zeros([4, 4, 4]).to_data(); + data_expected.assert_approx_eq(&data_actual, 3); + } } diff --git a/burn-tensor/src/tests/stats/diagonal.rs b/burn-tensor/src/tests/stats/diagonal.rs index c62f0f45f8..f7fda216ad 100644 --- a/burn-tensor/src/tests/stats/diagonal.rs +++ b/burn-tensor/src/tests/stats/diagonal.rs @@ -1,18 +1,18 @@ #[burn_tensor_testgen::testgen(diagonal)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Tensor}; - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; - #[test] - fn test_diagonal() { - let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]; - let lhs = Tensor::::from_floats(data); - let rhs = Tensor::::diagonal(3); - lhs.to_data().assert_approx_eq(&rhs.to_data(), 3); - } + #[test] + fn test_diagonal() { + let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]; + let lhs = Tensor::::from_floats(data); + let rhs = Tensor::::diagonal(3); + lhs.to_data().assert_approx_eq(&rhs.to_data(), 3); + } } diff --git a/burn-tensor/src/tests/stats/display.rs b/burn-tensor/src/tests/stats/display.rs index 33c8aa3c44..bb4238038f 100644 --- a/burn-tensor/src/tests/stats/display.rs +++ b/burn-tensor/src/tests/stats/display.rs @@ -1,21 +1,21 @@ #[burn_tensor_testgen::testgen(display)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Shape, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Shape, Tensor}; - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; - #[test] - fn test_display_2d_int_tensor() { - let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]); - let tensor_int: burn_tensor::Tensor = - Tensor::from_data(int_data); + #[test] + fn test_display_2d_int_tensor() { + let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]); + let tensor_int: burn_tensor::Tensor = + Tensor::from_data(int_data); - let output = format!("{}", tensor_int); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor_int); + let expected = format!( + r#"Tensor {{ data: [[1, 2, 3], [4, 5, 6], @@ -26,22 +26,22 @@ mod tests { kind: "Int", dtype: "{dtype}", }}"#, - tensor_int.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor_int.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_2d_float_tensor() { - let float_data = Data::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]); - let tensor_float: burn_tensor::Tensor = - Tensor::from_data(float_data); + #[test] + fn test_display_2d_float_tensor() { + let float_data = Data::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]); + let tensor_float: burn_tensor::Tensor = + Tensor::from_data(float_data); - let output = format!("{}", tensor_float); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor_float); + let expected = format!( + r#"Tensor {{ data: [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], @@ -52,26 +52,26 @@ mod tests { kind: "Float", dtype: "{dtype}", }}"#, - tensor_float.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor_float.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_2d_bool_tensor() { - let bool_data = Data::from([ - [true, false, true], - [false, true, false], - [false, true, true], - ]); - let tensor_bool: burn_tensor::Tensor = - Tensor::from_data(bool_data); + #[test] + fn test_display_2d_bool_tensor() { + let bool_data = Data::from([ + [true, false, true], + [false, true, false], + [false, true, true], + ]); + let tensor_bool: burn_tensor::Tensor = + Tensor::from_data(bool_data); - let output = format!("{}", tensor_bool); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor_bool); + let expected = format!( + r#"Tensor {{ data: [[true, false, true], [false, true, false], @@ -82,23 +82,23 @@ mod tests { kind: "Bool", dtype: "bool", }}"#, - tensor_bool.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor_bool.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_3d_tensor() { - let data = Data::from([ - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], - [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], - ]); - let tensor: burn_tensor::Tensor = Tensor::from_data(data); + #[test] + fn test_display_3d_tensor() { + let data = Data::from([ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], + ]); + let tensor: burn_tensor::Tensor = Tensor::from_data(data); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[1, 2, 3, 4], [5, 6, 7, 8], @@ -112,25 +112,25 @@ mod tests { kind: "Int", dtype: "{dtype}", }}"#, - tensor.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_4d_tensor() { - let data = Data::from([ - [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], - [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]], - ]); + #[test] + fn test_display_4d_tensor() { + let data = Data::from([ + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]], + ]); - let tensor: burn_tensor::Tensor = Tensor::from_data(data); + let tensor: burn_tensor::Tensor = Tensor::from_data(data); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[1, 2, 3], [4, 5, 6]], @@ -146,21 +146,21 @@ mod tests { kind: "Int", dtype: "{dtype}", }}"#, - tensor.device(), - TestBackend::name(), - dtype = core::any::type_name::(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + dtype = core::any::type_name::(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_tensor_summarize_1() { - let tensor: burn_tensor::Tensor = - Tensor::zeros(Shape::new([2, 2, 2, 1000])); + #[test] + fn test_display_tensor_summarize_1() { + let tensor: burn_tensor::Tensor = + Tensor::zeros(Shape::new([2, 2, 2, 1000])); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]], @@ -176,20 +176,20 @@ mod tests { kind: "Float", dtype: "f32", }}"#, - tensor.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_tensor_summarize_2() { - let tensor: burn_tensor::Tensor = - Tensor::zeros(Shape::new([2, 2, 20, 100])); + #[test] + fn test_display_tensor_summarize_2() { + let tensor: burn_tensor::Tensor = + Tensor::zeros(Shape::new([2, 2, 20, 100])); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], @@ -225,20 +225,20 @@ mod tests { kind: "Float", dtype: "f32", }}"#, - tensor.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } - #[test] - fn test_display_tensor_summarize_3() { - let tensor: burn_tensor::Tensor = - Tensor::zeros(Shape::new([2, 2, 200, 6])); + #[test] + fn test_display_tensor_summarize_3() { + let tensor: burn_tensor::Tensor = + Tensor::zeros(Shape::new([2, 2, 200, 6])); - let output = format!("{}", tensor); - let expected = format!( - r#"Tensor {{ + let output = format!("{}", tensor); + let expected = format!( + r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], @@ -274,9 +274,9 @@ mod tests { kind: "Float", dtype: "f32", }}"#, - tensor.device(), - TestBackend::name(), - ); - assert_eq!(output, expected); - } + tensor.device(), + TestBackend::name(), + ); + assert_eq!(output, expected); + } } diff --git a/burn-tensor/src/tests/stats/var.rs b/burn-tensor/src/tests/stats/var.rs index dfde1862e8..ac6ccf2f12 100644 --- a/burn-tensor/src/tests/stats/var.rs +++ b/burn-tensor/src/tests/stats/var.rs @@ -1,55 +1,55 @@ #[burn_tensor_testgen::testgen(var)] mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Tensor}; + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Tensor}; - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; + type FloatElem = ::FloatElem; + type IntElem = ::IntElem; - #[test] - fn test_var() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let data_actual = tensor.var(1).into_data(); + let data_actual = tensor.var(1).into_data(); - let data_expected = Data::from([[2.4892], [15.3333]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[2.4892], [15.3333]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn test_var_mean() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var_mean() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let (var, mean) = tensor.var_mean(1); + let (var, mean) = tensor.var_mean(1); - let var_expected = Data::from([[2.4892], [15.3333]]); - let mean_expected = Data::from([[0.125], [1.]]); + let var_expected = Data::from([[2.4892], [15.3333]]); + let mean_expected = Data::from([[0.125], [1.]]); - var_expected.assert_approx_eq(&(var.into_data()), 3); - mean_expected.assert_approx_eq(&(mean.into_data()), 3); - } + var_expected.assert_approx_eq(&(var.into_data()), 3); + mean_expected.assert_approx_eq(&(mean.into_data()), 3); + } - #[test] - fn test_var_bias() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var_bias() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let data_actual = tensor.var_bias(1).into_data(); + let data_actual = tensor.var_bias(1).into_data(); - let data_expected = Data::from([[1.86688], [11.5]]); - data_expected.assert_approx_eq(&data_actual, 3); - } + let data_expected = Data::from([[1.86688], [11.5]]); + data_expected.assert_approx_eq(&data_actual, 3); + } - #[test] - fn test_var_mean_bias() { - let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); + #[test] + fn test_var_mean_bias() { + let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let (var, mean) = tensor.var_mean_bias(1); + let (var, mean) = tensor.var_mean_bias(1); - let var_expected = Data::from([[1.86688], [11.5]]); - let mean_expected = Data::from([[0.125], [1.]]); + let var_expected = Data::from([[1.86688], [11.5]]); + let mean_expected = Data::from([[0.125], [1.]]); - var_expected.assert_approx_eq(&(var.into_data()), 3); - mean_expected.assert_approx_eq(&(mean.into_data()), 3); - } + var_expected.assert_approx_eq(&(var.into_data()), 3); + mean_expected.assert_approx_eq(&(mean.into_data()), 3); + } } diff --git a/burn-train/src/checkpoint/async_checkpoint.rs b/burn-train/src/checkpoint/async_checkpoint.rs index 805f1dc014..83ec9fcfff 100644 --- a/burn-train/src/checkpoint/async_checkpoint.rs +++ b/burn-train/src/checkpoint/async_checkpoint.rs @@ -3,122 +3,118 @@ use burn_core::record::Record; use std::sync::mpsc; enum Message { - Restore(usize, mpsc::SyncSender>), - Save(usize, R), - Delete(usize), - End, + Restore(usize, mpsc::SyncSender>), + Save(usize, R), + Delete(usize), + End, } #[derive(new)] struct CheckpointerThread { - checkpointer: C, - receiver: mpsc::Receiver>, + checkpointer: C, + receiver: mpsc::Receiver>, } impl, R: Record> CheckpointerThread { - fn run(self) { - for item in self.receiver.iter() { - match item { - Message::Restore(epoch, callback) => { - let record = self.checkpointer.restore(epoch); - callback - .send(record) - .expect("Can send response through callback channel."); + fn run(self) { + for item in self.receiver.iter() { + match item { + Message::Restore(epoch, callback) => { + let record = self.checkpointer.restore(epoch); + callback + .send(record) + .expect("Can send response through callback channel."); + } + Message::Save(epoch, state) => self + .checkpointer + .save(epoch, state) + .expect("Can save the state."), + Message::Delete(epoch) => self + .checkpointer + .delete(epoch) + .expect("Can delete the state."), + Message::End => { + return; + } + }; } - Message::Save(epoch, state) => self - .checkpointer - .save(epoch, state) - .expect("Can save the state."), - Message::Delete(epoch) => self - .checkpointer - .delete(epoch) - .expect("Can delete the state."), - Message::End => { - return; - } - }; } - } } /// Async checkpointer. pub struct AsyncCheckpointer { - sender: mpsc::SyncSender>, - handler: Option>, + sender: mpsc::SyncSender>, + handler: Option>, } impl AsyncCheckpointer { - /// Create a new async checkpointer. - /// - /// # Arguments - /// - /// * `checkpointer` - The checkpointer. - /// - /// # Returns - /// - /// The async checkpointer. - pub fn new(checkpointer: C) -> Self - where - C: Checkpointer + Send + 'static, - { - // Only on checkpoint can be done in advance. - let (sender, receiver) = mpsc::sync_channel(0); - let thread = CheckpointerThread::new(checkpointer, receiver); - let handler = Some(std::thread::spawn(move || thread.run())); + /// Create a new async checkpointer. + /// + /// # Arguments + /// + /// * `checkpointer` - The checkpointer. + /// + /// # Returns + /// + /// The async checkpointer. + pub fn new(checkpointer: C) -> Self + where + C: Checkpointer + Send + 'static, + { + // Only on checkpoint can be done in advance. + let (sender, receiver) = mpsc::sync_channel(0); + let thread = CheckpointerThread::new(checkpointer, receiver); + let handler = Some(std::thread::spawn(move || thread.run())); - Self { sender, handler } - } + Self { sender, handler } + } } impl Checkpointer for AsyncCheckpointer where - R: Record + 'static, + R: Record + 'static, { - fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { - self - .sender - .send(Message::Save(epoch, record)) - .expect("Can send message to checkpointer thread."); + fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { + self.sender + .send(Message::Save(epoch, record)) + .expect("Can send message to checkpointer thread."); - Ok(()) - } + Ok(()) + } - fn restore(&self, epoch: usize) -> Result { - let (sender, receiver) = mpsc::sync_channel(1); - self - .sender - .send(Message::Restore(epoch, sender)) - .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; + fn restore(&self, epoch: usize) -> Result { + let (sender, receiver) = mpsc::sync_channel(1); + self.sender + .send(Message::Restore(epoch, sender)) + .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; - if let Ok(record) = receiver.recv() { - return record; - }; + if let Ok(record) = receiver.recv() { + return record; + }; - Err(CheckpointerError::Unknown("Channel error.".to_string())) - } + Err(CheckpointerError::Unknown("Channel error.".to_string())) + } - fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { - self - .sender - .send(Message::Delete(epoch)) - .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; + fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { + self.sender + .send(Message::Delete(epoch)) + .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; - Ok(()) - } + Ok(()) + } } impl Drop for AsyncCheckpointer { - fn drop(&mut self) { - self - .sender - .send(Message::End) - .expect("Can send the end message to the checkpointer thread."); - let handler = self.handler.take(); + fn drop(&mut self) { + self.sender + .send(Message::End) + .expect("Can send the end message to the checkpointer thread."); + let handler = self.handler.take(); - if let Some(handler) = handler { - handler - .join() - .expect("The checkpointer thread should stop."); + if let Some(handler) = handler { + handler + .join() + .expect("The checkpointer thread should stop."); + } } - } } diff --git a/burn-train/src/checkpoint/base.rs b/burn-train/src/checkpoint/base.rs index 2104db82fb..61a2dca986 100644 --- a/burn-train/src/checkpoint/base.rs +++ b/burn-train/src/checkpoint/base.rs @@ -3,37 +3,37 @@ use burn_core::record::{Record, RecorderError}; /// The error type for checkpointer. #[derive(Debug)] pub enum CheckpointerError { - /// IO error. - IOError(std::io::Error), + /// IO error. + IOError(std::io::Error), - /// Recorder error. - RecorderError(RecorderError), + /// Recorder error. + RecorderError(RecorderError), - /// Other errors. - Unknown(String), + /// Other errors. + Unknown(String), } /// The trait for checkpointer. pub trait Checkpointer { - /// Save the record. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - /// * `record` - The record. - fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>; + /// Save the record. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + /// * `record` - The record. + fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>; - /// Delete the record at the given epoch if present. - fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>; + /// Delete the record at the given epoch if present. + fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>; - /// Restore the record. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - /// - /// # Returns - /// - /// The record. - fn restore(&self, epoch: usize) -> Result; + /// Restore the record. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + /// + /// # Returns + /// + /// The record. + fn restore(&self, epoch: usize) -> Result; } diff --git a/burn-train/src/checkpoint/file.rs b/burn-train/src/checkpoint/file.rs index afc5e88542..24c9c717f3 100644 --- a/burn-train/src/checkpoint/file.rs +++ b/burn-train/src/checkpoint/file.rs @@ -3,69 +3,68 @@ use burn_core::record::{FileRecorder, Record}; /// The file checkpointer. pub struct FileCheckpointer { - directory: String, - name: String, - recorder: FR, + directory: String, + name: String, + recorder: FR, } impl FileCheckpointer { - /// Creates a new file checkpointer. - /// - /// # Arguments - /// - /// * `recorder` - The file recorder. - /// * `directory` - The directory to save the checkpoints. - /// * `name` - The name of the checkpoint. - pub fn new(recorder: FR, directory: &str, name: &str) -> Self { - std::fs::create_dir_all(directory).ok(); + /// Creates a new file checkpointer. + /// + /// # Arguments + /// + /// * `recorder` - The file recorder. + /// * `directory` - The directory to save the checkpoints. + /// * `name` - The name of the checkpoint. + pub fn new(recorder: FR, directory: &str, name: &str) -> Self { + std::fs::create_dir_all(directory).ok(); - Self { - directory: directory.to_string(), - name: name.to_string(), - recorder, + Self { + directory: directory.to_string(), + name: name.to_string(), + recorder, + } + } + fn path_for_epoch(&self, epoch: usize) -> String { + format!("{}/{}-{}", self.directory, self.name, epoch) } - } - fn path_for_epoch(&self, epoch: usize) -> String { - format!("{}/{}-{}", self.directory, self.name, epoch) - } } impl Checkpointer for FileCheckpointer where - R: Record, - FR: FileRecorder, + R: Record, + FR: FileRecorder, { - fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { - let file_path = self.path_for_epoch(epoch); - log::info!("Saving checkpoint {} to {}", epoch, file_path); + fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { + let file_path = self.path_for_epoch(epoch); + log::info!("Saving checkpoint {} to {}", epoch, file_path); - self - .recorder - .record(record, file_path.into()) - .map_err(CheckpointerError::RecorderError)?; + self.recorder + .record(record, file_path.into()) + .map_err(CheckpointerError::RecorderError)?; - Ok(()) - } + Ok(()) + } - fn restore(&self, epoch: usize) -> Result { - let file_path = self.path_for_epoch(epoch); - log::info!("Restoring checkpoint {} from {}", epoch, file_path); - let record = self - .recorder - .load(file_path.into()) - .map_err(CheckpointerError::RecorderError)?; + fn restore(&self, epoch: usize) -> Result { + let file_path = self.path_for_epoch(epoch); + log::info!("Restoring checkpoint {} from {}", epoch, file_path); + let record = self + .recorder + .load(file_path.into()) + .map_err(CheckpointerError::RecorderError)?; - Ok(record) - } + Ok(record) + } - fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { - let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),); + fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { + let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),); - if std::path::Path::new(&file_to_remove).exists() { - log::info!("Removing checkpoint {}", file_to_remove); - std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?; - } + if std::path::Path::new(&file_to_remove).exists() { + log::info!("Removing checkpoint {}", file_to_remove); + std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?; + } - Ok(()) - } + Ok(()) + } } diff --git a/burn-train/src/checkpoint/strategy/base.rs b/burn-train/src/checkpoint/strategy/base.rs index 7d4131718d..f16acfeb41 100644 --- a/burn-train/src/checkpoint/strategy/base.rs +++ b/burn-train/src/checkpoint/strategy/base.rs @@ -5,30 +5,30 @@ use crate::metric::store::EventStoreClient; /// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer). #[derive(Clone, PartialEq, Debug)] pub enum CheckpointingAction { - /// Delete the given epoch. - Delete(usize), - /// Save the current record. - Save, + /// Delete the given epoch. + Delete(usize), + /// Save the current record. + Save, } /// Define when checkpoint should be saved and deleted. pub trait CheckpointingStrategy { - /// Based on the epoch, determine if the checkpoint should be saved. - fn checkpointing( - &mut self, - epoch: usize, - collector: &EventStoreClient, - ) -> Vec; + /// Based on the epoch, determine if the checkpoint should be saved. + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec; } // We make dyn box implement the checkpointing strategy so that it can be used with generic, but // still be dynamic. impl CheckpointingStrategy for Box { - fn checkpointing( - &mut self, - epoch: usize, - collector: &EventStoreClient, - ) -> Vec { - self.deref_mut().checkpointing(epoch, collector) - } + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec { + self.deref_mut().checkpointing(epoch, collector) + } } diff --git a/burn-train/src/checkpoint/strategy/composed.rs b/burn-train/src/checkpoint/strategy/composed.rs index 38d8354aad..8029c9ed78 100644 --- a/burn-train/src/checkpoint/strategy/composed.rs +++ b/burn-train/src/checkpoint/strategy/composed.rs @@ -6,144 +6,141 @@ use std::collections::HashSet; /// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an /// epoch to be deleted. pub struct ComposedCheckpointingStrategy { - strategies: Vec>, - deleted: Vec>, + strategies: Vec>, + deleted: Vec>, } /// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones. #[derive(Default)] pub struct ComposedCheckpointingStrategyBuilder { - strategies: Vec>, + strategies: Vec>, } impl ComposedCheckpointingStrategyBuilder { - /// Add a new [checkpointing strategy](CheckpointingStrategy). - #[allow(clippy::should_implement_trait)] - pub fn add(mut self, strategy: S) -> Self - where - S: CheckpointingStrategy + 'static, - { - self.strategies.push(Box::new(strategy)); - self - } - - /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy). - pub fn build(self) -> ComposedCheckpointingStrategy { - ComposedCheckpointingStrategy::new(self.strategies) - } -} + /// Add a new [checkpointing strategy](CheckpointingStrategy). + #[allow(clippy::should_implement_trait)] + pub fn add(mut self, strategy: S) -> Self + where + S: CheckpointingStrategy + 'static, + { + self.strategies.push(Box::new(strategy)); + self + } -impl ComposedCheckpointingStrategy { - fn new(strategies: Vec>) -> Self { - Self { - deleted: strategies.iter().map(|_| HashSet::new()).collect(), - strategies, + /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy). + pub fn build(self) -> ComposedCheckpointingStrategy { + ComposedCheckpointingStrategy::new(self.strategies) } - } - /// Create a new builder which help compose multiple - /// [checkpointing strategies](CheckpointingStrategy). - pub fn builder() -> ComposedCheckpointingStrategyBuilder { - ComposedCheckpointingStrategyBuilder::default() - } } -impl CheckpointingStrategy for ComposedCheckpointingStrategy { - fn checkpointing( - &mut self, - epoch: usize, - collector: &EventStoreClient, - ) -> Vec { - let mut saved = false; - let mut actions = Vec::new(); - let mut epochs_to_check = Vec::new(); - - for (i, strategy) in self.strategies.iter_mut().enumerate() { - let actions = strategy.checkpointing(epoch, collector); - // We assume that the strategy would not want the current epoch to be saved. - // So we flag it as deleted. - if actions.is_empty() { - self - .deleted - .get_mut(i) - .expect("As many 'deleted' as 'strategies'.") - .insert(epoch); - } - - for action in actions { - match action { - CheckpointingAction::Delete(epoch) => { - self - .deleted - .get_mut(i) - .expect("As many 'deleted' as 'strategies'.") - .insert(epoch); - epochs_to_check.push(epoch); - } - CheckpointingAction::Save => saved = true, +impl ComposedCheckpointingStrategy { + fn new(strategies: Vec>) -> Self { + Self { + deleted: strategies.iter().map(|_| HashSet::new()).collect(), + strategies, } - } } - - if saved { - actions.push(CheckpointingAction::Save); + /// Create a new builder which help compose multiple + /// [checkpointing strategies](CheckpointingStrategy). + pub fn builder() -> ComposedCheckpointingStrategyBuilder { + ComposedCheckpointingStrategyBuilder::default() } +} - for epoch in epochs_to_check.into_iter() { - let mut num_true = 0; - for i in 0..self.strategies.len() { - if self - .deleted - .get(i) - .expect("Ad many 'deleted' as 'strategies'.") - .contains(&epoch) - { - num_true += 1; +impl CheckpointingStrategy for ComposedCheckpointingStrategy { + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec { + let mut saved = false; + let mut actions = Vec::new(); + let mut epochs_to_check = Vec::new(); + + for (i, strategy) in self.strategies.iter_mut().enumerate() { + let actions = strategy.checkpointing(epoch, collector); + // We assume that the strategy would not want the current epoch to be saved. + // So we flag it as deleted. + if actions.is_empty() { + self.deleted + .get_mut(i) + .expect("As many 'deleted' as 'strategies'.") + .insert(epoch); + } + + for action in actions { + match action { + CheckpointingAction::Delete(epoch) => { + self.deleted + .get_mut(i) + .expect("As many 'deleted' as 'strategies'.") + .insert(epoch); + epochs_to_check.push(epoch); + } + CheckpointingAction::Save => saved = true, + } + } } - } - if num_true == self.strategies.len() { - actions.push(CheckpointingAction::Delete(epoch)); + if saved { + actions.push(CheckpointingAction::Save); + } - for i in 0..self.strategies.len() { - self - .deleted - .get_mut(i) - .expect("As many 'deleted' as 'strategies'.") - .remove(&epoch); + for epoch in epochs_to_check.into_iter() { + let mut num_true = 0; + for i in 0..self.strategies.len() { + if self + .deleted + .get(i) + .expect("Ad many 'deleted' as 'strategies'.") + .contains(&epoch) + { + num_true += 1; + } + } + + if num_true == self.strategies.len() { + actions.push(CheckpointingAction::Delete(epoch)); + + for i in 0..self.strategies.len() { + self.deleted + .get_mut(i) + .expect("As many 'deleted' as 'strategies'.") + .remove(&epoch); + } + } } - } - } - actions - } + actions + } } #[cfg(test)] mod tests { - use super::*; - use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore}; - - #[test] - fn should_delete_when_both_deletes() { - let store = EventStoreClient::new(LogEventStore::default()); - let mut strategy = ComposedCheckpointingStrategy::builder() - .add(KeepLastNCheckpoints::new(1)) - .add(KeepLastNCheckpoints::new(2)) - .build(); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(1, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(2, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], - strategy.checkpointing(3, &store) - ); - } + use super::*; + use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore}; + + #[test] + fn should_delete_when_both_deletes() { + let store = EventStoreClient::new(LogEventStore::default()); + let mut strategy = ComposedCheckpointingStrategy::builder() + .add(KeepLastNCheckpoints::new(1)) + .add(KeepLastNCheckpoints::new(2)) + .build(); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(1, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(2, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], + strategy.checkpointing(3, &store) + ); + } } diff --git a/burn-train/src/checkpoint/strategy/lastn.rs b/burn-train/src/checkpoint/strategy/lastn.rs index f810f0dc8d..66f5df91bf 100644 --- a/burn-train/src/checkpoint/strategy/lastn.rs +++ b/burn-train/src/checkpoint/strategy/lastn.rs @@ -7,46 +7,50 @@ use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient}; /// resumed even if something goes wrong. #[derive(new)] pub struct KeepLastNCheckpoints { - num_keep: usize, + num_keep: usize, } impl CheckpointingStrategy for KeepLastNCheckpoints { - fn checkpointing(&mut self, epoch: usize, _store: &EventStoreClient) -> Vec { - let mut actions = vec![CheckpointingAction::Save]; - - if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) { - if epoch > 0 { - actions.push(CheckpointingAction::Delete(epoch)); - } + fn checkpointing( + &mut self, + epoch: usize, + _store: &EventStoreClient, + ) -> Vec { + let mut actions = vec![CheckpointingAction::Save]; + + if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) { + if epoch > 0 { + actions.push(CheckpointingAction::Delete(epoch)); + } + } + + actions } - - actions - } } #[cfg(test)] mod tests { - use super::*; - use crate::metric::store::LogEventStore; - - #[test] - fn should_always_delete_lastn_epoch_if_higher_than_one() { - let mut strategy = KeepLastNCheckpoints::new(2); - let store = EventStoreClient::new(LogEventStore::default()); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(1, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(2, &store) - ); - - assert_eq!( - vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], - strategy.checkpointing(3, &store) - ); - } + use super::*; + use crate::metric::store::LogEventStore; + + #[test] + fn should_always_delete_lastn_epoch_if_higher_than_one() { + let mut strategy = KeepLastNCheckpoints::new(2); + let store = EventStoreClient::new(LogEventStore::default()); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(1, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(2, &store) + ); + + assert_eq!( + vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], + strategy.checkpointing(3, &store) + ); + } } diff --git a/burn-train/src/checkpoint/strategy/metric.rs b/burn-train/src/checkpoint/strategy/metric.rs index a746241066..f2aa58efeb 100644 --- a/burn-train/src/checkpoint/strategy/metric.rs +++ b/burn-train/src/checkpoint/strategy/metric.rs @@ -1,129 +1,133 @@ use super::CheckpointingStrategy; use crate::{ - checkpoint::CheckpointingAction, - metric::{ - store::{Aggregate, Direction, EventStoreClient, Split}, - Metric, - }, + checkpoint::CheckpointingAction, + metric::{ + store::{Aggregate, Direction, EventStoreClient, Split}, + Metric, + }, }; /// Keep the best checkpoint based on a metric. pub struct MetricCheckpointingStrategy { - current: Option, - aggregate: Aggregate, - direction: Direction, - split: Split, - name: String, + current: Option, + aggregate: Aggregate, + direction: Direction, + split: Split, + name: String, } impl MetricCheckpointingStrategy { - /// Create a new metric strategy. - pub fn new(aggregate: Aggregate, direction: Direction, split: Split) -> Self - where - M: Metric, - { - Self { - current: None, - name: M::NAME.to_string(), - aggregate, - direction, - split, + /// Create a new metric strategy. + pub fn new(aggregate: Aggregate, direction: Direction, split: Split) -> Self + where + M: Metric, + { + Self { + current: None, + name: M::NAME.to_string(), + aggregate, + direction, + split, + } } - } } impl CheckpointingStrategy for MetricCheckpointingStrategy { - fn checkpointing(&mut self, epoch: usize, store: &EventStoreClient) -> Vec { - let best_epoch = match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) - { - Some(epoch_best) => epoch_best, - None => epoch, - }; - - let mut actions = Vec::new(); - - if let Some(current) = self.current { - if current != best_epoch { - actions.push(CheckpointingAction::Delete(current)); - } + fn checkpointing( + &mut self, + epoch: usize, + store: &EventStoreClient, + ) -> Vec { + let best_epoch = + match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) { + Some(epoch_best) => epoch_best, + None => epoch, + }; + + let mut actions = Vec::new(); + + if let Some(current) = self.current { + if current != best_epoch { + actions.push(CheckpointingAction::Delete(current)); + } + } + + if best_epoch == epoch { + actions.push(CheckpointingAction::Save); + } + + self.current = Some(best_epoch); + + actions } - - if best_epoch == epoch { - actions.push(CheckpointingAction::Save); - } - - self.current = Some(best_epoch); - - actions - } } #[cfg(test)] mod tests { - use crate::{ - logger::InMemoryMetricLogger, - metric::{ - processor::{ - test_utils::{end_epoch, process_train}, - Metrics, MinimalEventProcessor, - }, - store::LogEventStore, - LossMetric, - }, - TestBackend, - }; - use std::sync::Arc; - - use super::*; - - #[test] - fn always_keep_the_best_epoch() { - let mut store = LogEventStore::default(); - let mut strategy = MetricCheckpointingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Train, - ); - let mut metrics = Metrics::::default(); - // Register an in memory logger. - store.register_logger_train(InMemoryMetricLogger::default()); - // Register the loss metric. - metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Arc::new(EventStoreClient::new(store)); - let mut processor = MinimalEventProcessor::new(metrics, store.clone()); - - // Two points for the first epoch. Mean 0.75 - let mut epoch = 1; - process_train(&mut processor, 1.0, epoch); - process_train(&mut processor, 0.5, epoch); - end_epoch(&mut processor, epoch); - - // Should save the current record. - assert_eq!( - vec![CheckpointingAction::Save], - strategy.checkpointing(epoch, &store) - ); - - // Two points for the second epoch. Mean 0.4 - epoch += 1; - process_train(&mut processor, 0.5, epoch); - process_train(&mut processor, 0.3, epoch); - end_epoch(&mut processor, epoch); - - // Should save the current record and delete the pervious one. - assert_eq!( - vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], - strategy.checkpointing(epoch, &store) - ); - - // Two points for the last epoch. Mean 2.0 - epoch += 1; - process_train(&mut processor, 1.0, epoch); - process_train(&mut processor, 3.0, epoch); - end_epoch(&mut processor, epoch); - - // Should not delete the previous record, since it's the best one, and should not save a - // new one. - assert!(strategy.checkpointing(epoch, &store).is_empty()); - } + use crate::{ + logger::InMemoryMetricLogger, + metric::{ + processor::{ + test_utils::{end_epoch, process_train}, + Metrics, MinimalEventProcessor, + }, + store::LogEventStore, + LossMetric, + }, + TestBackend, + }; + use std::sync::Arc; + + use super::*; + + #[test] + fn always_keep_the_best_epoch() { + let mut store = LogEventStore::default(); + let mut strategy = MetricCheckpointingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Train, + ); + let mut metrics = Metrics::::default(); + // Register an in memory logger. + store.register_logger_train(InMemoryMetricLogger::default()); + // Register the loss metric. + metrics.register_train_metric_numeric(LossMetric::::new()); + let store = Arc::new(EventStoreClient::new(store)); + let mut processor = MinimalEventProcessor::new(metrics, store.clone()); + + // Two points for the first epoch. Mean 0.75 + let mut epoch = 1; + process_train(&mut processor, 1.0, epoch); + process_train(&mut processor, 0.5, epoch); + end_epoch(&mut processor, epoch); + + // Should save the current record. + assert_eq!( + vec![CheckpointingAction::Save], + strategy.checkpointing(epoch, &store) + ); + + // Two points for the second epoch. Mean 0.4 + epoch += 1; + process_train(&mut processor, 0.5, epoch); + process_train(&mut processor, 0.3, epoch); + end_epoch(&mut processor, epoch); + + // Should save the current record and delete the pervious one. + assert_eq!( + vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], + strategy.checkpointing(epoch, &store) + ); + + // Two points for the last epoch. Mean 2.0 + epoch += 1; + process_train(&mut processor, 1.0, epoch); + process_train(&mut processor, 3.0, epoch); + end_epoch(&mut processor, epoch); + + // Should not delete the previous record, since it's the best one, and should not save a + // new one. + assert!(strategy.checkpointing(epoch, &store).is_empty()); + } } diff --git a/burn-train/src/components.rs b/burn-train/src/components.rs index 16cbe9ad80..3eddb93b0f 100644 --- a/burn-train/src/components.rs +++ b/burn-train/src/components.rs @@ -1,71 +1,71 @@ use crate::{ - checkpoint::{Checkpointer, CheckpointingStrategy}, - metric::processor::EventProcessor, + checkpoint::{Checkpointer, CheckpointingStrategy}, + metric::processor::EventProcessor, }; use burn_core::{ - lr_scheduler::LrScheduler, - module::{AutodiffModule, Module}, - optim::Optimizer, - tensor::backend::AutodiffBackend, + lr_scheduler::LrScheduler, + module::{AutodiffModule, Module}, + optim::Optimizer, + tensor::backend::AutodiffBackend, }; use std::marker::PhantomData; /// All components necessary to train a model grouped in one trait. pub trait LearnerComponents { - /// The backend in used for the training. - type Backend: AutodiffBackend; - /// The learning rate scheduler used for the training. - type LrScheduler: LrScheduler; - /// The model to train. - type Model: AutodiffModule + core::fmt::Display + 'static; - /// The optimizer used for the training. - type Optimizer: Optimizer; - /// The checkpointer used for the model. - type CheckpointerModel: Checkpointer<>::Record>; - /// The checkpointer used for the optimizer. - type CheckpointerOptimizer: Checkpointer< - >::Record, - >; - /// The checkpointer used for the scheduler. - type CheckpointerLrScheduler: Checkpointer<::Record>; - type EventProcessor: EventProcessor + 'static; - /// The strategy to save and delete checkpoints. - type CheckpointerStrategy: CheckpointingStrategy; + /// The backend in used for the training. + type Backend: AutodiffBackend; + /// The learning rate scheduler used for the training. + type LrScheduler: LrScheduler; + /// The model to train. + type Model: AutodiffModule + core::fmt::Display + 'static; + /// The optimizer used for the training. + type Optimizer: Optimizer; + /// The checkpointer used for the model. + type CheckpointerModel: Checkpointer<>::Record>; + /// The checkpointer used for the optimizer. + type CheckpointerOptimizer: Checkpointer< + >::Record, + >; + /// The checkpointer used for the scheduler. + type CheckpointerLrScheduler: Checkpointer<::Record>; + type EventProcessor: EventProcessor + 'static; + /// The strategy to save and delete checkpoints. + type CheckpointerStrategy: CheckpointingStrategy; } /// Concrete type that implements [training components trait](TrainingComponents). pub struct LearnerComponentsMarker { - _backend: PhantomData, - _lr_scheduler: PhantomData, - _model: PhantomData, - _optimizer: PhantomData, - _checkpointer_model: PhantomData, - _checkpointer_optim: PhantomData, - _checkpointer_scheduler: PhantomData, - _event_processor: PhantomData, - _strategy: S, + _backend: PhantomData, + _lr_scheduler: PhantomData, + _model: PhantomData, + _optimizer: PhantomData, + _checkpointer_model: PhantomData, + _checkpointer_optim: PhantomData, + _checkpointer_scheduler: PhantomData, + _event_processor: PhantomData, + _strategy: S, } impl LearnerComponents - for LearnerComponentsMarker + for LearnerComponentsMarker where - B: AutodiffBackend, - LR: LrScheduler, - M: AutodiffModule + core::fmt::Display + 'static, - O: Optimizer, - CM: Checkpointer, - CO: Checkpointer, - CS: Checkpointer, - EP: EventProcessor + 'static, - S: CheckpointingStrategy, + B: AutodiffBackend, + LR: LrScheduler, + M: AutodiffModule + core::fmt::Display + 'static, + O: Optimizer, + CM: Checkpointer, + CO: Checkpointer, + CS: Checkpointer, + EP: EventProcessor + 'static, + S: CheckpointingStrategy, { - type Backend = B; - type LrScheduler = LR; - type Model = M; - type Optimizer = O; - type CheckpointerModel = CM; - type CheckpointerOptimizer = CO; - type CheckpointerLrScheduler = CS; - type EventProcessor = EP; - type CheckpointerStrategy = S; + type Backend = B; + type LrScheduler = LR; + type Model = M; + type Optimizer = O; + type CheckpointerModel = CM; + type CheckpointerOptimizer = CO; + type CheckpointerLrScheduler = CS; + type EventProcessor = EP; + type CheckpointerStrategy = S; } diff --git a/burn-train/src/learner/base.rs b/burn-train/src/learner/base.rs index ea6f0ff245..55a5515ef7 100644 --- a/burn-train/src/learner/base.rs +++ b/burn-train/src/learner/base.rs @@ -13,121 +13,115 @@ use std::sync::Arc; /// /// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct. pub struct Learner { - pub(crate) model: LC::Model, - pub(crate) optim: LC::Optimizer, - pub(crate) lr_scheduler: LC::LrScheduler, - pub(crate) num_epochs: usize, - pub(crate) checkpoint: Option, - pub(crate) grad_accumulation: Option, - pub(crate) checkpointer: Option>, - pub(crate) devices: Vec<::Device>, - pub(crate) interrupter: TrainingInterrupter, - pub(crate) early_stopping: Option>, - pub(crate) event_processor: LC::EventProcessor, - pub(crate) event_store: Arc, + pub(crate) model: LC::Model, + pub(crate) optim: LC::Optimizer, + pub(crate) lr_scheduler: LC::LrScheduler, + pub(crate) num_epochs: usize, + pub(crate) checkpoint: Option, + pub(crate) grad_accumulation: Option, + pub(crate) checkpointer: Option>, + pub(crate) devices: Vec<::Device>, + pub(crate) interrupter: TrainingInterrupter, + pub(crate) early_stopping: Option>, + pub(crate) event_processor: LC::EventProcessor, + pub(crate) event_store: Arc, } #[derive(new)] pub(crate) struct LearnerCheckpointer { - model: LC::CheckpointerModel, - optim: LC::CheckpointerOptimizer, - lr_scheduler: LC::CheckpointerLrScheduler, - strategy: LC::CheckpointerStrategy, + model: LC::CheckpointerModel, + optim: LC::CheckpointerOptimizer, + lr_scheduler: LC::CheckpointerLrScheduler, + strategy: LC::CheckpointerStrategy, } impl LearnerCheckpointer { - pub(crate) fn checkpoint( - &mut self, - model: &LC::Model, - optim: &LC::Optimizer, - scheduler: &LC::LrScheduler, - epoch: usize, - store: &EventStoreClient, - ) { - let actions = self.strategy.checkpointing(epoch, store); + pub(crate) fn checkpoint( + &mut self, + model: &LC::Model, + optim: &LC::Optimizer, + scheduler: &LC::LrScheduler, + epoch: usize, + store: &EventStoreClient, + ) { + let actions = self.strategy.checkpointing(epoch, store); - for action in actions { - match action { - CheckpointingAction::Delete(epoch) => { - self - .model - .delete(epoch) - .expect("Can delete model checkpoint."); - self - .optim - .delete(epoch) - .expect("Can delete optimizer checkpoint."); - self - .lr_scheduler - .delete(epoch) - .expect("Can delete learning rate scheduler checkpoint."); - } - CheckpointingAction::Save => { - self - .model - .save(epoch, model.clone().into_record()) - .expect("Can save model checkpoint."); - self - .optim - .save(epoch, optim.to_record()) - .expect("Can save optimizer checkpoint."); - self - .lr_scheduler - .save(epoch, scheduler.to_record()) - .expect("Can save learning rate scheduler checkpoint."); + for action in actions { + match action { + CheckpointingAction::Delete(epoch) => { + self.model + .delete(epoch) + .expect("Can delete model checkpoint."); + self.optim + .delete(epoch) + .expect("Can delete optimizer checkpoint."); + self.lr_scheduler + .delete(epoch) + .expect("Can delete learning rate scheduler checkpoint."); + } + CheckpointingAction::Save => { + self.model + .save(epoch, model.clone().into_record()) + .expect("Can save model checkpoint."); + self.optim + .save(epoch, optim.to_record()) + .expect("Can save optimizer checkpoint."); + self.lr_scheduler + .save(epoch, scheduler.to_record()) + .expect("Can save learning rate scheduler checkpoint."); + } + } } - } } - } - pub(crate) fn load_checkpoint( - &self, - model: LC::Model, - optim: LC::Optimizer, - scheduler: LC::LrScheduler, - epoch: usize, - ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) { - let record = self - .model - .restore(epoch) - .expect("Can load model checkpoint."); - let model = model.load_record(record); + pub(crate) fn load_checkpoint( + &self, + model: LC::Model, + optim: LC::Optimizer, + scheduler: LC::LrScheduler, + epoch: usize, + ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) { + let record = self + .model + .restore(epoch) + .expect("Can load model checkpoint."); + let model = model.load_record(record); - let record = self - .optim - .restore(epoch) - .expect("Can load optimizer checkpoint."); - let optim = optim.load_record(record); + let record = self + .optim + .restore(epoch) + .expect("Can load optimizer checkpoint."); + let optim = optim.load_record(record); - let record = self - .lr_scheduler - .restore(epoch) - .expect("Can load learning rate scheduler checkpoint."); - let scheduler = scheduler.load_record(record); + let record = self + .lr_scheduler + .restore(epoch) + .expect("Can load learning rate scheduler checkpoint."); + let scheduler = scheduler.load_record(record); - (model, optim, scheduler) - } + (model, optim, scheduler) + } } #[derive(Clone, Default)] /// A handle that allows aborting the training process early. pub struct TrainingInterrupter { - state: Arc, + state: Arc, } impl TrainingInterrupter { - /// Create a new instance. - pub fn new() -> Self { - Self::default() - } + /// Create a new instance. + pub fn new() -> Self { + Self::default() + } - /// Notify the learner that it should stop. - pub fn stop(&self) { - self.state.store(true, Ordering::Relaxed); - } + /// Notify the learner that it should stop. + pub fn stop(&self) { + self.state.store(true, Ordering::Relaxed); + } - /// True if .stop() has been called. - pub fn should_stop(&self) -> bool { - self.state.load(Ordering::Relaxed) - } + /// True if .stop() has been called. + pub fn should_stop(&self) -> bool { + self.state.load(Ordering::Relaxed) + } } diff --git a/burn-train/src/learner/builder.rs b/burn-train/src/learner/builder.rs index b78760eae7..99e6d90bea 100644 --- a/burn-train/src/learner/builder.rs +++ b/burn-train/src/learner/builder.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use super::log::install_file_logger; use super::Learner; use crate::checkpoint::{ - AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, - KeepLastNCheckpoints, MetricCheckpointingStrategy, + AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, + KeepLastNCheckpoints, MetricCheckpointingStrategy, }; use crate::components::LearnerComponentsMarker; use crate::learner::base::TrainingInterrupter; @@ -24,317 +24,319 @@ use burn_core::tensor::backend::AutodiffBackend; /// Struct to configure and create a [learner](Learner). pub struct LearnerBuilder where - T: Send + Sync + 'static, - V: Send + Sync + 'static, - B: AutodiffBackend, - M: AutodiffModule, - O: Optimizer, - S: LrScheduler, + T: Send + Sync + 'static, + V: Send + Sync + 'static, + B: AutodiffBackend, + M: AutodiffModule, + O: Optimizer, + S: LrScheduler, { - // Not that complex and very convenient when the traits are - // already constrained correctly. Extracting in another type - // would be more complex. - #[allow(clippy::type_complexity)] - checkpointers: Option<( - AsyncCheckpointer, - AsyncCheckpointer, - AsyncCheckpointer, - )>, - num_epochs: usize, - checkpoint: Option, - directory: String, - grad_accumulation: Option, - devices: Vec, - renderer: Option>, - metrics: Metrics, - event_store: LogEventStore, - interrupter: TrainingInterrupter, - log_to_file: bool, - num_loggers: usize, - checkpointer_strategy: Box, - early_stopping: Option>, + // Not that complex and very convenient when the traits are + // already constrained correctly. Extracting in another type + // would be more complex. + #[allow(clippy::type_complexity)] + checkpointers: Option<( + AsyncCheckpointer, + AsyncCheckpointer, + AsyncCheckpointer, + )>, + num_epochs: usize, + checkpoint: Option, + directory: String, + grad_accumulation: Option, + devices: Vec, + renderer: Option>, + metrics: Metrics, + event_store: LogEventStore, + interrupter: TrainingInterrupter, + log_to_file: bool, + num_loggers: usize, + checkpointer_strategy: Box, + early_stopping: Option>, } impl LearnerBuilder where - B: AutodiffBackend, - T: Send + Sync + 'static, - V: Send + Sync + 'static, - M: AutodiffModule + core::fmt::Display + 'static, - O: Optimizer, - S: LrScheduler, + B: AutodiffBackend, + T: Send + Sync + 'static, + V: Send + Sync + 'static, + M: AutodiffModule + core::fmt::Display + 'static, + O: Optimizer, + S: LrScheduler, { - /// Creates a new learner builder. - /// - /// # Arguments - /// - /// * `directory` - The directory to save the checkpoints. - pub fn new(directory: &str) -> Self { - Self { - num_epochs: 1, - checkpoint: None, - checkpointers: None, - directory: directory.to_string(), - grad_accumulation: None, - devices: vec![B::Device::default()], - metrics: Metrics::default(), - event_store: LogEventStore::default(), - renderer: None, - interrupter: TrainingInterrupter::new(), - log_to_file: true, - num_loggers: 0, - checkpointer_strategy: Box::new( - ComposedCheckpointingStrategy::builder() - .add(KeepLastNCheckpoints::new(2)) - .add(MetricCheckpointingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Valid, - )) - .build(), - ), - early_stopping: None, + /// Creates a new learner builder. + /// + /// # Arguments + /// + /// * `directory` - The directory to save the checkpoints. + pub fn new(directory: &str) -> Self { + Self { + num_epochs: 1, + checkpoint: None, + checkpointers: None, + directory: directory.to_string(), + grad_accumulation: None, + devices: vec![B::Device::default()], + metrics: Metrics::default(), + event_store: LogEventStore::default(), + renderer: None, + interrupter: TrainingInterrupter::new(), + log_to_file: true, + num_loggers: 0, + checkpointer_strategy: Box::new( + ComposedCheckpointingStrategy::builder() + .add(KeepLastNCheckpoints::new(2)) + .add(MetricCheckpointingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + )) + .build(), + ), + early_stopping: None, + } } - } - /// Replace the default metric loggers with the provided ones. - /// - /// # Arguments - /// - /// * `logger_train` - The training logger. - /// * `logger_valid` - The validation logger. - pub fn metric_loggers(mut self, logger_train: MT, logger_valid: MV) -> Self - where - MT: MetricLogger + 'static, - MV: MetricLogger + 'static, - { - self.event_store.register_logger_train(logger_train); - self.event_store.register_logger_valid(logger_valid); - self.num_loggers += 1; - self - } - - /// Update the checkpointing_strategy. - pub fn with_checkpointing_strategy(&mut self, strategy: CS) - where - CS: CheckpointingStrategy + 'static, - { - self.checkpointer_strategy = Box::new(strategy); - } + /// Replace the default metric loggers with the provided ones. + /// + /// # Arguments + /// + /// * `logger_train` - The training logger. + /// * `logger_valid` - The validation logger. + pub fn metric_loggers(mut self, logger_train: MT, logger_valid: MV) -> Self + where + MT: MetricLogger + 'static, + MV: MetricLogger + 'static, + { + self.event_store.register_logger_train(logger_train); + self.event_store.register_logger_valid(logger_valid); + self.num_loggers += 1; + self + } - /// Replace the default CLI renderer with a custom one. - /// - /// # Arguments - /// - /// * `renderer` - The custom renderer. - pub fn renderer(mut self, renderer: MR) -> Self - where - MR: MetricsRenderer + 'static, - { - self.renderer = Some(Box::new(renderer)); - self - } + /// Update the checkpointing_strategy. + pub fn with_checkpointing_strategy(&mut self, strategy: CS) + where + CS: CheckpointingStrategy + 'static, + { + self.checkpointer_strategy = Box::new(strategy); + } - /// Register a training metric. - pub fn metric_train(mut self, metric: Me) -> Self - where - T: Adaptor, - { - self.metrics.register_metric_train(metric); - self - } + /// Replace the default CLI renderer with a custom one. + /// + /// # Arguments + /// + /// * `renderer` - The custom renderer. + pub fn renderer(mut self, renderer: MR) -> Self + where + MR: MetricsRenderer + 'static, + { + self.renderer = Some(Box::new(renderer)); + self + } - /// Register a validation metric. - pub fn metric_valid(mut self, metric: Me) -> Self - where - V: Adaptor, - { - self.metrics.register_valid_metric(metric); - self - } + /// Register a training metric. + pub fn metric_train(mut self, metric: Me) -> Self + where + T: Adaptor, + { + self.metrics.register_metric_train(metric); + self + } - /// Enable gradients accumulation. - /// - /// # Notes - /// - /// When you enable gradients accumulation, the gradients object used by the optimizer will be - /// the sum of all gradients generated by each backward pass. It might be a good idea to - /// reduce the learning to compensate. - /// - /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation` - /// amount. - pub fn grads_accumulation(mut self, accumulation: usize) -> Self { - self.grad_accumulation = Some(accumulation); - self - } + /// Register a validation metric. + pub fn metric_valid(mut self, metric: Me) -> Self + where + V: Adaptor, + { + self.metrics.register_valid_metric(metric); + self + } - /// Register a [numeric](crate::metric::Numeric) training [metric](Metric). - pub fn metric_train_numeric(mut self, metric: Me) -> Self - where - Me: Metric + crate::metric::Numeric + 'static, - T: Adaptor, - { - self.metrics.register_train_metric_numeric(metric); - self - } + /// Enable gradients accumulation. + /// + /// # Notes + /// + /// When you enable gradients accumulation, the gradients object used by the optimizer will be + /// the sum of all gradients generated by each backward pass. It might be a good idea to + /// reduce the learning to compensate. + /// + /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation` + /// amount. + pub fn grads_accumulation(mut self, accumulation: usize) -> Self { + self.grad_accumulation = Some(accumulation); + self + } - /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric). - pub fn metric_valid_numeric( - mut self, - metric: Me, - ) -> Self - where - V: Adaptor, - { - self.metrics.register_valid_metric_numeric(metric); - self - } + /// Register a [numeric](crate::metric::Numeric) training [metric](Metric). + pub fn metric_train_numeric(mut self, metric: Me) -> Self + where + Me: Metric + crate::metric::Numeric + 'static, + T: Adaptor, + { + self.metrics.register_train_metric_numeric(metric); + self + } - /// The number of epochs the training should last. - pub fn num_epochs(mut self, num_epochs: usize) -> Self { - self.num_epochs = num_epochs; - self - } + /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric). + pub fn metric_valid_numeric( + mut self, + metric: Me, + ) -> Self + where + V: Adaptor, + { + self.metrics.register_valid_metric_numeric(metric); + self + } - /// Run the training loop on multiple devices. - pub fn devices(mut self, devices: Vec) -> Self { - self.devices = devices; - self - } + /// The number of epochs the training should last. + pub fn num_epochs(mut self, num_epochs: usize) -> Self { + self.num_epochs = num_epochs; + self + } - /// The epoch from which the training must resume. - pub fn checkpoint(mut self, checkpoint: usize) -> Self { - self.checkpoint = Some(checkpoint); - self - } + /// Run the training loop on multiple devices. + pub fn devices(mut self, devices: Vec) -> Self { + self.devices = devices; + self + } - /// Provides a handle that can be used to interrupt training. - pub fn interrupter(&self) -> TrainingInterrupter { - self.interrupter.clone() - } + /// The epoch from which the training must resume. + pub fn checkpoint(mut self, checkpoint: usize) -> Self { + self.checkpoint = Some(checkpoint); + self + } - /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the - /// conditions are meet. - pub fn early_stopping(mut self, strategy: Strategy) -> Self - where - Strategy: EarlyStoppingStrategy + 'static, - { - self.early_stopping = Some(Box::new(strategy)); - self - } + /// Provides a handle that can be used to interrupt training. + pub fn interrupter(&self) -> TrainingInterrupter { + self.interrupter.clone() + } - /// By default, Rust logs are captured and written into - /// `experiment.log`. If disabled, standard Rust log handling - /// will apply. - pub fn log_to_file(mut self, enabled: bool) -> Self { - self.log_to_file = enabled; - self - } + /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the + /// conditions are meet. + pub fn early_stopping(mut self, strategy: Strategy) -> Self + where + Strategy: EarlyStoppingStrategy + 'static, + { + self.early_stopping = Some(Box::new(strategy)); + self + } - /// Register a checkpointer that will save the [optimizer](Optimizer), the - /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files. - pub fn with_file_checkpointer(mut self, recorder: FR) -> Self - where - FR: FileRecorder + 'static, - O::Record: 'static, - M::Record: 'static, - S::Record: 'static, - { - let checkpointer_model = FileCheckpointer::new( - recorder.clone(), - format!("{}/checkpoint", self.directory).as_str(), - "model", - ); - let checkpointer_optimizer = FileCheckpointer::new( - recorder.clone(), - format!("{}/checkpoint", self.directory).as_str(), - "optim", - ); - let checkpointer_scheduler = FileCheckpointer::new( - recorder, - format!("{}/checkpoint", self.directory).as_str(), - "scheduler", - ); + /// By default, Rust logs are captured and written into + /// `experiment.log`. If disabled, standard Rust log handling + /// will apply. + pub fn log_to_file(mut self, enabled: bool) -> Self { + self.log_to_file = enabled; + self + } - self.checkpointers = Some(( - AsyncCheckpointer::new(checkpointer_model), - AsyncCheckpointer::new(checkpointer_optimizer), - AsyncCheckpointer::new(checkpointer_scheduler), - )); + /// Register a checkpointer that will save the [optimizer](Optimizer), the + /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files. + pub fn with_file_checkpointer(mut self, recorder: FR) -> Self + where + FR: FileRecorder + 'static, + O::Record: 'static, + M::Record: 'static, + S::Record: 'static, + { + let checkpointer_model = FileCheckpointer::new( + recorder.clone(), + format!("{}/checkpoint", self.directory).as_str(), + "model", + ); + let checkpointer_optimizer = FileCheckpointer::new( + recorder.clone(), + format!("{}/checkpoint", self.directory).as_str(), + "optim", + ); + let checkpointer_scheduler = FileCheckpointer::new( + recorder, + format!("{}/checkpoint", self.directory).as_str(), + "scheduler", + ); - self - } + self.checkpointers = Some(( + AsyncCheckpointer::new(checkpointer_model), + AsyncCheckpointer::new(checkpointer_optimizer), + AsyncCheckpointer::new(checkpointer_scheduler), + )); - /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer). - /// The [learning rate scheduler](LrScheduler) can also be a simple - /// [learning rate](burn_core::LearningRate). - #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and - // creates a clean learner. - pub fn build( - mut self, - model: M, - optim: O, - lr_scheduler: S, - ) -> Learner< - LearnerComponentsMarker< - B, - S, - M, - O, - AsyncCheckpointer, - AsyncCheckpointer, - AsyncCheckpointer, - FullEventProcessor, - Box, - >, - > - where - M::Record: 'static, - O::Record: 'static, - S::Record: 'static, - { - if self.log_to_file { - self.init_logger(); + self } - let renderer = self - .renderer - .unwrap_or_else(|| Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))); - let directory = &self.directory; - if self.num_loggers == 0 { - self - .event_store - .register_logger_train(FileMetricLogger::new(format!("{directory}/train").as_str())); - self - .event_store - .register_logger_valid(FileMetricLogger::new(format!("{directory}/valid").as_str())); - } + /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer). + /// The [learning rate scheduler](LrScheduler) can also be a simple + /// [learning rate](burn_core::LearningRate). + #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and + // creates a clean learner. + pub fn build( + mut self, + model: M, + optim: O, + lr_scheduler: S, + ) -> Learner< + LearnerComponentsMarker< + B, + S, + M, + O, + AsyncCheckpointer, + AsyncCheckpointer, + AsyncCheckpointer, + FullEventProcessor, + Box, + >, + > + where + M::Record: 'static, + O::Record: 'static, + S::Record: 'static, + { + if self.log_to_file { + self.init_logger(); + } + let renderer = self.renderer.unwrap_or_else(|| { + Box::new(default_renderer(self.interrupter.clone(), self.checkpoint)) + }); + let directory = &self.directory; - let event_store = Arc::new(EventStoreClient::new(self.event_store)); - let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); + if self.num_loggers == 0 { + self.event_store + .register_logger_train(FileMetricLogger::new( + format!("{directory}/train").as_str(), + )); + self.event_store + .register_logger_valid(FileMetricLogger::new( + format!("{directory}/valid").as_str(), + )); + } - let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { - LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy) - }); + let event_store = Arc::new(EventStoreClient::new(self.event_store)); + let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); - Learner { - model, - optim, - lr_scheduler, - checkpointer, - num_epochs: self.num_epochs, - event_processor, - event_store, - checkpoint: self.checkpoint, - grad_accumulation: self.grad_accumulation, - devices: self.devices, - interrupter: self.interrupter, - early_stopping: self.early_stopping, + let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { + LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy) + }); + + Learner { + model, + optim, + lr_scheduler, + checkpointer, + num_epochs: self.num_epochs, + event_processor, + event_store, + checkpoint: self.checkpoint, + grad_accumulation: self.grad_accumulation, + devices: self.devices, + interrupter: self.interrupter, + early_stopping: self.early_stopping, + } } - } - fn init_logger(&self) { - let file_path = format!("{}/experiment.log", self.directory); - install_file_logger(file_path.as_str()); - } + fn init_logger(&self) { + let file_path = format!("{}/experiment.log", self.directory); + install_file_logger(file_path.as_str()); + } } diff --git a/burn-train/src/learner/classification.rs b/burn-train/src/learner/classification.rs index ecba5bed3e..f6b415fa29 100644 --- a/burn-train/src/learner/classification.rs +++ b/burn-train/src/learner/classification.rs @@ -5,24 +5,24 @@ use burn_core::tensor::{Int, Tensor}; /// Simple classification output adapted for multiple metrics. #[derive(new)] pub struct ClassificationOutput { - /// The loss. - pub loss: Tensor, + /// The loss. + pub loss: Tensor, - /// The output. - pub output: Tensor, + /// The output. + pub output: Tensor, - /// The targets. - pub targets: Tensor, + /// The targets. + pub targets: Tensor, } impl Adaptor> for ClassificationOutput { - fn adapt(&self) -> AccuracyInput { - AccuracyInput::new(self.output.clone(), self.targets.clone()) - } + fn adapt(&self) -> AccuracyInput { + AccuracyInput::new(self.output.clone(), self.targets.clone()) + } } impl Adaptor> for ClassificationOutput { - fn adapt(&self) -> LossInput { - LossInput::new(self.loss.clone()) - } + fn adapt(&self) -> LossInput { + LossInput::new(self.loss.clone()) + } } diff --git a/burn-train/src/learner/early_stopping.rs b/burn-train/src/learner/early_stopping.rs index f8c0f0c5a2..641d49551b 100644 --- a/burn-train/src/learner/early_stopping.rs +++ b/burn-train/src/learner/early_stopping.rs @@ -1,209 +1,209 @@ use crate::metric::{ - store::{Aggregate, Direction, EventStoreClient, Split}, - Metric, + store::{Aggregate, Direction, EventStoreClient, Split}, + Metric, }; /// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow. pub enum StoppingCondition { - /// When no improvement has happened since the given number of epochs. - NoImprovementSince { - /// The number of epochs allowed to worsen before it gets better. - n_epochs: usize, - }, + /// When no improvement has happened since the given number of epochs. + NoImprovementSince { + /// The number of epochs allowed to worsen before it gets better. + n_epochs: usize, + }, } /// A strategy that checks if the training should be stopped. pub trait EarlyStoppingStrategy { - /// Update its current state and returns if the training should be stopped. - fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool; + /// Update its current state and returns if the training should be stopped. + fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool; } /// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected /// during training or validation. pub struct MetricEarlyStoppingStrategy { - condition: StoppingCondition, - metric_name: String, - aggregate: Aggregate, - direction: Direction, - split: Split, - best_epoch: usize, - best_value: f64, + condition: StoppingCondition, + metric_name: String, + aggregate: Aggregate, + direction: Direction, + split: Split, + best_epoch: usize, + best_value: f64, } impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy { - fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool { - let current_value = - match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) { - Some(value) => value, - None => { - log::warn!("Can't find metric for early stopping."); - return false; + fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool { + let current_value = + match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) { + Some(value) => value, + None => { + log::warn!("Can't find metric for early stopping."); + return false; + } + }; + + let is_best = match self.direction { + Direction::Lowest => current_value < self.best_value, + Direction::Highest => current_value > self.best_value, + }; + + if is_best { + log::info!( + "New best epoch found {} {}: {}", + epoch, + self.metric_name, + current_value + ); + self.best_value = current_value; + self.best_epoch = epoch; + return false; } - }; - let is_best = match self.direction { - Direction::Lowest => current_value < self.best_value, - Direction::Highest => current_value > self.best_value, - }; + match self.condition { + StoppingCondition::NoImprovementSince { n_epochs } => { + let should_stop = epoch - self.best_epoch >= n_epochs; - if is_best { - log::info!( - "New best epoch found {} {}: {}", - epoch, - self.metric_name, - current_value - ); - self.best_value = current_value; - self.best_epoch = epoch; - return false; - } + if should_stop { + log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value); + } - match self.condition { - StoppingCondition::NoImprovementSince { n_epochs } => { - let should_stop = epoch - self.best_epoch >= n_epochs; - - if should_stop { - log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value); + should_stop + } } - - should_stop - } } - } } impl MetricEarlyStoppingStrategy { - /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected - /// during training or validation. - /// - /// # Notes - /// - /// The metric should be registered for early stopping to work, otherwise no data is collected. - pub fn new( - aggregate: Aggregate, - direction: Direction, - split: Split, - condition: StoppingCondition, - ) -> Self { - let init_value = match direction { - Direction::Lowest => f64::MAX, - Direction::Highest => f64::MIN, - }; - - Self { - metric_name: Me::NAME.to_string(), - condition, - aggregate, - direction, - split, - best_epoch: 1, - best_value: init_value, + /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected + /// during training or validation. + /// + /// # Notes + /// + /// The metric should be registered for early stopping to work, otherwise no data is collected. + pub fn new( + aggregate: Aggregate, + direction: Direction, + split: Split, + condition: StoppingCondition, + ) -> Self { + let init_value = match direction { + Direction::Lowest => f64::MAX, + Direction::Highest => f64::MIN, + }; + + Self { + metric_name: Me::NAME.to_string(), + condition, + aggregate, + direction, + split, + best_epoch: 1, + best_value: init_value, + } } - } } #[cfg(test)] mod tests { - use std::sync::Arc; - - use crate::{ - logger::InMemoryMetricLogger, - metric::{ - processor::{ - test_utils::{end_epoch, process_train}, - Metrics, MinimalEventProcessor, - }, - store::LogEventStore, - LossMetric, - }, - TestBackend, - }; - - use super::*; - - #[test] - fn never_early_stop_while_it_is_improving() { - test_early_stopping( - 1, - &[ - (&[0.5, 0.3], false, "Should not stop first epoch"), - (&[0.4, 0.3], false, "Should not stop when improving"), - (&[0.3, 0.3], false, "Should not stop when improving"), - (&[0.2, 0.3], false, "Should not stop when improving"), - ], - ); - } - - #[test] - fn early_stop_when_no_improvement_since_two_epochs() { - test_early_stopping( - 2, - &[ - (&[1.0, 0.5], false, "Should not stop first epoch"), - (&[0.5, 0.3], false, "Should not stop when improving"), - ( - &[1.0, 3.0], - false, - "Should not stop first time it gets worse", - ), - ( - &[1.0, 2.0], - true, - "Should stop since two following epochs didn't improve", - ), - ], - ); - } - - #[test] - fn early_stop_when_stays_equal() { - test_early_stopping( - 2, - &[ - (&[0.5, 0.3], false, "Should not stop first epoch"), - ( - &[0.5, 0.3], - false, - "Should not stop first time it stars the same", - ), - ( - &[0.5, 0.3], - true, - "Should stop since two following epochs didn't improve", - ), - ], - ); - } - - fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) { - let mut early_stopping = MetricEarlyStoppingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Train, - StoppingCondition::NoImprovementSince { n_epochs }, - ); - let mut store = LogEventStore::default(); - let mut metrics = Metrics::::default(); - - store.register_logger_train(InMemoryMetricLogger::default()); - metrics.register_train_metric_numeric(LossMetric::::new()); - - let store = Arc::new(EventStoreClient::new(store)); - let mut processor = MinimalEventProcessor::new(metrics, store.clone()); - - let mut epoch = 1; - for (points, should_start, comment) in data { - for point in points.iter() { - process_train(&mut processor, *point, epoch); - } - end_epoch(&mut processor, epoch); - - assert_eq!( - *should_start, - early_stopping.should_stop(epoch, &store), - "{comment}" - ); - epoch += 1; + use std::sync::Arc; + + use crate::{ + logger::InMemoryMetricLogger, + metric::{ + processor::{ + test_utils::{end_epoch, process_train}, + Metrics, MinimalEventProcessor, + }, + store::LogEventStore, + LossMetric, + }, + TestBackend, + }; + + use super::*; + + #[test] + fn never_early_stop_while_it_is_improving() { + test_early_stopping( + 1, + &[ + (&[0.5, 0.3], false, "Should not stop first epoch"), + (&[0.4, 0.3], false, "Should not stop when improving"), + (&[0.3, 0.3], false, "Should not stop when improving"), + (&[0.2, 0.3], false, "Should not stop when improving"), + ], + ); + } + + #[test] + fn early_stop_when_no_improvement_since_two_epochs() { + test_early_stopping( + 2, + &[ + (&[1.0, 0.5], false, "Should not stop first epoch"), + (&[0.5, 0.3], false, "Should not stop when improving"), + ( + &[1.0, 3.0], + false, + "Should not stop first time it gets worse", + ), + ( + &[1.0, 2.0], + true, + "Should stop since two following epochs didn't improve", + ), + ], + ); + } + + #[test] + fn early_stop_when_stays_equal() { + test_early_stopping( + 2, + &[ + (&[0.5, 0.3], false, "Should not stop first epoch"), + ( + &[0.5, 0.3], + false, + "Should not stop first time it stars the same", + ), + ( + &[0.5, 0.3], + true, + "Should stop since two following epochs didn't improve", + ), + ], + ); + } + + fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) { + let mut early_stopping = MetricEarlyStoppingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Train, + StoppingCondition::NoImprovementSince { n_epochs }, + ); + let mut store = LogEventStore::default(); + let mut metrics = Metrics::::default(); + + store.register_logger_train(InMemoryMetricLogger::default()); + metrics.register_train_metric_numeric(LossMetric::::new()); + + let store = Arc::new(EventStoreClient::new(store)); + let mut processor = MinimalEventProcessor::new(metrics, store.clone()); + + let mut epoch = 1; + for (points, should_start, comment) in data { + for point in points.iter() { + process_train(&mut processor, *point, epoch); + } + end_epoch(&mut processor, epoch); + + assert_eq!( + *should_start, + early_stopping.should_stop(epoch, &store), + "{comment}" + ); + epoch += 1; + } } - } } diff --git a/burn-train/src/learner/epoch.rs b/burn-train/src/learner/epoch.rs index eca2ff0b49..06eba7b7d9 100644 --- a/burn-train/src/learner/epoch.rs +++ b/burn-train/src/learner/epoch.rs @@ -1,6 +1,6 @@ use burn_core::{ - data::dataloader::DataLoader, lr_scheduler::LrScheduler, module::AutodiffModule, - optim::GradientsAccumulator, tensor::backend::Backend, + data::dataloader::DataLoader, lr_scheduler::LrScheduler, module::AutodiffModule, + optim::GradientsAccumulator, tensor::backend::Backend, }; use std::sync::Arc; @@ -11,237 +11,237 @@ use crate::{MultiDevicesTrainStep, TrainStep, ValidStep}; /// A validation epoch. #[derive(new)] pub struct ValidEpoch { - dataloader: Arc>, - epoch: usize, - epoch_total: usize, + dataloader: Arc>, + epoch: usize, + epoch_total: usize, } /// A training epoch. #[derive(new)] pub struct TrainEpoch { - dataloader: Arc>, - epoch: usize, - epoch_total: usize, - grad_accumulation: Option, + dataloader: Arc>, + epoch: usize, + epoch_total: usize, + grad_accumulation: Option, } impl ValidEpoch { - /// Runs the validation epoch. - /// - /// # Arguments - /// - /// * `model` - The model to validate. - /// * `processor` - The event processor to use. - pub fn run( - &self, - model: &LC::Model, - processor: &mut LC::EventProcessor, - interrupter: &TrainingInterrupter, - ) where - LC::EventProcessor: EventProcessor, - >::InnerModule: ValidStep, - { - log::info!("Executing validation step for epoch {}", self.epoch); - let model = model.valid(); - - let mut iterator = self.dataloader.iter(); - let mut iteration = 0; - - while let Some(item) = iterator.next() { - let progress = iterator.progress(); - iteration += 1; - - let item = model.step(item); - let item = LearnerItem::new( - item, - progress, - self.epoch, - self.epoch_total, - iteration, - None, - ); - - processor.process_valid(Event::ProcessedItem(item)); - - if interrupter.should_stop() { - log::info!("Training interrupted."); - break; - } + /// Runs the validation epoch. + /// + /// # Arguments + /// + /// * `model` - The model to validate. + /// * `processor` - The event processor to use. + pub fn run( + &self, + model: &LC::Model, + processor: &mut LC::EventProcessor, + interrupter: &TrainingInterrupter, + ) where + LC::EventProcessor: EventProcessor, + >::InnerModule: ValidStep, + { + log::info!("Executing validation step for epoch {}", self.epoch); + let model = model.valid(); + + let mut iterator = self.dataloader.iter(); + let mut iteration = 0; + + while let Some(item) = iterator.next() { + let progress = iterator.progress(); + iteration += 1; + + let item = model.step(item); + let item = LearnerItem::new( + item, + progress, + self.epoch, + self.epoch_total, + iteration, + None, + ); + + processor.process_valid(Event::ProcessedItem(item)); + + if interrupter.should_stop() { + log::info!("Training interrupted."); + break; + } + } + processor.process_valid(Event::EndEpoch(self.epoch)); } - processor.process_valid(Event::EndEpoch(self.epoch)); - } } impl TrainEpoch { - /// Runs the training epoch. - /// - /// # Arguments - /// - /// * `model` - The model to train. - /// * `optim` - The optimizer to use. - /// * `scheduler` - The learning rate scheduler to use. - /// * `processor` - The event processor to use. - /// - /// # Returns - /// - /// The trained model and the optimizer. - pub fn run( - &self, - mut model: LC::Model, - mut optim: LC::Optimizer, - scheduler: &mut LC::LrScheduler, - processor: &mut LC::EventProcessor, - interrupter: &TrainingInterrupter, - ) -> (LC::Model, LC::Optimizer) - where - LC::EventProcessor: EventProcessor, - LC::Model: TrainStep, - { - log::info!("Executing training step for epoch {}", self.epoch,); - - let mut iterator = self.dataloader.iter(); - let mut iteration = 0; - let mut accumulator = GradientsAccumulator::new(); - let mut accumulation_current = 0; - - while let Some(item) = iterator.next() { - iteration += 1; - let lr = scheduler.step(); - log::info!("Iteration {}", iteration); - - let progress = iterator.progress(); - let item = model.step(item); - - match self.grad_accumulation { - Some(accumulation) => { - accumulator.accumulate(&model, item.grads); - accumulation_current += 1; - - if accumulation <= accumulation_current { - let grads = accumulator.grads(); - model = model.optimize(&mut optim, lr, grads); - accumulation_current = 0; - } + /// Runs the training epoch. + /// + /// # Arguments + /// + /// * `model` - The model to train. + /// * `optim` - The optimizer to use. + /// * `scheduler` - The learning rate scheduler to use. + /// * `processor` - The event processor to use. + /// + /// # Returns + /// + /// The trained model and the optimizer. + pub fn run( + &self, + mut model: LC::Model, + mut optim: LC::Optimizer, + scheduler: &mut LC::LrScheduler, + processor: &mut LC::EventProcessor, + interrupter: &TrainingInterrupter, + ) -> (LC::Model, LC::Optimizer) + where + LC::EventProcessor: EventProcessor, + LC::Model: TrainStep, + { + log::info!("Executing training step for epoch {}", self.epoch,); + + let mut iterator = self.dataloader.iter(); + let mut iteration = 0; + let mut accumulator = GradientsAccumulator::new(); + let mut accumulation_current = 0; + + while let Some(item) = iterator.next() { + iteration += 1; + let lr = scheduler.step(); + log::info!("Iteration {}", iteration); + + let progress = iterator.progress(); + let item = model.step(item); + + match self.grad_accumulation { + Some(accumulation) => { + accumulator.accumulate(&model, item.grads); + accumulation_current += 1; + + if accumulation <= accumulation_current { + let grads = accumulator.grads(); + model = model.optimize(&mut optim, lr, grads); + accumulation_current = 0; + } + } + None => model = model.optimize(&mut optim, lr, item.grads), + } + + let item = LearnerItem::new( + item.item, + progress, + self.epoch, + self.epoch_total, + iteration, + Some(lr), + ); + + processor.process_train(Event::ProcessedItem(item)); + + if interrupter.should_stop() { + log::info!("Training interrupted."); + break; + } } - None => model = model.optimize(&mut optim, lr, item.grads), - } - - let item = LearnerItem::new( - item.item, - progress, - self.epoch, - self.epoch_total, - iteration, - Some(lr), - ); - - processor.process_train(Event::ProcessedItem(item)); - - if interrupter.should_stop() { - log::info!("Training interrupted."); - break; - } - } - processor.process_train(Event::EndEpoch(self.epoch)); + processor.process_train(Event::EndEpoch(self.epoch)); - (model, optim) - } + (model, optim) + } } impl TrainEpoch { - /// Runs the training epoch on multiple devices. - /// - /// # Arguments - /// - /// * `model` - The model to train. - /// * `optim` - The optimizer to use. - /// * `lr_scheduler` - The learning rate scheduler to use. - /// * `processor` - The event processor to use. - /// * `devices` - The devices to use. - /// - /// # Returns - /// - /// The trained model and the optimizer. - pub fn run_multi_device( - &self, - mut model: LC::Model, - mut optim: LC::Optimizer, - lr_scheduler: &mut LC::LrScheduler, - processor: &mut LC::EventProcessor, - devices: Vec<::Device>, - interrupter: &TrainingInterrupter, - ) -> (LC::Model, LC::Optimizer) - where - LC::EventProcessor: EventProcessor, - LC::Model: TrainStep, - TO: Send + 'static, - TI: Send + 'static, - { - log::info!( - "Executing training step for epoch {} on devices {:?}", - self.epoch, - devices - ); - - let mut iterator = self.dataloader.iter(); - let mut iteration = 0; - let mut accumulator = GradientsAccumulator::new(); - let mut accumulation_current = 0; - - let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len(); - let step = MultiDevicesTrainStep::new(&devices); - - // The main device is always the first in the list. - let device_main = devices.get(0).expect("A minimum of one device.").clone(); - let mut interrupted = false; - - loop { - let items = step.step(&mut iterator, &model); - if items.is_empty() { - break; - } - - for item in items { - iteration += 1; - let lr = lr_scheduler.step(); - let progress = iterator.progress(); - - let grads = item.grads.to_device(&device_main, &model); - - accumulator.accumulate(&model, grads); - accumulation_current += 1; - - if accumulation <= accumulation_current { - let grads = accumulator.grads(); - model = model.optimize(&mut optim, lr, grads); - accumulation_current = 0; - } - - let item = LearnerItem::new( - item.item, - progress, - self.epoch, - self.epoch_total, - iteration, - Some(lr), + /// Runs the training epoch on multiple devices. + /// + /// # Arguments + /// + /// * `model` - The model to train. + /// * `optim` - The optimizer to use. + /// * `lr_scheduler` - The learning rate scheduler to use. + /// * `processor` - The event processor to use. + /// * `devices` - The devices to use. + /// + /// # Returns + /// + /// The trained model and the optimizer. + pub fn run_multi_device( + &self, + mut model: LC::Model, + mut optim: LC::Optimizer, + lr_scheduler: &mut LC::LrScheduler, + processor: &mut LC::EventProcessor, + devices: Vec<::Device>, + interrupter: &TrainingInterrupter, + ) -> (LC::Model, LC::Optimizer) + where + LC::EventProcessor: EventProcessor, + LC::Model: TrainStep, + TO: Send + 'static, + TI: Send + 'static, + { + log::info!( + "Executing training step for epoch {} on devices {:?}", + self.epoch, + devices ); - processor.process_train(Event::ProcessedItem(item)); - - if interrupter.should_stop() { - log::info!("Training interrupted."); - interrupted = true; - break; + let mut iterator = self.dataloader.iter(); + let mut iteration = 0; + let mut accumulator = GradientsAccumulator::new(); + let mut accumulation_current = 0; + + let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len(); + let step = MultiDevicesTrainStep::new(&devices); + + // The main device is always the first in the list. + let device_main = devices.get(0).expect("A minimum of one device.").clone(); + let mut interrupted = false; + + loop { + let items = step.step(&mut iterator, &model); + if items.is_empty() { + break; + } + + for item in items { + iteration += 1; + let lr = lr_scheduler.step(); + let progress = iterator.progress(); + + let grads = item.grads.to_device(&device_main, &model); + + accumulator.accumulate(&model, grads); + accumulation_current += 1; + + if accumulation <= accumulation_current { + let grads = accumulator.grads(); + model = model.optimize(&mut optim, lr, grads); + accumulation_current = 0; + } + + let item = LearnerItem::new( + item.item, + progress, + self.epoch, + self.epoch_total, + iteration, + Some(lr), + ); + + processor.process_train(Event::ProcessedItem(item)); + + if interrupter.should_stop() { + log::info!("Training interrupted."); + interrupted = true; + break; + } + } + + if interrupted { + break; + } } - } - - if interrupted { - break; - } - } - processor.process_train(Event::EndEpoch(self.epoch)); + processor.process_train(Event::EndEpoch(self.epoch)); - (model, optim) - } + (model, optim) + } } diff --git a/burn-train/src/learner/log.rs b/burn-train/src/learner/log.rs index e8e1025a02..35162cfc73 100644 --- a/burn-train/src/learner/log.rs +++ b/burn-train/src/learner/log.rs @@ -7,41 +7,40 @@ use tracing_subscriber::{registry, Layer}; /// If a global tracing subscriber is not already configured, set up logging to a file, /// and add our custom panic hook. pub(crate) fn install_file_logger(file_path: &str) { - let path = Path::new(file_path); - let writer = tracing_appender::rolling::never( - path.parent().unwrap_or_else(|| Path::new(".")), - path - .file_name() - .unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")), - ); - let layer = tracing_subscriber::fmt::layer() - .with_ansi(false) - .with_writer(writer) - .with_filter(LevelFilter::INFO) - .with_filter(filter_fn(|m| { - if let Some(path) = m.module_path() { - // The wgpu crate is logging too much, so we skip `info` level. - if path.starts_with("wgpu") && *m.level() >= Level::INFO { - return false; - } - } - true - })); + let path = Path::new(file_path); + let writer = tracing_appender::rolling::never( + path.parent().unwrap_or_else(|| Path::new(".")), + path.file_name() + .unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")), + ); + let layer = tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_writer(writer) + .with_filter(LevelFilter::INFO) + .with_filter(filter_fn(|m| { + if let Some(path) = m.module_path() { + // The wgpu crate is logging too much, so we skip `info` level. + if path.starts_with("wgpu") && *m.level() >= Level::INFO { + return false; + } + } + true + })); - if registry().with(layer).try_init().is_ok() { - update_panic_hook(file_path); - } + if registry().with(layer).try_init().is_ok() { + update_panic_hook(file_path); + } } fn update_panic_hook(file_path: &str) { - let hook = std::panic::take_hook(); - let file_path = file_path.to_owned(); + let hook = std::panic::take_hook(); + let file_path = file_path.to_owned(); - std::panic::set_hook(Box::new(move |info| { - log::error!("PANIC => {}", info.to_string()); - eprintln!( + std::panic::set_hook(Box::new(move |info| { + log::error!("PANIC => {}", info.to_string()); + eprintln!( "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => '{file_path}'\n=============" ); - hook(info); - })); + hook(info); + })); } diff --git a/burn-train/src/learner/regression.rs b/burn-train/src/learner/regression.rs index d6d647db61..9aa5db2e94 100644 --- a/burn-train/src/learner/regression.rs +++ b/burn-train/src/learner/regression.rs @@ -5,18 +5,18 @@ use burn_core::tensor::Tensor; /// Simple regression output adapted for multiple metrics. #[derive(new)] pub struct RegressionOutput { - /// The loss. - pub loss: Tensor, + /// The loss. + pub loss: Tensor, - /// The output. - pub output: Tensor, + /// The output. + pub output: Tensor, - /// The targets. - pub targets: Tensor, + /// The targets. + pub targets: Tensor, } impl Adaptor> for RegressionOutput { - fn adapt(&self) -> LossInput { - LossInput::new(self.loss.clone()) - } + fn adapt(&self) -> LossInput { + LossInput::new(self.loss.clone()) + } } diff --git a/burn-train/src/learner/step/train.rs b/burn-train/src/learner/step/train.rs index c8c46ae6d8..c000e8b661 100644 --- a/burn-train/src/learner/step/train.rs +++ b/burn-train/src/learner/step/train.rs @@ -1,139 +1,139 @@ use crate::{TrainOutput, TrainStep}; use burn_core::{ - data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend, + data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend, }; use std::sync::mpsc::{Receiver, Sender}; use std::thread::spawn; /// Multi devices train step. pub struct MultiDevicesTrainStep { - workers: Vec>, - receiver: Receiver>, + workers: Vec>, + receiver: Receiver>, } struct Message { - item: TI, - model: M, + item: TI, + model: M, } struct Worker { - sender_input: Sender>, - device: B::Device, + sender_input: Sender>, + device: B::Device, } impl Worker where - B: AutodiffBackend, - M: AutodiffModule, + B: AutodiffBackend, + M: AutodiffModule, { - fn register(&self, item: TI, model: &M) { - let message = Message { - item, - model: model.clone(), - }; - self.sender_input.send(message).unwrap(); - } + fn register(&self, item: TI, model: &M) { + let message = Message { + item, + model: model.clone(), + }; + self.sender_input.send(message).unwrap(); + } - fn start( - &self, - sender_output: Sender>, - receiver_input: Receiver>, - ) where - TI: Send + 'static, - TO: Send + 'static, - M: TrainStep + Send + 'static, - { - let device = self.device.clone(); + fn start( + &self, + sender_output: Sender>, + receiver_input: Receiver>, + ) where + TI: Send + 'static, + TO: Send + 'static, + M: TrainStep + Send + 'static, + { + let device = self.device.clone(); - spawn(move || loop { - match receiver_input.recv() { - Ok(item) => { - let step = item.model.fork(&device); - let output = step.step(item.item); + spawn(move || loop { + match receiver_input.recv() { + Ok(item) => { + let step = item.model.fork(&device); + let output = step.step(item.item); - sender_output.send(output).unwrap(); - } - Err(_err) => { - log::info!("Closing thread on device {:?}", device); - break; - } - } - }); - } + sender_output.send(output).unwrap(); + } + Err(_err) => { + log::info!("Closing thread on device {:?}", device); + break; + } + } + }); + } } impl MultiDevicesTrainStep where - B: AutodiffBackend, - M: AutodiffModule + TrainStep + Send + Clone + 'static, - TI: Send + 'static, - TO: Send + 'static, -{ - /// Create a new multi devices train step. - /// - /// # Arguments - /// - /// * `devices` - Devices. - /// - /// # Returns - /// - /// MultiDevicesTrainStep instance. - pub fn new(devices: &[B::Device]) -> Self - where + B: AutodiffBackend, + M: AutodiffModule + TrainStep + Send + Clone + 'static, TI: Send + 'static, - { - let (sender_output, receiver_output) = std::sync::mpsc::channel(); - let workers = devices - .iter() - .map(|device| { - let (sender_input, receiver_input) = std::sync::mpsc::channel(); - let worker = Worker { - sender_input, - device: device.clone(), - }; + TO: Send + 'static, +{ + /// Create a new multi devices train step. + /// + /// # Arguments + /// + /// * `devices` - Devices. + /// + /// # Returns + /// + /// MultiDevicesTrainStep instance. + pub fn new(devices: &[B::Device]) -> Self + where + TI: Send + 'static, + { + let (sender_output, receiver_output) = std::sync::mpsc::channel(); + let workers = devices + .iter() + .map(|device| { + let (sender_input, receiver_input) = std::sync::mpsc::channel(); + let worker = Worker { + sender_input, + device: device.clone(), + }; - worker.start(sender_output.clone(), receiver_input); - worker - }) - .collect(); + worker.start(sender_output.clone(), receiver_input); + worker + }) + .collect(); - Self { - workers, - receiver: receiver_output, + Self { + workers, + receiver: receiver_output, + } } - } - /// Collect outputs from workers for one step. - /// - /// # Arguments - /// - /// * `dataloader` - Dataloader. - /// * `model` - Model. - /// - /// # Returns - /// - /// Outputs. - pub fn step<'a>( - &self, - dataloader: &mut Box + 'a>, - model: &M, - ) -> Vec> { - let mut num_send = 0; + /// Collect outputs from workers for one step. + /// + /// # Arguments + /// + /// * `dataloader` - Dataloader. + /// * `model` - Model. + /// + /// # Returns + /// + /// Outputs. + pub fn step<'a>( + &self, + dataloader: &mut Box + 'a>, + model: &M, + ) -> Vec> { + let mut num_send = 0; - for worker in self.workers.iter() { - if let Some(item) = dataloader.next() { - worker.register(item, model); - num_send += 1; - } - } + for worker in self.workers.iter() { + if let Some(item) = dataloader.next() { + worker.register(item, model); + num_send += 1; + } + } - let mut outputs = Vec::with_capacity(num_send); + let mut outputs = Vec::with_capacity(num_send); - for _ in 0..num_send { - let output = self.receiver.recv().unwrap(); - outputs.push(output); - } + for _ in 0..num_send { + let output = self.receiver.recv().unwrap(); + outputs.push(output); + } - outputs - } + outputs + } } diff --git a/burn-train/src/learner/train_val.rs b/burn-train/src/learner/train_val.rs index 6a9f902697..b8b16dddf9 100644 --- a/burn-train/src/learner/train_val.rs +++ b/burn-train/src/learner/train_val.rs @@ -9,33 +9,33 @@ use std::sync::Arc; /// A training output. pub struct TrainOutput { - /// The gradients. - pub grads: GradientsParams, + /// The gradients. + pub grads: GradientsParams, - /// The item. - pub item: TO, + /// The item. + pub item: TO, } impl TrainOutput { - /// Creates a new training output. - /// - /// # Arguments - /// - /// * `module` - The module. - /// * `grads` - The gradients. - /// * `item` - The item. - /// - /// # Returns - /// - /// A new training output. - pub fn new>( - module: &M, - grads: B::Gradients, - item: TO, - ) -> Self { - let grads = GradientsParams::from_grads(grads, module); - Self { grads, item } - } + /// Creates a new training output. + /// + /// # Arguments + /// + /// * `module` - The module. + /// * `grads` - The gradients. + /// * `item` - The item. + /// + /// # Returns + /// + /// A new training output. + pub fn new>( + module: &M, + grads: B::Gradients, + item: TO, + ) -> Self { + let grads = GradientsParams::from_grads(grads, module); + Self { grads, item } + } } /// Trait to be implemented for training models. @@ -52,144 +52,152 @@ impl TrainOutput { /// also implement the [AutodiffModule] trait, which is done automatically with the /// [Module](burn_core::module::Module) derive. pub trait TrainStep { - /// Runs the training step, which executes the forward and backward passes. - /// - /// # Arguments - /// - /// * `item` - The training input for the model. - /// - /// # Returns - /// - /// The training output containing the model output and the gradients. - fn step(&self, item: TI) -> TrainOutput; - /// Optimize the current module with the provided gradients and learning rate. - /// - /// # Arguments - /// - /// * `optim`: Optimizer used for training this model. - /// * `lr`: The learning rate used for this step. - /// * `grads`: The gradients of each parameter in the current model. - /// - /// # Returns - /// - /// The updated model. - fn optimize(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self - where - B: AutodiffBackend, - O: Optimizer, - Self: AutodiffModule, - { - optim.step(lr, self, grads) - } + /// Runs the training step, which executes the forward and backward passes. + /// + /// # Arguments + /// + /// * `item` - The training input for the model. + /// + /// # Returns + /// + /// The training output containing the model output and the gradients. + fn step(&self, item: TI) -> TrainOutput; + /// Optimize the current module with the provided gradients and learning rate. + /// + /// # Arguments + /// + /// * `optim`: Optimizer used for training this model. + /// * `lr`: The learning rate used for this step. + /// * `grads`: The gradients of each parameter in the current model. + /// + /// # Returns + /// + /// The updated model. + fn optimize(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self + where + B: AutodiffBackend, + O: Optimizer, + Self: AutodiffModule, + { + optim.step(lr, self, grads) + } } /// Trait to be implemented for validating models. pub trait ValidStep { - /// Runs a validation step. - /// - /// # Arguments - /// - /// * `item` - The item to validate on. - /// - /// # Returns - /// - /// The validation output. - fn step(&self, item: VI) -> VO; + /// Runs a validation step. + /// + /// # Arguments + /// + /// * `item` - The item to validate on. + /// + /// # Returns + /// + /// The validation output. + fn step(&self, item: VI) -> VO; } impl Learner { - /// Fits the model. - /// - /// # Arguments - /// - /// * `dataloader_train` - The training dataloader. - /// * `dataloader_valid` - The validation dataloader. - /// - /// # Returns - /// - /// The fitted model. - pub fn fit( - mut self, - dataloader_train: Arc>, - dataloader_valid: Arc>, - ) -> LC::Model - where - InputTrain: Send + 'static, - InputValid: Send, - OutputTrain: Send + 'static, - OutputValid: Send, - LC::Model: TrainStep, - >::InnerModule: ValidStep, - LC::EventProcessor: EventProcessor, - { - log::info!("Fitting {}", self.model.to_string()); - // The reference model is always on the first device provided. - if let Some(device) = self.devices.get(0) { - self.model = self.model.fork(device); - } - - let starting_epoch = match self.checkpoint { - Some(checkpoint) => { - if let Some(checkpointer) = &mut self.checkpointer { - (self.model, self.optim, self.lr_scheduler) = - checkpointer.load_checkpoint(self.model, self.optim, self.lr_scheduler, checkpoint); + /// Fits the model. + /// + /// # Arguments + /// + /// * `dataloader_train` - The training dataloader. + /// * `dataloader_valid` - The validation dataloader. + /// + /// # Returns + /// + /// The fitted model. + pub fn fit( + mut self, + dataloader_train: Arc>, + dataloader_valid: Arc>, + ) -> LC::Model + where + InputTrain: Send + 'static, + InputValid: Send, + OutputTrain: Send + 'static, + OutputValid: Send, + LC::Model: TrainStep, + >::InnerModule: ValidStep, + LC::EventProcessor: EventProcessor, + { + log::info!("Fitting {}", self.model.to_string()); + // The reference model is always on the first device provided. + if let Some(device) = self.devices.get(0) { + self.model = self.model.fork(device); } - checkpoint + 1 - } - None => 1, - }; - for epoch in starting_epoch..self.num_epochs + 1 { - let epoch_train = TrainEpoch::new( - dataloader_train.clone(), - epoch, - self.num_epochs, - self.grad_accumulation, - ); + let starting_epoch = match self.checkpoint { + Some(checkpoint) => { + if let Some(checkpointer) = &mut self.checkpointer { + (self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint( + self.model, + self.optim, + self.lr_scheduler, + checkpoint, + ); + } + checkpoint + 1 + } + None => 1, + }; - if self.devices.len() > 1 { - (self.model, self.optim) = epoch_train.run_multi_device::( - self.model, - self.optim, - &mut self.lr_scheduler, - &mut self.event_processor, - self.devices.clone(), - &self.interrupter, - ) - } else { - (self.model, self.optim) = epoch_train.run::( - self.model, - self.optim, - &mut self.lr_scheduler, - &mut self.event_processor, - &self.interrupter, - ); - } + for epoch in starting_epoch..self.num_epochs + 1 { + let epoch_train = TrainEpoch::new( + dataloader_train.clone(), + epoch, + self.num_epochs, + self.grad_accumulation, + ); - if self.interrupter.should_stop() { - break; - } + if self.devices.len() > 1 { + (self.model, self.optim) = epoch_train.run_multi_device::( + self.model, + self.optim, + &mut self.lr_scheduler, + &mut self.event_processor, + self.devices.clone(), + &self.interrupter, + ) + } else { + (self.model, self.optim) = epoch_train.run::( + self.model, + self.optim, + &mut self.lr_scheduler, + &mut self.event_processor, + &self.interrupter, + ); + } - let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs); - epoch_valid.run::(&self.model, &mut self.event_processor, &self.interrupter); + if self.interrupter.should_stop() { + break; + } - if let Some(checkpointer) = &mut self.checkpointer { - checkpointer.checkpoint( - &self.model, - &self.optim, - &self.lr_scheduler, - epoch, - &self.event_store, - ); - } + let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs); + epoch_valid.run::( + &self.model, + &mut self.event_processor, + &self.interrupter, + ); - if let Some(early_stopping) = &mut self.early_stopping { - if early_stopping.should_stop(epoch, &self.event_store) { - break; + if let Some(checkpointer) = &mut self.checkpointer { + checkpointer.checkpoint( + &self.model, + &self.optim, + &self.lr_scheduler, + epoch, + &self.event_store, + ); + } + + if let Some(early_stopping) = &mut self.early_stopping { + if early_stopping.should_stop(epoch, &self.event_store) { + break; + } + } } - } - } - self.model - } + self.model + } } diff --git a/burn-train/src/logger/async_logger.rs b/burn-train/src/logger/async_logger.rs index 79308e21f0..c659098b1e 100644 --- a/burn-train/src/logger/async_logger.rs +++ b/burn-train/src/logger/async_logger.rs @@ -2,93 +2,90 @@ use super::Logger; use std::sync::mpsc; enum Message { - Log(T), - End, - Sync(mpsc::Sender<()>), + Log(T), + End, + Sync(mpsc::Sender<()>), } /// Async logger. pub struct AsyncLogger { - sender: mpsc::Sender>, - handler: Option>, + sender: mpsc::Sender>, + handler: Option>, } #[derive(new)] struct LoggerThread> { - logger: L, - receiver: mpsc::Receiver>, + logger: L, + receiver: mpsc::Receiver>, } impl LoggerThread where - L: Logger, + L: Logger, { - fn run(mut self) { - for item in self.receiver.iter() { - match item { - Message::Log(item) => { - self.logger.log(item); + fn run(mut self) { + for item in self.receiver.iter() { + match item { + Message::Log(item) => { + self.logger.log(item); + } + Message::End => { + return; + } + Message::Sync(callback) => { + callback + .send(()) + .expect("Can return result with the callback channel."); + } + } } - Message::End => { - return; - } - Message::Sync(callback) => { - callback - .send(()) - .expect("Can return result with the callback channel."); - } - } } - } } impl AsyncLogger { - /// Create a new async logger. - pub fn new(logger: L) -> Self - where - L: Logger + 'static, - { - let (sender, receiver) = mpsc::channel(); - let thread = LoggerThread::new(logger, receiver); + /// Create a new async logger. + pub fn new(logger: L) -> Self + where + L: Logger + 'static, + { + let (sender, receiver) = mpsc::channel(); + let thread = LoggerThread::new(logger, receiver); - let handler = Some(std::thread::spawn(move || thread.run())); + let handler = Some(std::thread::spawn(move || thread.run())); - Self { sender, handler } - } + Self { sender, handler } + } - /// Sync the async logger. - pub(crate) fn sync(&self) { - let (sender, receiver) = mpsc::channel(); + /// Sync the async logger. + pub(crate) fn sync(&self) { + let (sender, receiver) = mpsc::channel(); - self - .sender - .send(Message::Sync(sender)) - .expect("Can send message to logger thread."); + self.sender + .send(Message::Sync(sender)) + .expect("Can send message to logger thread."); - receiver - .recv() - .expect("Should sync, otherwise the thread is dead."); - } + receiver + .recv() + .expect("Should sync, otherwise the thread is dead."); + } } impl Logger for AsyncLogger { - fn log(&mut self, item: T) { - self - .sender - .send(Message::Log(item)) - .expect("Can log using the logger thread."); - } + fn log(&mut self, item: T) { + self.sender + .send(Message::Log(item)) + .expect("Can log using the logger thread."); + } } impl Drop for AsyncLogger { - fn drop(&mut self) { - self - .sender - .send(Message::End) - .expect("Can send the end message to the logger thread."); - let handler = self.handler.take(); + fn drop(&mut self) { + self.sender + .send(Message::End) + .expect("Can send the end message to the logger thread."); + let handler = self.handler.take(); - if let Some(handler) = handler { - handler.join().expect("The logger thread should stop."); + if let Some(handler) = handler { + handler.join().expect("The logger thread should stop."); + } } - } } diff --git a/burn-train/src/logger/base.rs b/burn-train/src/logger/base.rs index 5e3fcd677b..3b37c55e61 100644 --- a/burn-train/src/logger/base.rs +++ b/burn-train/src/logger/base.rs @@ -1,26 +1,26 @@ /// The logger trait. pub trait Logger: Send { - /// Logs an item. - /// - /// # Arguments - /// - /// * `item` - The item. - fn log(&mut self, item: T); + /// Logs an item. + /// + /// # Arguments + /// + /// * `item` - The item. + fn log(&mut self, item: T); } /// The logger backend trait. pub trait LoggerBackend { - /// The logger type. - type Logger: Logger; + /// The logger type. + type Logger: Logger; - /// Create a new logger. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - /// - /// # Returns - /// - /// The logger. - fn create(&self, epoch: usize) -> Self::Logger; + /// Create a new logger. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + /// + /// # Returns + /// + /// The logger. + fn create(&self, epoch: usize) -> Self::Logger; } diff --git a/burn-train/src/logger/file.rs b/burn-train/src/logger/file.rs index 7b13089af9..79c23b462d 100644 --- a/burn-train/src/logger/file.rs +++ b/burn-train/src/logger/file.rs @@ -3,37 +3,37 @@ use std::{fs::File, io::Write}; /// File logger. pub struct FileLogger { - file: File, + file: File, } impl FileLogger { - /// Create a new file logger. - /// - /// # Arguments - /// - /// * `path` - The path. - /// - /// # Returns - /// - /// The file logger. - pub fn new(path: &str) -> Self { - let mut options = std::fs::File::options(); - let file = options - .write(true) - .truncate(true) - .create(true) - .open(path) - .unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}")); + /// Create a new file logger. + /// + /// # Arguments + /// + /// * `path` - The path. + /// + /// # Returns + /// + /// The file logger. + pub fn new(path: &str) -> Self { + let mut options = std::fs::File::options(); + let file = options + .write(true) + .truncate(true) + .create(true) + .open(path) + .unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}")); - Self { file } - } + Self { file } + } } impl Logger for FileLogger where - T: std::fmt::Display, + T: std::fmt::Display, { - fn log(&mut self, item: T) { - writeln!(&mut self.file, "{item}").expect("Can log an item."); - } + fn log(&mut self, item: T) { + writeln!(&mut self.file, "{item}").expect("Can log an item."); + } } diff --git a/burn-train/src/logger/in_memory.rs b/burn-train/src/logger/in_memory.rs index 425c8dc76e..31cf3f165c 100644 --- a/burn-train/src/logger/in_memory.rs +++ b/burn-train/src/logger/in_memory.rs @@ -3,14 +3,14 @@ use super::Logger; /// In memory logger. #[derive(Default)] pub struct InMemoryLogger { - pub(crate) values: Vec, + pub(crate) values: Vec, } impl Logger for InMemoryLogger where - T: std::fmt::Display, + T: std::fmt::Display, { - fn log(&mut self, item: T) { - self.values.push(item.to_string()); - } + fn log(&mut self, item: T) { + self.values.push(item.to_string()); + } } diff --git a/burn-train/src/logger/metric.rs b/burn-train/src/logger/metric.rs index 42519e7c48..5751eff925 100644 --- a/burn-train/src/logger/metric.rs +++ b/burn-train/src/logger/metric.rs @@ -4,173 +4,169 @@ use std::collections::HashMap; /// Metric logger. pub trait MetricLogger: Send { - /// Logs an item. - /// - /// # Arguments - /// - /// * `item` - The item. - fn log(&mut self, item: &MetricEntry); - - /// Logs an epoch. - /// - /// # Arguments - /// - /// * `epoch` - The epoch. - fn end_epoch(&mut self, epoch: usize); - - /// Read the logs for an epoch. - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String>; + /// Logs an item. + /// + /// # Arguments + /// + /// * `item` - The item. + fn log(&mut self, item: &MetricEntry); + + /// Logs an epoch. + /// + /// # Arguments + /// + /// * `epoch` - The epoch. + fn end_epoch(&mut self, epoch: usize); + + /// Read the logs for an epoch. + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String>; } /// The file metric logger. pub struct FileMetricLogger { - loggers: HashMap>, - directory: String, - epoch: usize, + loggers: HashMap>, + directory: String, + epoch: usize, } impl FileMetricLogger { - /// Create a new file metric logger. - /// - /// # Arguments - /// - /// * `directory` - The directory. - /// - /// # Returns - /// - /// The file metric logger. - pub fn new(directory: &str) -> Self { - Self { - loggers: HashMap::new(), - directory: directory.to_string(), - epoch: 1, + /// Create a new file metric logger. + /// + /// # Arguments + /// + /// * `directory` - The directory. + /// + /// # Returns + /// + /// The file metric logger. + pub fn new(directory: &str) -> Self { + Self { + loggers: HashMap::new(), + directory: directory.to_string(), + epoch: 1, + } + } + + fn file_path(&self, name: &str, epoch: usize) -> String { + let directory = format!("{}/epoch-{}", self.directory, epoch); + let name = name.replace(' ', "_"); + + format!("{directory}/{name}.log") + } + fn create_directory(&self, epoch: usize) { + let directory = format!("{}/epoch-{}", self.directory, epoch); + std::fs::create_dir_all(directory).ok(); } - } - - fn file_path(&self, name: &str, epoch: usize) -> String { - let directory = format!("{}/epoch-{}", self.directory, epoch); - let name = name.replace(' ', "_"); - - format!("{directory}/{name}.log") - } - fn create_directory(&self, epoch: usize) { - let directory = format!("{}/epoch-{}", self.directory, epoch); - std::fs::create_dir_all(directory).ok(); - } } impl MetricLogger for FileMetricLogger { - fn log(&mut self, item: &MetricEntry) { - let key = &item.name; - let value = &item.serialize; - - let logger = match self.loggers.get_mut(key) { - Some(val) => val, - None => { - self.create_directory(self.epoch); - - let file_path = self.file_path(key, self.epoch); - let logger = FileLogger::new(&file_path); - let logger = AsyncLogger::new(logger); - - self.loggers.insert(key.clone(), logger); - self - .loggers - .get_mut(key) - .expect("Can get the previously saved logger.") - } - }; - - logger.log(value.clone()); - } - - fn end_epoch(&mut self, epoch: usize) { - self.loggers.clear(); - self.epoch = epoch + 1; - } - - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { - if let Some(value) = self.loggers.get(name) { - value.sync() + fn log(&mut self, item: &MetricEntry) { + let key = &item.name; + let value = &item.serialize; + + let logger = match self.loggers.get_mut(key) { + Some(val) => val, + None => { + self.create_directory(self.epoch); + + let file_path = self.file_path(key, self.epoch); + let logger = FileLogger::new(&file_path); + let logger = AsyncLogger::new(logger); + + self.loggers.insert(key.clone(), logger); + self.loggers + .get_mut(key) + .expect("Can get the previously saved logger.") + } + }; + + logger.log(value.clone()); } - let file_path = self.file_path(name, epoch); + fn end_epoch(&mut self, epoch: usize) { + self.loggers.clear(); + self.epoch = epoch + 1; + } - let mut errors = false; + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { + if let Some(value) = self.loggers.get(name) { + value.sync() + } - let data = std::fs::read_to_string(file_path) - .unwrap_or_default() - .split('\n') - .filter_map(|value| { - if value.is_empty() { - None + let file_path = self.file_path(name, epoch); + + let mut errors = false; + + let data = std::fs::read_to_string(file_path) + .unwrap_or_default() + .split('\n') + .filter_map(|value| { + if value.is_empty() { + None + } else { + match value.parse::() { + Ok(value) => Some(value), + Err(err) => { + log::error!("{err}"); + errors = true; + None + } + } + } + }) + .collect(); + + if errors { + Err("Parsing float errors".to_string()) } else { - match value.parse::() { - Ok(value) => Some(value), - Err(err) => { - log::error!("{err}"); - errors = true; - None - } - } + Ok(data) } - }) - .collect(); - - if errors { - Err("Parsing float errors".to_string()) - } else { - Ok(data) } - } } /// In memory metric logger, useful when testing and debugging. #[derive(Default)] pub struct InMemoryMetricLogger { - values: HashMap>, + values: HashMap>, } impl InMemoryMetricLogger { - /// Create a new in-memory metric logger. - pub fn new() -> Self { - Self::default() - } + /// Create a new in-memory metric logger. + pub fn new() -> Self { + Self::default() + } } impl MetricLogger for InMemoryMetricLogger { - fn log(&mut self, item: &MetricEntry) { - if !self.values.contains_key(&item.name) { - self - .values - .insert(item.name.clone(), vec![InMemoryLogger::default()]); - } + fn log(&mut self, item: &MetricEntry) { + if !self.values.contains_key(&item.name) { + self.values + .insert(item.name.clone(), vec![InMemoryLogger::default()]); + } - let values = self.values.get_mut(&item.name).unwrap(); + let values = self.values.get_mut(&item.name).unwrap(); - values.last_mut().unwrap().log(item.serialize.clone()); - } + values.last_mut().unwrap().log(item.serialize.clone()); + } - fn end_epoch(&mut self, _epoch: usize) { - for (_, values) in self.values.iter_mut() { - values.push(InMemoryLogger::default()); + fn end_epoch(&mut self, _epoch: usize) { + for (_, values) in self.values.iter_mut() { + values.push(InMemoryLogger::default()); + } } - } - - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { - let values = match self.values.get(name) { - Some(values) => values, - None => return Ok(Vec::new()), - }; - - match values.get(epoch - 1) { - Some(logger) => Ok( - logger - .values - .iter() - .filter_map(|value| value.parse::().ok()) - .collect(), - ), - None => Ok(Vec::new()), + + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { + let values = match self.values.get(name) { + Some(values) => values, + None => return Ok(Vec::new()), + }; + + match values.get(epoch - 1) { + Some(logger) => Ok(logger + .values + .iter() + .filter_map(|value| value.parse::().ok()) + .collect()), + None => Ok(Vec::new()), + } } - } } diff --git a/burn-train/src/metric/acc.rs b/burn-train/src/metric/acc.rs index 7157739231..a1a0f3ff2f 100644 --- a/burn-train/src/metric/acc.rs +++ b/burn-train/src/metric/acc.rs @@ -7,123 +7,123 @@ use burn_core::tensor::{ElementConversion, Int, Tensor}; /// The accuracy metric. #[derive(Default)] pub struct AccuracyMetric { - state: NumericMetricState, - pad_token: Option, - _b: B, + state: NumericMetricState, + pad_token: Option, + _b: B, } /// The [accuracy metric](AccuracyMetric) input type. #[derive(new)] pub struct AccuracyInput { - outputs: Tensor, - targets: Tensor, + outputs: Tensor, + targets: Tensor, } impl AccuracyMetric { - /// Creates the metric. - pub fn new() -> Self { - Self::default() - } - - /// Sets the pad token. - pub fn with_pad_token(mut self, index: usize) -> Self { - self.pad_token = Some(index); - self - } + /// Creates the metric. + pub fn new() -> Self { + Self::default() + } + + /// Sets the pad token. + pub fn with_pad_token(mut self, index: usize) -> Self { + self.pad_token = Some(index); + self + } } impl Metric for AccuracyMetric { - const NAME: &'static str = "Accuracy"; - - type Input = AccuracyInput; - - fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { - let [batch_size, _n_classes] = input.outputs.dims(); - - let targets = input.targets.clone().to_device(&B::Device::default()); - let outputs = input - .outputs - .clone() - .argmax(1) - .to_device(&B::Device::default()) - .reshape([batch_size]); - - let accuracy = match self.pad_token { - Some(pad_token) => { - let mask = targets.clone().equal_elem(pad_token as i64); - let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0); - let num_pad = mask.int().sum().into_scalar().elem::(); - - matches.sum().into_scalar().elem::() / (batch_size as f64 - num_pad) - } - None => { - outputs - .equal(targets) - .int() - .sum() - .into_scalar() - .elem::() - / batch_size as f64 - } - }; - - self.state.update( - 100.0 * accuracy, - batch_size, - FormatOptions::new(Self::NAME).unit("%").precision(2), - ) - } - - fn clear(&mut self) { - self.state.reset() - } + const NAME: &'static str = "Accuracy"; + + type Input = AccuracyInput; + + fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { + let [batch_size, _n_classes] = input.outputs.dims(); + + let targets = input.targets.clone().to_device(&B::Device::default()); + let outputs = input + .outputs + .clone() + .argmax(1) + .to_device(&B::Device::default()) + .reshape([batch_size]); + + let accuracy = match self.pad_token { + Some(pad_token) => { + let mask = targets.clone().equal_elem(pad_token as i64); + let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0); + let num_pad = mask.int().sum().into_scalar().elem::(); + + matches.sum().into_scalar().elem::() / (batch_size as f64 - num_pad) + } + None => { + outputs + .equal(targets) + .int() + .sum() + .into_scalar() + .elem::() + / batch_size as f64 + } + }; + + self.state.update( + 100.0 * accuracy, + batch_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } } impl Numeric for AccuracyMetric { - fn value(&self) -> f64 { - self.state.value() - } + fn value(&self) -> f64 { + self.state.value() + } } #[cfg(test)] mod tests { - use super::*; - use crate::TestBackend; - - #[test] - fn test_accuracy_without_padding() { - let mut metric = AccuracyMetric::::new(); - let input = AccuracyInput::new( - Tensor::from_data([ - [0.0, 0.2, 0.8], // 2 - [1.0, 2.0, 0.5], // 1 - [0.4, 0.1, 0.2], // 0 - [0.6, 0.7, 0.2], // 1 - ]), - Tensor::from_data([2, 2, 1, 1]), - ); - - let _entry = metric.update(&input, &MetricMetadata::fake()); - assert_eq!(50.0, metric.value()); - } - - #[test] - fn test_accuracy_with_padding() { - let mut metric = AccuracyMetric::::new().with_pad_token(3); - let input = AccuracyInput::new( - Tensor::from_data([ - [0.0, 0.2, 0.8, 0.0], // 2 - [1.0, 2.0, 0.5, 0.0], // 1 - [0.4, 0.1, 0.2, 0.0], // 0 - [0.6, 0.7, 0.2, 0.0], // 1 - [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count - [0.0, 0.1, 0.2, 0.0], // Error on padding should not count - [0.6, 0.0, 0.2, 0.0], // Error on padding should not count - ]), - Tensor::from_data([2, 2, 1, 1, 3, 3, 3]), - ); - - let _entry = metric.update(&input, &MetricMetadata::fake()); - assert_eq!(50.0, metric.value()); - } + use super::*; + use crate::TestBackend; + + #[test] + fn test_accuracy_without_padding() { + let mut metric = AccuracyMetric::::new(); + let input = AccuracyInput::new( + Tensor::from_data([ + [0.0, 0.2, 0.8], // 2 + [1.0, 2.0, 0.5], // 1 + [0.4, 0.1, 0.2], // 0 + [0.6, 0.7, 0.2], // 1 + ]), + Tensor::from_data([2, 2, 1, 1]), + ); + + let _entry = metric.update(&input, &MetricMetadata::fake()); + assert_eq!(50.0, metric.value()); + } + + #[test] + fn test_accuracy_with_padding() { + let mut metric = AccuracyMetric::::new().with_pad_token(3); + let input = AccuracyInput::new( + Tensor::from_data([ + [0.0, 0.2, 0.8, 0.0], // 2 + [1.0, 2.0, 0.5, 0.0], // 1 + [0.4, 0.1, 0.2, 0.0], // 0 + [0.6, 0.7, 0.2, 0.0], // 1 + [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count + [0.0, 0.1, 0.2, 0.0], // Error on padding should not count + [0.6, 0.0, 0.2, 0.0], // Error on padding should not count + ]), + Tensor::from_data([2, 2, 1, 1, 3, 3, 3]), + ); + + let _entry = metric.update(&input, &MetricMetadata::fake()); + assert_eq!(50.0, metric.value()); + } } diff --git a/burn-train/src/metric/base.rs b/burn-train/src/metric/base.rs index 215c137f0a..1d0f2ca49b 100644 --- a/burn-train/src/metric/base.rs +++ b/burn-train/src/metric/base.rs @@ -2,36 +2,36 @@ use burn_core::{data::dataloader::Progress, LearningRate}; /// Metric metadata that can be used when computing metrics. pub struct MetricMetadata { - /// The current progress. - pub progress: Progress, + /// The current progress. + pub progress: Progress, - /// The current epoch. - pub epoch: usize, + /// The current epoch. + pub epoch: usize, - /// The total number of epochs. - pub epoch_total: usize, + /// The total number of epochs. + pub epoch_total: usize, - /// The current iteration. - pub iteration: usize, + /// The current iteration. + pub iteration: usize, - /// The current learning rate. - pub lr: Option, + /// The current learning rate. + pub lr: Option, } impl MetricMetadata { - #[cfg(test)] - pub fn fake() -> Self { - Self { - progress: Progress { - items_processed: 1, - items_total: 1, - }, - epoch: 0, - epoch_total: 1, - iteration: 0, - lr: None, + #[cfg(test)] + pub fn fake() -> Self { + Self { + progress: Progress { + items_processed: 1, + items_total: 1, + }, + epoch: 0, + epoch_total: 1, + iteration: 0, + lr: None, + } } - } } /// Metric trait. @@ -42,18 +42,18 @@ impl MetricMetadata { /// This is important since some conflict may happen when the model output is adapted for each /// metric's input type. pub trait Metric: Send + Sync { - /// The name of the metric. - /// - /// This should be unique, so avoid using short generic names, prefer using the long name. - const NAME: &'static str; + /// The name of the metric. + /// + /// This should be unique, so avoid using short generic names, prefer using the long name. + const NAME: &'static str; - /// The input type of the metric. - type Input; + /// The input type of the metric. + type Input; - /// Update the metric state and returns the current metric entry. - fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry; - /// Clear the metric state. - fn clear(&mut self); + /// Update the metric state and returns the current metric entry. + fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry; + /// Clear the metric state. + fn clear(&mut self); } /// Adaptor are used to transform types so that they can be used by metrics. @@ -61,35 +61,35 @@ pub trait Metric: Send + Sync { /// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are /// registered with the [leaner buidler](crate::learner::LearnerBuilder) . pub trait Adaptor { - /// Adapt the type to be passed to a [metric](Metric). - fn adapt(&self) -> T; + /// Adapt the type to be passed to a [metric](Metric). + fn adapt(&self) -> T; } /// Declare a metric to be numeric. /// /// This is useful to plot the values of a metric during training. pub trait Numeric { - /// Returns the numeric value of the metric. - fn value(&self) -> f64; + /// Returns the numeric value of the metric. + fn value(&self) -> f64; } /// Data type that contains the current state of a metric at a given time. #[derive(new, Debug, Clone)] pub struct MetricEntry { - /// The name of the metric. - pub name: String, - /// The string to be displayed. - pub formatted: String, - /// The string to be saved. - pub serialize: String, + /// The name of the metric. + pub name: String, + /// The string to be displayed. + pub formatted: String, + /// The string to be saved. + pub serialize: String, } /// Format a float with the given precision. Will use scientific notation if necessary. pub fn format_float(float: f64, precision: usize) -> String { - let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); + let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); - match scientific_notation_threshold >= float { - true => format!("{float:.precision$e}"), - false => format!("{float:.precision$}"), - } + match scientific_notation_threshold >= float { + true => format!("{float:.precision$e}"), + false => format!("{float:.precision$}"), + } } diff --git a/burn-train/src/metric/cpu_temp.rs b/burn-train/src/metric/cpu_temp.rs index ea96ec9f6b..a44aba8f05 100644 --- a/burn-train/src/metric/cpu_temp.rs +++ b/burn-train/src/metric/cpu_temp.rs @@ -5,51 +5,51 @@ use systemstat::{Platform, System}; /// CPU Temperature in celsius degrees pub struct CpuTemperature { - temp_celsius: f32, - sys: System, + temp_celsius: f32, + sys: System, } impl CpuTemperature { - /// Creates a new CPU temp metric - pub fn new() -> Self { - Self { - temp_celsius: 0., - sys: System::new(), + /// Creates a new CPU temp metric + pub fn new() -> Self { + Self { + temp_celsius: 0., + sys: System::new(), + } } - } } impl Default for CpuTemperature { - fn default() -> Self { - CpuTemperature::new() - } + fn default() -> Self { + CpuTemperature::new() + } } impl Metric for CpuTemperature { - const NAME: &'static str = "CPU Temperature"; + const NAME: &'static str = "CPU Temperature"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - match self.sys.cpu_temp() { - Ok(temp) => self.temp_celsius = temp, - Err(_) => self.temp_celsius = f32::NAN, - } + fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + match self.sys.cpu_temp() { + Ok(temp) => self.temp_celsius = temp, + Err(_) => self.temp_celsius = f32::NAN, + } - let formatted = match self.temp_celsius.is_nan() { - true => format!("{}: NaN °C", Self::NAME), - false => format!("{}: {:.2} °C", Self::NAME, self.temp_celsius), - }; - let raw = format!("{:.2}", self.temp_celsius); + let formatted = match self.temp_celsius.is_nan() { + true => format!("{}: NaN °C", Self::NAME), + false => format!("{}: {:.2} °C", Self::NAME, self.temp_celsius), + }; + let raw = format!("{:.2}", self.temp_celsius); - MetricEntry::new(Self::NAME.to_string(), formatted, raw) - } + MetricEntry::new(Self::NAME.to_string(), formatted, raw) + } - fn clear(&mut self) {} + fn clear(&mut self) {} } impl Numeric for CpuTemperature { - fn value(&self) -> f64 { - self.temp_celsius as f64 - } + fn value(&self) -> f64 { + self.temp_celsius as f64 + } } diff --git a/burn-train/src/metric/cpu_use.rs b/burn-train/src/metric/cpu_use.rs index 41849cb916..353165d289 100644 --- a/burn-train/src/metric/cpu_use.rs +++ b/burn-train/src/metric/cpu_use.rs @@ -5,65 +5,65 @@ use sysinfo::{CpuExt, CpuRefreshKind, RefreshKind, System, SystemExt}; /// General CPU Usage metric pub struct CpuUse { - last_refresh: Instant, - refresh_frequency: Duration, - sys: System, - current: f64, + last_refresh: Instant, + refresh_frequency: Duration, + sys: System, + current: f64, } impl CpuUse { - /// Creates a new CPU metric - pub fn new() -> Self { - let mut sys = System::new(); - let current = Self::refresh(&mut sys); + /// Creates a new CPU metric + pub fn new() -> Self { + let mut sys = System::new(); + let current = Self::refresh(&mut sys); - Self { - last_refresh: Instant::now(), - refresh_frequency: Duration::from_millis(200), - sys, - current, + Self { + last_refresh: Instant::now(), + refresh_frequency: Duration::from_millis(200), + sys, + current, + } } - } - fn refresh(sys: &mut System) -> f64 { - sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); + fn refresh(sys: &mut System) -> f64 { + sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); - let cpus = sys.cpus(); - let num_cpus = cpus.len(); - let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64; + let cpus = sys.cpus(); + let num_cpus = cpus.len(); + let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64; - use_percentage / num_cpus as f64 - } + use_percentage / num_cpus as f64 + } } impl Default for CpuUse { - fn default() -> Self { - CpuUse::new() - } + fn default() -> Self { + CpuUse::new() + } } impl Metric for CpuUse { - const NAME: &'static str = "CPU Usage"; + const NAME: &'static str = "CPU Usage"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - if self.last_refresh.elapsed() >= self.refresh_frequency { - self.current = Self::refresh(&mut self.sys); - self.last_refresh = Instant::now(); - } + fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + if self.last_refresh.elapsed() >= self.refresh_frequency { + self.current = Self::refresh(&mut self.sys); + self.last_refresh = Instant::now(); + } - let formatted = format!("{}: {:.2} %", Self::NAME, self.current); - let raw = format!("{:.2}", self.current); + let formatted = format!("{}: {:.2} %", Self::NAME, self.current); + let raw = format!("{:.2}", self.current); - MetricEntry::new(Self::NAME.to_string(), formatted, raw) - } + MetricEntry::new(Self::NAME.to_string(), formatted, raw) + } - fn clear(&mut self) {} + fn clear(&mut self) {} } impl Numeric for CpuUse { - fn value(&self) -> f64 { - self.current - } + fn value(&self) -> f64 { + self.current + } } diff --git a/burn-train/src/metric/cuda.rs b/burn-train/src/metric/cuda.rs index c7f15ef1df..e69e11ffc1 100644 --- a/burn-train/src/metric/cuda.rs +++ b/burn-train/src/metric/cuda.rs @@ -4,101 +4,101 @@ use nvml_wrapper::Nvml; /// Track basic cuda infos. pub struct CUDAMetric { - nvml: Option, + nvml: Option, } impl CUDAMetric { - /// Creates a new metric for CUDA. - pub fn new() -> Self { - Self { - nvml: Nvml::init().map(Some).unwrap_or_else(|err| { - log::warn!("Unable to initialize CUDA Metric: {err}"); - None - }), + /// Creates a new metric for CUDA. + pub fn new() -> Self { + Self { + nvml: Nvml::init().map(Some).unwrap_or_else(|err| { + log::warn!("Unable to initialize CUDA Metric: {err}"); + None + }), + } } - } } impl Default for CUDAMetric { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl Adaptor<()> for T { - fn adapt(&self) {} + fn adapt(&self) {} } impl Metric for CUDAMetric { - const NAME: &'static str = "CUDA Stats"; - - type Input = (); - - fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry { - let not_available = || { - MetricEntry::new( - Self::NAME.to_string(), - "Unavailable".to_string(), - "Unavailable".to_string(), - ) - }; - - let available = |nvml: &Nvml| { - let mut formatted = String::new(); - let mut raw_running = String::new(); - - let device_count = match nvml.device_count() { - Ok(val) => val, - Err(err) => { - log::warn!("Unable to get the number of cuda devices: {err}"); - return not_available(); - } - }; - - for index in 0..device_count { - let device = match nvml.device_by_index(index) { - Ok(val) => val, - Err(err) => { - log::warn!("Unable to get device {index}: {err}"); - return not_available(); - } - }; - let memory_info = match device.memory_info() { - Ok(info) => info, - Err(err) => { - log::warn!("Unable to get memory info from device {index}: {err}"); - return not_available(); - } - }; + const NAME: &'static str = "CUDA Stats"; - let used_gb = memory_info.used as f64 * 1e-9; - let total_gb = memory_info.total as f64 * 1e-9; + type Input = (); - let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb"); - let memory_info_raw = format!("{used_gb}/{total_gb}"); - - formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}"); - raw_running = format!("{memory_info_raw} "); - - let utilization_rates = match device.utilization_rates() { - Ok(rate) => rate, - Err(err) => { - log::warn!("Unable to get utilization rates from device {index}: {err}"); - return not_available(); - } + fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry { + let not_available = || { + MetricEntry::new( + Self::NAME.to_string(), + "Unavailable".to_string(), + "Unavailable".to_string(), + ) }; - let utilization_rate_formatted = format!("{}%", utilization_rates.gpu); - formatted = format!("{formatted} - Usage {utilization_rate_formatted}"); - } - MetricEntry::new(Self::NAME.to_string(), formatted, raw_running) - }; + let available = |nvml: &Nvml| { + let mut formatted = String::new(); + let mut raw_running = String::new(); + + let device_count = match nvml.device_count() { + Ok(val) => val, + Err(err) => { + log::warn!("Unable to get the number of cuda devices: {err}"); + return not_available(); + } + }; + + for index in 0..device_count { + let device = match nvml.device_by_index(index) { + Ok(val) => val, + Err(err) => { + log::warn!("Unable to get device {index}: {err}"); + return not_available(); + } + }; + let memory_info = match device.memory_info() { + Ok(info) => info, + Err(err) => { + log::warn!("Unable to get memory info from device {index}: {err}"); + return not_available(); + } + }; + + let used_gb = memory_info.used as f64 * 1e-9; + let total_gb = memory_info.total as f64 * 1e-9; + + let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb"); + let memory_info_raw = format!("{used_gb}/{total_gb}"); + + formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}"); + raw_running = format!("{memory_info_raw} "); + + let utilization_rates = match device.utilization_rates() { + Ok(rate) => rate, + Err(err) => { + log::warn!("Unable to get utilization rates from device {index}: {err}"); + return not_available(); + } + }; + let utilization_rate_formatted = format!("{}%", utilization_rates.gpu); + formatted = format!("{formatted} - Usage {utilization_rate_formatted}"); + } + + MetricEntry::new(Self::NAME.to_string(), formatted, raw_running) + }; - match &self.nvml { - Some(nvml) => available(nvml), - None => not_available(), + match &self.nvml { + Some(nvml) => available(nvml), + None => not_available(), + } } - } - fn clear(&mut self) {} + fn clear(&mut self) {} } diff --git a/burn-train/src/metric/learning_rate.rs b/burn-train/src/metric/learning_rate.rs index c7a542dbbc..c6ca58c018 100644 --- a/burn-train/src/metric/learning_rate.rs +++ b/burn-train/src/metric/learning_rate.rs @@ -1,49 +1,48 @@ use super::{ - state::{FormatOptions, NumericMetricState}, - MetricMetadata, Numeric, + state::{FormatOptions, NumericMetricState}, + MetricMetadata, Numeric, }; use crate::metric::{Metric, MetricEntry}; /// Track the learning rate across iterations. pub struct LearningRateMetric { - state: NumericMetricState, + state: NumericMetricState, } impl LearningRateMetric { - /// Creates a new learning rate metric. - pub fn new() -> Self { - Self { - state: NumericMetricState::new(), + /// Creates a new learning rate metric. + pub fn new() -> Self { + Self { + state: NumericMetricState::new(), + } } - } } impl Default for LearningRateMetric { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl Metric for LearningRateMetric { - const NAME: &'static str = "Learning Rate"; + const NAME: &'static str = "Learning Rate"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry { - let lr = metadata.lr.unwrap_or(0.0); + fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry { + let lr = metadata.lr.unwrap_or(0.0); - self - .state - .update(lr, 1, FormatOptions::new("Learning Rate").precision(2)) - } + self.state + .update(lr, 1, FormatOptions::new("Learning Rate").precision(2)) + } - fn clear(&mut self) { - self.state.reset() - } + fn clear(&mut self) { + self.state.reset() + } } impl Numeric for LearningRateMetric { - fn value(&self) -> f64 { - self.state.value() - } + fn value(&self) -> f64 { + self.state.value() + } } diff --git a/burn-train/src/metric/loss.rs b/burn-train/src/metric/loss.rs index 877cc7ed74..62ed71d816 100644 --- a/burn-train/src/metric/loss.rs +++ b/burn-train/src/metric/loss.rs @@ -10,43 +10,42 @@ use burn_core::tensor::Tensor; /// The loss metric. #[derive(Default)] pub struct LossMetric { - state: NumericMetricState, - _b: B, + state: NumericMetricState, + _b: B, } /// The [loss metric](LossMetric) input type. #[derive(new)] pub struct LossInput { - tensor: Tensor, + tensor: Tensor, } impl LossMetric { - /// Create the metric. - pub fn new() -> Self { - Self::default() - } + /// Create the metric. + pub fn new() -> Self { + Self::default() + } } impl Metric for LossMetric { - const NAME: &'static str = "Loss"; + const NAME: &'static str = "Loss"; - type Input = LossInput; + type Input = LossInput; - fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]); + fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]); - self - .state - .update(loss, 1, FormatOptions::new(Self::NAME).precision(2)) - } + self.state + .update(loss, 1, FormatOptions::new(Self::NAME).precision(2)) + } - fn clear(&mut self) { - self.state.reset() - } + fn clear(&mut self) { + self.state.reset() + } } impl Numeric for LossMetric { - fn value(&self) -> f64 { - self.state.value() - } + fn value(&self) -> f64 { + self.state.value() + } } diff --git a/burn-train/src/metric/memory_use.rs b/burn-train/src/metric/memory_use.rs index 72c85285ea..832c910f69 100644 --- a/burn-train/src/metric/memory_use.rs +++ b/burn-train/src/metric/memory_use.rs @@ -6,74 +6,74 @@ use sysinfo::{System, SystemExt}; /// Memory information pub struct CpuMemory { - last_refresh: Instant, - refresh_frequency: Duration, - sys: System, - ram_bytes_total: u64, - ram_bytes_used: u64, + last_refresh: Instant, + refresh_frequency: Duration, + sys: System, + ram_bytes_total: u64, + ram_bytes_used: u64, } impl CpuMemory { - /// Creates a new memory metric - pub fn new() -> Self { - let mut metric = Self { - last_refresh: Instant::now(), - refresh_frequency: Duration::from_millis(200), - sys: System::new(), - ram_bytes_total: 0, - ram_bytes_used: 0, - }; - metric.refresh(); - metric - } + /// Creates a new memory metric + pub fn new() -> Self { + let mut metric = Self { + last_refresh: Instant::now(), + refresh_frequency: Duration::from_millis(200), + sys: System::new(), + ram_bytes_total: 0, + ram_bytes_used: 0, + }; + metric.refresh(); + metric + } - fn refresh(&mut self) { - self.sys.refresh_memory(); - self.last_refresh = Instant::now(); + fn refresh(&mut self) { + self.sys.refresh_memory(); + self.last_refresh = Instant::now(); - // bytes of RAM available - self.ram_bytes_total = self.sys.total_memory(); + // bytes of RAM available + self.ram_bytes_total = self.sys.total_memory(); - // bytes of RAM in use - self.ram_bytes_used = self.sys.used_memory(); - } + // bytes of RAM in use + self.ram_bytes_used = self.sys.used_memory(); + } } impl Default for CpuMemory { - fn default() -> Self { - CpuMemory::new() - } + fn default() -> Self { + CpuMemory::new() + } } impl Metric for CpuMemory { - const NAME: &'static str = "CPU Memory"; + const NAME: &'static str = "CPU Memory"; - type Input = (); + type Input = (); - fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - if self.last_refresh.elapsed() >= self.refresh_frequency { - self.refresh(); - } + fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + if self.last_refresh.elapsed() >= self.refresh_frequency { + self.refresh(); + } - let raw = bytes2gb(self.ram_bytes_used); - let formatted = format!( - "RAM Used: {:.2} / {:.2} Gb", - raw, - bytes2gb(self.ram_bytes_total), - ); + let raw = bytes2gb(self.ram_bytes_used); + let formatted = format!( + "RAM Used: {:.2} / {:.2} Gb", + raw, + bytes2gb(self.ram_bytes_total), + ); - MetricEntry::new(Self::NAME.to_string(), formatted, raw.to_string()) - } + MetricEntry::new(Self::NAME.to_string(), formatted, raw.to_string()) + } - fn clear(&mut self) {} + fn clear(&mut self) {} } impl Numeric for CpuMemory { - fn value(&self) -> f64 { - bytes2gb(self.ram_bytes_used) - } + fn value(&self) -> f64 { + bytes2gb(self.ram_bytes_used) + } } fn bytes2gb(bytes: u64) -> f64 { - bytes as f64 / 1e9 + bytes as f64 / 1e9 } diff --git a/burn-train/src/metric/processor/base.rs b/burn-train/src/metric/processor/base.rs index d82d246c50..9093d26457 100644 --- a/burn-train/src/metric/processor/base.rs +++ b/burn-train/src/metric/processor/base.rs @@ -3,43 +3,43 @@ use burn_core::LearningRate; /// Event happening during the training/validation process. pub enum Event { - /// Signal that an item have been processed. - ProcessedItem(LearnerItem), - /// Signal the end of an epoch. - EndEpoch(usize), + /// Signal that an item have been processed. + ProcessedItem(LearnerItem), + /// Signal the end of an epoch. + EndEpoch(usize), } /// Process events happening during training and validation. pub trait EventProcessor { - /// The training item. - type ItemTrain; - /// The validation item. - type ItemValid; - - /// Collect a training event. - fn process_train(&mut self, event: Event); - /// Collect a validation event. - fn process_valid(&mut self, event: Event); + /// The training item. + type ItemTrain; + /// The validation item. + type ItemValid; + + /// Collect a training event. + fn process_train(&mut self, event: Event); + /// Collect a validation event. + fn process_valid(&mut self, event: Event); } /// A learner item. #[derive(new)] pub struct LearnerItem { - /// The item. - pub item: T, + /// The item. + pub item: T, - /// The progress. - pub progress: Progress, + /// The progress. + pub progress: Progress, - /// The epoch. - pub epoch: usize, + /// The epoch. + pub epoch: usize, - /// The total number of epochs. - pub epoch_total: usize, + /// The total number of epochs. + pub epoch_total: usize, - /// The iteration. - pub iteration: usize, + /// The iteration. + pub iteration: usize, - /// The learning rate. - pub lr: Option, + /// The learning rate. + pub lr: Option, } diff --git a/burn-train/src/metric/processor/full.rs b/burn-train/src/metric/processor/full.rs index d392932692..b25870dfb4 100644 --- a/burn-train/src/metric/processor/full.rs +++ b/burn-train/src/metric/processor/full.rs @@ -7,100 +7,94 @@ use std::sync::Arc; /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). /// - Render metrics using a [metrics renderer](MetricsRenderer). pub struct FullEventProcessor { - metrics: Metrics, - renderer: Box, - store: Arc, -} - -impl FullEventProcessor { - pub(crate) fn new( metrics: Metrics, renderer: Box, store: Arc, - ) -> Self { - Self { - metrics, - renderer, - store, +} + +impl FullEventProcessor { + pub(crate) fn new( + metrics: Metrics, + renderer: Box, + store: Arc, + ) -> Self { + Self { + metrics, + renderer, + store, + } } - } } impl EventProcessor for FullEventProcessor { - type ItemTrain = T; - type ItemValid = V; + type ItemTrain = T; + type ItemValid = V; - fn process_train(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let progress = (&item).into(); - let metadata = (&item).into(); + fn process_train(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let progress = (&item).into(); + let metadata = (&item).into(); - let update = self.metrics.update_train(&item, &metadata); + let update = self.metrics.update_train(&item, &metadata); - self - .store - .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); + self.store + .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); - update - .entries - .into_iter() - .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); - update - .entries_numeric - .into_iter() - .for_each(|(entry, value)| { - self - .renderer - .update_train(MetricState::Numeric(entry, value)) - }); + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self.renderer + .update_train(MetricState::Numeric(entry, value)) + }); - self.renderer.render_train(progress); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_train(); - self - .store - .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); - } + self.renderer.render_train(progress); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_train(); + self.store + .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); + } + } } - } - fn process_valid(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let progress = (&item).into(); - let metadata = (&item).into(); + fn process_valid(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let progress = (&item).into(); + let metadata = (&item).into(); - let update = self.metrics.update_valid(&item, &metadata); + let update = self.metrics.update_valid(&item, &metadata); - self - .store - .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); + self.store + .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); - update - .entries - .into_iter() - .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); - update - .entries_numeric - .into_iter() - .for_each(|(entry, value)| { - self - .renderer - .update_valid(MetricState::Numeric(entry, value)) - }); + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self.renderer + .update_valid(MetricState::Numeric(entry, value)) + }); - self.renderer.render_valid(progress); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_valid(); - self - .store - .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); - } + self.renderer.render_valid(progress); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_valid(); + self.store + .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); + } + } } - } } diff --git a/burn-train/src/metric/processor/metrics.rs b/burn-train/src/metric/processor/metrics.rs index b5c50d27be..e2992f12b0 100644 --- a/burn-train/src/metric/processor/metrics.rs +++ b/burn-train/src/metric/processor/metrics.rs @@ -1,196 +1,200 @@ use super::LearnerItem; use crate::{ - metric::{store::MetricsUpdate, Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, - renderer::TrainingProgress, + metric::{store::MetricsUpdate, Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, + renderer::TrainingProgress, }; pub(crate) struct Metrics { - train: Vec>>, - valid: Vec>>, - train_numeric: Vec>>, - valid_numeric: Vec>>, + train: Vec>>, + valid: Vec>>, + train_numeric: Vec>>, + valid_numeric: Vec>>, } impl Default for Metrics { - fn default() -> Self { - Self { - train: Vec::default(), - valid: Vec::default(), - train_numeric: Vec::default(), - valid_numeric: Vec::default(), + fn default() -> Self { + Self { + train: Vec::default(), + valid: Vec::default(), + train_numeric: Vec::default(), + valid_numeric: Vec::default(), + } } - } } impl Metrics { - /// Register a training metric. - pub(crate) fn register_metric_train(&mut self, metric: Me) - where - T: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.train.push(Box::new(metric)) - } - - /// Register a validation metric. - pub(crate) fn register_valid_metric(&mut self, metric: Me) - where - V: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.valid.push(Box::new(metric)) - } - - /// Register a numeric training metric. - pub(crate) fn register_train_metric_numeric(&mut self, metric: Me) - where - T: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.train_numeric.push(Box::new(metric)) - } - - /// Register a numeric validation metric. - pub(crate) fn register_valid_metric_numeric(&mut self, metric: Me) - where - V: Adaptor + 'static, - { - let metric = MetricWrapper::new(metric); - self.valid_numeric.push(Box::new(metric)) - } - - /// Update the training information from the training item. - pub(crate) fn update_train( - &mut self, - item: &LearnerItem, - metadata: &MetricMetadata, - ) -> MetricsUpdate { - let mut entries = Vec::with_capacity(self.train.len()); - let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); - - for metric in self.train.iter_mut() { - let state = metric.update(item, metadata); - entries.push(state); + /// Register a training metric. + pub(crate) fn register_metric_train(&mut self, metric: Me) + where + T: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.train.push(Box::new(metric)) } - for metric in self.train_numeric.iter_mut() { - let (state, value) = metric.update(item, metadata); - entries_numeric.push((state, value)); + /// Register a validation metric. + pub(crate) fn register_valid_metric(&mut self, metric: Me) + where + V: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.valid.push(Box::new(metric)) } - MetricsUpdate::new(entries, entries_numeric) - } - - /// Update the training information from the validation item. - pub(crate) fn update_valid( - &mut self, - item: &LearnerItem, - metadata: &MetricMetadata, - ) -> MetricsUpdate { - let mut entries = Vec::with_capacity(self.valid.len()); - let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); - - for metric in self.valid.iter_mut() { - let state = metric.update(item, metadata); - entries.push(state); + /// Register a numeric training metric. + pub(crate) fn register_train_metric_numeric( + &mut self, + metric: Me, + ) where + T: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.train_numeric.push(Box::new(metric)) } - for metric in self.valid_numeric.iter_mut() { - let (state, value) = metric.update(item, metadata); - entries_numeric.push((state, value)); + /// Register a numeric validation metric. + pub(crate) fn register_valid_metric_numeric( + &mut self, + metric: Me, + ) where + V: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.valid_numeric.push(Box::new(metric)) } - MetricsUpdate::new(entries, entries_numeric) - } - - /// Signal the end of a training epoch. - pub(crate) fn end_epoch_train(&mut self) { - for metric in self.train.iter_mut() { - metric.clear(); + /// Update the training information from the training item. + pub(crate) fn update_train( + &mut self, + item: &LearnerItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.train.len()); + let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); + + for metric in self.train.iter_mut() { + let state = metric.update(item, metadata); + entries.push(state); + } + + for metric in self.train_numeric.iter_mut() { + let (state, value) = metric.update(item, metadata); + entries_numeric.push((state, value)); + } + + MetricsUpdate::new(entries, entries_numeric) } - for metric in self.train_numeric.iter_mut() { - metric.clear(); + + /// Update the training information from the validation item. + pub(crate) fn update_valid( + &mut self, + item: &LearnerItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.valid.len()); + let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); + + for metric in self.valid.iter_mut() { + let state = metric.update(item, metadata); + entries.push(state); + } + + for metric in self.valid_numeric.iter_mut() { + let (state, value) = metric.update(item, metadata); + entries_numeric.push((state, value)); + } + + MetricsUpdate::new(entries, entries_numeric) } - } - /// Signal the end of a validation epoch. - pub(crate) fn end_epoch_valid(&mut self) { - for metric in self.valid.iter_mut() { - metric.clear(); + /// Signal the end of a training epoch. + pub(crate) fn end_epoch_train(&mut self) { + for metric in self.train.iter_mut() { + metric.clear(); + } + for metric in self.train_numeric.iter_mut() { + metric.clear(); + } } - for metric in self.valid_numeric.iter_mut() { - metric.clear(); + + /// Signal the end of a validation epoch. + pub(crate) fn end_epoch_valid(&mut self) { + for metric in self.valid.iter_mut() { + metric.clear(); + } + for metric in self.valid_numeric.iter_mut() { + metric.clear(); + } } - } } impl From<&LearnerItem> for TrainingProgress { - fn from(item: &LearnerItem) -> Self { - Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, + } } - } } impl From<&LearnerItem> for MetricMetadata { - fn from(item: &LearnerItem) -> Self { - Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, - lr: item.lr, + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, + lr: item.lr, + } } - } } trait NumericMetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64); - fn clear(&mut self); + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64); + fn clear(&mut self); } trait MetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry; - fn clear(&mut self); + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry; + fn clear(&mut self); } #[derive(new)] struct MetricWrapper { - metric: M, + metric: M, } impl NumericMetricUpdater for MetricWrapper where - T: 'static, - M: Metric + Numeric + 'static, - T: Adaptor, + T: 'static, + M: Metric + Numeric + 'static, + T: Adaptor, { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64) { - let update = self.metric.update(&item.item.adapt(), metadata); - let numeric = self.metric.value(); + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64) { + let update = self.metric.update(&item.item.adapt(), metadata); + let numeric = self.metric.value(); - (update, numeric) - } + (update, numeric) + } - fn clear(&mut self) { - self.metric.clear() - } + fn clear(&mut self) { + self.metric.clear() + } } impl MetricUpdater for MetricWrapper where - T: 'static, - M: Metric + 'static, - T: Adaptor, + T: 'static, + M: Metric + 'static, + T: Adaptor, { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry { - self.metric.update(&item.item.adapt(), metadata) - } + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry { + self.metric.update(&item.item.adapt(), metadata) + } - fn clear(&mut self) { - self.metric.clear() - } + fn clear(&mut self) { + self.metric.clear() + } } diff --git a/burn-train/src/metric/processor/minimal.rs b/burn-train/src/metric/processor/minimal.rs index 7350ca6e18..bb60713e45 100644 --- a/burn-train/src/metric/processor/minimal.rs +++ b/burn-train/src/metric/processor/minimal.rs @@ -6,51 +6,47 @@ use std::sync::Arc; /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). #[derive(new)] pub(crate) struct MinimalEventProcessor { - metrics: Metrics, - store: Arc, + metrics: Metrics, + store: Arc, } impl EventProcessor for MinimalEventProcessor { - type ItemTrain = T; - type ItemValid = V; - - fn process_train(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let metadata = (&item).into(); - - let update = self.metrics.update_train(&item, &metadata); - - self - .store - .add_event_train(crate::metric::store::Event::MetricsUpdate(update)); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_train(); - self - .store - .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); - } + type ItemTrain = T; + type ItemValid = V; + + fn process_train(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let metadata = (&item).into(); + + let update = self.metrics.update_train(&item, &metadata); + + self.store + .add_event_train(crate::metric::store::Event::MetricsUpdate(update)); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_train(); + self.store + .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); + } + } } - } - - fn process_valid(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let metadata = (&item).into(); - - let update = self.metrics.update_valid(&item, &metadata); - - self - .store - .add_event_valid(crate::metric::store::Event::MetricsUpdate(update)); - } - Event::EndEpoch(epoch) => { - self.metrics.end_epoch_valid(); - self - .store - .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); - } + + fn process_valid(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let metadata = (&item).into(); + + let update = self.metrics.update_valid(&item, &metadata); + + self.store + .add_event_valid(crate::metric::store::Event::MetricsUpdate(update)); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_valid(); + self.store + .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); + } + } } - } } diff --git a/burn-train/src/metric/processor/mod.rs b/burn-train/src/metric/processor/mod.rs index 532ad59620..f889894098 100644 --- a/burn-train/src/metric/processor/mod.rs +++ b/burn-train/src/metric/processor/mod.rs @@ -12,42 +12,42 @@ pub(crate) use minimal::*; #[cfg(test)] pub(crate) mod test_utils { - use crate::metric::{ - processor::{Event, EventProcessor, LearnerItem, MinimalEventProcessor}, - Adaptor, LossInput, - }; - use burn_core::tensor::{backend::Backend, ElementConversion, Tensor}; + use crate::metric::{ + processor::{Event, EventProcessor, LearnerItem, MinimalEventProcessor}, + Adaptor, LossInput, + }; + use burn_core::tensor::{backend::Backend, ElementConversion, Tensor}; - impl Adaptor> for f64 { - fn adapt(&self) -> LossInput { - LossInput::new(Tensor::from_data([self.elem()])) + impl Adaptor> for f64 { + fn adapt(&self) -> LossInput { + LossInput::new(Tensor::from_data([self.elem()])) + } } - } - pub(crate) fn process_train( - processor: &mut MinimalEventProcessor, - value: f64, - epoch: usize, - ) { - let dummy_progress = burn_core::data::dataloader::Progress { - items_processed: 1, - items_total: 10, - }; - let num_epochs = 3; - let dummy_iteration = 1; + pub(crate) fn process_train( + processor: &mut MinimalEventProcessor, + value: f64, + epoch: usize, + ) { + let dummy_progress = burn_core::data::dataloader::Progress { + items_processed: 1, + items_total: 10, + }; + let num_epochs = 3; + let dummy_iteration = 1; - processor.process_train(Event::ProcessedItem(LearnerItem::new( - value, - dummy_progress, - epoch, - num_epochs, - dummy_iteration, - None, - ))); - } + processor.process_train(Event::ProcessedItem(LearnerItem::new( + value, + dummy_progress, + epoch, + num_epochs, + dummy_iteration, + None, + ))); + } - pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor, epoch: usize) { - processor.process_train(Event::EndEpoch(epoch)); - processor.process_valid(Event::EndEpoch(epoch)); - } + pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor, epoch: usize) { + processor.process_train(Event::EndEpoch(epoch)); + processor.process_valid(Event::EndEpoch(epoch)); + } } diff --git a/burn-train/src/metric/state.rs b/burn-train/src/metric/state.rs index db8b887dfc..9a188198dc 100644 --- a/burn-train/src/metric/state.rs +++ b/burn-train/src/metric/state.rs @@ -7,95 +7,95 @@ use crate::metric::{format_float, MetricEntry, Numeric}; /// The numeric metric store values inside floats. /// Even if some metric are integers, their mean are floats. pub struct NumericMetricState { - sum: f64, - count: usize, - current: f64, + sum: f64, + count: usize, + current: f64, } /// Formatting options for the [numeric metric state](NumericMetricState). pub struct FormatOptions { - name: String, - unit: Option, - precision: Option, + name: String, + unit: Option, + precision: Option, } impl FormatOptions { - /// Create the [formatting options](FormatOptions) with a name. - pub fn new(name: &str) -> Self { - Self { - name: name.to_string(), - unit: None, - precision: None, + /// Create the [formatting options](FormatOptions) with a name. + pub fn new(name: &str) -> Self { + Self { + name: name.to_string(), + unit: None, + precision: None, + } } - } - /// Specify the metric unit. - pub fn unit(mut self, unit: &str) -> Self { - self.unit = Some(unit.to_string()); - self - } + /// Specify the metric unit. + pub fn unit(mut self, unit: &str) -> Self { + self.unit = Some(unit.to_string()); + self + } - /// Specify the floating point precision. - pub fn precision(mut self, precision: usize) -> Self { - self.precision = Some(precision); - self - } + /// Specify the floating point precision. + pub fn precision(mut self, precision: usize) -> Self { + self.precision = Some(precision); + self + } } impl NumericMetricState { - /// Create a new [numeric metric state](NumericMetricState). - pub fn new() -> Self { - Self { - sum: 0.0, - count: 0, - current: f64::NAN, + /// Create a new [numeric metric state](NumericMetricState). + pub fn new() -> Self { + Self { + sum: 0.0, + count: 0, + current: f64::NAN, + } } - } - /// Reset the state. - pub fn reset(&mut self) { - self.sum = 0.0; - self.count = 0; - self.current = f64::NAN; - } + /// Reset the state. + pub fn reset(&mut self) { + self.sum = 0.0; + self.count = 0; + self.current = f64::NAN; + } - /// Update the state. - pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry { - self.sum += value * batch_size as f64; - self.count += batch_size; - self.current = value; + /// Update the state. + pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry { + self.sum += value * batch_size as f64; + self.count += batch_size; + self.current = value; - let value_current = value; - let value_running = self.sum / self.count as f64; - let serialized = value_current.to_string(); + let value_current = value; + let value_running = self.sum / self.count as f64; + let serialized = value_current.to_string(); - let (formatted_current, formatted_running) = match format.precision { - Some(precision) => ( - format_float(value_current, precision), - format_float(value_running, precision), - ), - None => (format!("{value_current}"), format!("{value_running}")), - }; + let (formatted_current, formatted_running) = match format.precision { + Some(precision) => ( + format_float(value_current, precision), + format_float(value_running, precision), + ), + None => (format!("{value_current}"), format!("{value_running}")), + }; - let formatted = match format.unit { - Some(unit) => { - format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}") - } - None => format!("epoch {formatted_running} - batch {formatted_current}"), - }; + let formatted = match format.unit { + Some(unit) => { + format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}") + } + None => format!("epoch {formatted_running} - batch {formatted_current}"), + }; - MetricEntry::new(format.name, formatted, serialized) - } + MetricEntry::new(format.name, formatted, serialized) + } } impl Numeric for NumericMetricState { - fn value(&self) -> f64 { - self.current - } + fn value(&self) -> f64 { + self.current + } } impl Default for NumericMetricState { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } diff --git a/burn-train/src/metric/store/aggregate.rs b/burn-train/src/metric/store/aggregate.rs index 2579f23bda..679f6fa22e 100644 --- a/burn-train/src/metric/store/aggregate.rs +++ b/burn-train/src/metric/store/aggregate.rs @@ -6,157 +6,157 @@ use super::{Aggregate, Direction}; /// Type that can be used to fetch and use numeric metric aggregates. #[derive(Default, Debug)] pub(crate) struct NumericMetricsAggregate { - value_for_each_epoch: HashMap, + value_for_each_epoch: HashMap, } #[derive(new, Hash, PartialEq, Eq, Debug)] struct Key { - name: String, - epoch: usize, - aggregate: Aggregate, + name: String, + epoch: usize, + aggregate: Aggregate, } impl NumericMetricsAggregate { - pub(crate) fn aggregate( - &mut self, - name: &str, - epoch: usize, - aggregate: Aggregate, - loggers: &mut [Box], - ) -> Option { - let key = Key::new(name.to_string(), epoch, aggregate); + pub(crate) fn aggregate( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + loggers: &mut [Box], + ) -> Option { + let key = Key::new(name.to_string(), epoch, aggregate); + + if let Some(value) = self.value_for_each_epoch.get(&key) { + return Some(*value); + } - if let Some(value) = self.value_for_each_epoch.get(&key) { - return Some(*value); - } + let points = || { + let mut errors = Vec::new(); + for logger in loggers { + match logger.read_numeric(name, epoch) { + Ok(points) => return Ok(points), + Err(err) => errors.push(err), + }; + } - let points = || { - let mut errors = Vec::new(); - for logger in loggers { - match logger.read_numeric(name, epoch) { - Ok(points) => return Ok(points), - Err(err) => errors.push(err), + Err(errors.join(" ")) }; - } - Err(errors.join(" ")) - }; + let points = points().expect("Can read values"); - let points = points().expect("Can read values"); - - if points.is_empty() { - return None; - } - - let num_points = points.len(); - let sum = points.into_iter().sum::(); - let value = match aggregate { - Aggregate::Mean => sum / num_points as f64, - }; + if points.is_empty() { + return None; + } - self.value_for_each_epoch.insert(key, value); - Some(value) - } + let num_points = points.len(); + let sum = points.into_iter().sum::(); + let value = match aggregate { + Aggregate::Mean => sum / num_points as f64, + }; - pub(crate) fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - loggers: &mut [Box], - ) -> Option { - let mut data = Vec::new(); - let mut current_epoch = 1; - - while let Some(value) = self.aggregate(name, current_epoch, aggregate, loggers) { - data.push(value); - current_epoch += 1; + self.value_for_each_epoch.insert(key, value); + Some(value) } - if data.is_empty() { - return None; - } + pub(crate) fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + loggers: &mut [Box], + ) -> Option { + let mut data = Vec::new(); + let mut current_epoch = 1; + + while let Some(value) = self.aggregate(name, current_epoch, aggregate, loggers) { + data.push(value); + current_epoch += 1; + } - let mut current_value = match &direction { - Direction::Lowest => f64::MAX, - Direction::Highest => f64::MIN, - }; - - for (i, value) in data.into_iter().enumerate() { - match &direction { - Direction::Lowest => { - if value < current_value { - current_value = value; - current_epoch = i + 1; - } + if data.is_empty() { + return None; } - Direction::Highest => { - if value > current_value { - current_value = value; - current_epoch = i + 1; - } + + let mut current_value = match &direction { + Direction::Lowest => f64::MAX, + Direction::Highest => f64::MIN, + }; + + for (i, value) in data.into_iter().enumerate() { + match &direction { + Direction::Lowest => { + if value < current_value { + current_value = value; + current_epoch = i + 1; + } + } + Direction::Highest => { + if value > current_value { + current_value = value; + current_epoch = i + 1; + } + } + } } - } - } - Some(current_epoch) - } + Some(current_epoch) + } } #[cfg(test)] mod tests { - use crate::{logger::FileMetricLogger, metric::MetricEntry}; + use crate::{logger::FileMetricLogger, metric::MetricEntry}; - use super::*; + use super::*; - struct TestLogger { - logger: FileMetricLogger, - epoch: usize, - } - const NAME: &str = "test-logger"; - - impl TestLogger { - fn new() -> Self { - Self { - logger: FileMetricLogger::new("/tmp"), - epoch: 1, - } + struct TestLogger { + logger: FileMetricLogger, + epoch: usize, } - fn log(&mut self, num: f64) { - self.logger.log(&MetricEntry::new( - NAME.into(), - num.to_string(), - num.to_string(), - )); + const NAME: &str = "test-logger"; + + impl TestLogger { + fn new() -> Self { + Self { + logger: FileMetricLogger::new("/tmp"), + epoch: 1, + } + } + fn log(&mut self, num: f64) { + self.logger.log(&MetricEntry::new( + NAME.into(), + num.to_string(), + num.to_string(), + )); + } + fn new_epoch(&mut self) { + self.logger.end_epoch(self.epoch); + self.epoch += 1; + } } - fn new_epoch(&mut self) { - self.logger.end_epoch(self.epoch); - self.epoch += 1; + + #[test] + fn should_find_epoch() { + let mut logger = TestLogger::new(); + let mut aggregate = NumericMetricsAggregate::default(); + + logger.log(500.); // Epoch 1 + logger.log(1000.); // Epoch 1 + logger.new_epoch(); + logger.log(200.); // Epoch 2 + logger.log(1000.); // Epoch 2 + logger.new_epoch(); + logger.log(10000.); // Epoch 3 + + let value = aggregate + .find_epoch( + NAME, + Aggregate::Mean, + Direction::Lowest, + &mut [Box::new(logger.logger)], + ) + .unwrap(); + + assert_eq!(value, 2); } - } - - #[test] - fn should_find_epoch() { - let mut logger = TestLogger::new(); - let mut aggregate = NumericMetricsAggregate::default(); - - logger.log(500.); // Epoch 1 - logger.log(1000.); // Epoch 1 - logger.new_epoch(); - logger.log(200.); // Epoch 2 - logger.log(1000.); // Epoch 2 - logger.new_epoch(); - logger.log(10000.); // Epoch 3 - - let value = aggregate - .find_epoch( - NAME, - Aggregate::Mean, - Direction::Lowest, - &mut [Box::new(logger.logger)], - ) - .unwrap(); - - assert_eq!(value, 2); - } } diff --git a/burn-train/src/metric/store/base.rs b/burn-train/src/metric/store/base.rs index 6039d2dcc6..51592a683c 100644 --- a/burn-train/src/metric/store/base.rs +++ b/burn-train/src/metric/store/base.rs @@ -2,68 +2,68 @@ use crate::metric::MetricEntry; /// Event happening during the training/validation process. pub enum Event { - /// Signal that metrics have been updated. - MetricsUpdate(MetricsUpdate), - /// Signal the end of an epoch. - EndEpoch(usize), + /// Signal that metrics have been updated. + MetricsUpdate(MetricsUpdate), + /// Signal the end of an epoch. + EndEpoch(usize), } /// Contains all metric information. #[derive(new, Clone)] pub struct MetricsUpdate { - /// Metrics information related to non-numeric metrics. - pub entries: Vec, - /// Metrics information related to numeric metrics. - pub entries_numeric: Vec<(MetricEntry, f64)>, + /// Metrics information related to non-numeric metrics. + pub entries: Vec, + /// Metrics information related to numeric metrics. + pub entries_numeric: Vec<(MetricEntry, f64)>, } /// Defines how training and validation events are collected and searched. /// /// This trait also exposes methods that uses the collected data to compute useful information. pub trait EventStore: Send { - /// Collect a training/validation event. - fn add_event(&mut self, event: Event, split: Split); + /// Collect a training/validation event. + fn add_event(&mut self, event: Event, split: Split); - /// Find the epoch following the given criteria from the collected data. - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option; + /// Find the epoch following the given criteria from the collected data. + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option; - /// Find the metric value for the current epoch following the given criteria. - fn find_metric( - &mut self, - name: &str, - epoch: usize, - aggregate: Aggregate, - split: Split, - ) -> Option; + /// Find the metric value for the current epoch following the given criteria. + fn find_metric( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option; } #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)] /// How to aggregate the metric. pub enum Aggregate { - /// Compute the average. - Mean, + /// Compute the average. + Mean, } #[derive(Copy, Clone)] /// The split to use. pub enum Split { - /// The training split. - Train, - /// The validation split. - Valid, + /// The training split. + Train, + /// The validation split. + Valid, } #[derive(Copy, Clone)] /// The direction of the query. pub enum Direction { - /// Lower is better. - Lowest, - /// Higher is better. - Highest, + /// Lower is better. + Lowest, + /// Higher is better. + Highest, } diff --git a/burn-train/src/metric/store/client.rs b/burn-train/src/metric/store/client.rs index 192f4f2abd..74ba83ab74 100644 --- a/burn-train/src/metric/store/client.rs +++ b/burn-train/src/metric/store/client.rs @@ -4,161 +4,156 @@ use std::{sync::mpsc, thread::JoinHandle}; /// Type that allows to communicate with an [event store](EventStore). pub struct EventStoreClient { - sender: mpsc::Sender, - handler: Option>, + sender: mpsc::Sender, + handler: Option>, } impl EventStoreClient { - /// Create a new [event store](EventStore) client. - pub(crate) fn new(store: C) -> Self - where - C: EventStore + 'static, - { - let (sender, receiver) = mpsc::channel(); - let thread = WorkerThread::new(store, receiver); + /// Create a new [event store](EventStore) client. + pub(crate) fn new(store: C) -> Self + where + C: EventStore + 'static, + { + let (sender, receiver) = mpsc::channel(); + let thread = WorkerThread::new(store, receiver); - let handler = std::thread::spawn(move || thread.run()); - let handler = Some(handler); + let handler = std::thread::spawn(move || thread.run()); + let handler = Some(handler); - Self { sender, handler } - } + Self { sender, handler } + } } impl EventStoreClient { - /// Add a training event to the [event store](EventStore). - pub(crate) fn add_event_train(&self, event: Event) { - self - .sender - .send(Message::OnEventTrain(event)) - .expect("Can send event to event store thread."); - } + /// Add a training event to the [event store](EventStore). + pub(crate) fn add_event_train(&self, event: Event) { + self.sender + .send(Message::OnEventTrain(event)) + .expect("Can send event to event store thread."); + } - /// Add a validation event to the [event store](EventStore). - pub(crate) fn add_event_valid(&self, event: Event) { - self - .sender - .send(Message::OnEventValid(event)) - .expect("Can send event to event store thread."); - } + /// Add a validation event to the [event store](EventStore). + pub(crate) fn add_event_valid(&self, event: Event) { + self.sender + .send(Message::OnEventValid(event)) + .expect("Can send event to event store thread."); + } - /// Find the epoch following the given criteria from the collected data. - pub fn find_epoch( - &self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - let (sender, receiver) = mpsc::sync_channel(1); - self - .sender - .send(Message::FindEpoch( - name.to_string(), - aggregate, - direction, - split, - sender, - )) - .expect("Can send event to event store thread."); + /// Find the epoch following the given criteria from the collected data. + pub fn find_epoch( + &self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + let (sender, receiver) = mpsc::sync_channel(1); + self.sender + .send(Message::FindEpoch( + name.to_string(), + aggregate, + direction, + split, + sender, + )) + .expect("Can send event to event store thread."); - match receiver.recv() { - Ok(value) => value, - Err(err) => panic!("Event store thread crashed: {:?}", err), + match receiver.recv() { + Ok(value) => value, + Err(err) => panic!("Event store thread crashed: {:?}", err), + } } - } - /// Find the metric value for the current epoch following the given criteria. - pub fn find_metric( - &self, - name: &str, - epoch: usize, - aggregate: Aggregate, - split: Split, - ) -> Option { - let (sender, receiver) = mpsc::sync_channel(1); - self - .sender - .send(Message::FindMetric( - name.to_string(), - epoch, - aggregate, - split, - sender, - )) - .expect("Can send event to event store thread."); + /// Find the metric value for the current epoch following the given criteria. + pub fn find_metric( + &self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option { + let (sender, receiver) = mpsc::sync_channel(1); + self.sender + .send(Message::FindMetric( + name.to_string(), + epoch, + aggregate, + split, + sender, + )) + .expect("Can send event to event store thread."); - match receiver.recv() { - Ok(value) => value, - Err(err) => panic!("Event store thread crashed: {:?}", err), + match receiver.recv() { + Ok(value) => value, + Err(err) => panic!("Event store thread crashed: {:?}", err), + } } - } } #[derive(new)] struct WorkerThread { - store: S, - receiver: mpsc::Receiver, + store: S, + receiver: mpsc::Receiver, } impl WorkerThread where - C: EventStore, + C: EventStore, { - fn run(mut self) { - for item in self.receiver.iter() { - match item { - Message::End => { - return; - } - Message::FindEpoch(name, aggregate, direction, split, callback) => { - let response = self.store.find_epoch(&name, aggregate, direction, split); - callback - .send(response) - .expect("Can send response using callback channel."); + fn run(mut self) { + for item in self.receiver.iter() { + match item { + Message::End => { + return; + } + Message::FindEpoch(name, aggregate, direction, split, callback) => { + let response = self.store.find_epoch(&name, aggregate, direction, split); + callback + .send(response) + .expect("Can send response using callback channel."); + } + Message::FindMetric(name, epoch, aggregate, split, callback) => { + let response = self.store.find_metric(&name, epoch, aggregate, split); + callback + .send(response) + .expect("Can send response using callback channel."); + } + Message::OnEventTrain(event) => self.store.add_event(event, Split::Train), + Message::OnEventValid(event) => self.store.add_event(event, Split::Valid), + } } - Message::FindMetric(name, epoch, aggregate, split, callback) => { - let response = self.store.find_metric(&name, epoch, aggregate, split); - callback - .send(response) - .expect("Can send response using callback channel."); - } - Message::OnEventTrain(event) => self.store.add_event(event, Split::Train), - Message::OnEventValid(event) => self.store.add_event(event, Split::Valid), - } } - } } enum Message { - OnEventTrain(Event), - OnEventValid(Event), - End, - FindEpoch( - String, - Aggregate, - Direction, - Split, - mpsc::SyncSender>, - ), - FindMetric( - String, - usize, - Aggregate, - Split, - mpsc::SyncSender>, - ), + OnEventTrain(Event), + OnEventValid(Event), + End, + FindEpoch( + String, + Aggregate, + Direction, + Split, + mpsc::SyncSender>, + ), + FindMetric( + String, + usize, + Aggregate, + Split, + mpsc::SyncSender>, + ), } impl Drop for EventStoreClient { - fn drop(&mut self) { - self - .sender - .send(Message::End) - .expect("Can send the end message to the event store thread."); - let handler = self.handler.take(); + fn drop(&mut self) { + self.sender + .send(Message::End) + .expect("Can send the end message to the event store thread."); + let handler = self.handler.take(); - if let Some(handler) = handler { - handler.join().expect("The event store thread should stop."); + if let Some(handler) = handler { + handler.join().expect("The event store thread should stop."); + } } - } } diff --git a/burn-train/src/metric/store/log.rs b/burn-train/src/metric/store/log.rs index c8b88c6b72..9272e32330 100644 --- a/burn-train/src/metric/store/log.rs +++ b/burn-train/src/metric/store/log.rs @@ -3,105 +3,99 @@ use crate::logger::MetricLogger; #[derive(Default)] pub(crate) struct LogEventStore { - loggers_train: Vec>, - loggers_valid: Vec>, - aggregate_train: NumericMetricsAggregate, - aggregate_valid: NumericMetricsAggregate, + loggers_train: Vec>, + loggers_valid: Vec>, + aggregate_train: NumericMetricsAggregate, + aggregate_valid: NumericMetricsAggregate, } impl EventStore for LogEventStore { - fn add_event(&mut self, event: Event, split: Split) { - match event { - Event::MetricsUpdate(update) => match split { - Split::Train => { - update - .entries - .iter() - .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) - .for_each(|entry| { - self - .loggers_train - .iter_mut() - .for_each(|logger| logger.log(entry)); - }); + fn add_event(&mut self, event: Event, split: Split) { + match event { + Event::MetricsUpdate(update) => match split { + Split::Train => { + update + .entries + .iter() + .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) + .for_each(|entry| { + self.loggers_train + .iter_mut() + .for_each(|logger| logger.log(entry)); + }); + } + Split::Valid => { + update + .entries + .iter() + .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) + .for_each(|entry| { + self.loggers_valid + .iter_mut() + .for_each(|logger| logger.log(entry)); + }); + } + }, + Event::EndEpoch(epoch) => match split { + Split::Train => self + .loggers_train + .iter_mut() + .for_each(|logger| logger.end_epoch(epoch)), + Split::Valid => self + .loggers_valid + .iter_mut() + .for_each(|logger| logger.end_epoch(epoch + 1)), + }, } - Split::Valid => { - update - .entries - .iter() - .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) - .for_each(|entry| { - self - .loggers_valid - .iter_mut() - .for_each(|logger| logger.log(entry)); - }); - } - }, - Event::EndEpoch(epoch) => match split { - Split::Train => self - .loggers_train - .iter_mut() - .for_each(|logger| logger.end_epoch(epoch)), - Split::Valid => self - .loggers_valid - .iter_mut() - .for_each(|logger| logger.end_epoch(epoch + 1)), - }, } - } - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - match split { - Split::Train => { - self - .aggregate_train - .find_epoch(name, aggregate, direction, &mut self.loggers_train) - } - Split::Valid => { - self - .aggregate_valid - .find_epoch(name, aggregate, direction, &mut self.loggers_valid) - } + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + match split { + Split::Train => { + self.aggregate_train + .find_epoch(name, aggregate, direction, &mut self.loggers_train) + } + Split::Valid => { + self.aggregate_valid + .find_epoch(name, aggregate, direction, &mut self.loggers_valid) + } + } } - } - fn find_metric( - &mut self, - name: &str, - epoch: usize, - aggregate: Aggregate, - split: Split, - ) -> Option { - match split { - Split::Train => { - self - .aggregate_train - .aggregate(name, epoch, aggregate, &mut self.loggers_train) - } - Split::Valid => { - self - .aggregate_valid - .aggregate(name, epoch, aggregate, &mut self.loggers_valid) - } + fn find_metric( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option { + match split { + Split::Train => { + self.aggregate_train + .aggregate(name, epoch, aggregate, &mut self.loggers_train) + } + Split::Valid => { + self.aggregate_valid + .aggregate(name, epoch, aggregate, &mut self.loggers_valid) + } + } } - } } impl LogEventStore { - /// Register a logger for training metrics. - pub(crate) fn register_logger_train(&mut self, logger: ML) { - self.loggers_train.push(Box::new(logger)); - } + /// Register a logger for training metrics. + pub(crate) fn register_logger_train(&mut self, logger: ML) { + self.loggers_train.push(Box::new(logger)); + } - /// Register a logger for validation metrics. - pub(crate) fn register_logger_valid(&mut self, logger: ML) { - self.loggers_valid.push(Box::new(logger)); - } + /// Register a logger for validation metrics. + pub(crate) fn register_logger_valid(&mut self, logger: ML) { + self.loggers_valid.push(Box::new(logger)); + } } diff --git a/burn-train/src/renderer/base.rs b/burn-train/src/renderer/base.rs index 2258c32e46..6cfc2a5eb0 100644 --- a/burn-train/src/renderer/base.rs +++ b/burn-train/src/renderer/base.rs @@ -4,72 +4,72 @@ use crate::metric::MetricEntry; /// Trait for rendering metrics. pub trait MetricsRenderer: Send + Sync { - /// Updates the training metric state. - /// - /// # Arguments - /// - /// * `state` - The metric state. - fn update_train(&mut self, state: MetricState); + /// Updates the training metric state. + /// + /// # Arguments + /// + /// * `state` - The metric state. + fn update_train(&mut self, state: MetricState); - /// Updates the validation metric state. - /// - /// # Arguments - /// - /// * `state` - The metric state. - fn update_valid(&mut self, state: MetricState); + /// Updates the validation metric state. + /// + /// # Arguments + /// + /// * `state` - The metric state. + fn update_valid(&mut self, state: MetricState); - /// Renders the training progress. - /// - /// # Arguments - /// - /// * `item` - The training progress. - fn render_train(&mut self, item: TrainingProgress); + /// Renders the training progress. + /// + /// # Arguments + /// + /// * `item` - The training progress. + fn render_train(&mut self, item: TrainingProgress); - /// Renders the validation progress. - /// - /// # Arguments - /// - /// * `item` - The validation progress. - fn render_valid(&mut self, item: TrainingProgress); + /// Renders the validation progress. + /// + /// # Arguments + /// + /// * `item` - The validation progress. + fn render_valid(&mut self, item: TrainingProgress); } /// The state of a metric. #[derive(Debug)] pub enum MetricState { - /// A generic metric. - Generic(MetricEntry), + /// A generic metric. + Generic(MetricEntry), - /// A numeric metric. - Numeric(MetricEntry, f64), + /// A numeric metric. + Numeric(MetricEntry, f64), } /// Training progress. #[derive(Debug)] pub struct TrainingProgress { - /// The progress. - pub progress: Progress, + /// The progress. + pub progress: Progress, - /// The epoch. - pub epoch: usize, + /// The epoch. + pub epoch: usize, - /// The total number of epochs. - pub epoch_total: usize, + /// The total number of epochs. + pub epoch_total: usize, - /// The iteration. - pub iteration: usize, + /// The iteration. + pub iteration: usize, } impl TrainingProgress { - /// Creates a new empty training progress. - pub fn none() -> Self { - Self { - progress: Progress { - items_processed: 0, - items_total: 0, - }, - epoch: 0, - epoch_total: 0, - iteration: 0, + /// Creates a new empty training progress. + pub fn none() -> Self { + Self { + progress: Progress { + items_processed: 0, + items_total: 0, + }, + epoch: 0, + epoch_total: 0, + iteration: 0, + } } - } } diff --git a/burn-train/src/renderer/cli.rs b/burn-train/src/renderer/cli.rs index 1ed3cf3acb..d5a974a51e 100644 --- a/burn-train/src/renderer/cli.rs +++ b/burn-train/src/renderer/cli.rs @@ -4,22 +4,22 @@ use crate::renderer::{MetricState, MetricsRenderer, TrainingProgress}; pub struct CliMetricsRenderer; impl CliMetricsRenderer { - /// Create a new instance. - pub fn new() -> Self { - Self {} - } + /// Create a new instance. + pub fn new() -> Self { + Self {} + } } impl MetricsRenderer for CliMetricsRenderer { - fn update_train(&mut self, _state: MetricState) {} + fn update_train(&mut self, _state: MetricState) {} - fn update_valid(&mut self, _state: MetricState) {} + fn update_valid(&mut self, _state: MetricState) {} - fn render_train(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_train(&mut self, item: TrainingProgress) { + dbg!(item); + } - fn render_valid(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_valid(&mut self, item: TrainingProgress) { + dbg!(item); + } } diff --git a/burn-train/src/renderer/mod.rs b/burn-train/src/renderer/mod.rs index 7e2132bc74..9002184326 100644 --- a/burn-train/src/renderer/mod.rs +++ b/burn-train/src/renderer/mod.rs @@ -15,12 +15,12 @@ pub use tui::TuiMetricsRenderer as SelectedMetricsRenderer; /// The TUI renderer, or a simple stub if the tui feature is not enabled. #[allow(unused_variables)] pub(crate) fn default_renderer( - interuptor: TrainingInterrupter, - checkpoint: Option, + interuptor: TrainingInterrupter, + checkpoint: Option, ) -> SelectedMetricsRenderer { - #[cfg(feature = "tui")] - return SelectedMetricsRenderer::new(interuptor, checkpoint); + #[cfg(feature = "tui")] + return SelectedMetricsRenderer::new(interuptor, checkpoint); - #[cfg(not(feature = "tui"))] - return SelectedMetricsRenderer::new(); + #[cfg(not(feature = "tui"))] + return SelectedMetricsRenderer::new(); } diff --git a/burn-train/src/renderer/tui/base.rs b/burn-train/src/renderer/tui/base.rs index 04cce37144..38d6c25b31 100644 --- a/burn-train/src/renderer/tui/base.rs +++ b/burn-train/src/renderer/tui/base.rs @@ -1,45 +1,45 @@ use super::{ - ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView, + ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView, }; use ratatui::prelude::{Constraint, Direction, Layout, Rect}; #[derive(new)] pub(crate) struct MetricsView<'a> { - metric_numeric: NumericMetricView<'a>, - metric_text: TextMetricView, - progress: ProgressBarView, - controls: ControlsView, - status: StatusView, + metric_numeric: NumericMetricView<'a>, + metric_text: TextMetricView, + progress: ProgressBarView, + controls: ControlsView, + status: StatusView, } impl<'a> MetricsView<'a> { - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Min(16), Constraint::Max(3)].as_ref()) - .split(size); - let size_other = chunks[0]; - let size_progress = chunks[1]; + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(16), Constraint::Max(3)].as_ref()) + .split(size); + let size_other = chunks[0]; + let size_progress = chunks[1]; - let chunks = Layout::default() - .direction(Direction::Horizontal) - .constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref()) - .split(size_other); - let size_other = chunks[0]; - let size_metric_numeric = chunks[1]; + let chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref()) + .split(size_other); + let size_other = chunks[0]; + let size_metric_numeric = chunks[1]; - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref()) - .split(size_other); - let size_controls = chunks[0]; - let size_metric_text = chunks[1]; - let size_status = chunks[2]; + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref()) + .split(size_other); + let size_controls = chunks[0]; + let size_metric_text = chunks[1]; + let size_status = chunks[2]; - self.metric_numeric.render(frame, size_metric_numeric); - self.metric_text.render(frame, size_metric_text); - self.controls.render(frame, size_controls); - self.progress.render(frame, size_progress); - self.status.render(frame, size_status); - } + self.metric_numeric.render(frame, size_metric_numeric); + self.metric_text.render(frame, size_metric_text); + self.controls.render(frame, size_controls); + self.progress.render(frame, size_progress); + self.status.render(frame, size_status); + } } diff --git a/burn-train/src/renderer/tui/controls.rs b/burn-train/src/renderer/tui/controls.rs index 50ed7491dc..e48778034b 100644 --- a/burn-train/src/renderer/tui/controls.rs +++ b/burn-train/src/renderer/tui/controls.rs @@ -1,46 +1,46 @@ use super::TerminalFrame; use ratatui::{ - prelude::{Alignment, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; /// Controls view. pub(crate) struct ControlsView; impl ControlsView { - /// Render the view. - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let lines = vec![ - vec![ - Span::from(" Quit : ").yellow().bold(), - Span::from("q ").bold(), - Span::from(" Stop the training.").italic(), - ], - vec![ - Span::from(" Plots Metrics : ").yellow().bold(), - Span::from("⬅ ➡").bold(), - Span::from(" Switch between metrics.").italic(), - ], - vec![ - Span::from(" Plots Type : ").yellow().bold(), - Span::from("⬆ ⬇").bold(), - Span::from(" Switch between types.").italic(), - ], - ]; - let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::>()) - .alignment(Alignment::Left) - .wrap(Wrap { trim: false }) - .style(Style::default().fg(Color::Gray)) - .block( - Block::default() - .borders(Borders::ALL) - .style(Style::default().fg(Color::Gray)) - .title_alignment(Alignment::Left) - .title("Controls"), - ); + /// Render the view. + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let lines = vec![ + vec![ + Span::from(" Quit : ").yellow().bold(), + Span::from("q ").bold(), + Span::from(" Stop the training.").italic(), + ], + vec![ + Span::from(" Plots Metrics : ").yellow().bold(), + Span::from("⬅ ➡").bold(), + Span::from(" Switch between metrics.").italic(), + ], + vec![ + Span::from(" Plots Type : ").yellow().bold(), + Span::from("⬆ ⬇").bold(), + Span::from(" Switch between types.").italic(), + ], + ]; + let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::>()) + .alignment(Alignment::Left) + .wrap(Wrap { trim: false }) + .style(Style::default().fg(Color::Gray)) + .block( + Block::default() + .borders(Borders::ALL) + .style(Style::default().fg(Color::Gray)) + .title_alignment(Alignment::Left) + .title("Controls"), + ); - frame.render_widget(paragraph, size); - } + frame.render_widget(paragraph, size); + } } diff --git a/burn-train/src/renderer/tui/full_history.rs b/burn-train/src/renderer/tui/full_history.rs index d33641ad19..3c2e4e90e7 100644 --- a/burn-train/src/renderer/tui/full_history.rs +++ b/burn-train/src/renderer/tui/full_history.rs @@ -1,216 +1,216 @@ use super::PlotAxes; use ratatui::{ - style::{Color, Style, Stylize}, - symbols, - widgets::{Dataset, GraphType}, + style::{Color, Style, Stylize}, + symbols, + widgets::{Dataset, GraphType}, }; /// A plot that shows the full history at a reduced resolution. pub(crate) struct FullHistoryPlot { - pub(crate) axes: PlotAxes, - train: FullHistoryPoints, - valid: FullHistoryPoints, - next_x_state: usize, + pub(crate) axes: PlotAxes, + train: FullHistoryPoints, + valid: FullHistoryPoints, + next_x_state: usize, } struct FullHistoryPoints { - min_x: f64, - max_x: f64, - min_y: f64, - max_y: f64, - points: Vec<(f64, f64)>, - max_samples: usize, - step_size: usize, + min_x: f64, + max_x: f64, + min_y: f64, + max_y: f64, + points: Vec<(f64, f64)>, + max_samples: usize, + step_size: usize, } impl FullHistoryPlot { - /// Create a new history plot. - pub(crate) fn new(max_samples: usize) -> Self { - Self { - axes: PlotAxes::default(), - train: FullHistoryPoints::new(max_samples), - valid: FullHistoryPoints::new(max_samples), - next_x_state: 0, - } - } - - /// Update the maximum amount of sample to display for the validation points. - /// - /// This is necessary if we want the validation line to have the same point density as the - /// training line. - pub(crate) fn update_max_sample_valid(&mut self, ratio_train: f64) { - if self.valid.step_size == 1 { - self.valid.max_samples = (ratio_train * self.train.max_samples as f64) as usize; + /// Create a new history plot. + pub(crate) fn new(max_samples: usize) -> Self { + Self { + axes: PlotAxes::default(), + train: FullHistoryPoints::new(max_samples), + valid: FullHistoryPoints::new(max_samples), + next_x_state: 0, + } } - } - - /// Register a training data point. - pub(crate) fn push_train(&mut self, data: f64) { - let x_current = self.next_x(); - self.train.push((x_current, data)); - self.update_bounds(); - } + /// Update the maximum amount of sample to display for the validation points. + /// + /// This is necessary if we want the validation line to have the same point density as the + /// training line. + pub(crate) fn update_max_sample_valid(&mut self, ratio_train: f64) { + if self.valid.step_size == 1 { + self.valid.max_samples = (ratio_train * self.train.max_samples as f64) as usize; + } + } - /// Register a validation data point. - pub(crate) fn push_valid(&mut self, data: f64) { - let x_current = self.next_x(); + /// Register a training data point. + pub(crate) fn push_train(&mut self, data: f64) { + let x_current = self.next_x(); + self.train.push((x_current, data)); - self.valid.push((x_current, data)); + self.update_bounds(); + } - self.update_bounds(); - } + /// Register a validation data point. + pub(crate) fn push_valid(&mut self, data: f64) { + let x_current = self.next_x(); - /// Create the training and validation datasets from the data points. - pub(crate) fn datasets(&self) -> Vec> { - let mut datasets = Vec::with_capacity(2); + self.valid.push((x_current, data)); - if !self.train.is_empty() { - datasets.push(self.train.dataset("Train", Color::LightRed)); + self.update_bounds(); } - if !self.valid.is_empty() { - datasets.push(self.valid.dataset("Valid", Color::LightBlue)); - } + /// Create the training and validation datasets from the data points. + pub(crate) fn datasets(&self) -> Vec> { + let mut datasets = Vec::with_capacity(2); - datasets - } - - fn next_x(&mut self) -> f64 { - let value = self.next_x_state; - self.next_x_state += 1; - value as f64 - } - - fn update_bounds(&mut self) { - self.axes.update_bounds( - (self.train.min_x, self.train.max_x), - (self.valid.min_x, self.valid.max_x), - (self.train.min_y, self.train.max_y), - (self.valid.min_y, self.valid.max_y), - ); - } -} + if !self.train.is_empty() { + datasets.push(self.train.dataset("Train", Color::LightRed)); + } -impl FullHistoryPoints { - fn new(max_samples: usize) -> Self { - Self { - min_x: 0., - max_x: 0., - min_y: f64::MAX, - max_y: f64::MIN, - points: Vec::with_capacity(max_samples), - max_samples, - step_size: 1, - } - } + if !self.valid.is_empty() { + datasets.push(self.valid.dataset("Valid", Color::LightBlue)); + } - fn push(&mut self, (x, y): (f64, f64)) { - if x as usize % self.step_size != 0 { - return; + datasets } - if x > self.max_x { - self.max_x = x; - } - if x < self.min_x { - self.min_x = x; + fn next_x(&mut self) -> f64 { + let value = self.next_x_state; + self.next_x_state += 1; + value as f64 } - if y > self.max_y { - self.max_y = y; + + fn update_bounds(&mut self) { + self.axes.update_bounds( + (self.train.min_x, self.train.max_x), + (self.valid.min_x, self.valid.max_x), + (self.train.min_y, self.train.max_y), + (self.valid.min_y, self.valid.max_y), + ); } - if y < self.min_y { - self.min_y = y +} + +impl FullHistoryPoints { + fn new(max_samples: usize) -> Self { + Self { + min_x: 0., + max_x: 0., + min_y: f64::MAX, + max_y: f64::MIN, + points: Vec::with_capacity(max_samples), + max_samples, + step_size: 1, + } } - self.points.push((x, y)); + fn push(&mut self, (x, y): (f64, f64)) { + if x as usize % self.step_size != 0 { + return; + } - if self.points.len() > self.max_samples { - self.resize(); - } - } - - /// We keep only half the points and we double the step size. - /// - /// This ensure that we have the same amount of points across the X axis. - fn resize(&mut self) { - let mut points = Vec::with_capacity(self.max_samples / 2); - let mut max_x = f64::MIN; - let mut max_y = f64::MIN; - let mut min_x = f64::MAX; - let mut min_y = f64::MAX; - - for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() { - if i % 2 == 0 { - if x > max_x { - max_x = x; + if x > self.max_x { + self.max_x = x; } - if x < min_x { - min_x = x; + if x < self.min_x { + self.min_x = x; } - if y > max_y { - max_y = y; + if y > self.max_y { + self.max_y = y; } - if y < min_y { - min_y = y; + if y < self.min_y { + self.min_y = y } - points.push((x, y)); - } + self.points.push((x, y)); + + if self.points.len() > self.max_samples { + self.resize(); + } + } + + /// We keep only half the points and we double the step size. + /// + /// This ensure that we have the same amount of points across the X axis. + fn resize(&mut self) { + let mut points = Vec::with_capacity(self.max_samples / 2); + let mut max_x = f64::MIN; + let mut max_y = f64::MIN; + let mut min_x = f64::MAX; + let mut min_y = f64::MAX; + + for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() { + if i % 2 == 0 { + if x > max_x { + max_x = x; + } + if x < min_x { + min_x = x; + } + if y > max_y { + max_y = y; + } + if y < min_y { + min_y = y; + } + + points.push((x, y)); + } + } + + self.points = points; + self.step_size *= 2; + + self.min_x = min_x; + self.max_x = max_x; + self.min_y = min_y; + self.max_y = max_y; + } + + fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { + Dataset::default() + .name(name) + .marker(symbols::Marker::Braille) + .style(Style::default().fg(color).bold()) + .graph_type(GraphType::Line) + .data(&self.points) } - self.points = points; - self.step_size *= 2; - - self.min_x = min_x; - self.max_x = max_x; - self.min_y = min_y; - self.max_y = max_y; - } - - fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { - Dataset::default() - .name(name) - .marker(symbols::Marker::Braille) - .style(Style::default().fg(color).bold()) - .graph_type(GraphType::Line) - .data(&self.points) - } - - fn is_empty(&self) -> bool { - self.points.is_empty() - } + fn is_empty(&self) -> bool { + self.points.is_empty() + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_points() { - let mut chart = FullHistoryPlot::new(10); - chart.update_max_sample_valid(0.6); + #[test] + fn test_points() { + let mut chart = FullHistoryPlot::new(10); + chart.update_max_sample_valid(0.6); - for i in 0..100 { - chart.push_train(i as f64); - } - for i in 0..60 { - chart.push_valid(i as f64); - } + for i in 0..100 { + chart.push_train(i as f64); + } + for i in 0..60 { + chart.push_valid(i as f64); + } + + let expected_train = vec![ + (0.0, 0.0), + (16.0, 16.0), + (32.0, 32.0), + (48.0, 48.0), + (64.0, 64.0), + (80.0, 80.0), + (96.0, 96.0), + ]; - let expected_train = vec![ - (0.0, 0.0), - (16.0, 16.0), - (32.0, 32.0), - (48.0, 48.0), - (64.0, 64.0), - (80.0, 80.0), - (96.0, 96.0), - ]; - - let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)]; - - assert_eq!(chart.train.points, expected_train); - assert_eq!(chart.valid.points, expected_valid); - } + let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)]; + + assert_eq!(chart.train.points, expected_train); + assert_eq!(chart.valid.points, expected_valid); + } } diff --git a/burn-train/src/renderer/tui/metric_numeric.rs b/burn-train/src/renderer/tui/metric_numeric.rs index 5a27182147..ccae8e295c 100644 --- a/burn-train/src/renderer/tui/metric_numeric.rs +++ b/burn-train/src/renderer/tui/metric_numeric.rs @@ -3,10 +3,10 @@ use crate::renderer::TrainingProgress; use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame}; use crossterm::event::{Event, KeyCode}; use ratatui::{ - prelude::{Alignment, Constraint, Direction, Layout, Rect}, - style::{Color, Modifier, Style, Stylize}, - text::Line, - widgets::{Axis, Block, Borders, Chart, Paragraph, Tabs}, + prelude::{Alignment, Constraint, Direction, Layout, Rect}, + style::{Color, Modifier, Style, Stylize}, + text::Line, + widgets::{Axis, Block, Borders, Chart, Paragraph, Tabs}, }; use std::collections::HashMap; @@ -19,213 +19,214 @@ const MAX_NUM_SAMPLES_FULL: usize = 250; /// Numeric metrics state that handles creating plots. #[derive(Default)] pub(crate) struct NumericMetricsState { - data: HashMap, - names: Vec, - selected: usize, - kind: PlotKind, - num_samples_train: Option, - num_samples_valid: Option, + data: HashMap, + names: Vec, + selected: usize, + kind: PlotKind, + num_samples_train: Option, + num_samples_valid: Option, } /// The kind of plot to display. #[derive(Default, Clone, Copy)] pub(crate) enum PlotKind { - /// Display the full history of the metric with reduced resolution. - #[default] - Full, - /// Display only the recent history of the metric, but with more resolution. - Recent, + /// Display the full history of the metric with reduced resolution. + #[default] + Full, + /// Display only the recent history of the metric, but with more resolution. + Recent, } impl NumericMetricsState { - /// Register a new training value for the metric with the given name. - pub(crate) fn push_train(&mut self, name: String, data: f64) { - if let Some((recent, full)) = self.data.get_mut(&name) { - recent.push_train(data); - full.push_train(data); - } else { - let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); - let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); - - recent.push_train(data); - full.push_train(data); - - self.names.push(name.clone()); - self.data.insert(name, (recent, full)); + /// Register a new training value for the metric with the given name. + pub(crate) fn push_train(&mut self, name: String, data: f64) { + if let Some((recent, full)) = self.data.get_mut(&name) { + recent.push_train(data); + full.push_train(data); + } else { + let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); + let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); + + recent.push_train(data); + full.push_train(data); + + self.names.push(name.clone()); + self.data.insert(name, (recent, full)); + } } - } - /// Register a new validation value for the metric with the given name. - pub(crate) fn push_valid(&mut self, key: String, data: f64) { - if let Some((recent, full)) = self.data.get_mut(&key) { - recent.push_valid(data); - full.push_valid(data); - } else { - let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); - let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); + /// Register a new validation value for the metric with the given name. + pub(crate) fn push_valid(&mut self, key: String, data: f64) { + if let Some((recent, full)) = self.data.get_mut(&key) { + recent.push_valid(data); + full.push_valid(data); + } else { + let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); + let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); - recent.push_valid(data); - full.push_valid(data); + recent.push_valid(data); + full.push_valid(data); - self.data.insert(key, (recent, full)); + self.data.insert(key, (recent, full)); + } } - } - /// Update the state with the training progress. - pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) { - if self.num_samples_train.is_some() { - return; + /// Update the state with the training progress. + pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) { + if self.num_samples_train.is_some() { + return; + } + + self.num_samples_train = Some(progress.progress.items_total); } - self.num_samples_train = Some(progress.progress.items_total); - } + /// Update the state with the validation progress. + pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) { + if self.num_samples_valid.is_some() { + return; + } + + if let Some(num_sample_train) = self.num_samples_train { + for (_, (_recent, full)) in self.data.iter_mut() { + let ratio = progress.progress.items_total as f64 / num_sample_train as f64; + full.update_max_sample_valid(ratio); + } + } - /// Update the state with the validation progress. - pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) { - if self.num_samples_valid.is_some() { - return; + self.num_samples_valid = Some(progress.progress.items_total); } - if let Some(num_sample_train) = self.num_samples_train { - for (_, (_recent, full)) in self.data.iter_mut() { - let ratio = progress.progress.items_total as f64 / num_sample_train as f64; - full.update_max_sample_valid(ratio); - } + /// Create a view to display the numeric metrics. + pub(crate) fn view(&self) -> NumericMetricView<'_> { + match self.names.is_empty() { + true => NumericMetricView::None, + false => NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind), + } } - self.num_samples_valid = Some(progress.progress.items_total); - } + /// Handle the current event. + pub(crate) fn on_event(&mut self, event: &Event) { + if let Event::Key(key) = event { + match key.code { + KeyCode::Right => self.next_metric(), + KeyCode::Left => self.previous_metric(), + KeyCode::Up => self.switch_kind(), + KeyCode::Down => self.switch_kind(), + _ => {} + } + } + } - /// Create a view to display the numeric metrics. - pub(crate) fn view(&self) -> NumericMetricView<'_> { - match self.names.is_empty() { - true => NumericMetricView::None, - false => NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind), + fn switch_kind(&mut self) { + self.kind = match self.kind { + PlotKind::Full => PlotKind::Recent, + PlotKind::Recent => PlotKind::Full, + }; } - } - - /// Handle the current event. - pub(crate) fn on_event(&mut self, event: &Event) { - if let Event::Key(key) = event { - match key.code { - KeyCode::Right => self.next_metric(), - KeyCode::Left => self.previous_metric(), - KeyCode::Up => self.switch_kind(), - KeyCode::Down => self.switch_kind(), - _ => {} - } + + fn next_metric(&mut self) { + self.selected = (self.selected + 1) % { + let this = &self; + this.data.len() + }; } - } - - fn switch_kind(&mut self) { - self.kind = match self.kind { - PlotKind::Full => PlotKind::Recent, - PlotKind::Recent => PlotKind::Full, - }; - } - - fn next_metric(&mut self) { - self.selected = (self.selected + 1) % { - let this = &self; - this.data.len() - }; - } - - fn previous_metric(&mut self) { - if self.selected > 0 { - self.selected -= 1; - } else { - self.selected = ({ - let this = &self; - this.data.len() - }) - 1; + + fn previous_metric(&mut self) { + if self.selected > 0 { + self.selected -= 1; + } else { + self.selected = ({ + let this = &self; + this.data.len() + }) - 1; + } + } + + fn chart<'a>(&'a self) -> Chart<'a> { + let name = self.names.get(self.selected).unwrap(); + let (recent, full) = self.data.get(name).unwrap(); + + let (datasets, axes) = match self.kind { + PlotKind::Full => (full.datasets(), &full.axes), + PlotKind::Recent => (recent.datasets(), &recent.axes), + }; + + Chart::<'a>::new(datasets) + .block(Block::default()) + .x_axis( + Axis::default() + .style(Style::default().fg(Color::DarkGray)) + .title("Iteration") + .labels(axes.labels_x.iter().map(|s| s.bold()).collect()) + .bounds(axes.bounds_x), + ) + .y_axis( + Axis::default() + .style(Style::default().fg(Color::DarkGray)) + .labels(axes.labels_y.iter().map(|s| s.bold()).collect()) + .bounds(axes.bounds_y), + ) } - } - - fn chart<'a>(&'a self) -> Chart<'a> { - let name = self.names.get(self.selected).unwrap(); - let (recent, full) = self.data.get(name).unwrap(); - - let (datasets, axes) = match self.kind { - PlotKind::Full => (full.datasets(), &full.axes), - PlotKind::Recent => (recent.datasets(), &recent.axes), - }; - - Chart::<'a>::new(datasets) - .block(Block::default()) - .x_axis( - Axis::default() - .style(Style::default().fg(Color::DarkGray)) - .title("Iteration") - .labels(axes.labels_x.iter().map(|s| s.bold()).collect()) - .bounds(axes.bounds_x), - ) - .y_axis( - Axis::default() - .style(Style::default().fg(Color::DarkGray)) - .labels(axes.labels_y.iter().map(|s| s.bold()).collect()) - .bounds(axes.bounds_y), - ) - } } #[derive(new)] pub(crate) enum NumericMetricView<'a> { - Plots(&'a [String], usize, Chart<'a>, PlotKind), - None, + Plots(&'a [String], usize, Chart<'a>, PlotKind), + None, } impl<'a> NumericMetricView<'a> { - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - match self { - Self::Plots(titles, selected, chart, kind) => { - let block = Block::default() - .borders(Borders::ALL) - .title("Plots") - .title_alignment(Alignment::Left); - let size_new = block.inner(size); - frame.render_widget(block, size); - - let size = size_new; - - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints( - [ - Constraint::Length(2), - Constraint::Length(1), - Constraint::Min(0), - ] - .as_ref(), - ) - .split(size); - - let titles = titles - .iter() - .map(|i| Line::from(vec![i.yellow()])) - .collect(); - - let tabs = Tabs::new(titles) - .select(selected) - .style(Style::default()) - .highlight_style( - Style::default() - .add_modifier(Modifier::BOLD) - .add_modifier(Modifier::UNDERLINED) - .fg(Color::LightYellow), - ); - let title = match kind { - PlotKind::Full => "Full History", - PlotKind::Recent => "Recent History", + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + match self { + Self::Plots(titles, selected, chart, kind) => { + let block = Block::default() + .borders(Borders::ALL) + .title("Plots") + .title_alignment(Alignment::Left); + let size_new = block.inner(size); + frame.render_widget(block, size); + + let size = size_new; + + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints( + [ + Constraint::Length(2), + Constraint::Length(1), + Constraint::Min(0), + ] + .as_ref(), + ) + .split(size); + + let titles = titles + .iter() + .map(|i| Line::from(vec![i.yellow()])) + .collect(); + + let tabs = Tabs::new(titles) + .select(selected) + .style(Style::default()) + .highlight_style( + Style::default() + .add_modifier(Modifier::BOLD) + .add_modifier(Modifier::UNDERLINED) + .fg(Color::LightYellow), + ); + let title = match kind { + PlotKind::Full => "Full History", + PlotKind::Recent => "Recent History", + }; + + let plot_type = + Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center); + + frame.render_widget(tabs, chunks[0]); + frame.render_widget(plot_type, chunks[1]); + frame.render_widget(chart, chunks[2]); + } + Self::None => {} }; - - let plot_type = Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center); - - frame.render_widget(tabs, chunks[0]); - frame.render_widget(plot_type, chunks[1]); - frame.render_widget(chart, chunks[2]); - } - Self::None => {} - }; - } + } } diff --git a/burn-train/src/renderer/tui/metric_text.rs b/burn-train/src/renderer/tui/metric_text.rs index 338877b808..c6d6b021c1 100644 --- a/burn-train/src/renderer/tui/metric_text.rs +++ b/burn-train/src/renderer/tui/metric_text.rs @@ -1,101 +1,101 @@ use super::TerminalFrame; use crate::metric::MetricEntry; use ratatui::{ - prelude::{Alignment, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; use std::collections::HashMap; #[derive(Default)] pub(crate) struct TextMetricsState { - data: HashMap, - names: Vec, + data: HashMap, + names: Vec, } #[derive(new)] pub(crate) struct MetricData { - train: Option, - valid: Option, + train: Option, + valid: Option, } impl TextMetricsState { - pub(crate) fn update_train(&mut self, metric: MetricEntry) { - if let Some(existing) = self.data.get_mut(&metric.name) { - existing.train = Some(metric); - } else { - let key = metric.name.clone(); - let value = MetricData::new(Some(metric), None); - - self.names.push(key.clone()); - self.data.insert(key, value); + pub(crate) fn update_train(&mut self, metric: MetricEntry) { + if let Some(existing) = self.data.get_mut(&metric.name) { + existing.train = Some(metric); + } else { + let key = metric.name.clone(); + let value = MetricData::new(Some(metric), None); + + self.names.push(key.clone()); + self.data.insert(key, value); + } } - } - pub(crate) fn update_valid(&mut self, metric: MetricEntry) { - if let Some(existing) = self.data.get_mut(&metric.name) { - existing.valid = Some(metric); - } else { - let key = metric.name.clone(); - let value = MetricData::new(None, Some(metric)); - - self.names.push(key.clone()); - self.data.insert(key, value); + pub(crate) fn update_valid(&mut self, metric: MetricEntry) { + if let Some(existing) = self.data.get_mut(&metric.name) { + existing.valid = Some(metric); + } else { + let key = metric.name.clone(); + let value = MetricData::new(None, Some(metric)); + + self.names.push(key.clone()); + self.data.insert(key, value); + } + } + pub(crate) fn view(&self) -> TextMetricView { + TextMetricView::new(&self.names, &self.data) } - } - pub(crate) fn view(&self) -> TextMetricView { - TextMetricView::new(&self.names, &self.data) - } } pub(crate) struct TextMetricView { - lines: Vec>>, + lines: Vec>>, } impl TextMetricView { - fn new(names: &[String], data: &HashMap) -> Self { - let mut lines = Vec::with_capacity(names.len() * 4); - - let start_line = |title: &str| vec![Span::from(format!(" {title} ")).bold().yellow()]; - let train_line = |formatted: &str| { - vec![ - Span::from(" Train ").bold(), - Span::from(formatted.to_string()).italic(), - ] - }; - let valid_line = |formatted: &str| { - vec![ - Span::from(" Valid ").bold(), - Span::from(formatted.to_string()).italic(), - ] - }; - - for name in names { - lines.push(start_line(name)); - - let entry = data.get(name).unwrap(); - - if let Some(entry) = &entry.train { - lines.push(train_line(&entry.formatted)); - } - - if let Some(entry) = &entry.valid { - lines.push(valid_line(&entry.formatted)); - } - - lines.push(vec![Span::from("")]); + fn new(names: &[String], data: &HashMap) -> Self { + let mut lines = Vec::with_capacity(names.len() * 4); + + let start_line = |title: &str| vec![Span::from(format!(" {title} ")).bold().yellow()]; + let train_line = |formatted: &str| { + vec![ + Span::from(" Train ").bold(), + Span::from(formatted.to_string()).italic(), + ] + }; + let valid_line = |formatted: &str| { + vec![ + Span::from(" Valid ").bold(), + Span::from(formatted.to_string()).italic(), + ] + }; + + for name in names { + lines.push(start_line(name)); + + let entry = data.get(name).unwrap(); + + if let Some(entry) = &entry.train { + lines.push(train_line(&entry.formatted)); + } + + if let Some(entry) = &entry.valid { + lines.push(valid_line(&entry.formatted)); + } + + lines.push(vec![Span::from("")]); + } + + Self { lines } } - Self { lines } - } - - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) - .alignment(Alignment::Left) - .wrap(Wrap { trim: false }) - .block(Block::default().borders(Borders::ALL).title("Metrics")) - .style(Style::default().fg(Color::Gray)); + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) + .alignment(Alignment::Left) + .wrap(Wrap { trim: false }) + .block(Block::default().borders(Borders::ALL).title("Metrics")) + .style(Style::default().fg(Color::Gray)); - frame.render_widget(paragraph, size); - } + frame.render_widget(paragraph, size); + } } diff --git a/burn-train/src/renderer/tui/plot_utils.rs b/burn-train/src/renderer/tui/plot_utils.rs index 30207ced9a..615cad8184 100644 --- a/burn-train/src/renderer/tui/plot_utils.rs +++ b/burn-train/src/renderer/tui/plot_utils.rs @@ -4,45 +4,45 @@ const AXIS_TITLE_PRECISION: usize = 2; /// The data describing both X and Y axes. pub(crate) struct PlotAxes { - pub(crate) labels_x: Vec, - pub(crate) labels_y: Vec, - pub(crate) bounds_x: [f64; 2], - pub(crate) bounds_y: [f64; 2], + pub(crate) labels_x: Vec, + pub(crate) labels_y: Vec, + pub(crate) bounds_x: [f64; 2], + pub(crate) bounds_y: [f64; 2], } impl Default for PlotAxes { - fn default() -> Self { - Self { - bounds_x: [f64::MAX, f64::MIN], - bounds_y: [f64::MAX, f64::MIN], - labels_x: Vec::new(), - labels_y: Vec::new(), + fn default() -> Self { + Self { + bounds_x: [f64::MAX, f64::MIN], + bounds_y: [f64::MAX, f64::MIN], + labels_x: Vec::new(), + labels_y: Vec::new(), + } } - } } impl PlotAxes { - /// Update the bounds based on the min max of each X and Y axes with both train and valid data. - pub(crate) fn update_bounds( - &mut self, - (x_train_min, x_train_max): (f64, f64), - (x_valid_min, x_valid_max): (f64, f64), - (y_train_min, y_train_max): (f64, f64), - (y_valid_min, y_valid_max): (f64, f64), - ) { - let x_min = f64::min(x_train_min, x_valid_min); - let x_max = f64::max(x_train_max, x_valid_max); - let y_min = f64::min(y_train_min, y_valid_min); - let y_max = f64::max(y_train_max, y_valid_max); + /// Update the bounds based on the min max of each X and Y axes with both train and valid data. + pub(crate) fn update_bounds( + &mut self, + (x_train_min, x_train_max): (f64, f64), + (x_valid_min, x_valid_max): (f64, f64), + (y_train_min, y_train_max): (f64, f64), + (y_valid_min, y_valid_max): (f64, f64), + ) { + let x_min = f64::min(x_train_min, x_valid_min); + let x_max = f64::max(x_train_max, x_valid_max); + let y_min = f64::min(y_train_min, y_valid_min); + let y_max = f64::max(y_train_max, y_valid_max); - self.bounds_x = [x_min, x_max]; - self.bounds_y = [y_min, y_max]; + self.bounds_x = [x_min, x_max]; + self.bounds_y = [y_min, y_max]; - // We know x are integers. - self.labels_x = vec![format!("{x_min}"), format!("{x_max}")]; - self.labels_y = vec![ - format_float(y_min, AXIS_TITLE_PRECISION), - format_float(y_max, AXIS_TITLE_PRECISION), - ]; - } + // We know x are integers. + self.labels_x = vec![format!("{x_min}"), format!("{x_max}")]; + self.labels_y = vec![ + format_float(y_min, AXIS_TITLE_PRECISION), + format_float(y_max, AXIS_TITLE_PRECISION), + ]; + } } diff --git a/burn-train/src/renderer/tui/popup.rs b/burn-train/src/renderer/tui/popup.rs index b75773d0e7..b39a9c8f66 100644 --- a/burn-train/src/renderer/tui/popup.rs +++ b/burn-train/src/renderer/tui/popup.rs @@ -1,144 +1,144 @@ use crossterm::event::{Event, KeyCode}; use ratatui::{ - prelude::{Alignment, Constraint, Direction, Layout, Rect}, - style::{Color, Modifier, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Constraint, Direction, Layout, Rect}, + style::{Color, Modifier, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; use super::TerminalFrame; /// Popup callback function. pub(crate) trait CallbackFn: Send + Sync { - /// Call the function and return if the popup state should be reset. - fn call(&self) -> bool; + /// Call the function and return if the popup state should be reset. + fn call(&self) -> bool; } /// Popup callback. pub(crate) struct Callback { - title: String, - description: String, - trigger: char, - callback: Box, + title: String, + description: String, + trigger: char, + callback: Box, } impl Callback { - /// Create a new popup. - pub(crate) fn new(title: T, description: D, trigger: char, callback: C) -> Self - where - T: Into, - D: Into, - C: CallbackFn + 'static, - { - Self { - title: title.into(), - description: description.into(), - trigger, - callback: Box::new(callback), + /// Create a new popup. + pub(crate) fn new(title: T, description: D, trigger: char, callback: C) -> Self + where + T: Into, + D: Into, + C: CallbackFn + 'static, + { + Self { + title: title.into(), + description: description.into(), + trigger, + callback: Box::new(callback), + } } - } } /// Popup state. pub(crate) enum PopupState { - Empty, - Full(String, Vec), + Empty, + Full(String, Vec), } impl PopupState { - /// If the popup is empty. - pub(crate) fn is_empty(&self) -> bool { - matches!(&self, PopupState::Empty) - } - /// Handle popup events. - pub(crate) fn on_event(&mut self, event: &Event) { - let mut reset = false; + /// If the popup is empty. + pub(crate) fn is_empty(&self) -> bool { + matches!(&self, PopupState::Empty) + } + /// Handle popup events. + pub(crate) fn on_event(&mut self, event: &Event) { + let mut reset = false; - match self { - PopupState::Empty => {} - PopupState::Full(_, callbacks) => { - for callback in callbacks.iter() { - if let Event::Key(key) = event { - if let KeyCode::Char(key) = &key.code { - if &callback.trigger == key && callback.callback.call() { - reset = true; - } + match self { + PopupState::Empty => {} + PopupState::Full(_, callbacks) => { + for callback in callbacks.iter() { + if let Event::Key(key) = event { + if let KeyCode::Char(key) = &key.code { + if &callback.trigger == key && callback.callback.call() { + reset = true; + } + } + } + } } - } - } - } - }; + }; - if reset { - *self = Self::Empty; + if reset { + *self = Self::Empty; + } } - } - /// Create the popup view. - pub(crate) fn view(&self) -> Option> { - match self { - PopupState::Empty => None, - PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)), + /// Create the popup view. + pub(crate) fn view(&self) -> Option> { + match self { + PopupState::Empty => None, + PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)), + } } - } } #[derive(new)] pub(crate) struct PopupView<'a> { - title: &'a String, - callbacks: &'a [Callback], + title: &'a String, + callbacks: &'a [Callback], } impl<'a> PopupView<'a> { - /// Render the view. - pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) { - let lines = self - .callbacks - .iter() - .flat_map(|callback| { - vec![ - Line::from(vec![ - Span::from(format!("[{}] ", callback.trigger)).bold(), - Span::from(format!("{} ", callback.title)).yellow().bold(), - ]), - Line::from(Span::from("")), - Line::from(Span::from(callback.description.to_string()).italic()), - Line::from(Span::from("")), - ] - }) - .collect::>(); + /// Render the view. + pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) { + let lines = self + .callbacks + .iter() + .flat_map(|callback| { + vec![ + Line::from(vec![ + Span::from(format!("[{}] ", callback.trigger)).bold(), + Span::from(format!("{} ", callback.title)).yellow().bold(), + ]), + Line::from(Span::from("")), + Line::from(Span::from(callback.description.to_string()).italic()), + Line::from(Span::from("")), + ] + }) + .collect::>(); - let paragraph = Paragraph::new(lines) - .alignment(Alignment::Left) - .wrap(Wrap { trim: false }) - .style(Style::default().fg(Color::Gray)) - .block( - Block::default() - .borders(Borders::ALL) - .title_alignment(Alignment::Center) - .style(Style::default().fg(Color::Gray)) - .title(Span::styled( - self.title, - Style::default().add_modifier(Modifier::BOLD), - )), - ); + let paragraph = Paragraph::new(lines) + .alignment(Alignment::Left) + .wrap(Wrap { trim: false }) + .style(Style::default().fg(Color::Gray)) + .block( + Block::default() + .borders(Borders::ALL) + .title_alignment(Alignment::Center) + .style(Style::default().fg(Color::Gray)) + .title(Span::styled( + self.title, + Style::default().add_modifier(Modifier::BOLD), + )), + ); - let area = centered_percent(20, size, Direction::Horizontal); - let area = centered_percent(20, area, Direction::Vertical); + let area = centered_percent(20, size, Direction::Horizontal); + let area = centered_percent(20, area, Direction::Vertical); - frame.render_widget(paragraph, area); - } + frame.render_widget(paragraph, area); + } } /// The percent represents the amount of space that will be taken by each side. fn centered_percent(percent: u16, size: Rect, direction: Direction) -> Rect { - let center = 100 - (percent * 2); + let center = 100 - (percent * 2); - Layout::default() - .direction(direction) - .constraints([ - Constraint::Percentage(percent), - Constraint::Percentage(center), - Constraint::Percentage(percent), - ]) - .split(size)[1] + Layout::default() + .direction(direction) + .constraints([ + Constraint::Percentage(percent), + Constraint::Percentage(center), + Constraint::Percentage(percent), + ]) + .split(size)[1] } diff --git a/burn-train/src/renderer/tui/progress.rs b/burn-train/src/renderer/tui/progress.rs index 647eeab38e..5921db2438 100644 --- a/burn-train/src/renderer/tui/progress.rs +++ b/burn-train/src/renderer/tui/progress.rs @@ -1,10 +1,10 @@ use super::TerminalFrame; use crate::renderer::TrainingProgress; use ratatui::{ - prelude::{Alignment, Constraint, Direction, Layout, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Gauge, Paragraph}, + prelude::{Alignment, Constraint, Direction, Layout, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Gauge, Paragraph}, }; use std::time::{Duration, Instant}; @@ -12,9 +12,9 @@ use std::time::{Duration, Instant}; /// /// We currently ignore the time taken for the validation part. pub(crate) struct ProgressBarState { - progress_train: f64, // Progress for total training. - starting_epoch: usize, - estimate: ProgressEstimate, + progress_train: f64, // Progress for total training. + starting_epoch: usize, + estimate: ProgressEstimate, } const MINUTE: u64 = 60; @@ -22,243 +22,243 @@ const HOUR: u64 = 60 * 60; const DAY: u64 = 24 * 60 * 60; impl ProgressBarState { - pub fn new(checkpoint: Option) -> Self { - Self { - progress_train: 0.0, - estimate: ProgressEstimate::new(), - starting_epoch: checkpoint.unwrap_or(0), + pub fn new(checkpoint: Option) -> Self { + Self { + progress_train: 0.0, + estimate: ProgressEstimate::new(), + starting_epoch: checkpoint.unwrap_or(0), + } + } + /// Update the training progress. + pub(crate) fn update_train(&mut self, progress: &TrainingProgress) { + self.progress_train = calculate_progress(progress, 0, 0); + self.estimate.update(progress, self.starting_epoch); + } + + /// Update the validation progress. + pub(crate) fn update_valid(&mut self, _progress: &TrainingProgress) { + // We don't use the validation for the progress yet. + } + + /// Create a view for the current progress. + pub(crate) fn view(&self) -> ProgressBarView { + const NO_ETA: &str = "---"; + + let eta = match self.estimate.secs() { + Some(eta) => format_eta(eta), + None => NO_ETA.to_string(), + }; + ProgressBarView::new(self.progress_train, eta) } - } - /// Update the training progress. - pub(crate) fn update_train(&mut self, progress: &TrainingProgress) { - self.progress_train = calculate_progress(progress, 0, 0); - self.estimate.update(progress, self.starting_epoch); - } - - /// Update the validation progress. - pub(crate) fn update_valid(&mut self, _progress: &TrainingProgress) { - // We don't use the validation for the progress yet. - } - - /// Create a view for the current progress. - pub(crate) fn view(&self) -> ProgressBarView { - const NO_ETA: &str = "---"; - - let eta = match self.estimate.secs() { - Some(eta) => format_eta(eta), - None => NO_ETA.to_string(), - }; - ProgressBarView::new(self.progress_train, eta) - } } #[derive(new)] pub(crate) struct ProgressBarView { - progress: f64, - eta: String, + progress: f64, + eta: String, } impl ProgressBarView { - /// Render the view. - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let block = Block::default() - .borders(Borders::ALL) - .title("Progress") - .title_alignment(Alignment::Left); - let size_new = block.inner(size); - frame.render_widget(block, size); - let size = size_new; - - let chunks = Layout::default() - .direction(Direction::Horizontal) - .constraints( - [ - Constraint::Length(1), // Empty space - Constraint::Min(0), - Constraint::Length(self.eta.len() as u16 + 4), - ] - .as_ref(), - ) - .split(size); - - let size_gauge = chunks[1]; - let size_eta = chunks[2]; - - let iteration = Gauge::default() - .gauge_style(Style::default().fg(Color::Yellow)) - .ratio(self.progress); - let eta = Paragraph::new(Line::from(vec![ - Span::from(" ("), - Span::from(self.eta).italic(), - Span::from(") "), - ])); - - frame.render_widget(iteration, size_gauge); - frame.render_widget(eta, size_eta); - } + /// Render the view. + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let block = Block::default() + .borders(Borders::ALL) + .title("Progress") + .title_alignment(Alignment::Left); + let size_new = block.inner(size); + frame.render_widget(block, size); + let size = size_new; + + let chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints( + [ + Constraint::Length(1), // Empty space + Constraint::Min(0), + Constraint::Length(self.eta.len() as u16 + 4), + ] + .as_ref(), + ) + .split(size); + + let size_gauge = chunks[1]; + let size_eta = chunks[2]; + + let iteration = Gauge::default() + .gauge_style(Style::default().fg(Color::Yellow)) + .ratio(self.progress); + let eta = Paragraph::new(Line::from(vec![ + Span::from(" ("), + Span::from(self.eta).italic(), + Span::from(") "), + ])); + + frame.render_widget(iteration, size_gauge); + frame.render_widget(eta, size_eta); + } } struct ProgressEstimate { - started: Instant, - started_after_warmup: Option, - warmup_num_items: usize, - progress: f64, + started: Instant, + started_after_warmup: Option, + warmup_num_items: usize, + progress: f64, } impl ProgressEstimate { - fn new() -> Self { - Self { - started: Instant::now(), - started_after_warmup: None, - warmup_num_items: 0, - progress: 0.0, + fn new() -> Self { + Self { + started: Instant::now(), + started_after_warmup: None, + warmup_num_items: 0, + progress: 0.0, + } } - } - fn secs(&self) -> Option { - let eta = match self.started_after_warmup { - Some(started) => started.elapsed(), - None => return None, - }; - - let total_estimated = (eta.as_secs() as f64) / self.progress; - - if total_estimated.is_normal() { - let remaining = 1.0 - self.progress; - let eta = (total_estimated * remaining) as u64; - Some(eta) - } else { - None + fn secs(&self) -> Option { + let eta = match self.started_after_warmup { + Some(started) => started.elapsed(), + None => return None, + }; + + let total_estimated = (eta.as_secs() as f64) / self.progress; + + if total_estimated.is_normal() { + let remaining = 1.0 - self.progress; + let eta = (total_estimated * remaining) as u64; + Some(eta) + } else { + None + } } - } - fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) { - if self.started_after_warmup.is_some() { - self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); - return; + fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) { + if self.started_after_warmup.is_some() { + self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); + return; + } + + const WARMUP_NUM_ITERATION: usize = 10; + + // When the training has started since 30 seconds. + if self.started.elapsed() > Duration::from_secs(30) { + self.init(progress, starting_epoch); + return; + } + + // When the training has started since at least 10 seconds and completed 10 iterations. + if progress.iteration >= WARMUP_NUM_ITERATION + && self.started.elapsed() > Duration::from_secs(10) + { + self.init(progress, starting_epoch); + } } - const WARMUP_NUM_ITERATION: usize = 10; + fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) { + let epoch = progress.epoch - starting_epoch; + let epoch_items = (epoch - 1) * progress.progress.items_total; + let iteration_items = progress.progress.items_processed; - // When the training has started since 30 seconds. - if self.started.elapsed() > Duration::from_secs(30) { - self.init(progress, starting_epoch); - return; + self.warmup_num_items = epoch_items + iteration_items; + self.started_after_warmup = Some(Instant::now()); + self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); } - - // When the training has started since at least 10 seconds and completed 10 iterations. - if progress.iteration >= WARMUP_NUM_ITERATION - && self.started.elapsed() > Duration::from_secs(10) - { - self.init(progress, starting_epoch); - } - } - - fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) { - let epoch = progress.epoch - starting_epoch; - let epoch_items = (epoch - 1) * progress.progress.items_total; - let iteration_items = progress.progress.items_processed; - - self.warmup_num_items = epoch_items + iteration_items; - self.started_after_warmup = Some(Instant::now()); - self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); - } } fn calculate_progress( - progress: &TrainingProgress, - starting_epoch: usize, - ignore_num_items: usize, + progress: &TrainingProgress, + starting_epoch: usize, + ignore_num_items: usize, ) -> f64 { - let epoch_total = progress.epoch_total - starting_epoch; - let epoch = progress.epoch - starting_epoch; + let epoch_total = progress.epoch_total - starting_epoch; + let epoch = progress.epoch - starting_epoch; - let total_items = progress.progress.items_total * epoch_total; - let epoch_items = (epoch - 1) * progress.progress.items_total; - let iteration_items = progress.progress.items_processed; - let num_items = epoch_items + iteration_items - ignore_num_items; + let total_items = progress.progress.items_total * epoch_total; + let epoch_items = (epoch - 1) * progress.progress.items_total; + let iteration_items = progress.progress.items_processed; + let num_items = epoch_items + iteration_items - ignore_num_items; - num_items as f64 / total_items as f64 + num_items as f64 / total_items as f64 } fn format_eta(eta_secs: u64) -> String { - let seconds = eta_secs % 60; - let minutes = eta_secs / MINUTE % 60; - let hours = eta_secs / HOUR % 24; - let days = eta_secs / DAY; - - if days > 1 { - format!("{days} days") - } else if days == 1 { - "1 day".to_string() - } else if hours > 1 { - format!("{hours} hours") - } else if hours == 1 { - "1 hour".to_string() - } else if minutes > 1 { - format!("{minutes} mins") - } else if minutes == 1 { - "1 min".to_string() - } else if seconds > 1 { - format!("{seconds} secs") - } else { - "1 sec".to_string() - } + let seconds = eta_secs % 60; + let minutes = eta_secs / MINUTE % 60; + let hours = eta_secs / HOUR % 24; + let days = eta_secs / DAY; + + if days > 1 { + format!("{days} days") + } else if days == 1 { + "1 day".to_string() + } else if hours > 1 { + format!("{hours} hours") + } else if hours == 1 { + "1 hour".to_string() + } else if minutes > 1 { + format!("{minutes} mins") + } else if minutes == 1 { + "1 min".to_string() + } else if seconds > 1 { + format!("{seconds} secs") + } else { + "1 sec".to_string() + } } #[cfg(test)] mod tests { - use super::*; - use burn_core::data::dataloader::Progress; - - #[test] - fn test_format_eta() { - assert_eq!("55 secs", format_eta(55), "Less than 1 minutes"); - assert_eq!("1 min", format_eta(61), "More than 1 minutes"); - assert_eq!("2 mins", format_eta(2 * 61), "More than 2 minutes"); - assert_eq!("1 hour", format_eta(3601), "More than 1 hour"); - assert_eq!("2 hours", format_eta(2 * 3601), "More than 2 hour"); - assert_eq!("1 day", format_eta(24 * 3601), "More than 1 day"); - assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day"); - } - - #[test] - fn calculate_progress_for_eta() { - let half = Progress { - items_processed: 5, - items_total: 10, - }; - let progress = TrainingProgress { - progress: half, - epoch: 9, - epoch_total: 10, - iteration: 500, - }; - - let starting_epoch = 8; - let progress = calculate_progress(&progress, starting_epoch, 0); - - // Two epochs remaining while the first is half done. - assert_eq!(0.25, progress); - } - - #[test] - fn calculate_progress_for_eta_with_warmup() { - let half = Progress { - items_processed: 110, - items_total: 1000, - }; - let progress = TrainingProgress { - progress: half, - epoch: 9, - epoch_total: 10, - iteration: 500, - }; - - let starting_epoch = 8; - let progress = calculate_progress(&progress, starting_epoch, 10); - - // Two epochs remaining while the first is half done. - assert_eq!(0.05, progress); - } + use super::*; + use burn_core::data::dataloader::Progress; + + #[test] + fn test_format_eta() { + assert_eq!("55 secs", format_eta(55), "Less than 1 minutes"); + assert_eq!("1 min", format_eta(61), "More than 1 minutes"); + assert_eq!("2 mins", format_eta(2 * 61), "More than 2 minutes"); + assert_eq!("1 hour", format_eta(3601), "More than 1 hour"); + assert_eq!("2 hours", format_eta(2 * 3601), "More than 2 hour"); + assert_eq!("1 day", format_eta(24 * 3601), "More than 1 day"); + assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day"); + } + + #[test] + fn calculate_progress_for_eta() { + let half = Progress { + items_processed: 5, + items_total: 10, + }; + let progress = TrainingProgress { + progress: half, + epoch: 9, + epoch_total: 10, + iteration: 500, + }; + + let starting_epoch = 8; + let progress = calculate_progress(&progress, starting_epoch, 0); + + // Two epochs remaining while the first is half done. + assert_eq!(0.25, progress); + } + + #[test] + fn calculate_progress_for_eta_with_warmup() { + let half = Progress { + items_processed: 110, + items_total: 1000, + }; + let progress = TrainingProgress { + progress: half, + epoch: 9, + epoch_total: 10, + iteration: 500, + }; + + let starting_epoch = 8; + let progress = calculate_progress(&progress, starting_epoch, 10); + + // Two epochs remaining while the first is half done. + assert_eq!(0.05, progress); + } } diff --git a/burn-train/src/renderer/tui/recent_history.rs b/burn-train/src/renderer/tui/recent_history.rs index 60b1ce77c3..ac91d60888 100644 --- a/burn-train/src/renderer/tui/recent_history.rs +++ b/burn-train/src/renderer/tui/recent_history.rs @@ -1,244 +1,244 @@ use super::PlotAxes; use ratatui::{ - style::{Color, Style, Stylize}, - symbols, - widgets::{Dataset, GraphType}, + style::{Color, Style, Stylize}, + symbols, + widgets::{Dataset, GraphType}, }; const FACTOR_BEFORE_RESIZE: usize = 2; /// A plot that shows the recent history at full resolution. pub(crate) struct RecentHistoryPlot { - pub(crate) axes: PlotAxes, - train: RecentHistoryPoints, - valid: RecentHistoryPoints, - max_samples: usize, + pub(crate) axes: PlotAxes, + train: RecentHistoryPoints, + valid: RecentHistoryPoints, + max_samples: usize, } struct RecentHistoryPoints { - min_x: f64, - max_x: f64, - min_y: f64, - max_y: f64, - cursor: usize, - points: Vec<(f64, f64)>, - max_samples: usize, - factor_before_resize: usize, + min_x: f64, + max_x: f64, + min_y: f64, + max_y: f64, + cursor: usize, + points: Vec<(f64, f64)>, + max_samples: usize, + factor_before_resize: usize, } impl RecentHistoryPlot { - pub(crate) fn new(max_samples: usize) -> Self { - Self { - axes: PlotAxes::default(), - train: RecentHistoryPoints::new(max_samples), - valid: RecentHistoryPoints::new(max_samples), - max_samples, + pub(crate) fn new(max_samples: usize) -> Self { + Self { + axes: PlotAxes::default(), + train: RecentHistoryPoints::new(max_samples), + valid: RecentHistoryPoints::new(max_samples), + max_samples, + } } - } - pub(crate) fn push_train(&mut self, data: f64) { - let (x_min, x_current) = self.x(); + pub(crate) fn push_train(&mut self, data: f64) { + let (x_min, x_current) = self.x(); - self.train.push((x_current, data)); - self.train.update_cursor(x_min); - self.valid.update_cursor(x_min); + self.train.push((x_current, data)); + self.train.update_cursor(x_min); + self.valid.update_cursor(x_min); - self.update_bounds(); - } + self.update_bounds(); + } - pub(crate) fn push_valid(&mut self, data: f64) { - let (x_min, x_current) = self.x(); + pub(crate) fn push_valid(&mut self, data: f64) { + let (x_min, x_current) = self.x(); - self.valid.push((x_current, data)); - self.valid.update_cursor(x_min); - self.train.update_cursor(x_min); + self.valid.push((x_current, data)); + self.valid.update_cursor(x_min); + self.train.update_cursor(x_min); - self.update_bounds(); - } + self.update_bounds(); + } - pub(crate) fn datasets(&self) -> Vec> { - let mut datasets = Vec::with_capacity(2); + pub(crate) fn datasets(&self) -> Vec> { + let mut datasets = Vec::with_capacity(2); - if self.train.num_visible_points() > 0 { - datasets.push(self.train.dataset("Train", Color::LightRed)); - } + if self.train.num_visible_points() > 0 { + datasets.push(self.train.dataset("Train", Color::LightRed)); + } - if self.valid.num_visible_points() > 0 { - datasets.push(self.valid.dataset("Valid", Color::LightBlue)); + if self.valid.num_visible_points() > 0 { + datasets.push(self.valid.dataset("Valid", Color::LightBlue)); + } + + datasets } - datasets - } + fn x(&mut self) -> (f64, f64) { + let x_current = f64::max(self.train.max_x, self.valid.max_x) + 1.0; + let mut x_min = f64::min(self.train.min_x, self.valid.min_x); + if x_current - x_min >= self.max_samples as f64 { + x_min += 1.0; + } - fn x(&mut self) -> (f64, f64) { - let x_current = f64::max(self.train.max_x, self.valid.max_x) + 1.0; - let mut x_min = f64::min(self.train.min_x, self.valid.min_x); - if x_current - x_min >= self.max_samples as f64 { - x_min += 1.0; + (x_min, x_current) } - (x_min, x_current) - } - - fn update_bounds(&mut self) { - self.axes.update_bounds( - (self.train.min_x, self.train.max_x), - (self.valid.min_x, self.valid.max_x), - (self.train.min_y, self.train.max_y), - (self.valid.min_y, self.valid.max_y), - ); - } + fn update_bounds(&mut self) { + self.axes.update_bounds( + (self.train.min_x, self.train.max_x), + (self.valid.min_x, self.valid.max_x), + (self.train.min_y, self.train.max_y), + (self.valid.min_y, self.valid.max_y), + ); + } } impl RecentHistoryPoints { - fn new(max_samples: usize) -> Self { - let factor_before_resize = FACTOR_BEFORE_RESIZE; + fn new(max_samples: usize) -> Self { + let factor_before_resize = FACTOR_BEFORE_RESIZE; - Self { - min_x: 0., - max_x: 0., - min_y: f64::MAX, - max_y: f64::MIN, - points: Vec::with_capacity(factor_before_resize * max_samples), - cursor: 0, - max_samples, - factor_before_resize, + Self { + min_x: 0., + max_x: 0., + min_y: f64::MAX, + max_y: f64::MIN, + points: Vec::with_capacity(factor_before_resize * max_samples), + cursor: 0, + max_samples, + factor_before_resize, + } } - } - fn num_visible_points(&self) -> usize { - self.points.len() - } - - fn push(&mut self, (x, y): (f64, f64)) { - if x > self.max_x { - self.max_x = x; - } - if x < self.min_x { - self.min_x = x; - } - if y > self.max_y { - self.max_y = y; + fn num_visible_points(&self) -> usize { + self.points.len() } - if y < self.min_y { - self.min_y = y - } - self.points.push((x, y)); - } - fn update_cursor(&mut self, min_x: f64) { - if self.min_x >= min_x { - return; + fn push(&mut self, (x, y): (f64, f64)) { + if x > self.max_x { + self.max_x = x; + } + if x < self.min_x { + self.min_x = x; + } + if y > self.max_y { + self.max_y = y; + } + if y < self.min_y { + self.min_y = y + } + self.points.push((x, y)); } - self.min_x = min_x; - let mut update_y_max = false; - let mut update_y_min = false; + fn update_cursor(&mut self, min_x: f64) { + if self.min_x >= min_x { + return; + } + self.min_x = min_x; - while let Some((x, y)) = self.points.get(self.cursor) { - if *x >= self.min_x { - break; - } + let mut update_y_max = false; + let mut update_y_min = false; - if *y == self.max_y { - update_y_max = true - } - if *y == self.min_y { - update_y_min = true; - } + while let Some((x, y)) = self.points.get(self.cursor) { + if *x >= self.min_x { + break; + } - self.cursor += 1; - } + if *y == self.max_y { + update_y_max = true + } + if *y == self.min_y { + update_y_min = true; + } - if update_y_max { - self.max_y = self.calculate_max_y(); - } + self.cursor += 1; + } + + if update_y_max { + self.max_y = self.calculate_max_y(); + } + + if update_y_min { + self.min_y = self.calculate_min_y(); + } - if update_y_min { - self.min_y = self.calculate_min_y(); + if self.points.len() >= self.max_samples * self.factor_before_resize { + self.resize(); + } } - if self.points.len() >= self.max_samples * self.factor_before_resize { - self.resize(); + fn slice(&self) -> &[(f64, f64)] { + &self.points[self.cursor..self.points.len()] } - } - fn slice(&self) -> &[(f64, f64)] { - &self.points[self.cursor..self.points.len()] - } + fn calculate_max_y(&self) -> f64 { + let mut max_y = f64::MIN; - fn calculate_max_y(&self) -> f64 { - let mut max_y = f64::MIN; + for (_x, y) in self.slice() { + if *y > max_y { + max_y = *y; + } + } - for (_x, y) in self.slice() { - if *y > max_y { - max_y = *y; - } + max_y } - max_y - } + fn calculate_min_y(&self) -> f64 { + let mut min_y = f64::MAX; - fn calculate_min_y(&self) -> f64 { - let mut min_y = f64::MAX; + for (_x, y) in self.slice() { + if *y < min_y { + min_y = *y; + } + } - for (_x, y) in self.slice() { - if *y < min_y { - min_y = *y; - } + min_y } - min_y - } + fn resize(&mut self) { + let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize); - fn resize(&mut self) { - let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize); + for i in self.cursor..self.points.len() { + points.push(self.points[i]); + } - for i in self.cursor..self.points.len() { - points.push(self.points[i]); + self.points = points; + self.cursor = 0; } - self.points = points; - self.cursor = 0; - } + fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { + let data = &self.points[self.cursor..self.points.len()]; - fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> { - let data = &self.points[self.cursor..self.points.len()]; - - Dataset::default() - .name(name) - .marker(symbols::Marker::Braille) - .style(Style::default().fg(color).bold()) - .graph_type(GraphType::Scatter) - .data(data) - } + Dataset::default() + .name(name) + .marker(symbols::Marker::Braille) + .style(Style::default().fg(color).bold()) + .graph_type(GraphType::Scatter) + .data(data) + } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_push_update_bounds_max_y() { - let mut chart = RecentHistoryPlot::new(3); - chart.push_train(15.0); - chart.push_train(10.0); - chart.push_train(14.0); - - assert_eq!(chart.axes.bounds_y[1], 15.); - chart.push_train(10.0); - assert_eq!(chart.axes.bounds_y[1], 14.); - } - - #[test] - fn test_push_update_bounds_min_y() { - let mut chart = RecentHistoryPlot::new(3); - chart.push_train(5.0); - chart.push_train(10.0); - chart.push_train(14.0); - - assert_eq!(chart.axes.bounds_y[0], 5.); - chart.push_train(10.0); - assert_eq!(chart.axes.bounds_y[0], 10.); - } + use super::*; + + #[test] + fn test_push_update_bounds_max_y() { + let mut chart = RecentHistoryPlot::new(3); + chart.push_train(15.0); + chart.push_train(10.0); + chart.push_train(14.0); + + assert_eq!(chart.axes.bounds_y[1], 15.); + chart.push_train(10.0); + assert_eq!(chart.axes.bounds_y[1], 14.); + } + + #[test] + fn test_push_update_bounds_min_y() { + let mut chart = RecentHistoryPlot::new(3); + chart.push_train(5.0); + chart.push_train(10.0); + chart.push_train(14.0); + + assert_eq!(chart.axes.bounds_y[0], 5.); + chart.push_train(10.0); + assert_eq!(chart.axes.bounds_y[0], 10.); + } } diff --git a/burn-train/src/renderer/tui/renderer.rs b/burn-train/src/renderer/tui/renderer.rs index ae09c9885c..015e3e88ca 100644 --- a/burn-train/src/renderer/tui/renderer.rs +++ b/burn-train/src/renderer/tui/renderer.rs @@ -2,20 +2,20 @@ use crate::renderer::{tui::NumericMetricsState, MetricsRenderer}; use crate::renderer::{MetricState, TrainingProgress}; use crate::TrainingInterrupter; use crossterm::{ - event::{self, Event, KeyCode}, - execute, - terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, + event::{self, Event, KeyCode}, + execute, + terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; use ratatui::{prelude::*, Terminal}; use std::{ - error::Error, - io::{self, Stdout}, - time::{Duration, Instant}, + error::Error, + io::{self, Stdout}, + time::{Duration, Instant}, }; use super::{ - Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState, - TextMetricsState, + Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState, + TextMetricsState, }; /// The current terminal backend. @@ -27,124 +27,124 @@ const MAX_REFRESH_RATE_MILLIS: u64 = 100; /// The terminal UI metrics renderer. pub struct TuiMetricsRenderer { - terminal: Terminal, - last_update: std::time::Instant, - progress: ProgressBarState, - metrics_numeric: NumericMetricsState, - metrics_text: TextMetricsState, - status: StatusState, - interuptor: TrainingInterrupter, - popup: PopupState, + terminal: Terminal, + last_update: std::time::Instant, + progress: ProgressBarState, + metrics_numeric: NumericMetricsState, + metrics_text: TextMetricsState, + status: StatusState, + interuptor: TrainingInterrupter, + popup: PopupState, } impl MetricsRenderer for TuiMetricsRenderer { - fn update_train(&mut self, state: MetricState) { - match state { - MetricState::Generic(entry) => { - self.metrics_text.update_train(entry); - } - MetricState::Numeric(entry, value) => { - self.metrics_numeric.push_train(entry.name.clone(), value); - self.metrics_text.update_train(entry); - } - }; - } - - fn update_valid(&mut self, state: MetricState) { - match state { - MetricState::Generic(entry) => { - self.metrics_text.update_valid(entry); - } - MetricState::Numeric(entry, value) => { - self.metrics_numeric.push_valid(entry.name.clone(), value); - self.metrics_text.update_valid(entry); - } - }; - } - - fn render_train(&mut self, item: TrainingProgress) { - self.progress.update_train(&item); - self.metrics_numeric.update_progress_train(&item); - self.status.update_train(item); - self.render().unwrap(); - } - - fn render_valid(&mut self, item: TrainingProgress) { - self.progress.update_valid(&item); - self.metrics_numeric.update_progress_valid(&item); - self.status.update_valid(item); - self.render().unwrap(); - } -} + fn update_train(&mut self, state: MetricState) { + match state { + MetricState::Generic(entry) => { + self.metrics_text.update_train(entry); + } + MetricState::Numeric(entry, value) => { + self.metrics_numeric.push_train(entry.name.clone(), value); + self.metrics_text.update_train(entry); + } + }; + } -impl TuiMetricsRenderer { - /// Create a new terminal UI renderer. - pub fn new(interuptor: TrainingInterrupter, checkpoint: Option) -> Self { - let mut stdout = io::stdout(); - execute!(stdout, EnterAlternateScreen).unwrap(); - enable_raw_mode().unwrap(); - let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap(); - - Self { - terminal, - last_update: Instant::now(), - progress: ProgressBarState::new(checkpoint), - metrics_numeric: NumericMetricsState::default(), - metrics_text: TextMetricsState::default(), - status: StatusState::default(), - interuptor, - popup: PopupState::Empty, + fn update_valid(&mut self, state: MetricState) { + match state { + MetricState::Generic(entry) => { + self.metrics_text.update_valid(entry); + } + MetricState::Numeric(entry, value) => { + self.metrics_numeric.push_valid(entry.name.clone(), value); + self.metrics_text.update_valid(entry); + } + }; } - } - fn render(&mut self) -> Result<(), Box> { - let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); - if self.last_update.elapsed() < tick_rate { - return Ok(()); + fn render_train(&mut self, item: TrainingProgress) { + self.progress.update_train(&item); + self.metrics_numeric.update_progress_train(&item); + self.status.update_train(item); + self.render().unwrap(); } - self.draw()?; - self.handle_events()?; + fn render_valid(&mut self, item: TrainingProgress) { + self.progress.update_valid(&item); + self.metrics_numeric.update_progress_valid(&item); + self.status.update_valid(item); + self.render().unwrap(); + } +} - self.last_update = Instant::now(); +impl TuiMetricsRenderer { + /// Create a new terminal UI renderer. + pub fn new(interuptor: TrainingInterrupter, checkpoint: Option) -> Self { + let mut stdout = io::stdout(); + execute!(stdout, EnterAlternateScreen).unwrap(); + enable_raw_mode().unwrap(); + let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap(); + + Self { + terminal, + last_update: Instant::now(), + progress: ProgressBarState::new(checkpoint), + metrics_numeric: NumericMetricsState::default(), + metrics_text: TextMetricsState::default(), + status: StatusState::default(), + interuptor, + popup: PopupState::Empty, + } + } - Ok(()) - } + fn render(&mut self) -> Result<(), Box> { + let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); + if self.last_update.elapsed() < tick_rate { + return Ok(()); + } - fn draw(&mut self) -> Result<(), Box> { - self.terminal.draw(|frame| { - let size = frame.size(); + self.draw()?; + self.handle_events()?; - match self.popup.view() { - Some(view) => view.render(frame, size), - None => { - let view = MetricsView::new( - self.metrics_numeric.view(), - self.metrics_text.view(), - self.progress.view(), - ControlsView, - self.status.view(), - ); + self.last_update = Instant::now(); - view.render(frame, size); - } - }; - })?; + Ok(()) + } - Ok(()) - } + fn draw(&mut self) -> Result<(), Box> { + self.terminal.draw(|frame| { + let size = frame.size(); + + match self.popup.view() { + Some(view) => view.render(frame, size), + None => { + let view = MetricsView::new( + self.metrics_numeric.view(), + self.metrics_text.view(), + self.progress.view(), + ControlsView, + self.status.view(), + ); + + view.render(frame, size); + } + }; + })?; - fn handle_events(&mut self) -> Result<(), Box> { - while event::poll(Duration::from_secs(0))? { - let event = event::read()?; - self.popup.on_event(&event); + Ok(()) + } - if self.popup.is_empty() { - self.metrics_numeric.on_event(&event); + fn handle_events(&mut self) -> Result<(), Box> { + while event::poll(Duration::from_secs(0))? { + let event = event::read()?; + self.popup.on_event(&event); - if let Event::Key(key) = event { - if let KeyCode::Char('q') = key.code { - self.popup = PopupState::Full( + if self.popup.is_empty() { + self.metrics_numeric.on_event(&event); + + if let Event::Key(key) = event { + if let KeyCode::Char('q') = key.code { + self.popup = PopupState::Full( "Quit".to_string(), vec![ Callback::new( @@ -162,13 +162,13 @@ impl TuiMetricsRenderer { Callback::new("Cancel", "Cancel the action, continue the training.", 'c', PopupCancel), ], ); - } + } + } + } } - } - } - Ok(()) - } + Ok(()) + } } struct QuitPopupAccept(TrainingInterrupter); @@ -176,28 +176,28 @@ struct KillPopupAccept; struct PopupCancel; impl CallbackFn for KillPopupAccept { - fn call(&self) -> bool { - panic!("Killing training from user input."); - } + fn call(&self) -> bool { + panic!("Killing training from user input."); + } } impl CallbackFn for QuitPopupAccept { - fn call(&self) -> bool { - self.0.stop(); - true - } + fn call(&self) -> bool { + self.0.stop(); + true + } } impl CallbackFn for PopupCancel { - fn call(&self) -> bool { - true - } + fn call(&self) -> bool { + true + } } impl Drop for TuiMetricsRenderer { - fn drop(&mut self) { - disable_raw_mode().ok(); - execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap(); - self.terminal.show_cursor().ok(); - } + fn drop(&mut self) { + disable_raw_mode().ok(); + execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap(); + self.terminal.show_cursor().ok(); + } } diff --git a/burn-train/src/renderer/tui/status.rs b/burn-train/src/renderer/tui/status.rs index c067168498..3519d217cf 100644 --- a/burn-train/src/renderer/tui/status.rs +++ b/burn-train/src/renderer/tui/status.rs @@ -1,91 +1,91 @@ use super::TerminalFrame; use crate::renderer::TrainingProgress; use ratatui::{ - prelude::{Alignment, Rect}, - style::{Color, Style, Stylize}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph, Wrap}, + prelude::{Alignment, Rect}, + style::{Color, Style, Stylize}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, }; /// Show the training status with various information. pub(crate) struct StatusState { - progress: TrainingProgress, - mode: Mode, + progress: TrainingProgress, + mode: Mode, } enum Mode { - Valid, - Train, + Valid, + Train, } impl Default for StatusState { - fn default() -> Self { - Self { - progress: TrainingProgress::none(), - mode: Mode::Train, + fn default() -> Self { + Self { + progress: TrainingProgress::none(), + mode: Mode::Train, + } } - } } impl StatusState { - /// Update the training information. - pub(crate) fn update_train(&mut self, progress: TrainingProgress) { - self.progress = progress; - self.mode = Mode::Train; - } - /// Update the validation information. - pub(crate) fn update_valid(&mut self, progress: TrainingProgress) { - self.progress = progress; - self.mode = Mode::Valid; - } - /// Create a view. - pub(crate) fn view(&self) -> StatusView { - StatusView::new(&self.progress, &self.mode) - } + /// Update the training information. + pub(crate) fn update_train(&mut self, progress: TrainingProgress) { + self.progress = progress; + self.mode = Mode::Train; + } + /// Update the validation information. + pub(crate) fn update_valid(&mut self, progress: TrainingProgress) { + self.progress = progress; + self.mode = Mode::Valid; + } + /// Create a view. + pub(crate) fn view(&self) -> StatusView { + StatusView::new(&self.progress, &self.mode) + } } pub(crate) struct StatusView { - lines: Vec>>, + lines: Vec>>, } impl StatusView { - fn new(progress: &TrainingProgress, mode: &Mode) -> Self { - let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow(); - let value = |value: String| Span::from(value).italic(); - let mode = match mode { - Mode::Valid => "Validating", - Mode::Train => "Training", - }; + fn new(progress: &TrainingProgress, mode: &Mode) -> Self { + let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow(); + let value = |value: String| Span::from(value).italic(); + let mode = match mode { + Mode::Valid => "Validating", + Mode::Train => "Training", + }; - Self { - lines: vec![ - vec![title("Mode :"), value(mode.to_string())], - vec![ - title("Epoch :"), - value(format!("{}/{}", progress.epoch, progress.epoch_total)), - ], - vec![ - title("Iteration :"), - value(format!("{}", progress.iteration)), - ], - vec![ - title("Items :"), - value(format!( - "{}/{}", - progress.progress.items_processed, progress.progress.items_total - )), - ], - ], + Self { + lines: vec![ + vec![title("Mode :"), value(mode.to_string())], + vec![ + title("Epoch :"), + value(format!("{}/{}", progress.epoch, progress.epoch_total)), + ], + vec![ + title("Iteration :"), + value(format!("{}", progress.iteration)), + ], + vec![ + title("Items :"), + value(format!( + "{}/{}", + progress.progress.items_processed, progress.progress.items_total + )), + ], + ], + } } - } - pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { - let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) - .alignment(Alignment::Left) - .block(Block::default().borders(Borders::ALL).title("Status")) - .wrap(Wrap { trim: false }) - .style(Style::default().fg(Color::Gray)); + pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { + let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) + .alignment(Alignment::Left) + .block(Block::default().borders(Borders::ALL).title("Status")) + .wrap(Wrap { trim: false }) + .style(Style::default().fg(Color::Gray)); - frame.render_widget(paragraph, size); - } + frame.render_widget(paragraph, size); + } } diff --git a/burn-wgpu/benches/fused_elemwise.rs b/burn-wgpu/benches/fused_elemwise.rs index 417c91433e..ff8443bd2a 100644 --- a/burn-wgpu/benches/fused_elemwise.rs +++ b/burn-wgpu/benches/fused_elemwise.rs @@ -9,66 +9,66 @@ use std::marker::PhantomData; #[derive(new)] struct ElemWiseBenchmark { - shape: Shape<3>, - device: B::Device, - repeat: usize, - _b: PhantomData, + shape: Shape<3>, + device: B::Device, + repeat: usize, + _b: PhantomData, } impl Benchmark for ElemWiseBenchmark { - type Args = (Tensor, Tensor); + type Args = (Tensor, Tensor); - fn name(&self) -> String { - format!( - "Backend {} Shape {:?} Repeat {}", - B::name(), - self.shape.dims, - self.repeat - ) - } + fn name(&self) -> String { + format!( + "Backend {} Shape {:?} Repeat {}", + B::name(), + self.shape.dims, + self.repeat + ) + } - fn num_samples(&self) -> usize { - 10 - } + fn num_samples(&self) -> usize { + 10 + } - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.repeat { - let tmp_0 = lhs.clone() + rhs.clone(); - let tmp_1 = rhs.clone() * tmp_0.clone(); - let tmp_2 = rhs.clone().exp(); - let tmp_3 = tmp_0 * tmp_1; - let _tmp_4 = tmp_2 / tmp_3; + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.repeat { + let tmp_0 = lhs.clone() + rhs.clone(); + let tmp_1 = rhs.clone() * tmp_0.clone(); + let tmp_2 = rhs.clone().exp(); + let tmp_3 = tmp_0 * tmp_1; + let _tmp_4 = tmp_2 / tmp_3; + } } - } - fn prepare(&self) -> Self::Args { - B::seed(10); - let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + B::seed(10); + let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - B::sync(&self.device) - } + fn sync(&self) { + B::sync(&self.device) + } } #[allow(dead_code)] /// Runs the benchmarks for wgpu matmul implementations pub fn bench(device: &WgpuDevice) { - run_benchmark(ElemWiseBenchmark::::new( - Shape::new([256, 256, 1024]), - device.clone(), - 10, - )); - run_benchmark(ElemWiseBenchmark::>::new( - Shape::new([256, 256, 1024]), - device.clone(), - 10, - )); + run_benchmark(ElemWiseBenchmark::::new( + Shape::new([256, 256, 1024]), + device.clone(), + 10, + )); + run_benchmark(ElemWiseBenchmark::>::new( + Shape::new([256, 256, 1024]), + device.clone(), + 10, + )); } fn main() { - bench(&WgpuDevice::BestAvailable) + bench(&WgpuDevice::BestAvailable) } diff --git a/burn-wgpu/benches/matmul.rs b/burn-wgpu/benches/matmul.rs index f0a5cfa531..586faf0e9c 100644 --- a/burn-wgpu/benches/matmul.rs +++ b/burn-wgpu/benches/matmul.rs @@ -11,129 +11,131 @@ use derive_new::new; use std::marker::PhantomData; use burn_wgpu::{ - kernel::matmul::{matmul_mem_coalescing_default, matmul_naive_default}, - GraphicsApi, + kernel::matmul::{matmul_mem_coalescing_default, matmul_naive_default}, + GraphicsApi, }; type WTensor = Tensor, D>; #[derive(new)] struct MatmulBenchmark { - shape_lhs: Shape, - shape_rhs: Shape, - num_repeats: usize, - device: B::Device, - matmul: PhantomData, + shape_lhs: Shape, + shape_rhs: Shape, + num_repeats: usize, + device: B::Device, + matmul: PhantomData, } trait MatmulFunction { - fn run(lhs: WTensor, rhs: WTensor) -> WTensor; + fn run(lhs: WTensor, rhs: WTensor) -> WTensor; } impl Benchmark for MatmulBenchmark, F, D> where - F: MatmulFunction, - G: GraphicsApi, + F: MatmulFunction, + G: GraphicsApi, { - type Args = (WTensor, WTensor); + type Args = (WTensor, WTensor); - fn name(&self) -> String { - format!( - "{:?} {:?} x {:?}", - std::any::type_name::(), - self.shape_lhs.dims, - self.shape_rhs.dims - ) - } + fn name(&self) -> String { + format!( + "{:?} {:?} x {:?}", + std::any::type_name::(), + self.shape_lhs.dims, + self.shape_rhs.dims + ) + } - fn num_samples(&self) -> usize { - 10 - } + fn num_samples(&self) -> usize { + 10 + } - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.num_repeats { - F::run(lhs.clone(), rhs.clone()); + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.num_repeats { + F::run(lhs.clone(), rhs.clone()); + } } - } - fn prepare(&self) -> Self::Args { - let lhs = WTensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); - let rhs = WTensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); + fn prepare(&self) -> Self::Args { + let lhs = + WTensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device); + let rhs = + WTensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device); - (lhs, rhs) - } + (lhs, rhs) + } - fn sync(&self) { - Wgpu::::sync(&self.device) - } + fn sync(&self) { + Wgpu::::sync(&self.device) + } } macro_rules! bench_matmul { - ($benchmark:ident, $matmul_name:ident, $func:expr) => { - struct $matmul_name {} - impl MatmulFunction for $matmul_name { - fn run(lhs: WTensor, rhs: WTensor) -> WTensor { - let lhs = lhs.into_primitive(); - let rhs = rhs.into_primitive(); - let output = init_matmul_output(&lhs, &rhs); - Tensor::from_primitive($func(lhs, rhs, output)) - } - } - type $benchmark = - MatmulBenchmark, $matmul_name, D>; - }; + ($benchmark:ident, $matmul_name:ident, $func:expr) => { + struct $matmul_name {} + impl MatmulFunction for $matmul_name { + fn run(lhs: WTensor, rhs: WTensor) -> WTensor { + let lhs = lhs.into_primitive(); + let rhs = rhs.into_primitive(); + let output = init_matmul_output(&lhs, &rhs); + Tensor::from_primitive($func(lhs, rhs, output)) + } + } + type $benchmark = + MatmulBenchmark, $matmul_name, D>; + }; } bench_matmul!(NaiveMatmulBenchmark, NaiveMatmul, matmul_naive_default); bench_matmul!( - MemCoalescingMatmulBenchmark, - MemCoalescingMatmul, - matmul_mem_coalescing_default + MemCoalescingMatmulBenchmark, + MemCoalescingMatmul, + matmul_mem_coalescing_default ); bench_matmul!( - Tiling2DMatmulVec4LHSBenchmark, - Tiling2DMatmulVec4LHS, - matmul_tiling_2d_vec4_lhs + Tiling2DMatmulVec4LHSBenchmark, + Tiling2DMatmulVec4LHS, + matmul_tiling_2d_vec4_lhs ); bench_matmul!( - Tiling2DMatmulVec4Benchmark, - Tiling2DMatmulVec4, - matmul_tiling_2d_vec4 + Tiling2DMatmulVec4Benchmark, + Tiling2DMatmulVec4, + matmul_tiling_2d_vec4 ); bench_matmul!( - Tiling2DMatmulUnpaddedBenchmark, - Tiling2DMatmulUnpadded, - matmul_tiling_2d_unpadded + Tiling2DMatmulUnpaddedBenchmark, + Tiling2DMatmulUnpadded, + matmul_tiling_2d_unpadded ); #[allow(dead_code)] /// Runs the benchmarks for wgpu matmul implementations pub fn bench(device: &WgpuDevice) { - const D: usize = 3; - let num_repeats = 3; - let batch_size = 3; - let m = 1007; - let k = 1023; - let n = 1005; - let shape_lhs = Shape::new([batch_size, m, k]); - let shape_rhs = Shape::new([batch_size, k, n]); + const D: usize = 3; + let num_repeats = 3; + let batch_size = 3; + let m = 1007; + let k = 1023; + let n = 1005; + let shape_lhs = Shape::new([batch_size, m, k]); + let shape_rhs = Shape::new([batch_size, k, n]); - macro_rules! run_matmul_benchmark { - ($benchmark:ident) => { - run_benchmark($benchmark::new( - shape_lhs.clone(), - shape_rhs.clone(), - num_repeats, - device.clone(), - )); - }; - } - run_matmul_benchmark!(NaiveMatmulBenchmark); - run_matmul_benchmark!(MemCoalescingMatmulBenchmark); - run_matmul_benchmark!(Tiling2DMatmulUnpaddedBenchmark); - run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark); - run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark); + macro_rules! run_matmul_benchmark { + ($benchmark:ident) => { + run_benchmark($benchmark::new( + shape_lhs.clone(), + shape_rhs.clone(), + num_repeats, + device.clone(), + )); + }; + } + run_matmul_benchmark!(NaiveMatmulBenchmark); + run_matmul_benchmark!(MemCoalescingMatmulBenchmark); + run_matmul_benchmark!(Tiling2DMatmulUnpaddedBenchmark); + run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark); + run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark); } fn main() { - bench(&WgpuDevice::BestAvailable) + bench(&WgpuDevice::BestAvailable) } diff --git a/burn-wgpu/benches/reduction.rs b/burn-wgpu/benches/reduction.rs index a642192ab8..7eac3440ae 100644 --- a/burn-wgpu/benches/reduction.rs +++ b/burn-wgpu/benches/reduction.rs @@ -13,96 +13,96 @@ type WTensor = Tensor, D>; #[derive(new)] struct ReduceBenchmark { - shape: Shape, - dim: usize, - num_repeats: usize, - device: B::Device, - reduce: PhantomData, + shape: Shape, + dim: usize, + num_repeats: usize, + device: B::Device, + reduce: PhantomData, } trait ReduceFunction { - fn run(input: WTensor, dim: usize) -> WTensor; + fn run(input: WTensor, dim: usize) -> WTensor; } impl Benchmark for ReduceBenchmark, F, D> where - F: ReduceFunction, - G: GraphicsApi, + F: ReduceFunction, + G: GraphicsApi, { - type Args = WTensor; - - fn name(&self) -> String { - format!( - "{:?} {:?} dim={:?}", - std::any::type_name::(), - self.shape.dims, - self.dim - ) - } - - fn num_samples(&self) -> usize { - 10 - } - - fn execute(&self, input: Self::Args) { - for _ in 0..self.num_repeats { - F::run(input.clone(), self.dim); + type Args = WTensor; + + fn name(&self) -> String { + format!( + "{:?} {:?} dim={:?}", + std::any::type_name::(), + self.shape.dims, + self.dim + ) } - } - fn prepare(&self) -> Self::Args { - WTensor::random_device(self.shape.clone(), Distribution::Default, &self.device) - } + fn num_samples(&self) -> usize { + 10 + } + + fn execute(&self, input: Self::Args) { + for _ in 0..self.num_repeats { + F::run(input.clone(), self.dim); + } + } + + fn prepare(&self) -> Self::Args { + WTensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } - fn sync(&self) { - Wgpu::::sync(&self.device) - } + fn sync(&self) { + Wgpu::::sync(&self.device) + } } macro_rules! bench_reduce { - ($benchmark:ident, $reduce_name:ident, $func:expr) => { - struct $reduce_name {} - impl ReduceFunction for $reduce_name { - fn run(input: WTensor, dim: usize) -> WTensor { - let input = input.into_primitive(); - let output = init_reduce_output(&input, dim); - Tensor::from_primitive($func(input, output, dim)) - } - } - type $benchmark = - ReduceBenchmark, $reduce_name, D>; - }; + ($benchmark:ident, $reduce_name:ident, $func:expr) => { + struct $reduce_name {} + impl ReduceFunction for $reduce_name { + fn run(input: WTensor, dim: usize) -> WTensor { + let input = input.into_primitive(); + let output = init_reduce_output(&input, dim); + Tensor::from_primitive($func(input, output, dim)) + } + } + type $benchmark = + ReduceBenchmark, $reduce_name, D>; + }; } bench_reduce!(SumDimBenchmark, SumDim, sum_dim); bench_reduce!( - SumDimSharedMemoryBenchmark, - SumDimSharedMemory, - sum_dim_shared_memory + SumDimSharedMemoryBenchmark, + SumDimSharedMemory, + sum_dim_shared_memory ); #[allow(dead_code)] /// Runs the benchmarks for wgpu matmul implementations pub fn bench(device: &WgpuDevice) { - let num_repeats = 3; - let shape = Shape::new([50, 8000, 50]); - let dim = 1; - - macro_rules! run_reduce_benchmark { - ($benchmark:ident) => { - run_benchmark($benchmark::new( - shape.clone(), - dim, - num_repeats, - device.clone(), - )); - }; - } + let num_repeats = 3; + let shape = Shape::new([50, 8000, 50]); + let dim = 1; + + macro_rules! run_reduce_benchmark { + ($benchmark:ident) => { + run_benchmark($benchmark::new( + shape.clone(), + dim, + num_repeats, + device.clone(), + )); + }; + } - run_reduce_benchmark!(SumDimSharedMemoryBenchmark); - run_reduce_benchmark!(SumDimBenchmark); + run_reduce_benchmark!(SumDimSharedMemoryBenchmark); + run_reduce_benchmark!(SumDimBenchmark); } fn main() { - bench(&WgpuDevice::BestAvailable) + bench(&WgpuDevice::BestAvailable) } diff --git a/burn-wgpu/src/backend.rs b/burn-wgpu/src/backend.rs index d81bf2504e..39180b9a0f 100644 --- a/burn-wgpu/src/backend.rs +++ b/burn-wgpu/src/backend.rs @@ -1,8 +1,8 @@ use crate::{ - compute::compute_client, - element::{FloatElement, IntElement}, - tensor::WgpuTensor, - AutoGraphicsApi, GraphicsApi, WgpuDevice, + compute::compute_client, + element::{FloatElement, IntElement}, + tensor::WgpuTensor, + AutoGraphicsApi, GraphicsApi, WgpuDevice, }; use burn_tensor::backend::Backend; use rand::{rngs::StdRng, SeedableRng}; @@ -22,43 +22,43 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); #[derive(Debug, Default, Clone)] pub struct Wgpu where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - _g: PhantomData, - _f: PhantomData, - _i: PhantomData, + _g: PhantomData, + _f: PhantomData, + _i: PhantomData, } impl Backend for Wgpu { - type Device = WgpuDevice; - type FullPrecisionBackend = Wgpu; + type Device = WgpuDevice; + type FullPrecisionBackend = Wgpu; - type FullPrecisionElem = f32; - type FloatElem = F; - type IntElem = I; + type FullPrecisionElem = f32; + type FloatElem = F; + type IntElem = I; - type TensorPrimitive = WgpuTensor; - type IntTensorPrimitive = WgpuTensor; - type BoolTensorPrimitive = WgpuTensor; + type TensorPrimitive = WgpuTensor; + type IntTensorPrimitive = WgpuTensor; + type BoolTensorPrimitive = WgpuTensor; - fn name() -> String { - String::from("wgpu") - } + fn name() -> String { + String::from("wgpu") + } - fn seed(seed: u64) { - let rng = StdRng::seed_from_u64(seed); - let mut seed = SEED.lock().unwrap(); - *seed = Some(rng); - } + fn seed(seed: u64) { + let rng = StdRng::seed_from_u64(seed); + let mut seed = SEED.lock().unwrap(); + *seed = Some(rng); + } - fn ad_enabled() -> bool { - false - } + fn ad_enabled() -> bool { + false + } - fn sync(device: &Self::Device) { - let client = compute_client::(device); - client.sync(); - } + fn sync(device: &Self::Device) { + let client = compute_client::(device); + client.sync(); + } } diff --git a/burn-wgpu/src/compute/base.rs b/burn-wgpu/src/compute/base.rs index b978478791..864b40151b 100644 --- a/burn-wgpu/src/compute/base.rs +++ b/burn-wgpu/src/compute/base.rs @@ -2,11 +2,11 @@ use super::WgpuServer; use crate::{compute::WgpuStorage, GraphicsApi, WgpuDevice}; use alloc::sync::Arc; use burn_compute::{ - channel::MutexComputeChannel, - client::ComputeClient, - memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, - tune::Tuner, - Compute, + channel::MutexComputeChannel, + client::ComputeClient, + memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, + tune::Tuner, + Compute, }; use spin::Mutex; use wgpu::DeviceDescriptor; @@ -26,207 +26,207 @@ static COMPUTE: Compute, Channel> = Com /// Get the [compute client](ComputeClient) for the given [device](WgpuDevice). pub fn compute_client(device: &WgpuDevice) -> ComputeClient { - let device = Arc::new(device); + let device = Arc::new(device); - COMPUTE.client(&device, move || { - pollster::block_on(create_client::(&device)) - }) + COMPUTE.client(&device, move || { + pollster::block_on(create_client::(&device)) + }) } /// Init the client async, necessary for wasm. pub async fn init_async(device: &WgpuDevice) { - let device = Arc::new(device); - let client = create_client::(&device).await; + let device = Arc::new(device); + let client = create_client::(&device).await; - COMPUTE.register(&device, client) + COMPUTE.register(&device, client) } async fn create_client(device: &WgpuDevice) -> ComputeClient { - let (device_wgpu, queue, info) = select_device::(device).await; - - log::info!( - "Created wgpu compute server on device {:?} => {:?}", - device, - info - ); - - // TODO: Support a way to modify max_tasks without std. - let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { - Ok(value) => value - .parse::() - .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), - Err(_) => 64, // 64 tasks by default - }; - - let device = Arc::new(device_wgpu); - let storage = WgpuStorage::new(device.clone()); - let memory_management = SimpleMemoryManagement::new( - storage, - DeallocStrategy::new_period_tick(max_tasks * 2), - SliceStrategy::Ratio(0.8), - ); - let server = WgpuServer::new(memory_management, device, queue, max_tasks); - let channel = Channel::new(server); - - ComputeClient::new(channel, Arc::new(Mutex::new(Tuner::new()))) + let (device_wgpu, queue, info) = select_device::(device).await; + + log::info!( + "Created wgpu compute server on device {:?} => {:?}", + device, + info + ); + + // TODO: Support a way to modify max_tasks without std. + let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { + Ok(value) => value + .parse::() + .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), + Err(_) => 64, // 64 tasks by default + }; + + let device = Arc::new(device_wgpu); + let storage = WgpuStorage::new(device.clone()); + let memory_management = SimpleMemoryManagement::new( + storage, + DeallocStrategy::new_period_tick(max_tasks * 2), + SliceStrategy::Ratio(0.8), + ); + let server = WgpuServer::new(memory_management, device, queue, max_tasks); + let channel = Channel::new(server); + + ComputeClient::new(channel, Arc::new(Mutex::new(Tuner::new()))) } /// Select the wgpu device and queue based on the provided [device](WgpuDevice). pub async fn select_device( - device: &WgpuDevice, + device: &WgpuDevice, ) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { - #[cfg(target_family = "wasm")] - let adapter = select_adapter::(device).await; - - #[cfg(not(target_family = "wasm"))] - let adapter = select_adapter::(device); - - let limits = adapter.limits(); - - let (device, queue) = adapter - .request_device( - &DeviceDescriptor { - label: None, - features: wgpu::Features::empty(), - limits, - }, - None, - ) - .await - .map_err(|err| { - format!( - "Unable to request the device with the adapter {:?}, err {:?}", - adapter.get_info(), - err - ) - }) - .unwrap(); - - (device, queue, adapter.get_info()) + #[cfg(target_family = "wasm")] + let adapter = select_adapter::(device).await; + + #[cfg(not(target_family = "wasm"))] + let adapter = select_adapter::(device); + + let limits = adapter.limits(); + + let (device, queue) = adapter + .request_device( + &DeviceDescriptor { + label: None, + features: wgpu::Features::empty(), + limits, + }, + None, + ) + .await + .map_err(|err| { + format!( + "Unable to request the device with the adapter {:?}, err {:?}", + adapter.get_info(), + err + ) + }) + .unwrap(); + + (device, queue, adapter.get_info()) } #[cfg(target_family = "wasm")] async fn select_adapter(_device: &WgpuDevice) -> wgpu::Adapter { - let instance = wgpu::Instance::default(); + let instance = wgpu::Instance::default(); - instance - .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) - .await - .unwrap() + instance + .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) + .await + .unwrap() } #[cfg(not(target_family = "wasm"))] fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { - use wgpu::DeviceType; - - let instance = wgpu::Instance::default(); - let mut adapters_other = Vec::new(); - let mut adapters = Vec::new(); - - instance - .enumerate_adapters(G::backend().into()) - .for_each(|adapter| { - let device_type = adapter.get_info().device_type; - - if let DeviceType::Other = device_type { - adapters_other.push(adapter); - return; - } - - let is_same_type = match device { - WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, - WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, - WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, - WgpuDevice::Cpu => device_type == DeviceType::Cpu, - WgpuDevice::BestAvailable => true, - }; - - if is_same_type { - adapters.push(adapter); - } - }); - - fn select( - num: usize, - error: &str, - mut adapters: Vec, - mut adapters_other: Vec, - ) -> wgpu::Adapter { - if adapters.len() <= num { - if adapters_other.len() <= num { - panic!( - "{}, adapters {:?}, other adapters {:?}", - error, - adapters - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - adapters_other - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - ); - } else { - return adapters_other.remove(num); - } - } + use wgpu::DeviceType; + + let instance = wgpu::Instance::default(); + let mut adapters_other = Vec::new(); + let mut adapters = Vec::new(); - adapters.remove(num) - } - - let adapter = match device { - WgpuDevice::DiscreteGpu(num) => select( - *num, - "No Discrete GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::IntegratedGpu(num) => select( - *num, - "No Integrated GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::VirtualGpu(num) => select( - *num, - "No Virtual GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), - WgpuDevice::BestAvailable => { - let mut most_performant_adapter = None; - let mut current_score = -1; - - adapters - .into_iter() - .chain(adapters_other) + instance + .enumerate_adapters(G::backend().into()) .for_each(|adapter| { - let info = adapter.get_info(); - let score = match info.device_type { - DeviceType::DiscreteGpu => 5, - DeviceType::Other => 4, // Let's be optimistic with the Other device, it's - // often a Discrete Gpu. - DeviceType::IntegratedGpu => 3, - DeviceType::VirtualGpu => 2, - DeviceType::Cpu => 1, - }; - - if score > current_score { - most_performant_adapter = Some(adapter); - current_score = score; - } + let device_type = adapter.get_info().device_type; + + if let DeviceType::Other = device_type { + adapters_other.push(adapter); + return; + } + + let is_same_type = match device { + WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, + WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, + WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, + WgpuDevice::Cpu => device_type == DeviceType::Cpu, + WgpuDevice::BestAvailable => true, + }; + + if is_same_type { + adapters.push(adapter); + } }); - if let Some(adapter) = most_performant_adapter { - adapter - } else { - panic!("No adapter found for graphics API {:?}", G::default()); - } + fn select( + num: usize, + error: &str, + mut adapters: Vec, + mut adapters_other: Vec, + ) -> wgpu::Adapter { + if adapters.len() <= num { + if adapters_other.len() <= num { + panic!( + "{}, adapters {:?}, other adapters {:?}", + error, + adapters + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + adapters_other + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + ); + } else { + return adapters_other.remove(num); + } + } + + adapters.remove(num) } - }; - - log::info!("Using adapter {:?}", adapter.get_info()); - adapter + let adapter = match device { + WgpuDevice::DiscreteGpu(num) => select( + *num, + "No Discrete GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::IntegratedGpu(num) => select( + *num, + "No Integrated GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::VirtualGpu(num) => select( + *num, + "No Virtual GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), + WgpuDevice::BestAvailable => { + let mut most_performant_adapter = None; + let mut current_score = -1; + + adapters + .into_iter() + .chain(adapters_other) + .for_each(|adapter| { + let info = adapter.get_info(); + let score = match info.device_type { + DeviceType::DiscreteGpu => 5, + DeviceType::Other => 4, // Let's be optimistic with the Other device, it's + // often a Discrete Gpu. + DeviceType::IntegratedGpu => 3, + DeviceType::VirtualGpu => 2, + DeviceType::Cpu => 1, + }; + + if score > current_score { + most_performant_adapter = Some(adapter); + current_score = score; + } + }); + + if let Some(adapter) = most_performant_adapter { + adapter + } else { + panic!("No adapter found for graphics API {:?}", G::default()); + } + } + }; + + log::info!("Using adapter {:?}", adapter.get_info()); + + adapter } diff --git a/burn-wgpu/src/compute/kernel.rs b/burn-wgpu/src/compute/kernel.rs index 8254908ba2..332e725073 100644 --- a/burn-wgpu/src/compute/kernel.rs +++ b/burn-wgpu/src/compute/kernel.rs @@ -5,101 +5,102 @@ use core::marker::PhantomData; /// Provides launch information specifying the number of work groups to be used by a compute shader. #[derive(new, Clone, Debug)] pub struct WorkGroup { - /// Work groups for the x axis. - pub x: u32, - /// Work groups for the y axis. - pub y: u32, - /// Work groups for the z axis. - pub z: u32, + /// Work groups for the x axis. + pub x: u32, + /// Work groups for the y axis. + pub y: u32, + /// Work groups for the z axis. + pub z: u32, } impl WorkGroup { - /// Calculate the number of invocations of a compute shader. - pub fn num_invocations(&self) -> usize { - (self.x * self.y * self.z) as usize - } + /// Calculate the number of invocations of a compute shader. + pub fn num_invocations(&self) -> usize { + (self.x * self.y * self.z) as usize + } } /// Wraps a [dynamic kernel source](DynamicKernelSource) into a [kernel](Kernel) with launch /// information such as [workgroup](WorkGroup). #[derive(new)] pub struct DynamicKernel { - kernel: K, - workgroup: WorkGroup, + kernel: K, + workgroup: WorkGroup, } /// Wraps a [static kernel source](StaticKernelSource) into a [kernel](Kernel) with launch /// information such as [workgroup](WorkGroup). #[derive(new)] pub struct StaticKernel { - workgroup: WorkGroup, - _kernel: PhantomData, + workgroup: WorkGroup, + _kernel: PhantomData, } impl Kernel for DynamicKernel where - K: DynamicKernelSource + 'static, + K: DynamicKernelSource + 'static, { - fn source(&self) -> SourceTemplate { - self.kernel.source() - } + fn source(&self) -> SourceTemplate { + self.kernel.source() + } - fn id(&self) -> String { - self.kernel.id() - } + fn id(&self) -> String { + self.kernel.id() + } - fn workgroup(&self) -> WorkGroup { - self.workgroup.clone() - } + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } } impl Kernel for StaticKernel where - K: StaticKernelSource + 'static, + K: StaticKernelSource + 'static, { - fn source(&self) -> SourceTemplate { - K::source() - } + fn source(&self) -> SourceTemplate { + K::source() + } - fn id(&self) -> String { - format!("{:?}", core::any::TypeId::of::()) - } + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } - fn workgroup(&self) -> WorkGroup { - self.workgroup.clone() - } + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi, WgpuDevice, - }; + use super::*; + use crate::{ + binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi, + WgpuDevice, + }; - #[test] - fn can_run_kernel() { - binary_elemwise!(Add, "+"); + #[test] + fn can_run_kernel() { + binary_elemwise!(Add, "+"); - let client = compute_client::(&WgpuDevice::default()); + let client = compute_client::(&WgpuDevice::default()); - let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; - let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; - let info: Vec = vec![1, 1, 1, 1, 8, 8, 8]; + let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; + let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; + let info: Vec = vec![1, 1, 1, 1, 8, 8, 8]; - let lhs = client.create(bytemuck::cast_slice(&lhs)); - let rhs = client.create(bytemuck::cast_slice(&rhs)); - let out = client.empty(core::mem::size_of::() * 8); - let info = client.create(bytemuck::cast_slice(&info)); + let lhs = client.create(bytemuck::cast_slice(&lhs)); + let rhs = client.create(bytemuck::cast_slice(&rhs)); + let out = client.empty(core::mem::size_of::() * 8); + let info = client.create(bytemuck::cast_slice(&info)); - type Kernel = KernelSettings; - let kernel = Box::new(StaticKernel::::new(WorkGroup::new(1, 1, 1))); + type Kernel = KernelSettings; + let kernel = Box::new(StaticKernel::::new(WorkGroup::new(1, 1, 1))); - client.execute(kernel, &[&lhs, &rhs, &out, &info]); + client.execute(kernel, &[&lhs, &rhs, &out, &info]); - let data = client.read(&out).read_sync().unwrap(); - let output: &[f32] = bytemuck::cast_slice(&data); + let data = client.read(&out).read_sync().unwrap(); + let output: &[f32] = bytemuck::cast_slice(&data); - assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); - } + assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); + } } diff --git a/burn-wgpu/src/compute/server.rs b/burn-wgpu/src/compute/server.rs index ca47e7b0ab..1789f2f00c 100644 --- a/burn-wgpu/src/compute/server.rs +++ b/burn-wgpu/src/compute/server.rs @@ -2,35 +2,35 @@ use super::{WgpuAutotuneKey, WgpuStorage, WorkGroup}; use crate::kernel::SourceTemplate; use alloc::{borrow::Cow, sync::Arc}; use burn_compute::{ - memory_management::MemoryManagement, - server::{self, ComputeServer}, + memory_management::MemoryManagement, + server::{self, ComputeServer}, }; use burn_tensor::Reader; use hashbrown::HashMap; use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, + util::{BufferInitDescriptor, DeviceExt}, + BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, }; /// Wgpu compute server. #[derive(Debug)] pub struct WgpuServer> { - memory_management: MM, - device: Arc, - queue: wgpu::Queue, - encoder: CommandEncoder, - pipelines: HashMap>, - tasks: Vec, - max_tasks: usize, - manual_available: HashMap>>, - manual_taken: Vec<(usize, server::Handle)>, + memory_management: MM, + device: Arc, + queue: wgpu::Queue, + encoder: CommandEncoder, + pipelines: HashMap>, + tasks: Vec, + max_tasks: usize, + manual_available: HashMap>>, + manual_taken: Vec<(usize, server::Handle)>, } #[derive(new, Debug)] struct ComputeTask { - pipeline: Arc, - bind_group: BindGroup, - work_group: WorkGroup, + pipeline: Arc, + bind_group: BindGroup, + work_group: WorkGroup, } /// Kernel trait with the [source](SourceTemplate) that will be compiled and cached based on the @@ -38,306 +38,308 @@ struct ComputeTask { /// /// The kernel will be launched with the given [workgroup](WorkGroup). pub trait Kernel: 'static + Send + Sync { - /// Source template for the kernel. - fn source(&self) -> SourceTemplate; - /// Identifier for the kernel, used for caching kernel compilation. - fn id(&self) -> String; - /// Launch information. - fn workgroup(&self) -> WorkGroup; + /// Source template for the kernel. + fn source(&self) -> SourceTemplate; + /// Identifier for the kernel, used for caching kernel compilation. + fn id(&self) -> String; + /// Launch information. + fn workgroup(&self) -> WorkGroup; } impl WgpuServer where - MM: MemoryManagement, + MM: MemoryManagement, { - /// Create a new server. - pub fn new( - memory_management: MM, - device: Arc, - queue: wgpu::Queue, - max_tasks: usize, - ) -> Self { - let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Command Encoder"), - }); - - Self { - memory_management, - device, - queue, - encoder, - pipelines: HashMap::new(), - tasks: Vec::new(), - max_tasks, - manual_available: HashMap::new(), - manual_taken: Vec::new(), + /// Create a new server. + pub fn new( + memory_management: MM, + device: Arc, + queue: wgpu::Queue, + max_tasks: usize, + ) -> Self { + let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Command Encoder"), + }); + + Self { + memory_management, + device, + queue, + encoder, + pipelines: HashMap::new(), + tasks: Vec::new(), + max_tasks, + manual_available: HashMap::new(), + manual_taken: Vec::new(), + } + } + + fn submit(&mut self) { + assert!( + self.tasks.is_empty(), + "Tasks should be completed before submitting the current encoder." + ); + let mut new_encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + core::mem::swap(&mut new_encoder, &mut self.encoder); + + self.queue.submit(Some(new_encoder.finish())); + + // Cleanup allocations and deallocations. + self.free_manual_allocations(); + self.memory_management.storage().perform_deallocations(); } - } - - fn submit(&mut self) { - assert!( - self.tasks.is_empty(), - "Tasks should be completed before submitting the current encoder." - ); - let mut new_encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - core::mem::swap(&mut new_encoder, &mut self.encoder); - - self.queue.submit(Some(new_encoder.finish())); - - // Cleanup allocations and deallocations. - self.free_manual_allocations(); - self.memory_management.storage().perform_deallocations(); - } - - fn free_manual_allocations(&mut self) { - let mut manual_taken_tmp = Vec::new(); - core::mem::swap(&mut manual_taken_tmp, &mut self.manual_taken); - - for (size, handle) in manual_taken_tmp.drain(..) { - if handle.can_mut() { - self.register_manual(size, handle); - } else { - self.manual_taken.push((size, handle)); - } + + fn free_manual_allocations(&mut self) { + let mut manual_taken_tmp = Vec::new(); + core::mem::swap(&mut manual_taken_tmp, &mut self.manual_taken); + + for (size, handle) in manual_taken_tmp.drain(..) { + if handle.can_mut() { + self.register_manual(size, handle); + } else { + self.manual_taken.push((size, handle)); + } + } } - } - - // Finds a free, manually-added handle of specified size, or creates it if none is found - fn manual_reserve(&mut self, size: usize) -> server::Handle { - let handle = self - .manual_available - .get_mut(&size) - .and_then(|h| h.pop()) - .unwrap_or_else(|| { - let memory = self.memory_management.alloc(size); - server::Handle::new(memory) - }); - - self.manual_taken.push((size, handle.clone())); - - handle - } - - // Manually adds a handle of given size - fn register_manual(&mut self, size: usize, handle: server::Handle) { - if let Some(handles) = self.manual_available.get_mut(&size) { - handles.push(handle); - } else { - self.manual_available.insert(size, [handle].into()); + + // Finds a free, manually-added handle of specified size, or creates it if none is found + fn manual_reserve(&mut self, size: usize) -> server::Handle { + let handle = self + .manual_available + .get_mut(&size) + .and_then(|h| h.pop()) + .unwrap_or_else(|| { + let memory = self.memory_management.alloc(size); + server::Handle::new(memory) + }); + + self.manual_taken.push((size, handle.clone())); + + handle } - } - fn register_tasks(&mut self) { - if self.tasks.is_empty() { - return; + // Manually adds a handle of given size + fn register_manual(&mut self, size: usize, handle: server::Handle) { + if let Some(handles) = self.manual_available.get_mut(&size) { + handles.push(handle); + } else { + self.manual_available.insert(size, [handle].into()); + } } - let mut compute = self - .encoder - .begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); + fn register_tasks(&mut self) { + if self.tasks.is_empty() { + return; + } - for task in self.tasks.iter() { - compute.set_pipeline(&task.pipeline); - compute.set_bind_group(0, &task.bind_group, &[]); - compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z); + let mut compute = self + .encoder + .begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); + + for task in self.tasks.iter() { + compute.set_pipeline(&task.pipeline); + compute.set_bind_group(0, &task.bind_group, &[]); + compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z); + } + + std::mem::drop(compute); + self.tasks.clear(); } - std::mem::drop(compute); - self.tasks.clear(); - } + fn pipeline(&mut self, kernel: Box) -> Arc { + let kernel_id = kernel.id(); + if let Some(pipeline) = self.pipelines.get(&kernel_id) { + return pipeline.clone(); + } + + let source = kernel.source().complete(); + log::trace!("Compiling kernel {kernel_id}:\n {source}"); + let pipeline = self.compile_source(&source); + self.pipelines.insert(kernel_id.clone(), pipeline.clone()); + + pipeline + } - fn pipeline(&mut self, kernel: Box) -> Arc { - let kernel_id = kernel.id(); - if let Some(pipeline) = self.pipelines.get(&kernel_id) { - return pipeline.clone(); + fn compile_source(&self, source: &str) -> Arc { + let module = self.device.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }); + + Arc::new( + self.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + }), + ) } - let source = kernel.source().complete(); - log::trace!("Compiling kernel {kernel_id}:\n {source}"); - let pipeline = self.compile_source(&source); - self.pipelines.insert(kernel_id.clone(), pipeline.clone()); - - pipeline - } - - fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); - - Arc::new( - self - .device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: "main", - }), - ) - } - - fn buffer_reader(&mut self, handle: &server::Handle) -> BufferReader { - // Register previous tasks before reading the buffer so that it is up to date. - self.register_tasks(); - - let resource = self.memory_management.get(&handle.memory); - - let size = resource.size(); - let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, - mapped_at_creation: false, - }); - - self - .encoder - .copy_buffer_to_buffer(&resource.buffer, resource.offset(), &buffer_dest, 0, size); - - self.submit(); - - BufferReader::new(buffer_dest) - } + fn buffer_reader(&mut self, handle: &server::Handle) -> BufferReader { + // Register previous tasks before reading the buffer so that it is up to date. + self.register_tasks(); + + let resource = self.memory_management.get(&handle.memory); + + let size = resource.size(); + let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + self.encoder.copy_buffer_to_buffer( + &resource.buffer, + resource.offset(), + &buffer_dest, + 0, + size, + ); + + self.submit(); + + BufferReader::new(buffer_dest) + } } #[derive(new)] struct BufferReader { - buffer: wgpu::Buffer, + buffer: wgpu::Buffer, } impl BufferReader { - #[cfg(target_family = "wasm")] - async fn read(self, device: alloc::sync::Arc) -> Vec { - self.read_async(&device).await - } - - #[cfg(not(target_family = "wasm"))] - fn read(self, device: &wgpu::Device) -> Vec { - pollster::block_on(self.read_async(device)) - } - - async fn read_async(&self, device: &wgpu::Device) -> Vec { - let buffer_slice = self.buffer.slice(..); - let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - - device.poll(wgpu::Maintain::Wait); - - let result = receiver.receive().await; - - if let Some(Ok(())) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); - - drop(data); - self.buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) + #[cfg(target_family = "wasm")] + async fn read(self, device: alloc::sync::Arc) -> Vec { + self.read_async(&device).await + } + + #[cfg(not(target_family = "wasm"))] + fn read(self, device: &wgpu::Device) -> Vec { + pollster::block_on(self.read_async(device)) + } + + async fn read_async(&self, device: &wgpu::Device) -> Vec { + let buffer_slice = self.buffer.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .send(v) + .expect("Unable to send buffer slice result to async channel.") + }); + + device.poll(wgpu::Maintain::Wait); + + let result = receiver.receive().await; + + if let Some(Ok(())) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + self.buffer.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) + } } - } } impl ComputeServer for WgpuServer where - MM: MemoryManagement, + MM: MemoryManagement, { - type Kernel = Box; - type Storage = WgpuStorage; - type MemoryManagement = MM; - type AutotuneKey = WgpuAutotuneKey; + type Kernel = Box; + type Storage = WgpuStorage; + type MemoryManagement = MM; + type AutotuneKey = WgpuAutotuneKey; + + fn read(&mut self, handle: &server::Handle) -> Reader> { + #[cfg(target_family = "wasm")] + { + let future = self.buffer_reader(handle).read(self.device.clone()); + return Reader::Future(Box::pin(future)); + } + + #[cfg(not(target_family = "wasm"))] + Reader::Concrete(self.buffer_reader(handle).read(&self.device)) + } - fn read(&mut self, handle: &server::Handle) -> Reader> { - #[cfg(target_family = "wasm")] - { - let future = self.buffer_reader(handle).read(self.device.clone()); - return Reader::Future(Box::pin(future)); + /// When we create a new handle from existing data, we use custom allocations so that we don't + /// have to execute the current pending tasks. + /// + /// This is important, otherwise the compute passes are going to be too small and we won't be able to + /// fully utilize the GPU. + fn create(&mut self, data: &[u8]) -> server::Handle { + let handle = self.manual_reserve(data.len()); + + let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { + label: Some("Buffer Src"), + contents: data, + usage: wgpu::BufferUsages::COPY_SRC, + })); + + let resource = self.memory_management.get(&handle.memory); + + self.encoder.copy_buffer_to_buffer( + &buffer_src, + 0, + &resource.buffer, + resource.offset(), + buffer_src.size(), + ); + + handle } - #[cfg(not(target_family = "wasm"))] - Reader::Concrete(self.buffer_reader(handle).read(&self.device)) - } - - /// When we create a new handle from existing data, we use custom allocations so that we don't - /// have to execute the current pending tasks. - /// - /// This is important, otherwise the compute passes are going to be too small and we won't be able to - /// fully utilize the GPU. - fn create(&mut self, data: &[u8]) -> server::Handle { - let handle = self.manual_reserve(data.len()); - - let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { - label: Some("Buffer Src"), - contents: data, - usage: wgpu::BufferUsages::COPY_SRC, - })); - - let resource = self.memory_management.get(&handle.memory); - - self.encoder.copy_buffer_to_buffer( - &buffer_src, - 0, - &resource.buffer, - resource.offset(), - buffer_src.size(), - ); - - handle - } - - fn empty(&mut self, size: usize) -> server::Handle { - server::Handle::new(self.memory_management.reserve(size)) - } - - fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle]) { - let work_group = kernel.workgroup(); - let pipeline = self.pipeline(kernel); - let group_layout = pipeline.get_bind_group_layout(0); - - let handles = handles - .iter() - .map(|handle| self.memory_management.get(&handle.memory)) - .collect::>(); - - let entries = handles - .iter() - .enumerate() - .map(|(i, buffer)| wgpu::BindGroupEntry { - binding: i as u32, - resource: buffer.as_binding(), - }) - .collect::>(); - - let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { - label: None, - layout: &group_layout, - entries: &entries, - }); - - self - .tasks - .push(ComputeTask::new(pipeline, bind_group, work_group)); - - if self.tasks.len() >= self.max_tasks { - self.register_tasks(); - self.submit(); + fn empty(&mut self, size: usize) -> server::Handle { + server::Handle::new(self.memory_management.reserve(size)) } - } - fn sync(&mut self) { - if !self.tasks.is_empty() { - self.register_tasks(); - self.submit(); + fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle]) { + let work_group = kernel.workgroup(); + let pipeline = self.pipeline(kernel); + let group_layout = pipeline.get_bind_group_layout(0); + + let handles = handles + .iter() + .map(|handle| self.memory_management.get(&handle.memory)) + .collect::>(); + + let entries = handles + .iter() + .enumerate() + .map(|(i, buffer)| wgpu::BindGroupEntry { + binding: i as u32, + resource: buffer.as_binding(), + }) + .collect::>(); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &group_layout, + entries: &entries, + }); + + self.tasks + .push(ComputeTask::new(pipeline, bind_group, work_group)); + + if self.tasks.len() >= self.max_tasks { + self.register_tasks(); + self.submit(); + } } - self.device.poll(wgpu::Maintain::Wait); - } + fn sync(&mut self) { + if !self.tasks.is_empty() { + self.register_tasks(); + self.submit(); + } + + self.device.poll(wgpu::Maintain::Wait); + } } diff --git a/burn-wgpu/src/compute/storage.rs b/burn-wgpu/src/compute/storage.rs index 11314cc14b..ef74a927a3 100644 --- a/burn-wgpu/src/compute/storage.rs +++ b/burn-wgpu/src/compute/storage.rs @@ -4,119 +4,121 @@ use std::{num::NonZeroU64, sync::Arc}; /// Buffer storage for wgpu. pub struct WgpuStorage { - memory: HashMap>, - deallocations: Vec, - device: Arc, + memory: HashMap>, + deallocations: Vec, + device: Arc, } impl core::fmt::Debug for WgpuStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) + } } /// The memory resource that can be allocated for wgpu. #[derive(new, Debug)] pub struct WgpuResource { - /// The wgpu buffer. - pub buffer: Arc, - /// How the resource is used. - pub kind: WgpuResourceKind, + /// The wgpu buffer. + pub buffer: Arc, + /// How the resource is used. + pub kind: WgpuResourceKind, } impl WgpuResource { - /// Return the binding view of the buffer. - pub fn as_binding(&self) -> wgpu::BindingResource { - let binding = match &self.kind { - WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), - WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { - buffer: &self.buffer, - offset: *offs, - size: Some(*size), - }, - }; - wgpu::BindingResource::Buffer(binding) - } - - /// Return the buffer size. - pub fn size(&self) -> u64 { - match self.kind { - WgpuResourceKind::Full => self.buffer.size(), - WgpuResourceKind::Slice(_, size) => size.get(), + /// Return the binding view of the buffer. + pub fn as_binding(&self) -> wgpu::BindingResource { + let binding = match &self.kind { + WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), + WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { + buffer: &self.buffer, + offset: *offs, + size: Some(*size), + }, + }; + wgpu::BindingResource::Buffer(binding) + } + + /// Return the buffer size. + pub fn size(&self) -> u64 { + match self.kind { + WgpuResourceKind::Full => self.buffer.size(), + WgpuResourceKind::Slice(_, size) => size.get(), + } } - } - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - match self.kind { - WgpuResourceKind::Full => 0, - WgpuResourceKind::Slice(offset, _) => offset, + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + match self.kind { + WgpuResourceKind::Full => 0, + WgpuResourceKind::Slice(offset, _) => offset, + } } - } } /// How the resource is used, either as a slice or fully. #[derive(Debug)] pub enum WgpuResourceKind { - /// Represents an entire buffer. - Full, - /// A slice over a buffer. - Slice(wgpu::BufferAddress, wgpu::BufferSize), + /// Represents an entire buffer. + Full, + /// A slice over a buffer. + Slice(wgpu::BufferAddress, wgpu::BufferSize), } /// Keeps actual wgpu buffer references in a hashmap with ids as key. impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc) -> Self { - Self { - memory: HashMap::new(), - deallocations: Vec::new(), - device, + /// Create a new storage on the given [device](wgpu::Device). + pub fn new(device: Arc) -> Self { + Self { + memory: HashMap::new(), + deallocations: Vec::new(), + device, + } } - } - - /// Actually deallocates buffers tagged to be deallocated. - pub fn perform_deallocations(&mut self) { - for id in self.deallocations.drain(..) { - if let Some(buffer) = self.memory.remove(&id) { - buffer.destroy() - } + + /// Actually deallocates buffers tagged to be deallocated. + pub fn perform_deallocations(&mut self) { + for id in self.deallocations.drain(..) { + if let Some(buffer) = self.memory.remove(&id) { + buffer.destroy() + } + } } - } } impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; + type Resource = WgpuResource; + + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let buffer = self.memory.get(&handle.id).unwrap(); + + match handle.utilization { + StorageUtilization::Full(_) => { + WgpuResource::new(buffer.clone(), WgpuResourceKind::Full) + } + StorageUtilization::Slice(offset, size) => WgpuResource::new( + buffer.clone(), + WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), + ), + } + } + + fn alloc(&mut self, size: usize) -> StorageHandle { + let id = StorageId::new(); + let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: size as u64, + usage: wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + })); - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let buffer = self.memory.get(&handle.id).unwrap(); + self.memory.insert(id.clone(), buffer); + + StorageHandle::new(id, StorageUtilization::Full(size)) + } - match handle.utilization { - StorageUtilization::Full(_) => WgpuResource::new(buffer.clone(), WgpuResourceKind::Full), - StorageUtilization::Slice(offset, size) => WgpuResource::new( - buffer.clone(), - WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), - ), + fn dealloc(&mut self, id: StorageId) { + self.deallocations.push(id); } - } - - fn alloc(&mut self, size: usize) -> StorageHandle { - let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size: size as u64, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC, - mapped_at_creation: false, - })); - - self.memory.insert(id.clone(), buffer); - - StorageHandle::new(id, StorageUtilization::Full(size)) - } - - fn dealloc(&mut self, id: StorageId) { - self.deallocations.push(id); - } } diff --git a/burn-wgpu/src/compute/tune_key.rs b/burn-wgpu/src/compute/tune_key.rs index 3015b19697..2b2ce25018 100644 --- a/burn-wgpu/src/compute/tune_key.rs +++ b/burn-wgpu/src/compute/tune_key.rs @@ -7,22 +7,22 @@ use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey}; #[derive(Hash, Eq, PartialEq, Debug, Clone)] /// Key for all autotune-enabled operations pub enum WgpuAutotuneKey { - /// Key for matmul operation - Matmul(MatmulAutotuneKey), - /// Key for sum_dim operations - SumDim(ReduceAutotuneKey), - /// Key for mean_dim operations - MeanDim(ReduceAutotuneKey), + /// Key for matmul operation + Matmul(MatmulAutotuneKey), + /// Key for sum_dim operations + SumDim(ReduceAutotuneKey), + /// Key for mean_dim operations + MeanDim(ReduceAutotuneKey), } impl Display for WgpuAutotuneKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), - WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), - WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), + WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + } } - } } impl AutotuneKey for WgpuAutotuneKey {} diff --git a/burn-wgpu/src/device.rs b/burn-wgpu/src/device.rs index 1e9eacc5eb..c9d6657fb8 100644 --- a/burn-wgpu/src/device.rs +++ b/burn-wgpu/src/device.rs @@ -12,39 +12,39 @@ /// ``` #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum WgpuDevice { - /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list - /// of all discrete GPUs found on the system. - DiscreteGpu(usize), + /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list + /// of all discrete GPUs found on the system. + DiscreteGpu(usize), - /// Integrated GPU with the given index. The index is the index of the integrated GPU in the - /// list of all integrated GPUs found on the system. - IntegratedGpu(usize), + /// Integrated GPU with the given index. The index is the index of the integrated GPU in the + /// list of all integrated GPUs found on the system. + IntegratedGpu(usize), - /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of - /// all virtual GPUs found on the system. - VirtualGpu(usize), + /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of + /// all virtual GPUs found on the system. + VirtualGpu(usize), - /// CPU. - Cpu, + /// CPU. + Cpu, - /// The best available device found with the current [graphics API](crate::GraphicsApi). - /// - /// Priority - /// - /// 1. DiscreteGpu - /// 2. IntegratedGpu - /// 3. VirtualGpu - /// 4. Cpu - /// - /// # Notes - /// - /// A device might be identified as [Other](wgpu::DeviceType::Other) by [wgpu](wgpu), in this case, we chose this device over - /// `IntegratedGpu` since it's often a discrete GPU. - BestAvailable, + /// The best available device found with the current [graphics API](crate::GraphicsApi). + /// + /// Priority + /// + /// 1. DiscreteGpu + /// 2. IntegratedGpu + /// 3. VirtualGpu + /// 4. Cpu + /// + /// # Notes + /// + /// A device might be identified as [Other](wgpu::DeviceType::Other) by [wgpu](wgpu), in this case, we chose this device over + /// `IntegratedGpu` since it's often a discrete GPU. + BestAvailable, } impl Default for WgpuDevice { - fn default() -> Self { - Self::BestAvailable - } + fn default() -> Self { + Self::BestAvailable + } } diff --git a/burn-wgpu/src/element.rs b/burn-wgpu/src/element.rs index baf8613cb4..b14ddfe8a8 100644 --- a/burn-wgpu/src/element.rs +++ b/burn-wgpu/src/element.rs @@ -2,13 +2,13 @@ use burn_tensor::Element; /// The base element trait for the wgou backend. pub trait WgpuElement: - burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod + burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod where - Self: Sized, + Self: Sized, { - fn type_name() -> &'static str; - fn as_bytes(slice: &[Self]) -> &[u8]; - fn from_bytes(bytes: &[u8]) -> &[Self]; + fn type_name() -> &'static str; + fn as_bytes(slice: &[Self]) -> &[u8]; + fn from_bytes(bytes: &[u8]) -> &[Self]; } /// The float element type for the wgpu backend. @@ -18,39 +18,39 @@ pub trait FloatElement: WgpuElement + Element {} pub trait IntElement: WgpuElement + Element {} impl WgpuElement for u32 { - fn type_name() -> &'static str { - "u32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } + fn type_name() -> &'static str { + "u32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } } impl WgpuElement for i32 { - fn type_name() -> &'static str { - "i32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } + fn type_name() -> &'static str { + "i32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } } impl WgpuElement for f32 { - fn type_name() -> &'static str { - "f32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } + fn type_name() -> &'static str { + "f32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } } impl FloatElement for f32 {} diff --git a/burn-wgpu/src/fusion/base.rs b/burn-wgpu/src/fusion/base.rs index 1c1274e26a..b1e1f864c5 100644 --- a/burn-wgpu/src/fusion/base.rs +++ b/burn-wgpu/src/fusion/base.rs @@ -1,144 +1,144 @@ use crate::{ - compute::{WgpuComputeClient, WgpuHandle}, - element::WgpuElement, - fusion::FloatElementWiseFusionOps, - tensor::WgpuTensor, - FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice, + compute::{WgpuComputeClient, WgpuHandle}, + element::WgpuElement, + fusion::FloatElementWiseFusionOps, + tensor::WgpuTensor, + FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice, }; use burn_fusion::{ - client::MutexFusionClient, graph::GreedyGraphExecution, DeviceId, FusionBackend, FusionDevice, + client::MutexFusionClient, graph::GreedyGraphExecution, DeviceId, FusionBackend, FusionDevice, }; use burn_tensor::Shape; use core::marker::PhantomData; impl FusionDevice for WgpuDevice { - fn id(&self) -> DeviceId { - match self { - WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), - WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), - WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), - WgpuDevice::Cpu => DeviceId::new(3, 0), - WgpuDevice::BestAvailable => DeviceId::new(4, 0), + fn id(&self) -> DeviceId { + match self { + WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), + WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), + WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), + WgpuDevice::Cpu => DeviceId::new(3, 0), + WgpuDevice::BestAvailable => DeviceId::new(4, 0), + } } - } } impl FusionBackend for Wgpu where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - type FusionDevice = WgpuDevice; - type Handle = WgpuFusionHandle; - type FusionClient = MutexFusionClient; - - fn operations(device: &WgpuDevice) -> Vec>> { - vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))] - } - - fn float_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::TensorPrimitive { - handle.into_tensor(shape) - } - - fn int_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::IntTensorPrimitive { - handle.into_tensor(shape) - } - - fn bool_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::BoolTensorPrimitive { - handle.into_tensor(shape) - } - - fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle { - tensor.into() - } - - fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle { - tensor.into() - } - - fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle { - tensor.into() - } + type FusionDevice = WgpuDevice; + type Handle = WgpuFusionHandle; + type FusionClient = MutexFusionClient; + + fn operations(device: &WgpuDevice) -> Vec>> { + vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))] + } + + fn float_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::TensorPrimitive { + handle.into_tensor(shape) + } + + fn int_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::IntTensorPrimitive { + handle.into_tensor(shape) + } + + fn bool_tensor( + handle: Self::Handle, + shape: Shape, + ) -> Self::BoolTensorPrimitive { + handle.into_tensor(shape) + } + + fn float_tensor_handle(tensor: Self::TensorPrimitive) -> Self::Handle { + tensor.into() + } + + fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle { + tensor.into() + } + + fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle { + tensor.into() + } } pub fn strides_dyn_rank(shape: &[usize]) -> Vec { - let mut strides = vec![0; shape.len()]; + let mut strides = vec![0; shape.len()]; - let mut current = 1; - shape.iter().enumerate().rev().for_each(|(index, val)| { - strides[index] = current; - current *= val; - }); + let mut current = 1; + shape.iter().enumerate().rev().for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); - strides + strides } pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize { - let mut num_elems = 1; - for i in shape.iter() { - num_elems *= i; - } - num_elems + let mut num_elems = 1; + for i in shape.iter() { + num_elems *= i; + } + num_elems } #[derive(new, Debug, Clone)] /// Handle to be used when fusing operations. pub struct WgpuFusionHandle { - /// Compute client for wgpu. - pub client: WgpuComputeClient, - /// The buffer where the data are stored. - pub handle: WgpuHandle, - /// The device of the current tensor. - pub device: WgpuDevice, - pub(crate) strides: Vec, + /// Compute client for wgpu. + pub client: WgpuComputeClient, + /// The buffer where the data are stored. + pub handle: WgpuHandle, + /// The device of the current tensor. + pub device: WgpuDevice, + pub(crate) strides: Vec, } impl WgpuFusionHandle { - pub(crate) fn into_tensor( - self, - shape: Shape, - ) -> WgpuTensor { - WgpuTensor { - client: self.client, - handle: self.handle, - device: self.device, - shape, - strides: self.strides.try_into().expect("Wrong dimension"), - elem: PhantomData, + pub(crate) fn into_tensor( + self, + shape: Shape, + ) -> WgpuTensor { + WgpuTensor { + client: self.client, + handle: self.handle, + device: self.device, + shape, + strides: self.strides.try_into().expect("Wrong dimension"), + elem: PhantomData, + } } - } } impl From> for WgpuFusionHandle { - fn from(value: WgpuTensor) -> Self { - Self { - client: value.client, - handle: value.handle, - device: value.device, - strides: value.strides.into(), + fn from(value: WgpuTensor) -> Self { + Self { + client: value.client, + handle: value.handle, + device: value.device, + strides: value.strides.into(), + } } - } } #[cfg(test)] mod tests { - use super::*; - use burn_fusion::Fusion; + use super::*; + use burn_fusion::Fusion; - pub type TestBackend = Fusion; - pub type TestTensor = burn_tensor::Tensor; - pub type TestTensorInt = burn_tensor::Tensor; + pub type TestBackend = Fusion; + pub type TestTensor = burn_tensor::Tensor; + pub type TestTensorInt = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!(); + burn_autodiff::testgen_all!(); } diff --git a/burn-wgpu/src/fusion/codegen/body.rs b/burn-wgpu/src/fusion/codegen/body.rs index a08cf100ad..cab35bf75d 100644 --- a/burn-wgpu/src/fusion/codegen/body.rs +++ b/burn-wgpu/src/fusion/codegen/body.rs @@ -7,19 +7,21 @@ use std::fmt::Display; /// X and Y, but with Z=1. #[derive(Hash, new)] pub struct Body { - operators: Vec, + operators: Vec, } impl Display for Body { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n")?; - f.write_str("let rank: u32 = info[0];\n\n")?; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + "let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n", + )?; + f.write_str("let rank: u32 = info[0];\n\n")?; - for ops in self.operators.iter() { - f.write_fmt(format_args!("{ops}"))?; - f.write_str("\n")?; - } + for ops in self.operators.iter() { + f.write_fmt(format_args!("{ops}"))?; + f.write_str("\n")?; + } - Ok(()) - } + Ok(()) + } } diff --git a/burn-wgpu/src/fusion/codegen/function.rs b/burn-wgpu/src/fusion/codegen/function.rs index fa3c27d1f6..fceae4e399 100644 --- a/burn-wgpu/src/fusion/codegen/function.rs +++ b/burn-wgpu/src/fusion/codegen/function.rs @@ -4,22 +4,22 @@ use std::fmt::Display; /// Not all functions are native to WGSL, so this struct allows to support more functions. #[derive(Hash, PartialEq, Eq, Clone)] pub enum Function { - Powf(Elem), - Erf(Elem), + Powf(Elem), + Erf(Elem), } impl Display for Function { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Function::Powf(elem) => format_powf(f, elem), - Function::Erf(elem) => format_erf(f, elem), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Function::Powf(elem) => format_powf(f, elem), + Function::Erf(elem) => format_erf(f, elem), + } } - } } fn format_powf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result { - f.write_fmt(format_args!( - " + f.write_fmt(format_args!( + " fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{ let modulo = rhs % 2.0; @@ -35,11 +35,11 @@ fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{ }} }} " - )) + )) } fn format_erf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result { - f.write_fmt(format_args!( + f.write_fmt(format_args!( " /// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations /// diff --git a/burn-wgpu/src/fusion/codegen/operator.rs b/burn-wgpu/src/fusion/codegen/operator.rs index 474cefeedf..922f7cf81d 100644 --- a/burn-wgpu/src/fusion/codegen/operator.rs +++ b/burn-wgpu/src/fusion/codegen/operator.rs @@ -4,119 +4,133 @@ use std::fmt::Display; /// All operators that can be fused in a WGSL compute shader. #[derive(Debug, Hash, Clone)] pub enum Operator { - Add { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Sub { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Mul { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Div { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Abs { - input: Variable, - out: Variable, - }, - Exp { - input: Variable, - out: Variable, - }, - Log { - input: Variable, - out: Variable, - }, - Log1p { - input: Variable, - out: Variable, - }, - Cos { - input: Variable, - out: Variable, - }, - Sin { - input: Variable, - out: Variable, - }, - Tanh { - input: Variable, - out: Variable, - }, - Powf { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Erf { - input: Variable, - out: Variable, - }, - Recip { - input: Variable, - out: Variable, - }, - AssignGlobal { - input: Variable, - out: Variable, - }, - ReadGlobal { - variable: Variable, - position: usize, - position_out: usize, - }, + Add { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Sub { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Mul { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Div { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Abs { + input: Variable, + out: Variable, + }, + Exp { + input: Variable, + out: Variable, + }, + Log { + input: Variable, + out: Variable, + }, + Log1p { + input: Variable, + out: Variable, + }, + Cos { + input: Variable, + out: Variable, + }, + Sin { + input: Variable, + out: Variable, + }, + Tanh { + input: Variable, + out: Variable, + }, + Powf { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Erf { + input: Variable, + out: Variable, + }, + Recip { + input: Variable, + out: Variable, + }, + AssignGlobal { + input: Variable, + out: Variable, + }, + ReadGlobal { + variable: Variable, + position: usize, + position_out: usize, + }, } impl Display for Operator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Operator::Add { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} + {rhs};")), - Operator::Sub { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} - {rhs};")), - Operator::Mul { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} * {rhs};")), - Operator::Div { lhs, rhs, out } => f.write_fmt(format_args!("let {out} = {lhs} / {rhs};")), - Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")), - Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")), - Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")), - Operator::Powf { lhs, rhs, out } => { - f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});")) - } - Operator::Log1p { input, out } => { - f.write_fmt(format_args!("let {out} = log({input} + 1.0);")) - } - Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")), - Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")), - Operator::Tanh { input, out } => f.write_fmt(format_args!("let {out} = tanh({input});")), - Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")), - Operator::Recip { input, out } => f.write_fmt(format_args!("let {out} = 1.0 / {input};")), - Operator::AssignGlobal { input, out } => { - f.write_fmt(format_args!("{out}_global[id] = {input};")) - } - Operator::ReadGlobal { - variable, - position, - position_out, - } => { - let (global, local) = match variable { - Variable::Input(number) => (format!("input_{number}_global"), format!("input_{number}")), - Variable::Local(_) => panic!("can't read globala local variable."), - Variable::Output(number) => ( - format!("output_{number}_global"), - format!("output_{number}"), - ), - Variable::Scalar(_, _) => panic!("Can't read global scalar variable."), - }; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Operator::Add { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} + {rhs};")) + } + Operator::Sub { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} - {rhs};")) + } + Operator::Mul { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} * {rhs};")) + } + Operator::Div { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} / {rhs};")) + } + Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")), + Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")), + Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")), + Operator::Powf { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});")) + } + Operator::Log1p { input, out } => { + f.write_fmt(format_args!("let {out} = log({input} + 1.0);")) + } + Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")), + Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")), + Operator::Tanh { input, out } => { + f.write_fmt(format_args!("let {out} = tanh({input});")) + } + Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")), + Operator::Recip { input, out } => { + f.write_fmt(format_args!("let {out} = 1.0 / {input};")) + } + Operator::AssignGlobal { input, out } => { + f.write_fmt(format_args!("{out}_global[id] = {input};")) + } + Operator::ReadGlobal { + variable, + position, + position_out, + } => { + let (global, local) = match variable { + Variable::Input(number) => { + (format!("input_{number}_global"), format!("input_{number}")) + } + Variable::Local(_) => panic!("can't read globala local variable."), + Variable::Output(number) => ( + format!("output_{number}_global"), + format!("output_{number}"), + ), + Variable::Scalar(_, _) => panic!("Can't read global scalar variable."), + }; - f.write_fmt(format_args!( - " + f.write_fmt(format_args!( + " var index_{local}: u32 = 0u; for (var i: u32 = 1u; i <= rank; i++) {{ @@ -132,8 +146,8 @@ for (var i: u32 = 1u; i <= rank; i++) {{ let {local} = {global}[index_{local}]; " - )) - } + )) + } + } } - } } diff --git a/burn-wgpu/src/fusion/codegen/shader.rs b/burn-wgpu/src/fusion/codegen/shader.rs index ee7f8f330b..8ce3999784 100644 --- a/burn-wgpu/src/fusion/codegen/shader.rs +++ b/burn-wgpu/src/fusion/codegen/shader.rs @@ -1,201 +1,201 @@ use super::{Body, Function}; use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT}; use std::{ - collections::hash_map::DefaultHasher, - fmt::Display, - hash::{Hash, Hasher}, + collections::hash_map::DefaultHasher, + fmt::Display, + hash::{Hash, Hasher}, }; #[derive(Hash, PartialEq, Eq)] pub enum Location { - Storage, - #[allow(dead_code)] - Workgroup, + Storage, + #[allow(dead_code)] + Workgroup, } #[derive(Hash, PartialEq, Eq)] pub enum Visibility { - Read, - ReadWrite, + Read, + ReadWrite, } #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Elem { - F32, - #[allow(dead_code)] - I32, - U32, + F32, + #[allow(dead_code)] + I32, + U32, } #[derive(Hash, PartialEq, Eq)] pub struct Binding { - pub location: Location, - pub visibility: Visibility, - pub elem: Elem, - pub size: Option, + pub location: Location, + pub visibility: Visibility, + pub elem: Elem, + pub size: Option, } #[derive(Hash, PartialEq, Eq)] pub struct WorkgroupSize { - pub x: usize, - pub y: usize, - pub z: usize, + pub x: usize, + pub y: usize, + pub z: usize, } impl Default for WorkgroupSize { - fn default() -> Self { - Self { - x: WORKGROUP_DEFAULT, - y: WORKGROUP_DEFAULT, - z: 1, + fn default() -> Self { + Self { + x: WORKGROUP_DEFAULT, + y: WORKGROUP_DEFAULT, + z: 1, + } } - } } #[derive(Hash)] pub struct ComputeShader { - pub inputs: Vec, - pub outputs: Vec, - pub named: Vec<(String, Binding)>, - pub workgroup_size: WorkgroupSize, - pub global_invocation_id: bool, - pub num_workgroups: bool, - pub body: Body, - pub functions: Vec, + pub inputs: Vec, + pub outputs: Vec, + pub named: Vec<(String, Binding)>, + pub workgroup_size: WorkgroupSize, + pub global_invocation_id: bool, + pub num_workgroups: bool, + pub body: Body, + pub functions: Vec, } impl DynamicKernelSource for ComputeShader { - fn source(&self) -> SourceTemplate { - SourceTemplate::new(self.to_string()) - } + fn source(&self) -> SourceTemplate { + SourceTemplate::new(self.to_string()) + } - fn id(&self) -> String { - let mut s = DefaultHasher::new(); - self.hash(&mut s); + fn id(&self) -> String { + let mut s = DefaultHasher::new(); + self.hash(&mut s); - s.finish().to_string() - } + s.finish().to_string() + } } impl Display for ComputeShader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Self::format_bindings(f, "input", &self.inputs, 0)?; - Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?; - - for (i, (name, binding)) in self.named.iter().enumerate() { - Self::format_binding( - f, - name.as_str(), - binding, - self.inputs.len() + self.outputs.len() + i, - )?; - } - - f.write_fmt(format_args!( - "const WORKGROUP_SIZE_X = {}u; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Self::format_bindings(f, "input", &self.inputs, 0)?; + Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?; + + for (i, (name, binding)) in self.named.iter().enumerate() { + Self::format_binding( + f, + name.as_str(), + binding, + self.inputs.len() + self.outputs.len() + i, + )?; + } + + f.write_fmt(format_args!( + "const WORKGROUP_SIZE_X = {}u; const WORKGROUP_SIZE_Y = {}u; const WORKGROUP_SIZE_Z = {}u;\n", - self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z - ))?; + self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z + ))?; - f.write_fmt(format_args!( - " + f.write_fmt(format_args!( + " @compute @workgroup_size({}, {}, {}) fn main( ", - self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z - ))?; + self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z + ))?; - if self.global_invocation_id { - f.write_str(" @builtin(global_invocation_id) global_id: vec3,\n")?; - } + if self.global_invocation_id { + f.write_str(" @builtin(global_invocation_id) global_id: vec3,\n")?; + } - if self.num_workgroups { - f.write_str(" @builtin(num_workgroups) num_workgroups: vec3,\n")?; - } + if self.num_workgroups { + f.write_str(" @builtin(num_workgroups) num_workgroups: vec3,\n")?; + } - f.write_fmt(format_args!( - ") {{ + f.write_fmt(format_args!( + ") {{ {} }}", - self.body - ))?; + self.body + ))?; - for function in self.functions.iter() { - f.write_fmt(format_args!("{function}\n\n"))?; - } + for function in self.functions.iter() { + f.write_fmt(format_args!("{function}\n\n"))?; + } - Ok(()) - } + Ok(()) + } } impl ComputeShader { - fn format_bindings( - f: &mut core::fmt::Formatter<'_>, - prefix: &str, - bindings: &[Binding], - num_entry: usize, - ) -> core::fmt::Result { - for (i, binding) in bindings.iter().enumerate() { - Self::format_binding( - f, - format!("{prefix}_{i}_global").as_str(), - binding, - num_entry + i, - )?; + fn format_bindings( + f: &mut core::fmt::Formatter<'_>, + prefix: &str, + bindings: &[Binding], + num_entry: usize, + ) -> core::fmt::Result { + for (i, binding) in bindings.iter().enumerate() { + Self::format_binding( + f, + format!("{prefix}_{i}_global").as_str(), + binding, + num_entry + i, + )?; + } + + Ok(()) } - Ok(()) - } - - fn format_binding( - f: &mut core::fmt::Formatter<'_>, - name: &str, - binding: &Binding, - num_entry: usize, - ) -> core::fmt::Result { - let ty = match binding.size { - Some(size) => format!("array<{}, {}>", binding.elem, size), - None => format!("array<{}>", binding.elem), - }; - - f.write_fmt(format_args!( - "@group(0) + fn format_binding( + f: &mut core::fmt::Formatter<'_>, + name: &str, + binding: &Binding, + num_entry: usize, + ) -> core::fmt::Result { + let ty = match binding.size { + Some(size) => format!("array<{}, {}>", binding.elem, size), + None => format!("array<{}>", binding.elem), + }; + + f.write_fmt(format_args!( + "@group(0) @binding({}) var<{}, {}> {}: {}; \n", - num_entry, binding.location, binding.visibility, name, ty - ))?; + num_entry, binding.location, binding.visibility, name, ty + ))?; - Ok(()) - } + Ok(()) + } } impl Display for Location { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Location::Storage => f.write_str("storage"), - Location::Workgroup => f.write_str("workgroup"), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Location::Storage => f.write_str("storage"), + Location::Workgroup => f.write_str("workgroup"), + } } - } } impl Display for Elem { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Elem::F32 => f.write_str("f32"), - Elem::I32 => f.write_str("i32"), - Elem::U32 => f.write_str("u32"), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Elem::F32 => f.write_str("f32"), + Elem::I32 => f.write_str("i32"), + Elem::U32 => f.write_str("u32"), + } } - } } impl Display for Visibility { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Visibility::Read => f.write_str("read"), - Visibility::ReadWrite => f.write_str("read_write"), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Visibility::Read => f.write_str("read"), + Visibility::ReadWrite => f.write_str("read_write"), + } } - } } diff --git a/burn-wgpu/src/fusion/codegen/variable.rs b/burn-wgpu/src/fusion/codegen/variable.rs index f827bcef6d..b74c4dbb80 100644 --- a/burn-wgpu/src/fusion/codegen/variable.rs +++ b/burn-wgpu/src/fusion/codegen/variable.rs @@ -3,19 +3,19 @@ use std::fmt::Display; #[derive(Debug, Hash, Clone)] pub enum Variable { - Input(u16), - Scalar(u16, Elem), - Local(u16), - Output(u16), + Input(u16), + Scalar(u16, Elem), + Local(u16), + Output(u16), } impl Display for Variable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Variable::Input(number) => f.write_fmt(format_args!("input_{number}")), - Variable::Local(number) => f.write_fmt(format_args!("local_{number}")), - Variable::Output(number) => f.write_fmt(format_args!("output_{number}")), - Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Variable::Input(number) => f.write_fmt(format_args!("input_{number}")), + Variable::Local(number) => f.write_fmt(format_args!("local_{number}")), + Variable::Output(number) => f.write_fmt(format_args!("output_{number}")), + Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")), + } } - } } diff --git a/burn-wgpu/src/fusion/elemwise/ops.rs b/burn-wgpu/src/fusion/elemwise/ops.rs index 87db9e7b99..78dec96fd6 100644 --- a/burn-wgpu/src/fusion/elemwise/ops.rs +++ b/burn-wgpu/src/fusion/elemwise/ops.rs @@ -1,15 +1,15 @@ use crate::{ - fusion::codegen::{Elem, Operator, Variable}, - fusion::kernel::FusionKernel, - FloatElement, GraphicsApi, IntElement, Wgpu, + fusion::codegen::{Elem, Operator, Variable}, + fusion::kernel::FusionKernel, + FloatElement, GraphicsApi, IntElement, Wgpu, }; use burn_fusion::{ - graph::{ - BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription, - TensorOpsDescription, UnaryOpsDescription, - }, - FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, - TensorId, + graph::{ + BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription, + TensorOpsDescription, UnaryOpsDescription, + }, + FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, + TensorId, }; use burn_tensor::{Device, Element}; use hashbrown::HashMap; @@ -18,454 +18,453 @@ use std::sync::Arc; /// Fused element wise operations that are normally memory bound. pub struct FloatElementWiseFusionOps where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - pub(crate) inputs: Vec, - pub(crate) locals: HashMap, - pub(crate) tensors: HashMap, - pub(crate) scalars_f32: Vec, - pub(crate) operators: Vec, - pub(crate) properties: FusionProperties, - pub(crate) current_output_shape: Vec, - device: Device>, + pub(crate) inputs: Vec, + pub(crate) locals: HashMap, + pub(crate) tensors: HashMap, + pub(crate) scalars_f32: Vec, + pub(crate) operators: Vec, + pub(crate) properties: FusionProperties, + pub(crate) current_output_shape: Vec, + device: Device>, } impl FusionOps> - for FloatElementWiseFusionOps + for FloatElementWiseFusionOps { - fn register(&mut self, ops: Arc>>) -> FusionStatus { - match ops.as_ref() { - TensorOpsDescription::FloatOps(ops) => { - if !self.register_float(ops) { - return FusionStatus::Closed(self.properties); - } - } - TensorOpsDescription::NumericOpsFloat(ops) => { - if !self.register_numeric(ops) { - return FusionStatus::Closed(self.properties); - } - } - _ => { - return FusionStatus::Closed(self.properties); - } - }; - - self.properties.score += 1; - self.properties.ready = self.operators.len() > 1; - - FusionStatus::Open(self.properties) - } - - fn execute(&mut self, handles: &mut HandleContainer>) { - let inputs = self.input_descriptions(); - let outputs = self.output_descriptions(); - let locals = outputs - .iter() - .map(|out| *self.locals.get(&out.id).unwrap()) - .collect::>(); - - FusionKernel::new(&self.device) - .inputs(&inputs, &self.scalars_f32) - .body(&self.operators) - .outputs(&outputs, &locals) - .execute(handles); - } - - fn reset(&mut self) { - self.inputs.clear(); - self.locals.drain(); - self.tensors.clear(); - self.scalars_f32.clear(); - self.operators.clear(); - self.properties = FusionProperties::default(); - self.current_output_shape.clear(); - } - - fn len(&self) -> usize { - self.operators.len() - } + fn register(&mut self, ops: Arc>>) -> FusionStatus { + match ops.as_ref() { + TensorOpsDescription::FloatOps(ops) => { + if !self.register_float(ops) { + return FusionStatus::Closed(self.properties); + } + } + TensorOpsDescription::NumericOpsFloat(ops) => { + if !self.register_numeric(ops) { + return FusionStatus::Closed(self.properties); + } + } + _ => { + return FusionStatus::Closed(self.properties); + } + }; + + self.properties.score += 1; + self.properties.ready = self.operators.len() > 1; + + FusionStatus::Open(self.properties) + } + + fn execute(&mut self, handles: &mut HandleContainer>) { + let inputs = self.input_descriptions(); + let outputs = self.output_descriptions(); + let locals = outputs + .iter() + .map(|out| *self.locals.get(&out.id).unwrap()) + .collect::>(); + + FusionKernel::new(&self.device) + .inputs(&inputs, &self.scalars_f32) + .body(&self.operators) + .outputs(&outputs, &locals) + .execute(handles); + } + + fn reset(&mut self) { + self.inputs.clear(); + self.locals.drain(); + self.tensors.clear(); + self.scalars_f32.clear(); + self.operators.clear(); + self.properties = FusionProperties::default(); + self.current_output_shape.clear(); + } + + fn len(&self) -> usize { + self.operators.len() + } } impl FloatElementWiseFusionOps where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - pub fn new(device: Device>) -> Self { - Self { - inputs: Vec::new(), - locals: HashMap::new(), - tensors: HashMap::new(), - scalars_f32: Vec::new(), - operators: Vec::new(), - current_output_shape: Vec::new(), - properties: FusionProperties::default(), - device, - } - } - - fn input_descriptions(&self) -> Vec<&TensorDescription> { - self - .inputs - .iter() - .map(|input| { - let updated_tensor = self.tensors.get(&input.id).unwrap(); - updated_tensor - }) - .collect::>() - } - - fn output_descriptions(&self) -> Vec<&TensorDescription> { - let mut outputs = Vec::new(); - let mut local_tensor_ids_input = Vec::new(); - let mut local_tensor_ids_output = Vec::new(); - - // Mark a variable to the provided list of tensor ids using the variable list. - // - // Only local variables can become outputs. - let mark = |var: &Variable, list: &mut Vec| { - if let Variable::Local(index) = var { - if let Some((id, _)) = self - .locals - .iter() - .find(|(_id, position)| *position == index) - { - if !list.contains(id) { - list.push(id.clone()); - } - } - } - }; - - // For all operators, mark their local tensor id in the proper set. - for ops in self.operators.iter() { - match ops { - Operator::AssignGlobal { input: _, out: _ } => { - // Nothing to do here. - } - Operator::ReadGlobal { - variable: _, - position: _, - position_out: _, - } => { - // Nothing to do here. - } - Operator::Add { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Sub { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Mul { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Div { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Exp { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Abs { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Erf { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Log { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Log1p { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Cos { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Sin { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Tanh { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - Operator::Powf { lhs, rhs, out } => { - mark(lhs, &mut local_tensor_ids_input); - mark(rhs, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); + pub fn new(device: Device>) -> Self { + Self { + inputs: Vec::new(), + locals: HashMap::new(), + tensors: HashMap::new(), + scalars_f32: Vec::new(), + operators: Vec::new(), + current_output_shape: Vec::new(), + properties: FusionProperties::default(), + device, } - Operator::Recip { input, out } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } - } } - // All output tensors that are never read by a following operation should be written to - // since they are essentially the "logical" output of the shader. - for out in local_tensor_ids_output { - let is_read = local_tensor_ids_input.contains(&out); - - if !is_read { - outputs.push(self.tensors.get(&out).unwrap()); - } + fn input_descriptions(&self) -> Vec<&TensorDescription> { + self.inputs + .iter() + .map(|input| { + let updated_tensor = self.tensors.get(&input.id).unwrap(); + updated_tensor + }) + .collect::>() } - // All tensors where their latest description is read only should be written to since they - // are going to be used after the fused kernel by other operations. - for tensor in self.tensors.values() { - if let burn_fusion::TensorStatus::ReadOnly = tensor.status { - if self.locals.contains_key(&tensor.id) { - outputs.push(tensor); + fn output_descriptions(&self) -> Vec<&TensorDescription> { + let mut outputs = Vec::new(); + let mut local_tensor_ids_input = Vec::new(); + let mut local_tensor_ids_output = Vec::new(); + + // Mark a variable to the provided list of tensor ids using the variable list. + // + // Only local variables can become outputs. + let mark = |var: &Variable, list: &mut Vec| { + if let Variable::Local(index) = var { + if let Some((id, _)) = self + .locals + .iter() + .find(|(_id, position)| *position == index) + { + if !list.contains(id) { + list.push(id.clone()); + } + } + } + }; + + // For all operators, mark their local tensor id in the proper set. + for ops in self.operators.iter() { + match ops { + Operator::AssignGlobal { input: _, out: _ } => { + // Nothing to do here. + } + Operator::ReadGlobal { + variable: _, + position: _, + position_out: _, + } => { + // Nothing to do here. + } + Operator::Add { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Sub { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Mul { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Div { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Exp { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Abs { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Erf { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Log { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Log1p { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Cos { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Sin { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Tanh { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Powf { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Recip { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + } } - } - } - outputs - } - - fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable { - let already_exists = self.tensors.contains_key(&tensor.id); - - let variable = match already_exists { - false => { - // New input - let var = Variable::Input(self.inputs.len() as u16); - self.inputs.push(tensor.clone()); - var - } - true => match self.locals.get(&tensor.id) { - // Is a local variable. - Some(local_index) => Variable::Local(*local_index), - // Isn't a local variable, so must be an existing input. - None => { - let input = self - .inputs - .iter() - .enumerate() - .find(|(_, input)| input.id == tensor.id) - .unwrap(); - let input_index = input.0; - Variable::Input(input_index as u16) - } - }, - }; + // All output tensors that are never read by a following operation should be written to + // since they are essentially the "logical" output of the shader. + for out in local_tensor_ids_output { + let is_read = local_tensor_ids_input.contains(&out); - // Update the tensor description with the new version. - self.tensors.insert(tensor.id.clone(), tensor.clone()); + if !is_read { + outputs.push(self.tensors.get(&out).unwrap()); + } + } - variable - } + // All tensors where their latest description is read only should be written to since they + // are going to be used after the fused kernel by other operations. + for tensor in self.tensors.values() { + if let burn_fusion::TensorStatus::ReadOnly = tensor.status { + if self.locals.contains_key(&tensor.id) { + outputs.push(tensor); + } + } + } - fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable { - // Update the tensor description to the new version. - self.tensors.insert(tensor.id.clone(), tensor.clone()); + outputs + } - // Output already registered as a local variable. - if let Some(index) = self.locals.get(&tensor.id) { - return Variable::Local(*index); + fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable { + let already_exists = self.tensors.contains_key(&tensor.id); + + let variable = match already_exists { + false => { + // New input + let var = Variable::Input(self.inputs.len() as u16); + self.inputs.push(tensor.clone()); + var + } + true => match self.locals.get(&tensor.id) { + // Is a local variable. + Some(local_index) => Variable::Local(*local_index), + // Isn't a local variable, so must be an existing input. + None => { + let input = self + .inputs + .iter() + .enumerate() + .find(|(_, input)| input.id == tensor.id) + .unwrap(); + let input_index = input.0; + Variable::Input(input_index as u16) + } + }, + }; + + // Update the tensor description with the new version. + self.tensors.insert(tensor.id.clone(), tensor.clone()); + + variable } - // New local variable. - let local_index = self.locals.len() as u16; - self.locals.insert(tensor.id.clone(), local_index); - Variable::Local(local_index) - } - - fn register_float(&mut self, ops: &FloatOpsDescription) -> bool { - match ops { - FloatOpsDescription::Exp(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Exp { input, out }) - } - FloatOpsDescription::Log(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Log { input, out }) - } - FloatOpsDescription::Log1p(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out }) - } - FloatOpsDescription::Cos(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Cos { input, out }) - } - FloatOpsDescription::Sin(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Sin { input, out }) - } - FloatOpsDescription::Powf(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out }) - } - FloatOpsDescription::Tanh(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out }) - } - FloatOpsDescription::Erf(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Erf { input, out }) - } - FloatOpsDescription::Recip(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Recip { input, out }) - } - _ => false, + fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable { + // Update the tensor description to the new version. + self.tensors.insert(tensor.id.clone(), tensor.clone()); + + // Output already registered as a local variable. + if let Some(index) = self.locals.get(&tensor.id) { + return Variable::Local(*index); + } + + // New local variable. + let local_index = self.locals.len() as u16; + self.locals.insert(tensor.id.clone(), local_index); + Variable::Local(local_index) } - } - - fn register_numeric( - &mut self, - ops: &NumericOpsDescription, - ) -> bool { - match ops { - NumericOpsDescription::Add(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) - } - NumericOpsDescription::AddScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) - } - NumericOpsDescription::Sub(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) - } - NumericOpsDescription::SubScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) - } - NumericOpsDescription::Mul(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) - } - NumericOpsDescription::MulScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) - } - NumericOpsDescription::Div(desc, _) => { - self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) - } - NumericOpsDescription::DivScalar(desc, _) => { - self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) - } - NumericOpsDescription::Abs(desc, _) => { - self.register_unary_ops(desc, |input, out| Operator::Abs { input, out }) - } - _ => false, + + fn register_float(&mut self, ops: &FloatOpsDescription) -> bool { + match ops { + FloatOpsDescription::Exp(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Exp { input, out }) + } + FloatOpsDescription::Log(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Log { input, out }) + } + FloatOpsDescription::Log1p(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out }) + } + FloatOpsDescription::Cos(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Cos { input, out }) + } + FloatOpsDescription::Sin(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Sin { input, out }) + } + FloatOpsDescription::Powf(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out }) + } + FloatOpsDescription::Tanh(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out }) + } + FloatOpsDescription::Erf(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Erf { input, out }) + } + FloatOpsDescription::Recip(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Recip { input, out }) + } + _ => false, + } } - } - - fn register_binary_ops(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool - where - Func: Fn(Variable, Variable, Variable) -> Operator, - { - if !self.output_is_compatible(&desc.out) { - return false; + + fn register_numeric( + &mut self, + ops: &NumericOpsDescription, + ) -> bool { + match ops { + NumericOpsDescription::Add(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) + } + NumericOpsDescription::AddScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) + } + NumericOpsDescription::Sub(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) + } + NumericOpsDescription::SubScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) + } + NumericOpsDescription::Mul(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) + } + NumericOpsDescription::MulScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) + } + NumericOpsDescription::Div(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) + } + NumericOpsDescription::DivScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) + } + NumericOpsDescription::Abs(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Abs { input, out }) + } + _ => false, + } } - let lhs = self.input_to_var(&desc.lhs); - let rhs = self.input_to_var(&desc.rhs); - let out = self.output_to_var(&desc.out); + fn register_binary_ops(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool + where + Func: Fn(Variable, Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; + } - self.operators.push(func(lhs, rhs, out)); + let lhs = self.input_to_var(&desc.lhs); + let rhs = self.input_to_var(&desc.rhs); + let out = self.output_to_var(&desc.out); - true - } + self.operators.push(func(lhs, rhs, out)); - fn register_unary_ops(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool - where - Func: Fn(Variable, Variable) -> Operator, - { - if !self.output_is_compatible(&desc.out) { - return false; + true } - let input = self.input_to_var(&desc.input); - let out = self.output_to_var(&desc.out); + fn register_unary_ops(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool + where + Func: Fn(Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; + } - self.operators.push(func(input, out)); + let input = self.input_to_var(&desc.input); + let out = self.output_to_var(&desc.out); - true - } + self.operators.push(func(input, out)); - fn register_scalar_ops( - &mut self, - desc: &ScalarOpsDescription, - func: Func, - ) -> bool - where - Func: Fn(Variable, Variable, Variable) -> Operator, - { - if !self.output_is_compatible(&desc.out) { - return false; + true } - let lhs = self.input_to_var(&desc.lhs); - let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32); - self.scalars_f32.push(desc.rhs.elem()); - let out = self.output_to_var(&desc.out); + fn register_scalar_ops( + &mut self, + desc: &ScalarOpsDescription, + func: Func, + ) -> bool + where + Func: Fn(Variable, Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; + } - self.operators.push(func(lhs, rhs, out)); + let lhs = self.input_to_var(&desc.lhs); + let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32); + self.scalars_f32.push(desc.rhs.elem()); + let out = self.output_to_var(&desc.out); - true - } + self.operators.push(func(lhs, rhs, out)); - fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { - if self.current_output_shape.is_empty() { - self.current_output_shape = out.shape.clone(); - } else if self.current_output_shape != out.shape { - return false; + true } - true - } + fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { + if self.current_output_shape.is_empty() { + self.current_output_shape = out.shape.clone(); + } else if self.current_output_shape != out.shape { + return false; + } + + true + } } #[cfg(test)] mod tests { - use super::*; - use burn_fusion::graph::{BinaryOpsDescription, Ops}; - use burn_fusion::Fusion; - use burn_tensor::Tensor; + use super::*; + use burn_fusion::graph::{BinaryOpsDescription, Ops}; + use burn_fusion::Fusion; + use burn_tensor::Tensor; - struct FakeAddOps; + struct FakeAddOps; - impl Ops for FakeAddOps { - type Args = BinaryOpsDescription; + impl Ops for FakeAddOps { + type Args = BinaryOpsDescription; + + fn execute(&self, _: &Self::Args, _: &mut HandleContainer) { + todo!() + } + } - fn execute(&self, _: &Self::Args, _: &mut HandleContainer) { - todo!() + #[test] + fn test_fusion_same_behavior() { + type Backend = Wgpu; + type FusedBackend = Fusion; + + let data_1 = + Tensor::::random([1, 32], burn_tensor::Distribution::Default).into_data(); + let data_2 = + Tensor::::random([32, 32], burn_tensor::Distribution::Default).into_data(); + + let tensor_1 = Tensor::::from_data(data_1.clone()); + let tensor_2 = Tensor::::from_data(data_2.clone()); + let tensor_3 = tensor_1.clone() + tensor_2; + let tensor_4 = tensor_3.clone() - tensor_1; + let tensor_5 = tensor_4 + 5.0; + let tensor_6 = tensor_5 + tensor_3; + let result_ref = tensor_6.recip().into_data(); + + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + let tensor_3 = tensor_1.clone() + tensor_2; + let tensor_4 = tensor_3.clone() - tensor_1; + let tensor_5 = tensor_4 + 5.0; + let tensor_6 = tensor_5 + tensor_3; + let result_fused = tensor_6.recip().into_data(); + + result_fused.assert_approx_eq(&result_ref, 3); } - } - - #[test] - fn test_fusion_same_behavior() { - type Backend = Wgpu; - type FusedBackend = Fusion; - - let data_1 = - Tensor::::random([1, 32], burn_tensor::Distribution::Default).into_data(); - let data_2 = - Tensor::::random([32, 32], burn_tensor::Distribution::Default).into_data(); - - let tensor_1 = Tensor::::from_data(data_1.clone()); - let tensor_2 = Tensor::::from_data(data_2.clone()); - let tensor_3 = tensor_1.clone() + tensor_2; - let tensor_4 = tensor_3.clone() - tensor_1; - let tensor_5 = tensor_4 + 5.0; - let tensor_6 = tensor_5 + tensor_3; - let result_ref = tensor_6.recip().into_data(); - - let tensor_1 = Tensor::::from_data(data_1); - let tensor_2 = Tensor::::from_data(data_2); - let tensor_3 = tensor_1.clone() + tensor_2; - let tensor_4 = tensor_3.clone() - tensor_1; - let tensor_5 = tensor_4 + 5.0; - let tensor_6 = tensor_5 + tensor_3; - let result_fused = tensor_6.recip().into_data(); - - result_fused.assert_approx_eq(&result_ref, 3); - } } diff --git a/burn-wgpu/src/fusion/kernel.rs b/burn-wgpu/src/fusion/kernel.rs index d1abc04111..31a0353cb7 100644 --- a/burn-wgpu/src/fusion/kernel.rs +++ b/burn-wgpu/src/fusion/kernel.rs @@ -3,10 +3,10 @@ use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient}; use crate::fusion::codegen::Function; use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank}; use crate::fusion::{ - codegen::{ - Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize, - }, - WgpuFusionHandle, + codegen::{ + Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize, + }, + WgpuFusionHandle, }; use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT}; use crate::{FloatElement, GraphicsApi, IntElement, Wgpu}; @@ -42,279 +42,279 @@ pub struct ExecutionPhase; /// handles provided. pub struct FusionKernel where - G: GraphicsApi, - F: FloatElement, - I: IntElement, + G: GraphicsApi, + F: FloatElement, + I: IntElement, { - operations: Vec, - input_bindings: Vec<(Binding, TensorDescription)>, - output_bindings: Vec<(Binding, TensorDescription)>, - named_bindings: Vec<(String, Binding, DataBuffer)>, - functions: Vec, - num_elems_output: usize, - device: Device>, - client: WgpuComputeClient, - _phase: PhantomData, + operations: Vec, + input_bindings: Vec<(Binding, TensorDescription)>, + output_bindings: Vec<(Binding, TensorDescription)>, + named_bindings: Vec<(String, Binding, DataBuffer)>, + functions: Vec, + num_elems_output: usize, + device: Device>, + client: WgpuComputeClient, + _phase: PhantomData, } enum DataBuffer { - F32(Vec), - U32(Vec), + F32(Vec), + U32(Vec), } impl FusionKernel { - /// Create a new fusion kernel on the given device. - pub fn new(device: &Device>) -> Self { - let client = compute_client::(device); + /// Create a new fusion kernel on the given device. + pub fn new(device: &Device>) -> Self { + let client = compute_client::(device); - Self { - operations: Vec::new(), - input_bindings: Vec::new(), - output_bindings: Vec::new(), - named_bindings: Vec::new(), - functions: Vec::new(), - num_elems_output: 0, - device: device.clone(), - client, - _phase: PhantomData, + Self { + operations: Vec::new(), + input_bindings: Vec::new(), + output_bindings: Vec::new(), + named_bindings: Vec::new(), + functions: Vec::new(), + num_elems_output: 0, + device: device.clone(), + client, + _phase: PhantomData, + } } - } - /// Register the inputs used by the kernel. - pub fn inputs( - mut self, - inputs_tensor: &[&TensorDescription], - inputs_scalar_f32: &[f32], - ) -> FusionKernel { - for (i, input) in inputs_tensor.iter().enumerate() { - self.input_bindings.push(( - Binding { - elem: Elem::F32, - visibility: Visibility::Read, - location: Location::Storage, - size: None, - }, - (*input).clone(), - )); + /// Register the inputs used by the kernel. + pub fn inputs( + mut self, + inputs_tensor: &[&TensorDescription], + inputs_scalar_f32: &[f32], + ) -> FusionKernel { + for (i, input) in inputs_tensor.iter().enumerate() { + self.input_bindings.push(( + Binding { + elem: Elem::F32, + visibility: Visibility::Read, + location: Location::Storage, + size: None, + }, + (*input).clone(), + )); - self.operations.push(Operator::ReadGlobal { - variable: Variable::Input(i as u16), - position: i, - position_out: inputs_tensor.len(), // First output - }); - } + self.operations.push(Operator::ReadGlobal { + variable: Variable::Input(i as u16), + position: i, + position_out: inputs_tensor.len(), // First output + }); + } - if !inputs_scalar_f32.is_empty() { - self.named_bindings.push(( - "scalars_f32".to_string(), - Binding { - elem: Elem::F32, - visibility: Visibility::Read, - location: Location::Storage, - size: Some(inputs_scalar_f32.len()), - }, - DataBuffer::F32(inputs_scalar_f32.to_vec()), - )); - } + if !inputs_scalar_f32.is_empty() { + self.named_bindings.push(( + "scalars_f32".to_string(), + Binding { + elem: Elem::F32, + visibility: Visibility::Read, + location: Location::Storage, + size: Some(inputs_scalar_f32.len()), + }, + DataBuffer::F32(inputs_scalar_f32.to_vec()), + )); + } - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, + } } - } } impl FusionKernel { - /// Register the [operators](Operator) that the kernel must execute in the order provided. - pub fn body(mut self, operators: &[Operator]) -> FusionKernel { - let mut register_function = |function: Function| { - if !self.functions.contains(&function) { - self.functions.push(function); - } - }; + /// Register the [operators](Operator) that the kernel must execute in the order provided. + pub fn body(mut self, operators: &[Operator]) -> FusionKernel { + let mut register_function = |function: Function| { + if !self.functions.contains(&function) { + self.functions.push(function); + } + }; - // Since not all operators are native to WGSL, we need to add the custom ones. - for ops in operators.iter() { - match ops { - Operator::Powf { - lhs: _, - rhs: _, - out: _, - } => { - register_function(Function::Powf(Elem::F32)); - } - Operator::Erf { input: _, out: _ } => { - register_function(Function::Erf(Elem::F32)); + // Since not all operators are native to WGSL, we need to add the custom ones. + for ops in operators.iter() { + match ops { + Operator::Powf { + lhs: _, + rhs: _, + out: _, + } => { + register_function(Function::Powf(Elem::F32)); + } + Operator::Erf { input: _, out: _ } => { + register_function(Function::Erf(Elem::F32)); + } + _ => {} + } + self.operations.push(ops.clone()); } - _ => {} - } - self.operations.push(ops.clone()); - } - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, + } } - } } impl FusionKernel { - /// Register the outputs with their local variable index. - /// - /// Note that the index corresponds to the registered [operator](Operator) number at the - /// [body phase](BodyPhase). - /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). - pub fn outputs( - mut self, - outputs: &[&TensorDescription], - locals: &[u16], - ) -> FusionKernel { - let mut num_elems_launch_option = 0; + /// Register the outputs with their local variable index. + /// + /// Note that the index corresponds to the registered [operator](Operator) number at the + /// [body phase](BodyPhase). + /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). + pub fn outputs( + mut self, + outputs: &[&TensorDescription], + locals: &[u16], + ) -> FusionKernel { + let mut num_elems_launch_option = 0; - for (i, (output, local)) in outputs.iter().zip(locals).enumerate() { - let num_elems_output = calculate_num_elems_dyn_rank(&output.shape); - if num_elems_output > num_elems_launch_option { - num_elems_launch_option = num_elems_output; - } + for (i, (output, local)) in outputs.iter().zip(locals).enumerate() { + let num_elems_output = calculate_num_elems_dyn_rank(&output.shape); + if num_elems_output > num_elems_launch_option { + num_elems_launch_option = num_elems_output; + } - self.output_bindings.push(( - Binding { - elem: Elem::F32, - visibility: Visibility::ReadWrite, - location: Location::Storage, - size: None, - }, - (*output).clone(), - )); + self.output_bindings.push(( + Binding { + elem: Elem::F32, + visibility: Visibility::ReadWrite, + location: Location::Storage, + size: None, + }, + (*output).clone(), + )); - self.operations.push(Operator::AssignGlobal { - input: Variable::Local(*local), - out: Variable::Output(i as u16), - }); - } + self.operations.push(Operator::AssignGlobal { + input: Variable::Local(*local), + out: Variable::Output(i as u16), + }); + } - self.num_elems_output = num_elems_launch_option; + self.num_elems_output = num_elems_launch_option; - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, + } } - } } impl FusionKernel { - /// Execute the kernel on the provided [handles](HandleContainer). - pub fn execute(mut self, handle_container: &mut HandleContainer>) { - let mut inputs = Vec::with_capacity(self.input_bindings.len()); - let mut outputs = Vec::with_capacity(self.output_bindings.len()); - let mut named = Vec::with_capacity(2); - let mut info = Vec::new(); - let mut handles = Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity()); + /// Execute the kernel on the provided [handles](HandleContainer). + pub fn execute(mut self, handle_container: &mut HandleContainer>) { + let mut inputs = Vec::with_capacity(self.input_bindings.len()); + let mut outputs = Vec::with_capacity(self.output_bindings.len()); + let mut named = Vec::with_capacity(2); + let mut info = Vec::new(); + let mut handles = + Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity()); - // Inner function to fill the info buffer. - let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| { - if info.is_empty() { - info.push(handle.strides.len() as u32); - } + // Inner function to fill the info buffer. + let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| { + if info.is_empty() { + info.push(handle.strides.len() as u32); + } - for s in handle.strides.iter() { - info.push(*s as u32); - } - for s in tensor.shape.iter() { - info.push(*s as u32); - } - }; + for s in handle.strides.iter() { + info.push(*s as u32); + } + for s in tensor.shape.iter() { + info.push(*s as u32); + } + }; - // We start by registering the inputs. - for (binding, tensor) in self.input_bindings.into_iter() { - let handle = handle_container.get_handle(&tensor); - register_info_tensor(&tensor, &handle); + // We start by registering the inputs. + for (binding, tensor) in self.input_bindings.into_iter() { + let handle = handle_container.get_handle(&tensor); + register_info_tensor(&tensor, &handle); - inputs.push(binding); - handles.push(handle.handle); - } + inputs.push(binding); + handles.push(handle.handle); + } - // Then we follow with the outputs. - for (binding, tensor) in self.output_bindings { - let num_elems = calculate_num_elems_dyn_rank(&tensor.shape); - let handle_fusion = WgpuFusionHandle { - client: self.client.clone(), - device: self.device.clone(), - strides: strides_dyn_rank(&tensor.shape), - handle: self.client.empty(core::mem::size_of::() * num_elems), - }; - register_info_tensor(&tensor, &handle_fusion); + // Then we follow with the outputs. + for (binding, tensor) in self.output_bindings { + let num_elems = calculate_num_elems_dyn_rank(&tensor.shape); + let handle_fusion = WgpuFusionHandle { + client: self.client.clone(), + device: self.device.clone(), + strides: strides_dyn_rank(&tensor.shape), + handle: self.client.empty(core::mem::size_of::() * num_elems), + }; + register_info_tensor(&tensor, &handle_fusion); - handles.push(handle_fusion.handle.clone()); - handle_container.register_handle(tensor.id, handle_fusion); - outputs.push(binding); - } + handles.push(handle_fusion.handle.clone()); + handle_container.register_handle(tensor.id, handle_fusion); + outputs.push(binding); + } - // Now we can create the info handle. - Self::build_info_handle(&mut self.named_bindings, info); + // Now we can create the info handle. + Self::build_info_handle(&mut self.named_bindings, info); - // Finally we finish with the named bindings. - for (name, binding, data) in self.named_bindings { - let handle = self.client.create(match &data { - DataBuffer::F32(values) => bytemuck::cast_slice(values), - DataBuffer::U32(values) => bytemuck::cast_slice(values), - }); - named.push((name, binding)); - handles.push(handle); - } + // Finally we finish with the named bindings. + for (name, binding, data) in self.named_bindings { + let handle = self.client.create(match &data { + DataBuffer::F32(values) => bytemuck::cast_slice(values), + DataBuffer::U32(values) => bytemuck::cast_slice(values), + }); + named.push((name, binding)); + handles.push(handle); + } - // We create the shader codegen type and launch the kernel. - let kernel = ComputeShader { - inputs, - outputs, - named, - workgroup_size: WorkgroupSize::default(), - body: Body::new(self.operations), - num_workgroups: true, - global_invocation_id: true, - functions: self.functions, - }; + // We create the shader codegen type and launch the kernel. + let kernel = ComputeShader { + inputs, + outputs, + named, + workgroup_size: WorkgroupSize::default(), + body: Body::new(self.operations), + num_workgroups: true, + global_invocation_id: true, + functions: self.functions, + }; - let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT); - let kernel = Box::new(DynamicKernel::new(kernel, workgroup)); + let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT); + let kernel = Box::new(DynamicKernel::new(kernel, workgroup)); - self - .client - .execute(kernel, &handles.iter().collect::>()); - } + self.client + .execute(kernel, &handles.iter().collect::>()); + } - fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec) { - named_bindings.push(( - "info".to_string(), - Binding { - elem: Elem::U32, - visibility: Visibility::Read, - location: Location::Storage, - size: None, // We avoid putting the length here since it will force a new kernel - // for each tensor rank. - }, - DataBuffer::U32(info), - )); - } + fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec) { + named_bindings.push(( + "info".to_string(), + Binding { + elem: Elem::U32, + visibility: Visibility::Read, + location: Location::Storage, + size: None, // We avoid putting the length here since it will force a new kernel + // for each tensor rank. + }, + DataBuffer::U32(info), + )); + } } diff --git a/burn-wgpu/src/graphics.rs b/burn-wgpu/src/graphics.rs index 8d709c6e4b..c3b3c01f19 100644 --- a/burn-wgpu/src/graphics.rs +++ b/burn-wgpu/src/graphics.rs @@ -8,8 +8,8 @@ /// - [DirectX 12](Dx12) /// - [WebGpu](WebGpu) pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static { - /// The wgpu backend. - fn backend() -> wgpu::Backend; + /// The wgpu backend. + fn backend() -> wgpu::Backend; } /// Vulkan graphics API. @@ -41,46 +41,46 @@ pub struct WebGpu; pub struct AutoGraphicsApi; impl GraphicsApi for Vulkan { - fn backend() -> wgpu::Backend { - wgpu::Backend::Vulkan - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Vulkan + } } impl GraphicsApi for Metal { - fn backend() -> wgpu::Backend { - wgpu::Backend::Metal - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Metal + } } impl GraphicsApi for OpenGl { - fn backend() -> wgpu::Backend { - wgpu::Backend::Gl - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Gl + } } impl GraphicsApi for Dx11 { - fn backend() -> wgpu::Backend { - wgpu::Backend::Dx11 - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Dx11 + } } impl GraphicsApi for Dx12 { - fn backend() -> wgpu::Backend { - wgpu::Backend::Dx12 - } + fn backend() -> wgpu::Backend { + wgpu::Backend::Dx12 + } } impl GraphicsApi for WebGpu { - fn backend() -> wgpu::Backend { - wgpu::Backend::BrowserWebGpu - } + fn backend() -> wgpu::Backend { + wgpu::Backend::BrowserWebGpu + } } impl GraphicsApi for AutoGraphicsApi { - fn backend() -> wgpu::Backend { - #[cfg(target_os = "macos")] - return wgpu::Backend::Metal; - #[cfg(not(target_os = "macos"))] - wgpu::Backend::Vulkan - } + fn backend() -> wgpu::Backend { + #[cfg(target_os = "macos")] + return wgpu::Backend::Metal; + #[cfg(not(target_os = "macos"))] + wgpu::Backend::Vulkan + } } diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs index aca2333233..133015f49d 100644 --- a/burn-wgpu/src/kernel/base.rs +++ b/burn-wgpu/src/kernel/base.rs @@ -1,9 +1,9 @@ use super::SourceTemplate; use crate::{ - compute::{StaticKernel, WgpuComputeClient, WgpuHandle, WorkGroup}, - element::WgpuElement, - kernel, - tensor::WgpuTensor, + compute::{StaticKernel, WgpuComputeClient, WgpuHandle, WorkGroup}, + element::WgpuElement, + kernel, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -14,169 +14,169 @@ pub(crate) const WORKGROUP_DEFAULT: usize = 32; /// Static wgpu kernel to create a [source template](SourceTemplate). pub trait StaticKernelSource: Send + 'static + Sync { - /// Source template for the kernel. - fn source() -> SourceTemplate; + /// Source template for the kernel. + fn source() -> SourceTemplate; } /// Dynamic wgpu kernel to create a [source template](SourceTemplate). pub trait DynamicKernelSource: Send + Sync { - /// Source template for the kernel. - fn source(&self) -> SourceTemplate; - /// Identifier for the kernel, used for caching kernel compilation. - fn id(&self) -> String; + /// Source template for the kernel. + fn source(&self) -> SourceTemplate; + /// Identifier for the kernel, used for caching kernel compilation. + fn id(&self) -> String; } /// Generates kernel source code by replacing some information using templating. #[macro_export] macro_rules! kernel_wgsl { - ( + ( $struct:ident, $file:expr ) => { - /// Generated kernel from wgsl file. - #[derive(new)] - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::SourceTemplate::new(include_str!($file)) - } - } - }; + /// Generated kernel from wgsl file. + #[derive(new)] + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::SourceTemplate::new(include_str!($file)) + } + } + }; } kernel_wgsl!(ContiguousRaw, "../template/contiguous.wgsl"); /// Make a wgpu tensor contiguous. pub fn into_contiguous( - tensor: WgpuTensor, + tensor: WgpuTensor, ) -> WgpuTensor { - if tensor.is_contiguous() { - return tensor; - } - - let num_elems = tensor.shape.num_elements(); - let handle = tensor.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - handle, - ); - let info = build_info(&[&tensor, &output]); - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); - - tensor - .client - .execute(kernel, &[&tensor.handle, &output.handle, &info_handle]); - - output + if tensor.is_contiguous() { + return tensor; + } + + let num_elems = tensor.shape.num_elements(); + let handle = tensor.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + handle, + ); + let info = build_info(&[&tensor, &output]); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = Box::new(StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); + + tensor + .client + .execute(kernel, &[&tensor.handle, &output.handle, &info_handle]); + + output } /// Similar to [into contiguous](into_contiguous) but with dynamic rank. pub fn into_contiguous_dyn( - client: WgpuComputeClient, - input: WgpuHandle, - input_shape: &[usize], - input_strides: &[usize], - output_shape: &[usize], - output_strides: &[usize], - num_elems: usize, + client: WgpuComputeClient, + input: WgpuHandle, + input_shape: &[usize], + input_strides: &[usize], + output_shape: &[usize], + output_strides: &[usize], + num_elems: usize, ) -> WgpuHandle { - let handle = client.empty(num_elems * core::mem::size_of::()); - let info = kernel::build_info_dyn::( - &[input_shape, output_shape], - &[input_strides, output_strides], - ); + let handle = client.empty(num_elems * core::mem::size_of::()); + let info = kernel::build_info_dyn::( + &[input_shape, output_shape], + &[input_strides, output_strides], + ); - let info_handle = client.create(bytemuck::cast_slice(&info)); + let info_handle = client.create(bytemuck::cast_slice(&info)); - let kernel = Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); + let kernel = Box::new(StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); - client.execute(kernel, &[&input, &handle, &info_handle]); + client.execute(kernel, &[&input, &handle, &info_handle]); - handle + handle } /// Generates kernel source code by replacing some information using templating. pub struct KernelSettings< - K: StaticKernelSource, - E: WgpuElement, - I: WgpuElement, - const WORKGROUP_X_SIZE: usize, - const WORKGROUP_Y_SIZE: usize, - const WORKGROUP_Z_SIZE: usize, -> { - _k: PhantomData, - _e: PhantomData, - _i: PhantomData, -} - -impl< K: StaticKernelSource, E: WgpuElement, I: WgpuElement, const WORKGROUP_X_SIZE: usize, const WORKGROUP_Y_SIZE: usize, const WORKGROUP_Z_SIZE: usize, - > StaticKernelSource - for KernelSettings +> { + _k: PhantomData, + _e: PhantomData, + _i: PhantomData, +} + +impl< + K: StaticKernelSource, + E: WgpuElement, + I: WgpuElement, + const WORKGROUP_X_SIZE: usize, + const WORKGROUP_Y_SIZE: usize, + const WORKGROUP_Z_SIZE: usize, + > StaticKernelSource + for KernelSettings { - fn source() -> SourceTemplate { - K::source() - .register("workgroup_size_x", WORKGROUP_X_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string()) - .register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string()) - .register( - "workgroup_size", - (WORKGROUP_X_SIZE * WORKGROUP_Y_SIZE * WORKGROUP_Z_SIZE).to_string(), - ) - .register("elem", E::type_name()) - .register("int", I::type_name()) - } + fn source() -> SourceTemplate { + K::source() + .register("workgroup_size_x", WORKGROUP_X_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string()) + .register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string()) + .register( + "workgroup_size", + (WORKGROUP_X_SIZE * WORKGROUP_Y_SIZE * WORKGROUP_Z_SIZE).to_string(), + ) + .register("elem", E::type_name()) + .register("int", I::type_name()) + } } /// Generate kernel source code by replacing some information using templating. #[derive(new)] pub struct DynamicKernelSettings { - workgroup_x_size: usize, - workgroup_y_size: usize, - workgroup_z_size: usize, - _k: PhantomData, - _e: PhantomData, - _i: PhantomData, + workgroup_x_size: usize, + workgroup_y_size: usize, + workgroup_z_size: usize, + _k: PhantomData, + _e: PhantomData, + _i: PhantomData, } impl DynamicKernelSource - for DynamicKernelSettings + for DynamicKernelSettings { - fn source(&self) -> SourceTemplate { - K::source() - .register("workgroup_size_x", self.workgroup_x_size.to_string()) - .register("workgroup_size_y", self.workgroup_y_size.to_string()) - .register("workgroup_size_z", self.workgroup_z_size.to_string()) - .register( - "workgroup_size", - (self.workgroup_x_size * self.workgroup_y_size * self.workgroup_z_size).to_string(), - ) - .register("elem", E::type_name()) - .register("int", I::type_name()) - } - - fn id(&self) -> String { - let id = core::any::TypeId::of::(); - - format!( - "{:?}-dyn-settings{}-{}-{}", - id, self.workgroup_x_size, self.workgroup_y_size, self.workgroup_z_size - ) - } + fn source(&self) -> SourceTemplate { + K::source() + .register("workgroup_size_x", self.workgroup_x_size.to_string()) + .register("workgroup_size_y", self.workgroup_y_size.to_string()) + .register("workgroup_size_z", self.workgroup_z_size.to_string()) + .register( + "workgroup_size", + (self.workgroup_x_size * self.workgroup_y_size * self.workgroup_z_size).to_string(), + ) + .register("elem", E::type_name()) + .register("int", I::type_name()) + } + + fn id(&self) -> String { + let id = core::any::TypeId::of::(); + + format!( + "{:?}-dyn-settings{}-{}-{}", + id, self.workgroup_x_size, self.workgroup_y_size, self.workgroup_z_size + ) + } } /// Create a vector containing the dimension, strides and shape of tensors. @@ -193,84 +193,84 @@ impl DynamicKernelSource /// | (2 * D + 1)..(3 * D + 1) | lhs shape | /// | (3 * D + 1)..(4 * D + 1) | rhs shape | pub fn build_info(tensors: &[&WgpuTensor]) -> Vec { - let mut info: Vec = vec![0; tensors.len() * 2 * D + 1]; - info[0] = D as u32; - - let mut current = 1; - for tensor in tensors.iter() { - for d in 0..D { - info[current] = tensor.strides[d] as u32; - current += 1; + let mut info: Vec = vec![0; tensors.len() * 2 * D + 1]; + info[0] = D as u32; + + let mut current = 1; + for tensor in tensors.iter() { + for d in 0..D { + info[current] = tensor.strides[d] as u32; + current += 1; + } } - } - for tensor in tensors.iter() { - for d in 0..D { - info[current] = tensor.shape.dims[d] as u32; - current += 1; + for tensor in tensors.iter() { + for d in 0..D { + info[current] = tensor.shape.dims[d] as u32; + current += 1; + } } - } - info + info } /// Similar to [build info](build_info) but with dynamic rank. pub fn build_info_dyn(shapes: &[&[usize]], strides: &[&[usize]]) -> Vec { - let rank = shapes.get(0).unwrap().len(); - let mut info: Vec = vec![0; shapes.len() * 2 * rank + 1]; - info[0] = rank as u32; - - let mut current = 1; - for stride in strides.iter() { - for d in 0..rank { - info[current] = stride[d] as u32; - current += 1; + let rank = shapes.get(0).unwrap().len(); + let mut info: Vec = vec![0; shapes.len() * 2 * rank + 1]; + info[0] = rank as u32; + + let mut current = 1; + for stride in strides.iter() { + for d in 0..rank { + info[current] = stride[d] as u32; + current += 1; + } } - } - for shape in shapes.iter() { - for d in 0..rank { - info[current] = shape[d] as u32; - current += 1; + for shape in shapes.iter() { + for d in 0..rank { + info[current] = shape[d] as u32; + current += 1; + } } - } - info + info } pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> WorkGroup { - let num_elem_per_invocation = workgroup_size * workgroup_size; - let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32); - let workgroup_x = f32::ceil(f32::sqrt(workgroups)); - let workgroup_y = f32::ceil(num_elems as f32 / (workgroup_x * num_elem_per_invocation as f32)); + let num_elem_per_invocation = workgroup_size * workgroup_size; + let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32); + let workgroup_x = f32::ceil(f32::sqrt(workgroups)); + let workgroup_y = f32::ceil(num_elems as f32 / (workgroup_x * num_elem_per_invocation as f32)); - WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) + WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) } pub(crate) fn prng_workgroup( - num_elems: usize, - workgroup_size: usize, - n_values_per_thread: usize, + num_elems: usize, + workgroup_size: usize, + n_values_per_thread: usize, ) -> WorkGroup { - let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32); - let num_elem_per_invocation = workgroup_size * workgroup_size; - let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32); - let workgroup_x = f32::ceil(f32::sqrt(num_invocations)); - let workgroup_y = f32::ceil(num_invocations / workgroup_x); + let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32); + let num_elem_per_invocation = workgroup_size * workgroup_size; + let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32); + let workgroup_x = f32::ceil(f32::sqrt(num_invocations)); + let workgroup_y = f32::ceil(num_invocations / workgroup_x); - WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) + WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1) } #[cfg(test)] mod tests { - use super::*; - use core::any::TypeId; + use super::*; + use core::any::TypeId; - #[test] - fn test_kernel_type_id() { - kernel_wgsl!(Add, "../template/binary_elemwise.wgsl"); + #[test] + fn test_kernel_type_id() { + kernel_wgsl!(Add, "../template/binary_elemwise.wgsl"); - let type_id_1 = TypeId::of::>(); - let type_id_2 = TypeId::of::>(); - let type_id_3 = TypeId::of::>(); + let type_id_1 = TypeId::of::>(); + let type_id_2 = TypeId::of::>(); + let type_id_3 = TypeId::of::>(); - assert_ne!(type_id_1, type_id_2); - assert_eq!(type_id_1, type_id_3); - } + assert_ne!(type_id_1, type_id_2); + assert_eq!(type_id_1, type_id_3); + } } diff --git a/burn-wgpu/src/kernel/binary_elemwise.rs b/burn-wgpu/src/kernel/binary_elemwise.rs index 05b2324dea..0fc4a0e93d 100644 --- a/burn-wgpu/src/kernel/binary_elemwise.rs +++ b/burn-wgpu/src/kernel/binary_elemwise.rs @@ -1,5 +1,5 @@ use super::{ - build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, + build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, }; use crate::compute::StaticKernel; use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; @@ -7,177 +7,175 @@ use burn_tensor::Shape; kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl"); kernel_wgsl!( - BinaryElemwiseInplaceRaw, - "../template/binary_elemwise_inplace.wgsl" + BinaryElemwiseInplaceRaw, + "../template/binary_elemwise_inplace.wgsl" ); /// Creates a binary elementwise kernel. #[macro_export] macro_rules! binary_elemwise { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::BinaryElemwiseRaw::source().register( - "body", - format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops), - ) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::BinaryElemwiseRaw::source().register( + "body", + format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops), + ) + } + } + }; } /// Creates a binary elementwise inplace kernel. #[macro_export] macro_rules! binary_elemwise_inplace { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::BinaryElemwiseInplaceRaw::source().register( - "body", - format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops), - ) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::BinaryElemwiseInplaceRaw::source().register( + "body", + format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops), + ) + } + } + }; } /// Execute a binary kernel using the default settings. pub fn binary_elemwise_default( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise::(lhs, rhs) + binary_elemwise::(lhs, rhs) } /// Execute a binary kernel using the provided WORKGROUP. pub fn binary_elemwise< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); - - let mut shape_out = [0; D]; - lhs - .shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - - let shape_out = Shape::new(shape_out); - let num_elems = shape_out.num_elements(); - - let handle = lhs.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle); - - let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], - ); - - output + lhs.assert_is_on_same_device(&rhs); + + let mut shape_out = [0; D]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::new(shape_out); + let num_elems = shape_out.num_elements(); + + let handle = lhs.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle); + + let info = build_info(&[&lhs, &rhs, &output]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + ); + + output } /// Execute a binary inplace kernel using the default settings. pub fn binary_elemwise_inplace_default( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise_inplace::(lhs, rhs) + binary_elemwise_inplace::(lhs, rhs) } /// Execute a binary inplace kernel using the provided WORKGROUP. pub fn binary_elemwise_inplace< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&rhs); - let info = build_info(&[&lhs, &rhs]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::>::new( - elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP), - ); + let info = build_info(&[&lhs, &rhs]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::>::new( + elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP), + ); - lhs - .client - .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); - lhs + lhs } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; - - binary_elemwise!(TestKernel, "*"); - binary_elemwise_inplace!(TestKernelInplace, "*"); - - #[test] - fn binary_should_work_with_multiple_invocations() { - let lhs = Tensor::::random([6, 256], Distribution::Default); - let rhs = Tensor::::random([6, 256], Distribution::Default); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - let rhs_ref = Tensor::::from_data(rhs.to_data()); - - let actual = - binary_elemwise::(lhs.into_primitive(), rhs.into_primitive()); - let expected = lhs_ref * rhs_ref; - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } - - #[test] - fn binary_inplace_should_work_with_multiple_invocations() { - let lhs = Tensor::::random([6, 256], Distribution::Default); - let rhs = Tensor::::random([6, 256], Distribution::Default); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - let rhs_ref = Tensor::::from_data(rhs.to_data()); - - let actual = binary_elemwise_inplace::( - lhs.into_primitive(), - rhs.into_primitive(), - ); - let expected = lhs_ref * rhs_ref; + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; + + binary_elemwise!(TestKernel, "*"); + binary_elemwise_inplace!(TestKernelInplace, "*"); + + #[test] + fn binary_should_work_with_multiple_invocations() { + let lhs = Tensor::::random([6, 256], Distribution::Default); + let rhs = Tensor::::random([6, 256], Distribution::Default); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + let rhs_ref = Tensor::::from_data(rhs.to_data()); + + let actual = + binary_elemwise::(lhs.into_primitive(), rhs.into_primitive()); + let expected = lhs_ref * rhs_ref; + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + #[test] + fn binary_inplace_should_work_with_multiple_invocations() { + let lhs = Tensor::::random([6, 256], Distribution::Default); + let rhs = Tensor::::random([6, 256], Distribution::Default); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + let rhs_ref = Tensor::::from_data(rhs.to_data()); + + let actual = binary_elemwise_inplace::( + lhs.into_primitive(), + rhs.into_primitive(), + ); + let expected = lhs_ref * rhs_ref; + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/kernel/cast.rs b/burn-wgpu/src/kernel/cast.rs index 8daf33b83b..d840c5268e 100644 --- a/burn-wgpu/src/kernel/cast.rs +++ b/burn-wgpu/src/kernel/cast.rs @@ -1,77 +1,84 @@ use super::{KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}; use crate::{ - compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, + tensor::WgpuTensor, }; use std::{any::TypeId, marker::PhantomData}; kernel_wgsl!(CastRaw, "../template/cast.wgsl"); struct Cast { - _i: PhantomData, - _o: PhantomData, + _i: PhantomData, + _o: PhantomData, } impl StaticKernelSource - for Cast + for Cast { - fn source() -> SourceTemplate { - CastRaw::source() - .register("input_elem", InputElem::type_name()) - .register("output_elem", OutputElem::type_name()) - } + fn source() -> SourceTemplate { + CastRaw::source() + .register("input_elem", InputElem::type_name()) + .register("output_elem", OutputElem::type_name()) + } } /// Cast a tensor to the given element type. pub fn cast( - tensor: WgpuTensor, + tensor: WgpuTensor, ) -> WgpuTensor { - if TypeId::of::() == TypeId::of::() { - return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); - } + if TypeId::of::() == TypeId::of::() { + return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); + } - let num_elems = tensor.shape.num_elements(); - let kernel = StaticKernel::< - KernelSettings, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let num_elems = tensor.shape.num_elements(); + let kernel = StaticKernel::< + KernelSettings< + Cast, + f32, + i32, + WORKGROUP_DEFAULT, + WORKGROUP_DEFAULT, + 1, + >, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let handle = tensor - .client - .empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - tensor.client.clone(), - tensor.device, - tensor.shape.clone(), - handle, - ); + let handle = tensor + .client + .empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + tensor.client.clone(), + tensor.device, + tensor.shape.clone(), + handle, + ); - tensor - .client - .execute(Box::new(kernel), &[&tensor.handle, &output.handle]); + tensor + .client + .execute(Box::new(kernel), &[&tensor.handle, &output.handle]); - output + output } #[cfg(test)] mod tests { - use super::*; - use crate::tests::TestBackend; - use burn_tensor::{Int, Tensor}; + use super::*; + use crate::tests::TestBackend; + use burn_tensor::{Int, Tensor}; - #[test] - fn should_cast_int_to_float() { - const START: usize = 0; - const END: usize = 100; + #[test] + fn should_cast_int_to_float() { + const START: usize = 0; + const END: usize = 100; - let tensor = Tensor::::arange(START..END); - let tensor_float = cast::(tensor.clone().into_primitive()); + let tensor = Tensor::::arange(START..END); + let tensor_float = cast::(tensor.clone().into_primitive()); - let data_int = tensor.into_data(); - let data_float = Tensor::::from_primitive(tensor_float).into_data(); + let data_int = tensor.into_data(); + let data_float = Tensor::::from_primitive(tensor_float).into_data(); - for i in START..END { - assert_eq!(data_int.value[i], i as i32); - assert_eq!(data_float.value[i], i as f32); + for i in START..END { + assert_eq!(data_int.value[i], i as i32); + assert_eq!(data_float.value[i], i as f32); + } } - } } diff --git a/burn-wgpu/src/kernel/cat.rs b/burn-wgpu/src/kernel/cat.rs index e4c4b44cac..4271541153 100644 --- a/burn-wgpu/src/kernel/cat.rs +++ b/burn-wgpu/src/kernel/cat.rs @@ -1,9 +1,9 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings}, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings}, + kernel_wgsl, + tensor::WgpuTensor, }; use super::WORKGROUP_DEFAULT; @@ -11,82 +11,84 @@ use super::WORKGROUP_DEFAULT; kernel_wgsl!(Cat, "../template/cat.wgsl"); pub fn cat( - inputs: Vec>, - dim: usize, + inputs: Vec>, + dim: usize, ) -> WgpuTensor { - let first_input = inputs.get(0).unwrap(); - let client = &first_input.client; - let mut shape_output = first_input.shape.clone(); - shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum(); - - let buffer = first_input - .client - .empty(shape_output.num_elements() * std::mem::size_of::()); - - let output = WgpuTensor::new( - client.clone(), - first_input.device.clone(), - shape_output, - buffer, - ); - - let mut dim_cat_index = 0; - - for input in inputs.iter() { - let mut info = build_info(&[input, &output]); - info.push(dim as u32); - info.push(dim_cat_index as u32); - dim_cat_index += input.shape.dims[dim]; - let info_buffer = client.create(bytemuck::cast_slice(&info)); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT), - ); - - client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_buffer], + let first_input = inputs.get(0).unwrap(); + let client = &first_input.client; + let mut shape_output = first_input.shape.clone(); + shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum(); + + let buffer = first_input + .client + .empty(shape_output.num_elements() * std::mem::size_of::()); + + let output = WgpuTensor::new( + client.clone(), + first_input.device.clone(), + shape_output, + buffer, ); - } - output + let mut dim_cat_index = 0; + + for input in inputs.iter() { + let mut info = build_info(&[input, &output]); + info.push(dim as u32); + info.push(dim_cat_index as u32); + dim_cat_index += input.shape.dims[dim]; + let info_buffer = client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + input.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + + client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_buffer], + ); + } + + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Tensor}; - - #[test] - fn cat_should_support_multiple_invocations_dim0() { - test_same_as_reference([6, 256], 2, 0); - } - - #[test] - fn cat_should_support_multiple_invocations_dim1() { - test_same_as_reference([6, 256], 2, 1); - } - - #[test] - fn cat_should_support_uneven_launch() { - test_same_as_reference([1, 137], 2, 0); - } - - fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) { - TestBackend::seed(0); - let tensors = (0..num_tensors) - .map(|_| Tensor::::random(shape, Distribution::Default)) - .collect::>(); - let tensors_ref = tensors - .iter() - .map(|tensor| Tensor::::from_data(tensor.to_data())) - .collect::>(); - - let tensor = Tensor::::cat(tensors, dim); - let tensor_ref = Tensor::::cat(tensors_ref, dim); - - tensor - .into_data() - .assert_approx_eq(&tensor_ref.into_data(), 3); - } + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Tensor}; + + #[test] + fn cat_should_support_multiple_invocations_dim0() { + test_same_as_reference([6, 256], 2, 0); + } + + #[test] + fn cat_should_support_multiple_invocations_dim1() { + test_same_as_reference([6, 256], 2, 1); + } + + #[test] + fn cat_should_support_uneven_launch() { + test_same_as_reference([1, 137], 2, 0); + } + + fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) { + TestBackend::seed(0); + let tensors = (0..num_tensors) + .map(|_| Tensor::::random(shape, Distribution::Default)) + .collect::>(); + let tensors_ref = tensors + .iter() + .map(|tensor| Tensor::::from_data(tensor.to_data())) + .collect::>(); + + let tensor = Tensor::::cat(tensors, dim); + let tensor_ref = Tensor::::cat(tensors_ref, dim); + + tensor + .into_data() + .assert_approx_eq(&tensor_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/clamp.rs b/burn-wgpu/src/kernel/clamp.rs index b3195a7de7..dcd774d8e0 100644 --- a/burn-wgpu/src/kernel/clamp.rs +++ b/burn-wgpu/src/kernel/clamp.rs @@ -1,11 +1,11 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, - unary_scalar, unary_scalar_inplace, + compute::StaticKernel, + element::WgpuElement, + kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, + unary_scalar, unary_scalar_inplace, }; use super::{elemwise_workgroup, KernelSettings}; @@ -14,106 +14,105 @@ kernel_wgsl!(Clamp, "../template/clamp/clamp.wgsl"); kernel_wgsl!(ClampInplace, "../template/clamp/clamp_inplace.wgsl"); pub(crate) fn clamp_min( - input: WgpuTensor, - min_value: E, + input: WgpuTensor, + min_value: E, ) -> WgpuTensor { - unary_scalar!(ClampMin, func "max"); - unary_scalar_inplace!(ClampMinInplace, func "max"); + unary_scalar!(ClampMin, func "max"); + unary_scalar_inplace!(ClampMinInplace, func "max"); - if input.can_mut() { - return unary_scalar_inplace_default::(input, min_value); - } + if input.can_mut() { + return unary_scalar_inplace_default::(input, min_value); + } - unary_scalar::(input, min_value) + unary_scalar::(input, min_value) } pub(crate) fn clamp_max( - input: WgpuTensor, - max_value: E, + input: WgpuTensor, + max_value: E, ) -> WgpuTensor { - unary_scalar!(ClampMax, func "min"); - unary_scalar_inplace!(ClampMaxInPlace, func "min"); + unary_scalar!(ClampMax, func "min"); + unary_scalar_inplace!(ClampMaxInPlace, func "min"); - if input.can_mut() { - return unary_scalar_inplace_default::(input, max_value); - } + if input.can_mut() { + return unary_scalar_inplace_default::(input, max_value); + } - unary_scalar::(input, max_value) + unary_scalar::(input, max_value) } pub(crate) fn clamp( - input: WgpuTensor, - min_value: E, - max_value: E, + input: WgpuTensor, + min_value: E, + max_value: E, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let min_handle = input.client.create(E::as_bytes(&[min_value])); - let max_handle = input.client.create(E::as_bytes(&[max_value])); + let num_elems = input.shape.num_elements(); + let min_handle = input.client.create(E::as_bytes(&[min_value])); + let max_handle = input.client.create(E::as_bytes(&[max_value])); - if input.can_mut() { - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + if input.can_mut() { + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - input - .client - .execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]); + input + .client + .execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]); - return input; - } + return input; + } - let output = empty_device(input.client.clone(), input.device.clone(), input.shape); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), - ); + let output = empty_device(input.client.clone(), input.device.clone(), input.shape); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &min_handle, &max_handle], - ); + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &min_handle, &max_handle], + ); - output + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; - #[test] - fn clamp_min_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); + #[test] + fn clamp_min_should_match_reference() { + let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); - let output = input.clamp_min(0.5); + let output = input.clamp_min(0.5); - output - .into_data() - .assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3); + } - #[test] - fn clamp_max_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); + #[test] + fn clamp_max_should_match_reference() { + let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); - let output = input.clamp_max(0.5); + let output = input.clamp_max(0.5); - output - .into_data() - .assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3); + } - #[test] - fn clamp_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); + #[test] + fn clamp_should_match_reference() { + let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); - let output = input.clamp(0.3, 0.7); + let output = input.clamp(0.3, 0.7); - output - .into_data() - .assert_approx_eq(&input_ref.clamp(0.3, 0.7).into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&input_ref.clamp(0.3, 0.7).into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/comparison/base.rs b/burn-wgpu/src/kernel/comparison/base.rs index 8db310f587..570b5deedd 100644 --- a/burn-wgpu/src/kernel/comparison/base.rs +++ b/burn-wgpu/src/kernel/comparison/base.rs @@ -1,8 +1,8 @@ use crate::{ - comparison, comparison_elem, comparison_elem_inplace, comparison_inplace, - element::WgpuElement, - kernel::{comparison, comparison_elem, comparison_elem_inplace, comparison_inplace}, - tensor::WgpuTensor, + comparison, comparison_elem, comparison_elem_inplace, comparison_inplace, + element::WgpuElement, + kernel::{comparison, comparison_elem, comparison_elem_inplace, comparison_inplace}, + tensor::WgpuTensor, }; use std::mem; @@ -31,136 +31,136 @@ comparison_elem_inplace!(LowerElemInplace, "<"); comparison_elem_inplace!(LowerEqualElemInplace, "<="); pub fn equal( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn greater( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn greater_equal( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn lower( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn lower_equal( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); + let can_be_used_as_bool = mem::size_of::() == mem::size_of::(); - if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { - return comparison_inplace::(lhs, rhs); - } - if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { - return comparison_inplace::(rhs, lhs); - } + if can_be_used_as_bool && lhs.can_mut_broadcast(&rhs) { + return comparison_inplace::(lhs, rhs); + } + if can_be_used_as_bool && rhs.can_mut_broadcast(&lhs) { + return comparison_inplace::(rhs, lhs); + } - comparison::(lhs, rhs) + comparison::(lhs, rhs) } pub fn equal_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn greater_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn lower_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn greater_equal_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } pub fn lower_equal_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - if mem::size_of::() == mem::size_of::() && lhs.can_mut() { - return comparison_elem_inplace::(lhs, rhs); - } + if mem::size_of::() == mem::size_of::() && lhs.can_mut() { + return comparison_elem_inplace::(lhs, rhs); + } - comparison_elem::(lhs, rhs) + comparison_elem::(lhs, rhs) } diff --git a/burn-wgpu/src/kernel/comparison/binary.rs b/burn-wgpu/src/kernel/comparison/binary.rs index 9c4b8c7178..257585dd2a 100644 --- a/burn-wgpu/src/kernel/comparison/binary.rs +++ b/burn-wgpu/src/kernel/comparison/binary.rs @@ -1,172 +1,177 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{ + build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, + }, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::Shape; kernel_wgsl!(ComparisonRaw, "../../template/comparison/binary.wgsl"); kernel_wgsl!( - ComparisonInplaceRaw, - "../../template/comparison/binary_inplace.wgsl" + ComparisonInplaceRaw, + "../../template/comparison/binary_inplace.wgsl" ); /// Creates a comparison kernel. #[macro_export] macro_rules! comparison { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonRaw::source().register( - "body", - format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops), - ) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonRaw::source().register( + "body", + format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops), + ) + } + } + }; } /// Creates a comparison inplace kernel. #[macro_export] macro_rules! comparison_inplace { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonInplaceRaw::source() - .register( - "body", - "lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);", - ) - .add_template(format!( - "{}return {{{{ elem }}}}(lhs {} rhs);{}", - "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", $ops, "\n}\n" - )) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonInplaceRaw::source() + .register( + "body", + "lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);", + ) + .add_template(format!( + "{}return {{{{ elem }}}}(lhs {} rhs);{}", + "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", + $ops, + "\n}\n" + )) + } + } + }; } pub fn comparison( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); - let mut shape_out = [0; D]; - lhs - .shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - - let shape_out = Shape::new(shape_out); - let num_elems = shape_out.num_elements(); - - let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); - - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), + lhs.assert_is_on_same_device(&rhs); + let mut shape_out = [0; D]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::new(shape_out); + let num_elems = shape_out.num_elements(); + + let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); + + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), + ); + let info = build_info(&[&lhs, &rhs, &output]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], ); - let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], - ); - - WgpuTensor::new(output.client, output.device, output.shape, output.handle) + WgpuTensor::new(output.client, output.device, output.shape, output.handle) } pub fn comparison_inplace( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&rhs); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), - ); - let info = build_info(&[&lhs, &rhs]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), + ); + let info = build_info(&[&lhs, &rhs]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - lhs - .client - .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); - WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; - - comparison!(LowerEqual, "<="); - comparison_inplace!(LowerEqualInplace, "<="); - - #[test] - fn comparison_should_work_with_multiple_invocations() { - let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); - - let value = Tensor::::from_primitive(comparison::( - lhs.into_primitive(), - rhs.into_primitive(), - )); - - let value_ref = lhs_ref.lower_equal(rhs_ref); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[test] - fn comparison_inplace_should_work_with_multiple_invocations() { - let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); - - let value = - Tensor::::from_primitive( - comparison_inplace::(lhs.into_primitive(), rhs.into_primitive()), - ); - - let value_ref = lhs_ref.lower_equal(rhs_ref); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs() -> ( - Tensor, - Tensor, - Tensor, - Tensor, - ) { - TestBackend::seed(0); - let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); - let rhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - let rhs_ref = Tensor::::from_data(rhs.to_data()); - - (lhs, rhs, lhs_ref, rhs_ref) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; + + comparison!(LowerEqual, "<="); + comparison_inplace!(LowerEqualInplace, "<="); + + #[test] + fn comparison_should_work_with_multiple_invocations() { + let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); + + let value = Tensor::::from_primitive( + comparison::(lhs.into_primitive(), rhs.into_primitive()), + ); + + let value_ref = lhs_ref.lower_equal(rhs_ref); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[test] + fn comparison_inplace_should_work_with_multiple_invocations() { + let (lhs, rhs, lhs_ref, rhs_ref) = inputs(); + + let value = Tensor::::from_primitive(comparison_inplace::< + LowerEqualInplace, + f32, + 3, + >( + lhs.into_primitive(), + rhs.into_primitive(), + )); + + let value_ref = lhs_ref.lower_equal(rhs_ref); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs() -> ( + Tensor, + Tensor, + Tensor, + Tensor, + ) { + TestBackend::seed(0); + let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); + let rhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + let rhs_ref = Tensor::::from_data(rhs.to_data()); + + (lhs, rhs, lhs_ref, rhs_ref) + } } diff --git a/burn-wgpu/src/kernel/comparison/elem.rs b/burn-wgpu/src/kernel/comparison/elem.rs index 56d8b2b6a0..b358307222 100644 --- a/burn-wgpu/src/kernel/comparison/elem.rs +++ b/burn-wgpu/src/kernel/comparison/elem.rs @@ -1,138 +1,141 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!(ComparisonElemRaw, "../../template/comparison/elem.wgsl"); kernel_wgsl!( - ComparisonElemInplaceRaw, - "../../template/comparison/elem_inplace.wgsl" + ComparisonElemInplaceRaw, + "../../template/comparison/elem_inplace.wgsl" ); /// Creates a comparison elementwise kernel. #[macro_export] macro_rules! comparison_elem { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonElemRaw::source() - .register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonElemRaw::source() + .register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops)) + } + } + }; } /// Creates a comparison elementwise inplace kernel. #[macro_export] macro_rules! comparison_elem_inplace { - ( + ( $struct:ident, $ops:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonElemInplaceRaw::source() - .register("body", "lhs[id] = compare(lhs[id], rhs);") - .add_template(format!( - "{}return {{{{ elem }}}}(lhs {} rhs);{}", - "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", $ops, "\n}\n" - )) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonElemInplaceRaw::source() + .register("body", "lhs[id] = compare(lhs[id], rhs);") + .add_template(format!( + "{}return {{{{ elem }}}}(lhs {} rhs);{}", + "fn compare(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {\n", + $ops, + "\n}\n" + )) + } + } + }; } pub fn comparison_elem( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); + let num_elems = lhs.shape.num_elements(); - let handle = lhs.client.empty(num_elems * core::mem::size_of::()); - let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), - ); + let handle = lhs.client.empty(num_elems * core::mem::size_of::()); + let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), + ); - lhs - .client - .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]); + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]); - WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle) + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle) } pub fn comparison_elem_inplace( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - let kernel = - StaticKernel::>::new( - elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); - lhs - .client - .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); - - WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) + let kernel = + StaticKernel::>::new( + elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT), + ); + let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); + + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; - - comparison_elem!(LowerEqual, "<="); - comparison_elem_inplace!(LowerEqualInplace, "<="); - - #[test] - fn comparison_elem_should_work_with_multiple_invocations() { - let (lhs, lhs_ref, rhs) = inputs(); - - let value = Tensor::::from_primitive( - comparison_elem::(lhs.into_primitive(), rhs), - ); - - let value_ref = lhs_ref.lower_equal_elem(rhs); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[test] - fn comparison_elem_inplace_should_work_with_multiple_invocations() { - let (lhs, lhs_ref, rhs) = inputs(); - - let value = Tensor::::from_primitive(comparison_elem_inplace::< - LowerEqualInplace, - f32, - 3, - >(lhs.into_primitive(), rhs)); - - let value_ref = lhs_ref.lower_equal_elem(rhs); - value - .into_data() - .assert_approx_eq(&value_ref.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs() -> (Tensor, Tensor, f32) { - TestBackend::seed(0); - let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); - let lhs_ref = Tensor::::from_data(lhs.to_data()); - - (lhs, lhs_ref, 5.0) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; + + comparison_elem!(LowerEqual, "<="); + comparison_elem_inplace!(LowerEqualInplace, "<="); + + #[test] + fn comparison_elem_should_work_with_multiple_invocations() { + let (lhs, lhs_ref, rhs) = inputs(); + + let value = + Tensor::::from_primitive(comparison_elem::( + lhs.into_primitive(), + rhs, + )); + + let value_ref = lhs_ref.lower_equal_elem(rhs); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[test] + fn comparison_elem_inplace_should_work_with_multiple_invocations() { + let (lhs, lhs_ref, rhs) = inputs(); + + let value = + Tensor::::from_primitive(comparison_elem_inplace::< + LowerEqualInplace, + f32, + 3, + >(lhs.into_primitive(), rhs)); + + let value_ref = lhs_ref.lower_equal_elem(rhs); + value + .into_data() + .assert_approx_eq(&value_ref.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs() -> (Tensor, Tensor, f32) { + TestBackend::seed(0); + let lhs = Tensor::::random([2, 6, 256], Distribution::Uniform(0.0, 1.0)); + let lhs_ref = Tensor::::from_data(lhs.to_data()); + + (lhs, lhs_ref, 5.0) + } } diff --git a/burn-wgpu/src/kernel/conv/conv2d.rs b/burn-wgpu/src/kernel/conv/conv2d.rs index 86c03f303d..39b0ecf45d 100644 --- a/burn-wgpu/src/kernel/conv/conv2d.rs +++ b/burn-wgpu/src/kernel/conv/conv2d.rs @@ -1,106 +1,108 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions}, - Element, ElementConversion, Shape, + ops::{conv::calculate_conv_output_size, ConvOptions}, + Element, ElementConversion, Shape, }; kernel_wgsl!(Conv2d, "../../template/conv/conv2d.wgsl"); pub(crate) fn conv2d( - input: WgpuTensor, - weight: WgpuTensor, - bias: Option>, - options: ConvOptions<2>, + input: WgpuTensor, + weight: WgpuTensor, + bias: Option>, + options: ConvOptions<2>, ) -> WgpuTensor { - let input = kernel::into_contiguous(input); - let weight = kernel::into_contiguous(weight); - let [batch_size, _, in_height, in_width] = input.shape.dims; - let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims; + let input = kernel::into_contiguous(input); + let weight = kernel::into_contiguous(weight); + let [batch_size, _, in_height, in_width] = input.shape.dims; + let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims; - let out_0 = calculate_conv_output_size( - kernel_0, - options.stride[0], - options.padding[0], - options.dilation[0], - in_height, - ); - let out_1 = calculate_conv_output_size( - kernel_1, - options.stride[1], - options.padding[1], - options.dilation[1], - in_width, - ); + let out_0 = calculate_conv_output_size( + kernel_0, + options.stride[0], + options.padding[0], + options.dilation[0], + in_height, + ); + let out_1 = calculate_conv_output_size( + kernel_1, + options.stride[1], + options.padding[1], + options.dilation[1], + in_width, + ); - let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]); + let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]); - let output = empty_device( - input.client.clone(), - input.device.clone(), - shape_out.clone(), - ); + let output = empty_device( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + ); - let mut info = build_info(&[&input, &output, &weight]); - info.push(options.stride[0] as u32); - info.push(options.stride[1] as u32); - info.push(options.padding[0] as u32); - info.push(options.padding[1] as u32); - info.push(options.dilation[0] as u32); - info.push(options.dilation[1] as u32); - info.push(options.groups as u32); + let mut info = build_info(&[&input, &output, &weight]); + info.push(options.stride[0] as u32); + info.push(options.stride[1] as u32); + info.push(options.padding[0] as u32); + info.push(options.padding[1] as u32); + info.push(options.dilation[0] as u32); + info.push(options.dilation[1] as u32); + info.push(options.groups as u32); - let bias_handle = bias - .map(|bias| bias.handle) - .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); + let bias_handle = bias + .map(|bias| bias.handle) + .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), - ); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &weight.handle, - &bias_handle, - &output.handle, - &info_handle, - ], - ); + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &weight.handle, + &bias_handle, + &output.handle, + &info_handle, + ], + ); - output + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{module, Distribution, Tensor}; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{module, Distribution, Tensor}; - #[test] - fn conv2d_should_work_with_multiple_invocations() { - let input = Tensor::::random([6, 16, 32, 32], Distribution::Default); - let weight = Tensor::::random([12, 8, 3, 3], Distribution::Default); - let bias = Tensor::::random([12], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); - let weight_ref = Tensor::::from_data(weight.to_data()); - let bias_ref = Tensor::::from_data(bias.to_data()); - let options = burn_tensor::ops::ConvOptions::new([2, 3], [2, 3], [2, 3], 2); + #[test] + fn conv2d_should_work_with_multiple_invocations() { + let input = Tensor::::random([6, 16, 32, 32], Distribution::Default); + let weight = Tensor::::random([12, 8, 3, 3], Distribution::Default); + let bias = Tensor::::random([12], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); + let weight_ref = Tensor::::from_data(weight.to_data()); + let bias_ref = Tensor::::from_data(bias.to_data()); + let options = burn_tensor::ops::ConvOptions::new([2, 3], [2, 3], [2, 3], 2); - let output = module::conv2d(input, weight, Some(bias), options.clone()); - let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options); + let output = module::conv2d(input, weight, Some(bias), options.clone()); + let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options); - output - .into_data() - .assert_approx_eq(&output_ref.into_data(), 3); - } + output + .into_data() + .assert_approx_eq(&output_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/conv/conv_transpose2d.rs b/burn-wgpu/src/kernel/conv/conv_transpose2d.rs index b9eae1c861..d59a2cb6f8 100644 --- a/burn-wgpu/src/kernel/conv/conv_transpose2d.rs +++ b/burn-wgpu/src/kernel/conv/conv_transpose2d.rs @@ -1,119 +1,120 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape}; kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl"); pub(crate) fn conv_transpose2d( - input: WgpuTensor, - weight: WgpuTensor, - bias: Option>, - options: ConvTransposeOptions<2>, + input: WgpuTensor, + weight: WgpuTensor, + bias: Option>, + options: ConvTransposeOptions<2>, ) -> WgpuTensor { - let input = kernel::into_contiguous(input); - let weight = kernel::into_contiguous(weight); - let [batch_size, _, in_height, in_width] = input.shape.dims; - let [_, out_channels, kernel_0, kernel_1] = weight.shape.dims; - - let out_0 = (in_height - 1) * options.stride[0] - + options.dilation[0] * (kernel_0 - 1) - + options.padding_out[0] - - 2 * options.padding[0] - + 1; - let out_1 = (in_width - 1) * options.stride[1] - + options.dilation[1] * (kernel_1 - 1) - + options.padding_out[1] - - 2 * options.padding[1] - + 1; - - let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); - let num_elems = shape_out.num_elements(); - - let output = empty_device( - input.client.clone(), - input.device.clone(), - shape_out.clone(), - ); - let mut info = build_info(&[&input, &output, &weight]); - - info.push(options.stride[0] as u32); - info.push(options.stride[1] as u32); - info.push(options.padding[0] as u32); - info.push(options.padding[1] as u32); - info.push(options.dilation[0] as u32); - info.push(options.dilation[1] as u32); - info.push(options.groups as u32); - - let bias_handle = bias - .map(|bias| bias.handle) - .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); - - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &weight.handle, - &bias_handle, - &output.handle, - &info_handle, - ], - ); - - output -} - -#[cfg(test)] -mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, module, Distribution, Tensor}; - - #[test] - fn conv_transpose2d_should_work_with_multiple_invocations() { - TestBackend::seed(0); - - let height = 8; - let width = 8; - let in_channels = 8; - let out_channels = 8; - let batch_size = 32; - let kernel_size_0 = 3; - let kernel_size_1 = 3; - let options = burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1); - - let input = Tensor::::random( - [batch_size, in_channels, height, width], - Distribution::Default, + let input = kernel::into_contiguous(input); + let weight = kernel::into_contiguous(weight); + let [batch_size, _, in_height, in_width] = input.shape.dims; + let [_, out_channels, kernel_0, kernel_1] = weight.shape.dims; + + let out_0 = (in_height - 1) * options.stride[0] + + options.dilation[0] * (kernel_0 - 1) + + options.padding_out[0] + - 2 * options.padding[0] + + 1; + let out_1 = (in_width - 1) * options.stride[1] + + options.dilation[1] * (kernel_1 - 1) + + options.padding_out[1] + - 2 * options.padding[1] + + 1; + + let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); + let num_elems = shape_out.num_elements(); + + let output = empty_device( + input.client.clone(), + input.device.clone(), + shape_out.clone(), ); - let weight = Tensor::::random( - [ - in_channels, - out_channels / options.groups, - kernel_size_0, - kernel_size_1, - ], - Distribution::Default, + let mut info = build_info(&[&input, &output, &weight]); + + info.push(options.stride[0] as u32); + info.push(options.stride[1] as u32); + info.push(options.padding[0] as u32); + info.push(options.padding[1] as u32); + info.push(options.dilation[0] as u32); + info.push(options.dilation[1] as u32); + info.push(options.groups as u32); + + let bias_handle = bias + .map(|bias| bias.handle) + .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); + + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &weight.handle, + &bias_handle, + &output.handle, + &info_handle, + ], ); - let bias = Tensor::::random([out_channels], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); - let weight_ref = Tensor::::from_data(weight.to_data()); - let bias_ref = Tensor::::from_data(bias.to_data()); - - let output = module::conv_transpose2d(input, weight, Some(bias), options.clone()); - let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options); output - .into_data() - .assert_approx_eq(&output_ref.into_data(), 3); - } +} + +#[cfg(test)] +mod tests { + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, module, Distribution, Tensor}; + + #[test] + fn conv_transpose2d_should_work_with_multiple_invocations() { + TestBackend::seed(0); + + let height = 8; + let width = 8; + let in_channels = 8; + let out_channels = 8; + let batch_size = 32; + let kernel_size_0 = 3; + let kernel_size_1 = 3; + let options = + burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1); + + let input = Tensor::::random( + [batch_size, in_channels, height, width], + Distribution::Default, + ); + let weight = Tensor::::random( + [ + in_channels, + out_channels / options.groups, + kernel_size_0, + kernel_size_1, + ], + Distribution::Default, + ); + let bias = Tensor::::random([out_channels], Distribution::Default); + let input_ref = Tensor::::from_data(input.to_data()); + let weight_ref = Tensor::::from_data(weight.to_data()); + let bias_ref = Tensor::::from_data(bias.to_data()); + + let output = module::conv_transpose2d(input, weight, Some(bias), options.clone()); + let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options); + + output + .into_data() + .assert_approx_eq(&output_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/index/gather.rs b/burn-wgpu/src/kernel/index/gather.rs index a1da2533c0..48c7c5baab 100644 --- a/burn-wgpu/src/kernel/index/gather.rs +++ b/burn-wgpu/src/kernel/index/gather.rs @@ -1,87 +1,88 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(Gather, "../../template/index/gather.wgsl"); pub(crate) fn gather( - dim: usize, - tensor: WgpuTensor, - indices: WgpuTensor, + dim: usize, + tensor: WgpuTensor, + indices: WgpuTensor, ) -> WgpuTensor { - let shape_output = indices.shape.clone(); - let num_elems = shape_output.num_elements(); - let indices = kernel::into_contiguous(indices); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + let shape_output = indices.shape.clone(); + let num_elems = shape_output.num_elements(); + let indices = kernel::into_contiguous(indices); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - let mut info = build_info(&[&tensor, &output]); - info.push(dim as u32); - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + let mut info = build_info(&[&tensor, &output]); + info.push(dim as u32); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - tensor.client.execute( - Box::new(kernel), - &[ - &tensor.handle, - &indices.handle, - &output.handle, - &info_handle, - ], - ); + tensor.client.execute( + Box::new(kernel), + &[ + &tensor.handle, + &indices.handle, + &output.handle, + &info_handle, + ], + ); - output + output } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Int, Shape, Tensor}; + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Int, Shape, Tensor}; - #[test] - fn gather_should_work_with_multiple_workgroups_dim0() { - test_same_as_ref([6, 256], 0); - } + #[test] + fn gather_should_work_with_multiple_workgroups_dim0() { + test_same_as_ref([6, 256], 0); + } - #[test] - fn gather_should_work_with_multiple_workgroups_dim1() { - test_same_as_ref([6, 256], 1); - } + #[test] + fn gather_should_work_with_multiple_workgroups_dim1() { + test_same_as_ref([6, 256], 1); + } - fn test_same_as_ref(shape: [usize; D], dim: usize) { - TestBackend::seed(0); - let max = shape[dim]; - let shape = Shape::new(shape); - let tensor = Tensor::::random(shape.clone(), Distribution::Default); - let indices = Tensor::::from_data( - Tensor::::random( - [shape.num_elements()], - Distribution::Uniform(0., max as f32), - ) - .into_data() - .convert(), - ) - .reshape(shape); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let indices_ref = Tensor::::from_data(indices.to_data().convert()); + fn test_same_as_ref(shape: [usize; D], dim: usize) { + TestBackend::seed(0); + let max = shape[dim]; + let shape = Shape::new(shape); + let tensor = Tensor::::random(shape.clone(), Distribution::Default); + let indices = Tensor::::from_data( + Tensor::::random( + [shape.num_elements()], + Distribution::Uniform(0., max as f32), + ) + .into_data() + .convert(), + ) + .reshape(shape); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let indices_ref = + Tensor::::from_data(indices.to_data().convert()); - let actual = Tensor::::from_primitive(gather( - dim, - tensor.into_primitive(), - indices.into_primitive(), - )); - let expected = tensor_ref.gather(dim, indices_ref); + let actual = Tensor::::from_primitive(gather( + dim, + tensor.into_primitive(), + indices.into_primitive(), + )); + let expected = tensor_ref.gather(dim, indices_ref); - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/index/scatter.rs b/burn-wgpu/src/kernel/index/scatter.rs index 5b8489f782..2c09a4f1fc 100644 --- a/burn-wgpu/src/kernel/index/scatter.rs +++ b/burn-wgpu/src/kernel/index/scatter.rs @@ -1,138 +1,141 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!(Scatter, "../../template/index/scatter.wgsl"); pub(crate) fn scatter( - dim: usize, - tensor: WgpuTensor, - indices: WgpuTensor, - value: WgpuTensor, + dim: usize, + tensor: WgpuTensor, + indices: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - let indices = kernel::into_contiguous(indices); - let tensor = kernel::into_contiguous(tensor); - let value = kernel::into_contiguous(value); - - let tensor = match tensor.can_mut() { - true => tensor, - false => tensor.copy(), - }; - - let mut info = build_info(&[&tensor, &value]); - let mut strides = [0; D]; - let mut current = 1; - let mut num_elems_per_workgroup = 1; - - tensor - .shape - .dims - .iter() - .enumerate() - .rev() - .filter(|(index, _val)| *index != dim) - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - num_elems_per_workgroup *= tensor.shape.dims[index]; - }); - - strides - .into_iter() - .for_each(|stride| info.push(stride as u32)); - - info.push(dim as u32); - - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems_per_workgroup, WORKGROUP_DEFAULT), - ); + let indices = kernel::into_contiguous(indices); + let tensor = kernel::into_contiguous(tensor); + let value = kernel::into_contiguous(value); + + let tensor = match tensor.can_mut() { + true => tensor, + false => tensor.copy(), + }; + + let mut info = build_info(&[&tensor, &value]); + let mut strides = [0; D]; + let mut current = 1; + let mut num_elems_per_workgroup = 1; + + tensor + .shape + .dims + .iter() + .enumerate() + .rev() + .filter(|(index, _val)| *index != dim) + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + num_elems_per_workgroup *= tensor.shape.dims[index]; + }); + + strides + .into_iter() + .for_each(|stride| info.push(stride as u32)); + + info.push(dim as u32); + + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + num_elems_per_workgroup, + WORKGROUP_DEFAULT, + )); - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &indices.handle, &value.handle, &info_handle], - ); + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &indices.handle, &value.handle, &info_handle], + ); - tensor + tensor } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; - - #[test] - fn scatter_should_work_with_multiple_workgroups_2d_dim0() { - same_as_reference_same_shape(0, [256, 32]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_2d_dim1() { - same_as_reference_same_shape(1, [32, 256]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_3d_dim0() { - same_as_reference_same_shape(0, [256, 6, 6]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_3d_dim1() { - same_as_reference_same_shape(1, [6, 256, 6]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_3d_dim2() { - same_as_reference_same_shape(2, [6, 6, 256]); - } - - #[test] - fn scatter_should_work_with_multiple_workgroups_diff_shapes() { - same_as_reference_diff_shape(1, [32, 128], [32, 1]); - } - - fn same_as_reference_diff_shape( - dim: usize, - shape1: [usize; D], - shape2: [usize; D], - ) { - TestBackend::seed(0); - let tensor = Tensor::::random(shape1, Distribution::Default); - let value = Tensor::::random(shape2, Distribution::Default); - let indices = Tensor::::from_data( - Tensor::::random( - [shape2.iter().product()], - Distribution::Uniform(0., shape2[dim] as f32), - ) - .into_data() - .convert(), - ) - .reshape(shape2); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - let indices_ref = Tensor::::from_data(indices.to_data().convert()); - - let actual = Tensor::::from_primitive(scatter( - dim, - tensor.into_primitive(), - indices.into_primitive(), - value.into_primitive(), - )); - let expected = tensor_ref.scatter(dim, indices_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - fn same_as_reference_same_shape(dim: usize, shape: [usize; D]) { - same_as_reference_diff_shape(dim, shape, shape); - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; + + #[test] + fn scatter_should_work_with_multiple_workgroups_2d_dim0() { + same_as_reference_same_shape(0, [256, 32]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_2d_dim1() { + same_as_reference_same_shape(1, [32, 256]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_3d_dim0() { + same_as_reference_same_shape(0, [256, 6, 6]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_3d_dim1() { + same_as_reference_same_shape(1, [6, 256, 6]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_3d_dim2() { + same_as_reference_same_shape(2, [6, 6, 256]); + } + + #[test] + fn scatter_should_work_with_multiple_workgroups_diff_shapes() { + same_as_reference_diff_shape(1, [32, 128], [32, 1]); + } + + fn same_as_reference_diff_shape( + dim: usize, + shape1: [usize; D], + shape2: [usize; D], + ) { + TestBackend::seed(0); + let tensor = Tensor::::random(shape1, Distribution::Default); + let value = Tensor::::random(shape2, Distribution::Default); + let indices = Tensor::::from_data( + Tensor::::random( + [shape2.iter().product()], + Distribution::Uniform(0., shape2[dim] as f32), + ) + .into_data() + .convert(), + ) + .reshape(shape2); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + let indices_ref = + Tensor::::from_data(indices.to_data().convert()); + + let actual = Tensor::::from_primitive(scatter( + dim, + tensor.into_primitive(), + indices.into_primitive(), + value.into_primitive(), + )); + let expected = tensor_ref.scatter(dim, indices_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + fn same_as_reference_same_shape(dim: usize, shape: [usize; D]) { + same_as_reference_diff_shape(dim, shape, shape); + } } diff --git a/burn-wgpu/src/kernel/index/select.rs b/burn-wgpu/src/kernel/index/select.rs index 228be8bd05..5000b90608 100644 --- a/burn-wgpu/src/kernel/index/select.rs +++ b/burn-wgpu/src/kernel/index/select.rs @@ -1,172 +1,177 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(IndexSelect, "../../template/index/select.wgsl"); kernel_wgsl!( - SelectAssignInplace, - "../../template/index/select_assign_inplace.wgsl" + SelectAssignInplace, + "../../template/index/select_assign_inplace.wgsl" ); pub(crate) fn select( - tensor: WgpuTensor, - dim: usize, - indices: WgpuTensor, + tensor: WgpuTensor, + dim: usize, + indices: WgpuTensor, ) -> WgpuTensor { - let mut output_shape = tensor.shape.clone(); - output_shape.dims[dim] = indices.shape.dims[0]; - - let num_elems = output_shape.num_elements(); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), output_shape); - - let mut info = build_info(&[&tensor, &output]); - info.push(dim as u32); - - let info_handle = output.client.create(bytemuck::cast_slice(&info)); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - tensor.client.execute( - Box::new(kernel), - &[ - &tensor.handle, - &indices.handle, - &output.handle, - &info_handle, - ], - ); - - output + let mut output_shape = tensor.shape.clone(); + output_shape.dims[dim] = indices.shape.dims[0]; + + let num_elems = output_shape.num_elements(); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), output_shape); + + let mut info = build_info(&[&tensor, &output]); + info.push(dim as u32); + + let info_handle = output.client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + + tensor.client.execute( + Box::new(kernel), + &[ + &tensor.handle, + &indices.handle, + &output.handle, + &info_handle, + ], + ); + + output } pub(crate) fn select_assign( - tensor: WgpuTensor, - dim: usize, - indices: WgpuTensor, - value: WgpuTensor, + tensor: WgpuTensor, + dim: usize, + indices: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - let tensor = match tensor.can_mut() { - true => tensor, - false => tensor.copy(), - }; - - let mut info = build_info(&[&tensor, &value]); - let mut strides = [0; D]; - let mut current = 1; - let mut num_elems_per_workgroup = 1; - - tensor - .shape - .dims - .iter() - .enumerate() - .rev() - .filter(|(index, _val)| *index != dim) - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - num_elems_per_workgroup *= tensor.shape.dims[index]; - }); - - strides - .into_iter() - .for_each(|stride| info.push(stride as u32)); - - info.push(dim as u32); - - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - num_elems_per_workgroup, - WORKGROUP_DEFAULT, - )); - - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &indices.handle, &value.handle, &info_handle], - ); - - tensor + let tensor = match tensor.can_mut() { + true => tensor, + false => tensor.copy(), + }; + + let mut info = build_info(&[&tensor, &value]); + let mut strides = [0; D]; + let mut current = 1; + let mut num_elems_per_workgroup = 1; + + tensor + .shape + .dims + .iter() + .enumerate() + .rev() + .filter(|(index, _val)| *index != dim) + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + num_elems_per_workgroup *= tensor.shape.dims[index]; + }); + + strides + .into_iter() + .for_each(|stride| info.push(stride as u32)); + + info.push(dim as u32); + + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + num_elems_per_workgroup, + WORKGROUP_DEFAULT, + )); + + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &indices.handle, &value.handle, &info_handle], + ); + + tensor } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; - - #[test] - fn select_should_work_with_multiple_workgroups() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let indices = Tensor::::arange(0..100); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let indices_ref = Tensor::::from_data(indices.to_data().convert()); - - let actual = select(tensor.into_primitive(), 1, indices.into_primitive()); - let expected = tensor_ref.select(1, indices_ref); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_2d_dim0() { - select_assign_same_as_ref(0, [256, 6]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_2d_dim1() { - select_assign_same_as_ref(1, [6, 256]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_3d_dim0() { - select_assign_same_as_ref(0, [256, 6, 6]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_3d_dim1() { - select_assign_same_as_ref(1, [6, 256, 6]); - } - - #[test] - fn select_assign_should_work_with_multiple_workgroups_3d_dim2() { - select_assign_same_as_ref(2, [6, 6, 256]); - } - - fn select_assign_same_as_ref(dim: usize, shape: [usize; D]) { - TestBackend::seed(0); - let tensor = Tensor::::random(shape, Distribution::Default); - let value = Tensor::::random(shape, Distribution::Default); - let indices = Tensor::::from_data( - Tensor::::random([shape[dim]], Distribution::Uniform(0., shape[dim] as f32)) - .into_data() - .convert(), - ); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - let indices_ref = Tensor::::from_data(indices.to_data().convert()); - - let actual = Tensor::::from_primitive(select_assign( - tensor.into_primitive(), - dim, - indices.into_primitive(), - value.into_primitive(), - )); - let expected = tensor_ref.select_assign(dim, indices_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Distribution, Int, Tensor}; + + #[test] + fn select_should_work_with_multiple_workgroups() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let indices = Tensor::::arange(0..100); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let indices_ref = + Tensor::::from_data(indices.to_data().convert()); + + let actual = select(tensor.into_primitive(), 1, indices.into_primitive()); + let expected = tensor_ref.select(1, indices_ref); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_2d_dim0() { + select_assign_same_as_ref(0, [256, 6]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_2d_dim1() { + select_assign_same_as_ref(1, [6, 256]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_3d_dim0() { + select_assign_same_as_ref(0, [256, 6, 6]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_3d_dim1() { + select_assign_same_as_ref(1, [6, 256, 6]); + } + + #[test] + fn select_assign_should_work_with_multiple_workgroups_3d_dim2() { + select_assign_same_as_ref(2, [6, 6, 256]); + } + + fn select_assign_same_as_ref(dim: usize, shape: [usize; D]) { + TestBackend::seed(0); + let tensor = Tensor::::random(shape, Distribution::Default); + let value = Tensor::::random(shape, Distribution::Default); + let indices = Tensor::::from_data( + Tensor::::random( + [shape[dim]], + Distribution::Uniform(0., shape[dim] as f32), + ) + .into_data() + .convert(), + ); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + let indices_ref = + Tensor::::from_data(indices.to_data().convert()); + + let actual = Tensor::::from_primitive(select_assign( + tensor.into_primitive(), + dim, + indices.into_primitive(), + value.into_primitive(), + )); + let expected = tensor_ref.select_assign(dim, indices_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/index/slice.rs b/burn-wgpu/src/kernel/index/slice.rs index 62d59d06b6..e431168b68 100644 --- a/burn-wgpu/src/kernel/index/slice.rs +++ b/burn-wgpu/src/kernel/index/slice.rs @@ -1,130 +1,132 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::Shape; use std::ops::Range; kernel_wgsl!(IndexRaw, "../../template/index/slice.wgsl"); kernel_wgsl!( - IndexAssignInplaceRaw, - "../../template/index/slice_assign_inplace.wgsl" + IndexAssignInplaceRaw, + "../../template/index/slice_assign_inplace.wgsl" ); pub(crate) fn slice( - tensor: WgpuTensor, - indices: [Range; D2], + tensor: WgpuTensor, + indices: [Range; D2], ) -> WgpuTensor { - let mut dims = tensor.shape.dims; - for i in 0..D2 { - dims[i] = indices[i].end - indices[i].start; - } - let shape_output = Shape::new(dims); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - slice_on_output(tensor, output, indices) + let mut dims = tensor.shape.dims; + for i in 0..D2 { + dims[i] = indices[i].end - indices[i].start; + } + let shape_output = Shape::new(dims); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + slice_on_output(tensor, output, indices) } pub(crate) fn slice_on_output( - tensor: WgpuTensor, - output: WgpuTensor, - indices: [Range; D2], + tensor: WgpuTensor, + output: WgpuTensor, + indices: [Range; D2], ) -> WgpuTensor { - let mut info = build_info(&[&tensor, &output]); + let mut info = build_info(&[&tensor, &output]); - for i in 0..D1 { - let start = indices.get(i).map(|index| index.start).unwrap_or(0); - info.push(start as u32); - } + for i in 0..D1 { + let start = indices.get(i).map(|index| index.start).unwrap_or(0); + info.push(start as u32); + } - let info_handle = output.client.create(bytemuck::cast_slice(&info)); + let info_handle = output.client.create(bytemuck::cast_slice(&info)); - let kernel = - StaticKernel::>::new( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), - ); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &output.handle, &info_handle], - ); + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &output.handle, &info_handle], + ); - output + output } pub(crate) fn slice_assign( - tensor: WgpuTensor, - indices: [Range; D2], - value: WgpuTensor, + tensor: WgpuTensor, + indices: [Range; D2], + value: WgpuTensor, ) -> WgpuTensor { - let tensor = match tensor.can_mut() { - true => tensor, - false => tensor.copy(), - }; - let num_elems = tensor.shape.num_elements(); - let mut info = build_info(&[&tensor, &value]); - - for i in 0..D1 { - let start = indices.get(i).map(|index| index.start).unwrap_or(0); - info.push(start as u32); - } - - let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - tensor.client.execute( - Box::new(kernel), - &[&tensor.handle, &value.handle, &info_handle], - ); - - tensor + let tensor = match tensor.can_mut() { + true => tensor, + false => tensor.copy(), + }; + let num_elems = tensor.shape.num_elements(); + let mut info = build_info(&[&tensor, &value]); + + for i in 0..D1 { + let start = indices.get(i).map(|index| index.start).unwrap_or(0); + info.push(start as u32); + } + + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &value.handle, &info_handle], + ); + + tensor } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; - - #[test] - fn slice_should_work_with_multiple_workgroups() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let indices = [3..5, 45..256]; - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let actual = slice(tensor.into_primitive(), indices.clone()); - let expected = tensor_ref.slice(indices); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } - - #[test] - fn slice_assign_should_work_with_multiple_workgroups() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let value = Tensor::::random([2, 211], Distribution::Default); - let indices = [3..5, 45..256]; - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - - let actual = slice_assign( - tensor.into_primitive(), - indices.clone(), - value.into_primitive(), - ); - let expected = tensor_ref.slice_assign(indices, value_ref); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; + + #[test] + fn slice_should_work_with_multiple_workgroups() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let indices = [3..5, 45..256]; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let actual = slice(tensor.into_primitive(), indices.clone()); + let expected = tensor_ref.slice(indices); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn slice_assign_should_work_with_multiple_workgroups() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let value = Tensor::::random([2, 211], Distribution::Default); + let indices = [3..5, 45..256]; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + + let actual = slice_assign( + tensor.into_primitive(), + indices.clone(), + value.into_primitive(), + ); + let expected = tensor_ref.slice_assign(indices, value_ref); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/kernel/mask/base.rs b/burn-wgpu/src/kernel/mask/base.rs index 44b63f73ba..cfd2fe765f 100644 --- a/burn-wgpu/src/kernel/mask/base.rs +++ b/burn-wgpu/src/kernel/mask/base.rs @@ -2,29 +2,29 @@ use crate::{element::WgpuElement, tensor::WgpuTensor}; /// Execute the mask fill kernel. pub fn mask_fill( - tensor: WgpuTensor, - mask: WgpuTensor, - value: E, + tensor: WgpuTensor, + mask: WgpuTensor, + value: E, ) -> WgpuTensor { - if tensor.can_mut() { - return super::mask_fill::mask_fill_inplace(tensor, mask, value); - } + if tensor.can_mut() { + return super::mask_fill::mask_fill_inplace(tensor, mask, value); + } - super::mask_fill::mask_fill(tensor, mask, value) + super::mask_fill::mask_fill(tensor, mask, value) } /// Execute the mask where kernel. pub fn mask_where( - tensor: WgpuTensor, - mask: WgpuTensor, - value: WgpuTensor, + tensor: WgpuTensor, + mask: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - if tensor.can_mut_broadcast(&value) { - return super::mask_where::mask_where_inplace(tensor, mask, value, false); - } - if value.can_mut_broadcast(&tensor) { - return super::mask_where::mask_where_inplace(value, mask, tensor, true); - } + if tensor.can_mut_broadcast(&value) { + return super::mask_where::mask_where_inplace(tensor, mask, value, false); + } + if value.can_mut_broadcast(&tensor) { + return super::mask_where::mask_where_inplace(value, mask, tensor, true); + } - super::mask_where::mask_where(tensor, mask, value) + super::mask_where::mask_where(tensor, mask, value) } diff --git a/burn-wgpu/src/kernel/mask/mask_fill.rs b/burn-wgpu/src/kernel/mask/mask_fill.rs index 22679853bd..aed9c33593 100644 --- a/burn-wgpu/src/kernel/mask/mask_fill.rs +++ b/burn-wgpu/src/kernel/mask/mask_fill.rs @@ -1,122 +1,122 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(MaskFill, "../../template/mask/fill.wgsl"); kernel_wgsl!(MaskFillInplace, "../../template/mask/fill_inplace.wgsl"); pub fn mask_fill( - input: WgpuTensor, - mask: WgpuTensor, - value: E, + input: WgpuTensor, + mask: WgpuTensor, + value: E, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let output = empty_device( - input.client.clone(), - input.device.clone(), - input.shape.clone(), - ); - - let value_handle = output.client.create(E::as_bytes(&[value])); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &mask, &output]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &value_handle, - &mask.handle, - &output.handle, - &info_handle, - ], - ); - - output + let num_elems = input.shape.num_elements(); + let output = empty_device( + input.client.clone(), + input.device.clone(), + input.shape.clone(), + ); + + let value_handle = output.client.create(E::as_bytes(&[value])); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &mask, &output]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &value_handle, + &mask.handle, + &output.handle, + &info_handle, + ], + ); + + output } pub fn mask_fill_inplace( - input: WgpuTensor, - mask: WgpuTensor, - value: E, + input: WgpuTensor, + mask: WgpuTensor, + value: E, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let value_handle = input.client.create(E::as_bytes(&[value])); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &mask]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &value_handle, &mask.handle, &info_handle], - ); - - input + let num_elems = input.shape.num_elements(); + let value_handle = input.client.create(E::as_bytes(&[value])); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &mask]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &value_handle, &mask.handle, &info_handle], + ); + + input } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Bool, Distribution, Tensor}; - - #[test] - fn mask_fill_should_work_with_multiple_invocations() { - let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); - - let actual = Tensor::::from_primitive(mask_fill::( - tensor.into_primitive(), - mask.into_primitive(), - 4.0, - )); - let expected = tensor_ref.mask_fill(mask_ref, 4.0); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[test] - fn mask_fill_inplace_should_work_with_multiple_invocations() { - let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); - - let actual = Tensor::::from_primitive(mask_fill_inplace::( - tensor.into_primitive(), - mask.into_primitive(), - 4.0, - )); - let expected = tensor_ref.mask_fill(mask_ref, 4.0); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs_mask_fill() -> ( - Tensor, - Tensor, - Tensor, - Tensor, - ) { - let tensor = Tensor::::random([2, 6, 256], Distribution::Default); - let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) - .lower_equal_elem(0.5); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let mask_ref = Tensor::::from_data(mask.to_data()); - - (tensor, mask, tensor_ref, mask_ref) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Bool, Distribution, Tensor}; + + #[test] + fn mask_fill_should_work_with_multiple_invocations() { + let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); + + let actual = Tensor::::from_primitive(mask_fill::( + tensor.into_primitive(), + mask.into_primitive(), + 4.0, + )); + let expected = tensor_ref.mask_fill(mask_ref, 4.0); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[test] + fn mask_fill_inplace_should_work_with_multiple_invocations() { + let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); + + let actual = Tensor::::from_primitive(mask_fill_inplace::( + tensor.into_primitive(), + mask.into_primitive(), + 4.0, + )); + let expected = tensor_ref.mask_fill(mask_ref, 4.0); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs_mask_fill() -> ( + Tensor, + Tensor, + Tensor, + Tensor, + ) { + let tensor = Tensor::::random([2, 6, 256], Distribution::Default); + let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) + .lower_equal_elem(0.5); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let mask_ref = Tensor::::from_data(mask.to_data()); + + (tensor, mask, tensor_ref, mask_ref) + } } diff --git a/burn-wgpu/src/kernel/mask/mask_where.rs b/burn-wgpu/src/kernel/mask/mask_where.rs index 9775ed242d..9972554ab8 100644 --- a/burn-wgpu/src/kernel/mask/mask_where.rs +++ b/burn-wgpu/src/kernel/mask/mask_where.rs @@ -1,150 +1,150 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(MaskWhere, "../../template/mask/where.wgsl"); kernel_wgsl!(MaskWhereInplace, "../../template/mask/where_inplace.wgsl"); pub fn mask_where( - input: WgpuTensor, - mask: WgpuTensor, - value: WgpuTensor, + input: WgpuTensor, + mask: WgpuTensor, + value: WgpuTensor, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let output = empty_device( - input.client.clone(), - input.device.clone(), - input.shape.clone(), - ); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let info = build_info(&[&input, &value, &mask, &output]); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[ - &input.handle, - &value.handle, - &mask.handle, - &output.handle, - &info_handle, - ], - ); - - output + let num_elems = input.shape.num_elements(); + let output = empty_device( + input.client.clone(), + input.device.clone(), + input.shape.clone(), + ); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &value, &mask, &output]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[ + &input.handle, + &value.handle, + &mask.handle, + &output.handle, + &info_handle, + ], + ); + + output } pub fn mask_where_inplace( - input: WgpuTensor, - mask: WgpuTensor, - value: WgpuTensor, - reverse: bool, + input: WgpuTensor, + mask: WgpuTensor, + value: WgpuTensor, + reverse: bool, ) -> WgpuTensor { - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - input.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); - let mut info = build_info(&[&input, &value, &mask]); - info.push(match reverse { - true => 1, - false => 0, - }); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &value.handle, &mask.handle, &info_handle], - ); - - input + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + input.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let mut info = build_info(&[&input, &value, &mask]); + info.push(match reverse { + true => 1, + false => 0, + }); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &value.handle, &mask.handle, &info_handle], + ); + + input } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; - - #[test] - fn mask_where_should_work_with_multiple_invocations() { - let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - - let actual = Tensor::::from_primitive(mask_where::( - tensor.into_primitive(), - mask.into_primitive(), - value.into_primitive(), - )); - let expected = tensor_ref.mask_where(mask_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - #[test] - fn mask_where_inplace_direction_1_should_work_with_multiple_invocations() { - let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - - let actual = Tensor::::from_primitive(mask_where_inplace::( - tensor.into_primitive(), - mask.into_primitive(), - value.into_primitive(), - false, - )); - let expected = tensor_ref.mask_where(mask_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[test] - fn mask_where_inplace_direction_0_should_work_with_multiple_invocation() { - let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - - let actual = Tensor::::from_primitive(mask_where_inplace::( - value.into_primitive(), - mask.into_primitive(), - tensor.into_primitive(), - true, - )); - let expected = tensor_ref.mask_where(mask_ref, value_ref); - - expected - .into_data() - .assert_approx_eq(&actual.into_data(), 3); - } - - #[allow(clippy::type_complexity)] - fn inputs_mask_where() -> ( - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - ) { - TestBackend::seed(0); - let tensor = Tensor::::random([2, 6, 256], Distribution::Default); - let value = Tensor::::random([2, 6, 256], Distribution::Default); - let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) - .lower_equal_elem(0.5); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let value_ref = Tensor::::from_data(value.to_data()); - let mask_ref = Tensor::::from_data(mask.to_data()); - assert_eq!(mask.to_data(), mask_ref.to_data()); - - (tensor, value, mask, tensor_ref, value_ref, mask_ref) - } + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, Bool, Distribution, Tensor}; + + #[test] + fn mask_where_should_work_with_multiple_invocations() { + let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); + + let actual = Tensor::::from_primitive(mask_where::( + tensor.into_primitive(), + mask.into_primitive(), + value.into_primitive(), + )); + let expected = tensor_ref.mask_where(mask_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + #[test] + fn mask_where_inplace_direction_1_should_work_with_multiple_invocations() { + let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); + + let actual = Tensor::::from_primitive(mask_where_inplace::( + tensor.into_primitive(), + mask.into_primitive(), + value.into_primitive(), + false, + )); + let expected = tensor_ref.mask_where(mask_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[test] + fn mask_where_inplace_direction_0_should_work_with_multiple_invocation() { + let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); + + let actual = Tensor::::from_primitive(mask_where_inplace::( + value.into_primitive(), + mask.into_primitive(), + tensor.into_primitive(), + true, + )); + let expected = tensor_ref.mask_where(mask_ref, value_ref); + + expected + .into_data() + .assert_approx_eq(&actual.into_data(), 3); + } + + #[allow(clippy::type_complexity)] + fn inputs_mask_where() -> ( + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + ) { + TestBackend::seed(0); + let tensor = Tensor::::random([2, 6, 256], Distribution::Default); + let value = Tensor::::random([2, 6, 256], Distribution::Default); + let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.)) + .lower_equal_elem(0.5); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let value_ref = Tensor::::from_data(value.to_data()); + let mask_ref = Tensor::::from_data(mask.to_data()); + assert_eq!(mask.to_data(), mask_ref.to_data()); + + (tensor, value, mask, tensor_ref, value_ref, mask_ref) + } } diff --git a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs index 4d2e0f71e8..c57270a12b 100644 --- a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs +++ b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs @@ -2,201 +2,201 @@ use burn_tensor::Shape; use std::marker::PhantomData; use crate::{ - compute::{DynamicKernel, Kernel, WorkGroup}, - element::WgpuElement, - kernel::{ - build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, - WORKGROUP_DEFAULT, - }, - kernel_wgsl, - tensor::WgpuTensor, + compute::{DynamicKernel, Kernel, WorkGroup}, + element::WgpuElement, + kernel::{ + build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, + WORKGROUP_DEFAULT, + }, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!( - MatmulMemCoalescingRaw, - "../../template/matmul/mem_coalescing.wgsl" + MatmulMemCoalescingRaw, + "../../template/matmul/mem_coalescing.wgsl" ); #[derive(new, Debug)] struct MatmulMemCoalescing { - workgroup_size_x: usize, - workgroup_size_y: usize, - _elem: PhantomData, + workgroup_size_x: usize, + workgroup_size_y: usize, + _elem: PhantomData, } impl DynamicKernelSource for MatmulMemCoalescing { - fn source(&self) -> SourceTemplate { - MatmulMemCoalescingRaw::source() - .register("workgroup_size_x", self.workgroup_size_x.to_string()) - .register("workgroup_size_y", self.workgroup_size_y.to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulMemCoalescingRaw::source() + .register("workgroup_size_x", self.workgroup_size_x.to_string()) + .register("workgroup_size_y", self.workgroup_size_y.to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using memory coalescing algorithm with workgroups of size 16 pub fn matmul_mem_coalescing_default( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) + matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) } /// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes pub fn matmul_mem_coalescing( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, - workgroup_size_x: usize, - workgroup_size_y: usize, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, + workgroup_size_x: usize, + workgroup_size_y: usize, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&rhs); - let lhs = into_contiguous(lhs); - let rhs = into_contiguous(rhs); + let lhs = into_contiguous(lhs); + let rhs = into_contiguous(rhs); - let info = build_info(&[&lhs, &rhs, &output]); + let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - let kernel = matmul_mem_coalescing_kernel::( - &lhs.shape, - &rhs.shape, - &output.shape, - workgroup_size_x, - workgroup_size_y, - ); + let kernel = matmul_mem_coalescing_kernel::( + &lhs.shape, + &rhs.shape, + &output.shape, + workgroup_size_x, + workgroup_size_y, + ); - lhs.client.execute( - kernel, - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], - ); + lhs.client.execute( + kernel, + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + ); - output + output } fn matmul_mem_coalescing_kernel( - lhs_shape: &Shape, - rhs_shape: &Shape, - output_shape: &Shape, - workgroup_size_x: usize, - workgroup_size_y: usize, + lhs_shape: &Shape, + rhs_shape: &Shape, + output_shape: &Shape, + workgroup_size_x: usize, + workgroup_size_y: usize, ) -> Box { - let num_rows = lhs_shape.dims[D - 2]; - let num_cols = rhs_shape.dims[D - 1]; - - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output_shape.dims[i]; - } - - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); - - Box::new(DynamicKernel::new( - MatmulMemCoalescing::::new(workgroup_size_x, workgroup_size_y), - workgroup, - )) + let num_rows = lhs_shape.dims[D - 2]; + let num_cols = rhs_shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output_shape.dims[i]; + } + + let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); + + Box::new(DynamicKernel::new( + MatmulMemCoalescing::::new(workgroup_size_x, workgroup_size_y), + workgroup, + )) } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_mem_coalescing_straightforward() { - test_with_params::<2, 2>(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_shapes_smaller_than_blocks() { - test_with_params::<16, 16>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_n_smaller_than_m() { - test_with_params::<2, 2>(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_m_smaller_than_n() { - test_with_params::<2, 2>(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_k_smaller_than_m_n() { - test_with_params::<2, 2>(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_k_larger_than_m_n() { - test_with_params::<2, 2>(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_multibatch_1_dim() { - test_with_params::<2, 2>(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_mem_coalescing_multibatch_2_dims() { - test_with_params::<2, 2>(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_mem_coalescing_blocks_divide_shapes_unevenly() { - test_with_params::<3, 3>(7, 7, 7, 1, 1); - } - - fn test_with_params( - m: usize, - k: usize, - n: usize, - batch_1: usize, - batch_2: usize, - ) { - let func = |lhs, rhs, out| { - matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y) - }; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_batches_no_padding() { - let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_col_no_padding() { - let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_with_batch_no_padding() { - let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_mem_coalescing_straightforward() { + test_with_params::<2, 2>(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_shapes_smaller_than_blocks() { + test_with_params::<16, 16>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_n_smaller_than_m() { + test_with_params::<2, 2>(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_m_smaller_than_n() { + test_with_params::<2, 2>(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_k_smaller_than_m_n() { + test_with_params::<2, 2>(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_k_larger_than_m_n() { + test_with_params::<2, 2>(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_multibatch_1_dim() { + test_with_params::<2, 2>(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_mem_coalescing_multibatch_2_dims() { + test_with_params::<2, 2>(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_mem_coalescing_blocks_divide_shapes_unevenly() { + test_with_params::<3, 3>(7, 7, 7, 1, 1); + } + + fn test_with_params( + m: usize, + k: usize, + n: usize, + batch_1: usize, + batch_2: usize, + ) { + let func = |lhs, rhs, out| { + matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y) + }; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_batches_no_padding() { + let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_col_no_padding() { + let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_with_batch_no_padding() { + let matmul_func = |lhs, rhs, out| matmul_mem_coalescing::(lhs, rhs, out, 2, 2); + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/naive.rs b/burn-wgpu/src/kernel/matmul/naive.rs index a0f2055809..8d7d95b51a 100644 --- a/burn-wgpu/src/kernel/matmul/naive.rs +++ b/burn-wgpu/src/kernel/matmul/naive.rs @@ -1,9 +1,9 @@ use crate::{ - compute::{StaticKernel, WorkGroup}, - element::WgpuElement, - kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource}, - kernel_wgsl, - tensor::WgpuTensor, + compute::{StaticKernel, WorkGroup}, + element::WgpuElement, + kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl"); @@ -11,164 +11,164 @@ kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl"); struct MatmulNaive; impl StaticKernelSource - for MatmulNaive + for MatmulNaive { - fn source() -> SourceTemplate { - MatmulNaiveRaw::source() - .register("block_size_m", WORKGROUP_SIZE_X.to_string()) - .register("block_size_n", WORKGROUP_SIZE_Y.to_string()) - } + fn source() -> SourceTemplate { + MatmulNaiveRaw::source() + .register("block_size_m", WORKGROUP_SIZE_X.to_string()) + .register("block_size_n", WORKGROUP_SIZE_Y.to_string()) + } } /// Matrix multiplication using naive algorithm with workgroups of size 16 pub fn matmul_naive_default( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, ) -> WgpuTensor { - matmul_naive::(lhs, rhs, output) + matmul_naive::(lhs, rhs, output) } /// Matrix multiplication using naive algorithm with custom workgroup sizes pub fn matmul_naive< - E: WgpuElement, - const D: usize, - const WORKGROUP_SIZE_X: usize, - const WORKGROUP_SIZE_Y: usize, + E: WgpuElement, + const D: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, >( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, ) -> WgpuTensor { - lhs.assert_is_on_same_device(&rhs); - - let lhs = into_contiguous(lhs); - let rhs = into_contiguous(rhs); - - let num_rows = lhs.shape.dims[D - 2]; - let num_cols = rhs.shape.dims[D - 1]; - - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / WORKGROUP_SIZE_X as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / WORKGROUP_SIZE_Y as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output.shape.dims[i]; - } - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); - - let kernel = StaticKernel::< - KernelSettings< - MatmulNaive, - E, - i32, - WORKGROUP_SIZE_X, - WORKGROUP_SIZE_Y, - 1, - >, - >::new(workgroup); - - let info = build_info(&[&lhs, &rhs, &output]); - - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], - ); - - output + lhs.assert_is_on_same_device(&rhs); + + let lhs = into_contiguous(lhs); + let rhs = into_contiguous(rhs); + + let num_rows = lhs.shape.dims[D - 2]; + let num_cols = rhs.shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / WORKGROUP_SIZE_X as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / WORKGROUP_SIZE_Y as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output.shape.dims[i]; + } + let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); + + let kernel = StaticKernel::< + KernelSettings< + MatmulNaive, + E, + i32, + WORKGROUP_SIZE_X, + WORKGROUP_SIZE_Y, + 1, + >, + >::new(workgroup); + + let info = build_info(&[&lhs, &rhs, &output]); + + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], + ); + + output } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_naive_straightforward() { - test_with_params::<2, 2>(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_naive_shapes_smaller_than_blocks() { - test_with_params::<16, 16>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_n_smaller_than_m() { - test_with_params::<2, 2>(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_naive_m_smaller_than_n() { - test_with_params::<2, 2>(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_k_smaller_than_m_n() { - test_with_params::<2, 2>(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_k_larger_than_m_n() { - test_with_params::<2, 2>(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_naive_multibatch_1_dim() { - test_with_params::<2, 2>(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_naive_multibatch_2_dims() { - test_with_params::<2, 2>(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_naive_blocks_divide_shapes_unevenly() { - test_with_params::<3, 3>(7, 7, 7, 1, 1); - } - - fn test_with_params( - m: usize, - k: usize, - n: usize, - batch_1: usize, - batch_2: usize, - ) { - let func = matmul_naive::; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_batches_no_padding() { - let matmul_func = matmul_naive::; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_col_no_padding() { - let matmul_func = matmul_naive::; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_naive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_naive::; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_naive_straightforward() { + test_with_params::<2, 2>(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_naive_shapes_smaller_than_blocks() { + test_with_params::<16, 16>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_n_smaller_than_m() { + test_with_params::<2, 2>(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_naive_m_smaller_than_n() { + test_with_params::<2, 2>(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_k_smaller_than_m_n() { + test_with_params::<2, 2>(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_k_larger_than_m_n() { + test_with_params::<2, 2>(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_naive_multibatch_1_dim() { + test_with_params::<2, 2>(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_naive_multibatch_2_dims() { + test_with_params::<2, 2>(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_naive_blocks_divide_shapes_unevenly() { + test_with_params::<3, 3>(7, 7, 7, 1, 1); + } + + fn test_with_params( + m: usize, + k: usize, + n: usize, + batch_1: usize, + batch_2: usize, + ) { + let func = matmul_naive::; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_batches_no_padding() { + let matmul_func = matmul_naive::; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_col_no_padding() { + let matmul_func = matmul_naive::; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_naive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_naive::; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs index 8700bc8336..0ce7808d5c 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs @@ -1,10 +1,10 @@ use super::padding::{crop, pad_round, PaddingOutput}; use crate::{ - compute::{DynamicKernel, WgpuHandle, WorkGroup}, - element::WgpuElement, - kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource}, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::{DynamicKernel, WgpuHandle, WorkGroup}, + element::WgpuElement, + kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource}, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::{Element, Shape}; @@ -14,75 +14,75 @@ pub(crate) const B_K: usize = 32; pub(crate) const WORKGROUP_SIZE: usize = 16; pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { - let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; - let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; - let mut num_blocks_z = 1; - for i in 0..D - 2 { - num_blocks_z *= output_shape.dims[i]; - } + let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; + let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; + let mut num_blocks_z = 1; + for i in 0..D - 2 { + num_blocks_z *= output_shape.dims[i]; + } - WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) + WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) } pub(super) fn make_info_handle( - lhs: &WgpuTensor, - rhs: &WgpuTensor, - output: &WgpuTensor, + lhs: &WgpuTensor, + rhs: &WgpuTensor, + output: &WgpuTensor, ) -> WgpuHandle { - let info = build_info(&[lhs, rhs, output]); - rhs.client.create(bytemuck::cast_slice(&info)) + let info = build_info(&[lhs, rhs, output]); + rhs.client.create(bytemuck::cast_slice(&info)) } #[allow(clippy::too_many_arguments)] pub(super) fn matmul_tiling_2d_launch< - E: WgpuElement + Element, - const D: usize, - K: DynamicKernelSource + 'static, + E: WgpuElement + Element, + const D: usize, + K: DynamicKernelSource + 'static, >( - lhs: WgpuTensor, - rhs: WgpuTensor, - output: WgpuTensor, - kernel: K, + lhs: WgpuTensor, + rhs: WgpuTensor, + output: WgpuTensor, + kernel: K, ) -> WgpuTensor { - // A tensor may need to be padded, in which case it will implicitly become contiguous - // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. - // If batches were swapped among themselves, or if the last two dims are transposed, the underlying - // kernel handles it without needing to turn it into contiguous. - let round_lhs = pad_round(lhs, B_M, B_K); - let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_lhs.into_tensor(), - }; - let round_rhs = pad_round(rhs, B_K, B_N); - let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_rhs.into_tensor(), - }; + // A tensor may need to be padded, in which case it will implicitly become contiguous + // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. + // If batches were swapped among themselves, or if the last two dims are transposed, the underlying + // kernel handles it without needing to turn it into contiguous. + let round_lhs = pad_round(lhs, B_M, B_K); + let lhs = match round_lhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_lhs.into_tensor(), + }; + let round_rhs = pad_round(rhs, B_K, B_N); + let rhs = match round_rhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_rhs.into_tensor(), + }; - let rounded_output_shape = shape_out(&lhs, &rhs); + let rounded_output_shape = shape_out(&lhs, &rhs); - let rounded_output = empty_device( - rhs.client.clone(), - rhs.device.clone(), - rounded_output_shape.clone(), - ); + let rounded_output = empty_device( + rhs.client.clone(), + rhs.device.clone(), + rounded_output_shape.clone(), + ); - let workgroup = make_workgroup(&rounded_output_shape); - let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); + let workgroup = make_workgroup(&rounded_output_shape); + let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); - lhs.client.execute( - Box::new(DynamicKernel::new(kernel, workgroup)), - &[ - &lhs.handle, - &rhs.handle, - &rounded_output.handle, - &info_handle, - ], - ); + lhs.client.execute( + Box::new(DynamicKernel::new(kernel, workgroup)), + &[ + &lhs.handle, + &rhs.handle, + &rounded_output.handle, + &info_handle, + ], + ); - crop(rounded_output, output) + crop(rounded_output, output) } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs index b159ea1f59..290d340783 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs @@ -3,25 +3,25 @@ use std::ops::Range; use burn_tensor::{Element, Shape}; use crate::{ - element::WgpuElement, - kernel::{slice_assign, slice_on_output}, - ops::numeric::zeros_device, - tensor::WgpuTensor, + element::WgpuElement, + kernel::{slice_assign, slice_on_output}, + ops::numeric::zeros_device, + tensor::WgpuTensor, }; // Output of the pad_round function. Allows to know explicitly if early return occurred pub(super) enum PaddingOutput { - Padded(WgpuTensor), - Unchanged(WgpuTensor), + Padded(WgpuTensor), + Unchanged(WgpuTensor), } impl PaddingOutput { - pub fn into_tensor(self) -> WgpuTensor { - match self { - PaddingOutput::Padded(tensor) => tensor, - PaddingOutput::Unchanged(tensor) => tensor, + pub fn into_tensor(self) -> WgpuTensor { + match self { + PaddingOutput::Padded(tensor) => tensor, + PaddingOutput::Unchanged(tensor) => tensor, + } } - } } /// Pads tensor with zeros to make tensor number of rows and columns @@ -29,279 +29,280 @@ impl PaddingOutput { /// For instance tensor of shape [1000, 1000] with divisors 64 and 64 /// will be padded to [1024, 1024] with the last 24 elements being zeros pub(super) fn pad_round( - tensor: WgpuTensor, - row_divisor: usize, - col_divisor: usize, + tensor: WgpuTensor, + row_divisor: usize, + col_divisor: usize, ) -> PaddingOutput { - let previous_row_dim = tensor.shape.dims[D - 2]; - let previous_col_dim = tensor.shape.dims[D - 1]; - let row_modulo = previous_row_dim % row_divisor; - let col_modulo = previous_col_dim % col_divisor; - - let new_row_dim = match row_modulo { - 0 => previous_row_dim, - _ => previous_row_dim + row_divisor - row_modulo, - }; - let new_col_dim = match col_modulo { - 0 => previous_col_dim, - _ => previous_col_dim + col_divisor - col_modulo, - }; - if previous_row_dim == new_row_dim && previous_col_dim == new_col_dim { - return PaddingOutput::Unchanged(tensor); - } - - let mut padded_shape = Vec::with_capacity(D); - for i in 0..D - 2 { - padded_shape.push(tensor.shape.dims[i]); - } - padded_shape.push(new_row_dim); - padded_shape.push(new_col_dim); - - PaddingOutput::Padded(padding::(tensor, padded_shape.into())) + let previous_row_dim = tensor.shape.dims[D - 2]; + let previous_col_dim = tensor.shape.dims[D - 1]; + let row_modulo = previous_row_dim % row_divisor; + let col_modulo = previous_col_dim % col_divisor; + + let new_row_dim = match row_modulo { + 0 => previous_row_dim, + _ => previous_row_dim + row_divisor - row_modulo, + }; + let new_col_dim = match col_modulo { + 0 => previous_col_dim, + _ => previous_col_dim + col_divisor - col_modulo, + }; + if previous_row_dim == new_row_dim && previous_col_dim == new_col_dim { + return PaddingOutput::Unchanged(tensor); + } + + let mut padded_shape = Vec::with_capacity(D); + for i in 0..D - 2 { + padded_shape.push(tensor.shape.dims[i]); + } + padded_shape.push(new_row_dim); + padded_shape.push(new_col_dim); + + PaddingOutput::Padded(padding::(tensor, padded_shape.into())) } /// Pads tensor by adding zeros when padded dim is larger than tensor dim fn padding( - tensor: WgpuTensor, - padded_shape: Shape, + tensor: WgpuTensor, + padded_shape: Shape, ) -> WgpuTensor { - let ranges = padded_shape - .dims - .iter() - .map(|dim| 0..*dim) - .collect::>>() - .try_into() - .unwrap(); - - slice_assign::( - zeros_device(tensor.client.clone(), tensor.device.clone(), padded_shape), - ranges, - tensor, - ) + let ranges = padded_shape + .dims + .iter() + .map(|dim| 0..*dim) + .collect::>>() + .try_into() + .unwrap(); + + slice_assign::( + zeros_device(tensor.client.clone(), tensor.device.clone(), padded_shape), + ranges, + tensor, + ) } /// Crops tensor by deleting values when cropped dim is smaller than tensor dim pub(super) fn crop( - tensor: WgpuTensor, - output: WgpuTensor, + tensor: WgpuTensor, + output: WgpuTensor, ) -> WgpuTensor { - let ranges = output - .shape - .dims - .iter() - .map(|dim| 0..*dim) - .collect::>>() - .try_into() - .unwrap(); - slice_on_output::(tensor, output, ranges) + let ranges = output + .shape + .dims + .iter() + .map(|dim| 0..*dim) + .collect::>>() + .try_into() + .unwrap(); + slice_on_output::(tensor, output, ranges) } #[cfg(test)] mod tests { - use super::*; - use crate::tests::TestTensor; - - #[test] - fn padding_already_round_should_have_same_shape() { - let row = 10; - let row_divisor = 5; - let col = 12; - let col_divisor = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [row, col].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn padding_already_round_should_have_same_values() { - let row = 10; - let row_divisor = 5; - let col = 12; - let col_divisor = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor); - - let padded = TestTensor::from_primitive(padded.into_tensor()); - padded.into_data().assert_approx_eq(&tensor.into_data(), 3); - } - - #[test] - fn padding_not_round_should_have_rounded_shape() { - let row = 10; - let row_divisor = 6; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [12, 15].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn padding_not_round_should_have_same_values() { - let row = 10; - let row_divisor = 6; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor).into_tensor(); - - let padded = TestTensor::from_primitive(padded).to_data(); - let tensor = tensor.into_data(); - for i in 0..row { - for j in 0..col { - assert!(padded.value[i * 15 + j] == tensor.value[i * col + j]); - } + use super::*; + use crate::tests::TestTensor; + + #[test] + fn padding_already_round_should_have_same_shape() { + let row = 10; + let row_divisor = 5; + let col = 12; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row, col].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); } - } - - #[test] - fn padding_not_round_should_have_zero_padding() { - let row = 10; - let row_divisor = 6; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - let padded = TestTensor::from_primitive(padded).to_data(); - - // check right of matrix - for i in 0..row { - for j in col..15 { - assert!(padded.value[i * 15 + j] == 0.0); - } + + #[test] + fn padding_already_round_should_have_same_values() { + let row = 10; + let row_divisor = 5; + let col = 12; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor); + + let padded = TestTensor::from_primitive(padded.into_tensor()); + padded.into_data().assert_approx_eq(&tensor.into_data(), 3); } - // check below matrix, including bottom right - for i in row..12 { - for j in 0..15 { - assert!(padded.value[i * 15 + j] == 0.0); - } + + #[test] + fn padding_not_round_should_have_rounded_shape() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [12, 15].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); } - } - - #[test] - fn padding_works_with_batch() { - let row = 10; - let row_divisor = 4; - let col = 12; - let col_divisor = 5; - let tensor = TestTensor::random([2, 3, row, col], burn_tensor::Distribution::Default); - let expected_shape = [2, 3, 12, 15].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn padding_with_row_divisor_larger_than_row() { - let row = 10; - let row_divisor = 32; - let col = 4; - let col_divisor = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [row_divisor, 2 * col_divisor].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn padding_with_row_divisor_equal_to_row_but_col_must_be_padded() { - let row = 32; - let row_divisor = 32; - let col = 4; - let col_divisor = 64; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [32, 64].into(); - - let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); - - assert!(padded.shape == expected_shape); - } - - #[test] - fn crop_same_shape_should_be_unchanged_shape() { - let row = 10; - let col = 12; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [row, col].into(); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([row, col]).into_primitive(), - ); - - assert!(unpadded.shape == expected_shape); - } - - #[test] - fn crop_same_shape_should_have_unchanged_values() { - let row = 10; - let col = 12; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([row, col]).into_primitive(), - ); - - let unpadded = TestTensor::from_primitive(unpadded).to_data(); - let tensor = tensor.into_data(); - for i in 0..row { - for j in 0..col { - assert!(unpadded.value[i * col + j] == tensor.value[i * col + j]); - } + + #[test] + fn padding_not_round_should_have_same_values() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = + pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor).into_tensor(); + + let padded = TestTensor::from_primitive(padded).to_data(); + let tensor = tensor.into_data(); + for i in 0..row { + for j in 0..col { + assert!(padded.value[i * 15 + j] == tensor.value[i * col + j]); + } + } } - } - - #[test] - fn crop_should_decrease_shape() { - let row = 10; - let keep_rows = 8; - let col = 12; - let keep_cols = 10; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - let expected_shape = [keep_rows, keep_cols].into(); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([keep_rows, keep_cols]).into_primitive(), - ); - - assert!(unpadded.shape == expected_shape); - } - - #[test] - fn crop_should_keep_same_values() { - let row = 4; - let keep_rows = 3; - let col = 4; - let keep_cols = 3; - let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); - - let unpadded = crop( - tensor.clone().into_primitive(), - TestTensor::empty([keep_rows, keep_cols]).into_primitive(), - ); - - let unpadded = TestTensor::from_primitive(unpadded).to_data(); - let tensor = tensor.into_data(); - - for i in 0..keep_rows { - for j in 0..keep_cols { - assert!(unpadded.value[i * keep_cols + j] == tensor.value[i * col + j]); - } + + #[test] + fn padding_not_round_should_have_zero_padding() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + let padded = TestTensor::from_primitive(padded).to_data(); + + // check right of matrix + for i in 0..row { + for j in col..15 { + assert!(padded.value[i * 15 + j] == 0.0); + } + } + // check below matrix, including bottom right + for i in row..12 { + for j in 0..15 { + assert!(padded.value[i * 15 + j] == 0.0); + } + } + } + + #[test] + fn padding_works_with_batch() { + let row = 10; + let row_divisor = 4; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([2, 3, row, col], burn_tensor::Distribution::Default); + let expected_shape = [2, 3, 12, 15].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_with_row_divisor_larger_than_row() { + let row = 10; + let row_divisor = 32; + let col = 4; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row_divisor, 2 * col_divisor].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_with_row_divisor_equal_to_row_but_col_must_be_padded() { + let row = 32; + let row_divisor = 32; + let col = 4; + let col_divisor = 64; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [32, 64].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor).into_tensor(); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn crop_same_shape_should_be_unchanged_shape() { + let row = 10; + let col = 12; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row, col].into(); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([row, col]).into_primitive(), + ); + + assert!(unpadded.shape == expected_shape); + } + + #[test] + fn crop_same_shape_should_have_unchanged_values() { + let row = 10; + let col = 12; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([row, col]).into_primitive(), + ); + + let unpadded = TestTensor::from_primitive(unpadded).to_data(); + let tensor = tensor.into_data(); + for i in 0..row { + for j in 0..col { + assert!(unpadded.value[i * col + j] == tensor.value[i * col + j]); + } + } + } + + #[test] + fn crop_should_decrease_shape() { + let row = 10; + let keep_rows = 8; + let col = 12; + let keep_cols = 10; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [keep_rows, keep_cols].into(); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([keep_rows, keep_cols]).into_primitive(), + ); + + assert!(unpadded.shape == expected_shape); + } + + #[test] + fn crop_should_keep_same_values() { + let row = 4; + let keep_rows = 3; + let col = 4; + let keep_cols = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let unpadded = crop( + tensor.clone().into_primitive(), + TestTensor::empty([keep_rows, keep_cols]).into_primitive(), + ); + + let unpadded = TestTensor::from_primitive(unpadded).to_data(); + let tensor = tensor.into_data(); + + for i in 0..keep_rows { + for j in 0..keep_cols { + assert!(unpadded.value[i * keep_cols + j] == tensor.value[i * col + j]); + } + } } - } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs b/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs index ed8e68ec52..b444bd20bc 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs @@ -1,10 +1,10 @@ use burn_tensor::Element; use crate::{ - compute::DynamicKernel, - element::WgpuElement, - kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::WgpuTensor, + compute::DynamicKernel, + element::WgpuElement, + kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -13,183 +13,183 @@ use crate::kernel_wgsl; use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE}; kernel_wgsl!( - MatmulTiling2DUnpaddedRaw, - "../../../template/matmul/blocktiling_2d/unpadded.wgsl" + MatmulTiling2DUnpaddedRaw, + "../../../template/matmul/blocktiling_2d/unpadded.wgsl" ); #[derive(new, Debug)] struct MatmulTiling2DUnpadded { - _elem: PhantomData, + _elem: PhantomData, } impl DynamicKernelSource for MatmulTiling2DUnpadded { - fn source(&self) -> SourceTemplate { - MatmulTiling2DUnpaddedRaw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulTiling2DUnpaddedRaw::source() + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) + .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) + .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_z", "1".to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using tiling 2d algorithm with /// vec4 primitive on both lhs and rhs, with no padding needed pub fn matmul_tiling_2d_unpadded( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - let lhs = match lhs.batch_swapped_with_row_col() { - true => into_contiguous(lhs), - false => lhs, - }; - let rhs = match rhs.batch_swapped_with_row_col() { - true => into_contiguous(rhs), - false => rhs, - }; - - let workgroup = make_workgroup(&out.shape); - let info_handle = make_info_handle(&lhs, &rhs, &out); - - lhs.client.execute( - Box::new(DynamicKernel::new( - MatmulTiling2DUnpadded::::new(), - workgroup, - )), - &[&lhs.handle, &rhs.handle, &out.handle, &info_handle], - ); - - out + let lhs = match lhs.batch_swapped_with_row_col() { + true => into_contiguous(lhs), + false => lhs, + }; + let rhs = match rhs.batch_swapped_with_row_col() { + true => into_contiguous(rhs), + false => rhs, + }; + + let workgroup = make_workgroup(&out.shape); + let info_handle = make_info_handle(&lhs, &rhs, &out); + + lhs.client.execute( + Box::new(DynamicKernel::new( + MatmulTiling2DUnpadded::::new(), + workgroup, + )), + &[&lhs.handle, &rhs.handle, &out.handle, &info_handle], + ); + + out } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_unpadded_straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_shapes_equal_blocks() { - test_with_params(64, 32, 64, 2, 2); - } - - #[test] - pub fn test_matmul_unpadded_m_exceeds_block() { - test_with_params(75, 32, 64, 2, 2); - } - - #[test] - pub fn test_matmul_unpadded_k_exceeds_block() { - test_with_params(64, 33, 32, 1, 1); - } - - #[test] - pub fn test_matmul_irregular_shape() { - test_with_params(123, 255, 72, 3, 5); - } - - #[test] - pub fn test64_matmul_unpadded_n_exceeds_block() { - test_with_params(64, 32, 75, 2, 2); - } - - #[test] - pub fn test_matmul_unpadded_n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_unpadded_multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_unpadded_blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn test_matmul_unpadded_large() { - test_with_params(134, 242, 250, 1, 1); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let func = matmul_tiling_2d_unpadded; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_primitive_swapped_batches_no_padding() { - let matmul_func = matmul_tiling_2d_unpadded; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_primitive_swapped_row_col_no_padding() { - let matmul_func = matmul_tiling_2d_unpadded; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_primitive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_tiling_2d_unpadded; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_unpadded_straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_shapes_equal_blocks() { + test_with_params(64, 32, 64, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_m_exceeds_block() { + test_with_params(75, 32, 64, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_k_exceeds_block() { + test_with_params(64, 33, 32, 1, 1); + } + + #[test] + pub fn test_matmul_irregular_shape() { + test_with_params(123, 255, 72, 3, 5); + } + + #[test] + pub fn test64_matmul_unpadded_n_exceeds_block() { + test_with_params(64, 32, 75, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_unpadded_multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_unpadded_blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_large() { + test_with_params(134, 242, 250, 1, 1); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let func = matmul_tiling_2d_unpadded; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_batches_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_row_col_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs b/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs index 587a15bb2a..1130ccd742 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs @@ -1,9 +1,9 @@ use burn_tensor::Element; use crate::{ - element::WgpuElement, - kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::WgpuTensor, + element::WgpuElement, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -12,139 +12,139 @@ use crate::kernel_wgsl; use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE}; kernel_wgsl!( - MatmulTiling2Dvec4Raw, - "../../../template/matmul/blocktiling_2d/vec4.wgsl" + MatmulTiling2Dvec4Raw, + "../../../template/matmul/blocktiling_2d/vec4.wgsl" ); #[derive(new, Debug)] struct MatmulTiling2Dvec4 { - _elem: PhantomData, + _elem: PhantomData, } impl DynamicKernelSource for MatmulTiling2Dvec4 { - fn source(&self) -> SourceTemplate { - MatmulTiling2Dvec4Raw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulTiling2Dvec4Raw::source() + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) + .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) + .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_z", "1".to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using tiling 2d algorithm with /// vec4 primitive on both lhs and rhs pub fn matmul_tiling_2d_vec4( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - let kernel = MatmulTiling2Dvec4::::new(); - matmul_tiling_2d_launch(lhs, rhs, out, kernel) + let kernel = MatmulTiling2Dvec4::::new(); + matmul_tiling_2d_launch(lhs, rhs, out, kernel) } #[cfg(test)] mod tests { - use super::*; - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - #[test] - pub fn test_matmul_vec4_primitive_straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_large() { - test_with_params(134, 242, 250, 1, 1); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let func = matmul_tiling_2d_vec4; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { - let matmul_func = matmul_tiling_2d_vec4; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { - let matmul_func = matmul_tiling_2d_vec4; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_tiling_2d_vec4; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_vec4_primitive_straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_large() { + test_with_params(134, 242, 250, 1, 1); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let func = matmul_tiling_2d_vec4; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { + let matmul_func = matmul_tiling_2d_vec4; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { + let matmul_func = matmul_tiling_2d_vec4; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_tiling_2d_vec4; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs b/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs index 3dd1a77861..e00db5ec89 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/vec4_lhs.rs @@ -1,9 +1,9 @@ use burn_tensor::Element; use crate::{ - element::WgpuElement, - kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::WgpuTensor, + element::WgpuElement, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, + tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -12,140 +12,140 @@ use crate::kernel_wgsl; use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE}; kernel_wgsl!( - MatmulTiling2DVec4LhsRaw, - "../../../template/matmul/blocktiling_2d/vec4_lhs.wgsl" + MatmulTiling2DVec4LhsRaw, + "../../../template/matmul/blocktiling_2d/vec4_lhs.wgsl" ); #[derive(new, Debug)] struct MatmulTiling2DVec4Lhs { - _elem: PhantomData, + _elem: PhantomData, } impl DynamicKernelSource for MatmulTiling2DVec4Lhs { - fn source(&self) -> SourceTemplate { - MatmulTiling2DVec4LhsRaw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn", (B_K * B_N).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + MatmulTiling2DVec4LhsRaw::source() + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) + .register("bk_x_bn", (B_K * B_N).to_string()) + .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_z", "1".to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } } /// Matrix multiplication using tiling 2d algorithm with /// vec4 primitive on lhs only pub fn matmul_tiling_2d_vec4_lhs( - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, ) -> WgpuTensor { - let kernel = MatmulTiling2DVec4Lhs::::new(); - matmul_tiling_2d_launch(lhs, rhs, out, kernel) + let kernel = MatmulTiling2DVec4Lhs::::new(); + matmul_tiling_2d_launch(lhs, rhs, out, kernel) } #[cfg(test)] mod tests { - use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; - - use super::matmul_tiling_2d_vec4_lhs; - - #[test] - pub fn test_matmul_vec4_primitive_straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn test_matmul_vec4_primitive_large() { - test_with_params(134, 242, 250, 1, 1); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let func = matmul_tiling_2d_vec4_lhs; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { - let matmul_func = matmul_tiling_2d_vec4_lhs; - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { - let matmul_func = matmul_tiling_2d_vec4_lhs; - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } - - #[test] - fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { - let matmul_func = matmul_tiling_2d_vec4_lhs; - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); - } + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + use super::matmul_tiling_2d_vec4_lhs; + + #[test] + pub fn test_matmul_vec4_primitive_straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn test_matmul_vec4_primitive_large() { + test_with_params(134, 242, 250, 1, 1); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let func = matmul_tiling_2d_vec4_lhs; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() { + let matmul_func = matmul_tiling_2d_vec4_lhs; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() { + let matmul_func = matmul_tiling_2d_vec4_lhs; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_tiling_2d_vec4_lhs; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } } diff --git a/burn-wgpu/src/kernel/matmul/tune/base.rs b/burn-wgpu/src/kernel/matmul/tune/base.rs index 5ae2c261ec..2699e233c4 100644 --- a/burn-wgpu/src/kernel/matmul/tune/base.rs +++ b/burn-wgpu/src/kernel/matmul/tune/base.rs @@ -2,11 +2,11 @@ use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; use crate::{ - compute::WgpuAutotuneKey, - element::WgpuElement, - kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use super::key::MatmulAutotuneKey; @@ -14,162 +14,162 @@ use super::key::MatmulAutotuneKey; /// Set of matmul implementations available for autotune /// Autotune key is given by concatenating the closest upper power of 2 of m, k and n pub struct MatmulAutotuneOperationSet { - key: WgpuAutotuneKey, - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, + key: WgpuAutotuneKey, + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, } impl MatmulAutotuneOperationSet { - fn new(lhs: WgpuTensor, rhs: WgpuTensor, out: WgpuTensor) -> Self { - Self { - key: WgpuAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)), - lhs, - rhs, - out, + fn new(lhs: WgpuTensor, rhs: WgpuTensor, out: WgpuTensor) -> Self { + Self { + key: WgpuAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)), + lhs, + rhs, + out, + } } - } } impl AutotuneOperationSet - for MatmulAutotuneOperationSet + for MatmulAutotuneOperationSet { - fn key(&self) -> WgpuAutotuneKey { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); - let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1); - let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1); - - let out = empty_device( - self.out.client.clone(), - self.out.device.clone(), - self.out.shape.clone(), - ); - - vec![ - Box::new(MemoryCoalescingMatmulDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(MemoryCoalescingMatmulW16x16::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(Vec4TilingMatmulDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(Vec4TilingMatmulUnpaddedDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - match fastest_index { - 0 => Box::new(MemoryCoalescingMatmulDefault::::new( - self.lhs, self.rhs, self.out, - )), - 1 => Box::new(MemoryCoalescingMatmulW16x16::::new( - self.lhs, self.rhs, self.out, - )), - 2 => Box::new(Vec4TilingMatmulDefault::::new( - self.lhs, self.rhs, self.out, - )), - 3 => Box::new(Vec4TilingMatmulUnpaddedDefault::::new( - self.lhs, self.rhs, self.out, - )), - 4 => Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( - self.lhs, self.rhs, self.out, - )), - _ => panic!("Fastest index is out of bound"), + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1); + let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1); + + let out = empty_device( + self.out.client.clone(), + self.out.device.clone(), + self.out.shape.clone(), + ); + + vec![ + Box::new(MemoryCoalescingMatmulDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(MemoryCoalescingMatmulW16x16::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(Vec4TilingMatmulDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(Vec4TilingMatmulUnpaddedDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + match fastest_index { + 0 => Box::new(MemoryCoalescingMatmulDefault::::new( + self.lhs, self.rhs, self.out, + )), + 1 => Box::new(MemoryCoalescingMatmulW16x16::::new( + self.lhs, self.rhs, self.out, + )), + 2 => Box::new(Vec4TilingMatmulDefault::::new( + self.lhs, self.rhs, self.out, + )), + 3 => Box::new(Vec4TilingMatmulUnpaddedDefault::::new( + self.lhs, self.rhs, self.out, + )), + 4 => Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( + self.lhs, self.rhs, self.out, + )), + _ => panic!("Fastest index is out of bound"), + } } - } } /// Executes autotune on matmul operations pub fn matmul_autotune( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - let client = lhs.client.clone(); + let client = lhs.client.clone(); - let output = init_matmul_output(&lhs, &rhs); + let output = init_matmul_output(&lhs, &rhs); - let operation_set = Box::new(MatmulAutotuneOperationSet::::new( - lhs, - rhs, - output.clone(), - )); + let operation_set = Box::new(MatmulAutotuneOperationSet::::new( + lhs, + rhs, + output.clone(), + )); - client.execute_autotune(operation_set); + client.execute_autotune(operation_set); - output + output } macro_rules! matmul_tune_ops { - ($name:ident, $func:expr) => { - #[derive(new)] - pub(crate) struct $name { - lhs: WgpuTensor, - rhs: WgpuTensor, - out: WgpuTensor, - } - - impl AutotuneOperation for $name { - fn execute(self: Box) { - #[allow(clippy::redundant_closure_call)] - $func(self.lhs, self.rhs, self.out); - } - - fn clone(&self) -> Box { - Box::new(Self { - lhs: self.lhs.clone(), - rhs: self.rhs.clone(), - out: self.out.clone(), - }) - } - } - }; + ($name:ident, $func:expr) => { + #[derive(new)] + pub(crate) struct $name { + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, + } + + impl AutotuneOperation for $name { + fn execute(self: Box) { + #[allow(clippy::redundant_closure_call)] + $func(self.lhs, self.rhs, self.out); + } + + fn clone(&self) -> Box { + Box::new(Self { + lhs: self.lhs.clone(), + rhs: self.rhs.clone(), + out: self.out.clone(), + }) + } + } + }; } // Potentially better for small matrices. matmul_tune_ops!( - MemoryCoalescingMatmulDefault, - crate::kernel::matmul::matmul_mem_coalescing_default + MemoryCoalescingMatmulDefault, + crate::kernel::matmul::matmul_mem_coalescing_default ); // Potentially better for small matrices. matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| { - crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16) + crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16) }); // Maybe the fastest on MacOS. matmul_tune_ops!( - Vec4LhsOnlyTilingMatmulDefault, - crate::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs + Vec4LhsOnlyTilingMatmulDefault, + crate::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs ); // Probably the fastest when fixed sizes. matmul_tune_ops!( - Vec4TilingMatmulDefault, - crate::kernel::matmul::vec4::matmul_tiling_2d_vec4 + Vec4TilingMatmulDefault, + crate::kernel::matmul::vec4::matmul_tiling_2d_vec4 ); // Probably the fastest otherwise. matmul_tune_ops!( - Vec4TilingMatmulUnpaddedDefault, - crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded + Vec4TilingMatmulUnpaddedDefault, + crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded ); diff --git a/burn-wgpu/src/kernel/matmul/tune/key.rs b/burn-wgpu/src/kernel/matmul/tune/key.rs index 48d7655f50..37f619dde1 100644 --- a/burn-wgpu/src/kernel/matmul/tune/key.rs +++ b/burn-wgpu/src/kernel/matmul/tune/key.rs @@ -1,119 +1,119 @@ use burn_tensor::Shape; use core::fmt::Debug; use std::{ - cmp::{max, min}, - fmt::Display, - hash::Hash, + cmp::{max, min}, + fmt::Display, + hash::Hash, }; #[derive(Hash, Eq, PartialEq, Debug, Clone)] /// Autotune key representative of matmul versions pub struct MatmulAutotuneKey { - round: bool, // True when all matmul dims are multiples of 64 - broadcast: bool, // True when there are differences in batch size - anchored_m: usize, - anchored_k: usize, - anchored_n: usize, - anchored_batch: usize, + round: bool, // True when all matmul dims are multiples of 64 + broadcast: bool, // True when there are differences in batch size + anchored_m: usize, + anchored_k: usize, + anchored_n: usize, + anchored_batch: usize, } impl Display for MatmulAutotuneKey { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - format!( - "Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}", - self.round, - self.broadcast, - self.anchored_m, - self.anchored_k, - self.anchored_n, - self.anchored_batch - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + format!( + "Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}", + self.round, + self.broadcast, + self.anchored_m, + self.anchored_k, + self.anchored_n, + self.anchored_batch + ) + .as_str(), + ) + } } impl MatmulAutotuneKey { - /// Create a matmul autotune key from the input shapes - pub fn new(lhs_shape: &Shape, rhs_shape: &Shape) -> Self { - let m = lhs_shape.dims[D - 2]; - let k = lhs_shape.dims[D - 1]; - let n = rhs_shape.dims[D - 1]; + /// Create a matmul autotune key from the input shapes + pub fn new(lhs_shape: &Shape, rhs_shape: &Shape) -> Self { + let m = lhs_shape.dims[D - 2]; + let k = lhs_shape.dims[D - 1]; + let n = rhs_shape.dims[D - 1]; - let mut broadcast = false; - let mut batch_product_lhs = 1; - let mut batch_product_rhs = 1; + let mut broadcast = false; + let mut batch_product_lhs = 1; + let mut batch_product_rhs = 1; - for b in 0..D - 2 { - batch_product_lhs *= lhs_shape.dims[b]; - batch_product_rhs *= rhs_shape.dims[b]; - if lhs_shape.dims[b] != rhs_shape.dims[b] { - broadcast = true; - } - } - let batch_product = max(batch_product_lhs, batch_product_rhs); + for b in 0..D - 2 { + batch_product_lhs *= lhs_shape.dims[b]; + batch_product_rhs *= rhs_shape.dims[b]; + if lhs_shape.dims[b] != rhs_shape.dims[b] { + broadcast = true; + } + } + let batch_product = max(batch_product_lhs, batch_product_rhs); - let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0; + let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0; - Self { - round, - broadcast, - anchored_m: anchor(m, None), - anchored_k: anchor(k, None), - anchored_n: anchor(n, None), - anchored_batch: anchor(batch_product, Some(256)), + Self { + round, + broadcast, + anchored_m: anchor(m, None), + anchored_k: anchor(k, None), + anchored_n: anchor(n, None), + anchored_batch: anchor(batch_product, Some(256)), + } } - } } fn anchor(x: usize, max: Option) -> usize { - let exp = f32::ceil(f32::log2(x as f32)) as u32; - let power_of_2 = 2_u32.pow(exp) as usize; - if let Some(max) = max { - min(power_of_2, max) - } else { - power_of_2 - } + let exp = f32::ceil(f32::log2(x as f32)) as u32; + let power_of_2 = 2_u32.pow(exp) as usize; + if let Some(max) = max { + min(power_of_2, max) + } else { + power_of_2 + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn matmul_autotune_key_all_same_and_round() { - let lhs_shape: Shape<3> = [4, 512, 512].into(); - let rhs_shape: Shape<3> = [4, 512, 512].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + #[test] + fn matmul_autotune_key_all_same_and_round() { + let lhs_shape: Shape<3> = [4, 512, 512].into(); + let rhs_shape: Shape<3> = [4, 512, 512].into(); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); - assert!(key.round); - assert!(!key.broadcast); - assert!(key.anchored_m == 512); - assert!(key.anchored_k == 512); - assert!(key.anchored_n == 512); - } + assert!(key.round); + assert!(!key.broadcast); + assert!(key.anchored_m == 512); + assert!(key.anchored_k == 512); + assert!(key.anchored_n == 512); + } - #[test] - fn matmul_autotune_key_all_different() { - let lhs_shape: Shape<4> = [2, 3, 511, 512].into(); - let rhs_shape: Shape<4> = [3, 2, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + #[test] + fn matmul_autotune_key_all_different() { + let lhs_shape: Shape<4> = [2, 3, 511, 512].into(); + let rhs_shape: Shape<4> = [3, 2, 512, 513].into(); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); - assert!(!key.round); - assert!(key.broadcast); - assert!(key.anchored_m == 512); - assert!(key.anchored_k == 512); - assert!(key.anchored_n == 1024); - assert!(key.anchored_batch == 8); - } + assert!(!key.round); + assert!(key.broadcast); + assert!(key.anchored_m == 512); + assert!(key.anchored_k == 512); + assert!(key.anchored_n == 1024); + assert!(key.anchored_batch == 8); + } - #[test] - fn matmul_autotune_key_large_batch() { - let lhs_shape: Shape<4> = [128, 512, 511, 512].into(); - let rhs_shape: Shape<4> = [200, 400, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + #[test] + fn matmul_autotune_key_large_batch() { + let lhs_shape: Shape<4> = [128, 512, 511, 512].into(); + let rhs_shape: Shape<4> = [200, 400, 512, 513].into(); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); - assert!(key.anchored_batch == 256); - } + assert!(key.anchored_batch == 256); + } } diff --git a/burn-wgpu/src/kernel/matmul/utils.rs b/burn-wgpu/src/kernel/matmul/utils.rs index d9b4fee4da..8a791ba359 100644 --- a/burn-wgpu/src/kernel/matmul/utils.rs +++ b/burn-wgpu/src/kernel/matmul/utils.rs @@ -3,86 +3,85 @@ use burn_tensor::Shape; /// Creates an empty output tensor with matmul output shape pub fn init_matmul_output( - lhs: &WgpuTensor, - rhs: &WgpuTensor, + lhs: &WgpuTensor, + rhs: &WgpuTensor, ) -> WgpuTensor { - empty_device(lhs.client.clone(), lhs.device.clone(), shape_out(lhs, rhs)) + empty_device(lhs.client.clone(), lhs.device.clone(), shape_out(lhs, rhs)) } pub(crate) fn shape_out( - lhs: &WgpuTensor, - rhs: &WgpuTensor, + lhs: &WgpuTensor, + rhs: &WgpuTensor, ) -> Shape { - let mut shape_out = [0; D]; - lhs - .shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - shape_out[D - 2] = lhs.shape.dims[D - 2]; - shape_out[D - 1] = rhs.shape.dims[D - 1]; - Shape::new(shape_out) + let mut shape_out = [0; D]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + shape_out[D - 2] = lhs.shape.dims[D - 2]; + shape_out[D - 1] = rhs.shape.dims[D - 1]; + Shape::new(shape_out) } #[cfg(test)] pub(crate) mod tests { - use crate::tensor::WgpuTensor; - use crate::tests::{ReferenceTensor, TestTensor}; - use burn_tensor::Shape; + use crate::tensor::WgpuTensor; + use crate::tests::{ReferenceTensor, TestTensor}; + use burn_tensor::Shape; - use super::init_matmul_output; + use super::init_matmul_output; - pub(crate) fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) - where - F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, - S: Into>, - { - let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + pub(crate) fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) + where + F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, + S: Into>, + { + let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let x_wgpu = TestTensor::from_data(x.to_data()).into_primitive(); - let y_wgpu = TestTensor::from_data(y.to_data()).into_primitive(); + let x_wgpu = TestTensor::from_data(x.to_data()).into_primitive(); + let y_wgpu = TestTensor::from_data(y.to_data()).into_primitive(); - let z_reference = x.matmul(y); + let z_reference = x.matmul(y); - let out = init_matmul_output(&x_wgpu, &y_wgpu); - let z = func(x_wgpu, y_wgpu, out); - let z = TestTensor::from_primitive(z); + let out = init_matmul_output(&x_wgpu, &y_wgpu); + let z = func(x_wgpu, y_wgpu, out); + let z = TestTensor::from_primitive(z); - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } - pub(crate) fn same_as_reference_swapped_dims( - func: F, - swap_lhs: [usize; 2], - swap_rhs: [usize; 2], - shape_lhs: S, - shape_rhs: S, - ) where - F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, - S: Into>, - { - let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + pub(crate) fn same_as_reference_swapped_dims( + func: F, + swap_lhs: [usize; 2], + swap_rhs: [usize; 2], + shape_lhs: S, + shape_rhs: S, + ) where + F: Fn(WgpuTensor, WgpuTensor, WgpuTensor) -> WgpuTensor, + S: Into>, + { + let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let x_wgpu = TestTensor::from_data(x.to_data()).swap_dims(swap_lhs[0], swap_lhs[1]); - let y_wgpu = TestTensor::from_data(y.to_data()).swap_dims(swap_rhs[0], swap_rhs[1]); + let x_wgpu = TestTensor::from_data(x.to_data()).swap_dims(swap_lhs[0], swap_lhs[1]); + let y_wgpu = TestTensor::from_data(y.to_data()).swap_dims(swap_rhs[0], swap_rhs[1]); - let z_reference = x - .swap_dims(swap_lhs[0], swap_lhs[1]) - .matmul(y.swap_dims(swap_rhs[0], swap_rhs[1])); + let z_reference = x + .swap_dims(swap_lhs[0], swap_lhs[1]) + .matmul(y.swap_dims(swap_rhs[0], swap_rhs[1])); - let out = init_matmul_output( - &x_wgpu.clone().into_primitive(), - &y_wgpu.clone().into_primitive(), - ); - let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive(), out); - let z = TestTensor::from_primitive(z); + let out = init_matmul_output( + &x_wgpu.clone().into_primitive(), + &y_wgpu.clone().into_primitive(), + ); + let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive(), out); + let z = TestTensor::from_primitive(z); - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs b/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs index 20162ffab8..3da4c74776 100644 --- a/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs @@ -1,95 +1,95 @@ use crate::{ - compute::{StaticKernel, WgpuHandle}, - element::WgpuElement, - kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::{StaticKernel, WgpuHandle}, + element::WgpuElement, + kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; use burn_tensor::Shape; kernel_wgsl!( - AdaptiveAvgPool2d, - "../../template/pool/adaptive_avg_pool2d.wgsl" + AdaptiveAvgPool2d, + "../../template/pool/adaptive_avg_pool2d.wgsl" ); kernel_wgsl!( - AdaptiveAvgPool2dBackward, - "../../template/pool/adaptive_avg_pool2d_backward.wgsl" + AdaptiveAvgPool2dBackward, + "../../template/pool/adaptive_avg_pool2d_backward.wgsl" ); pub(crate) fn adaptive_avg_pool2d( - x: WgpuTensor, - output_size: [usize; 2], + x: WgpuTensor, + output_size: [usize; 2], ) -> WgpuTensor { - let [batch_size, channels, _, _] = x.shape.dims; + let [batch_size, channels, _, _] = x.shape.dims; - let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); - let output = empty_device(x.client.clone(), x.device.clone(), output_shape); + let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); + let output = empty_device(x.client.clone(), x.device.clone(), output_shape); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); - let info_handle = build_info(&x, &output); - x.client - .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); + let info_handle = build_info(&x, &output); + x.client + .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); - output + output } pub(crate) fn adaptive_avg_pool2d_backward( - x: WgpuTensor, - out_grad: WgpuTensor, + x: WgpuTensor, + out_grad: WgpuTensor, ) -> WgpuTensor { - let output_shape = x.shape.clone(); - let num_elems = output_shape.num_elements(); - let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - x.client.clone(), - x.device.clone(), - output_shape, - output_buffer, - ); + let output_shape = x.shape.clone(); + let num_elems = output_shape.num_elements(); + let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + x.client.clone(), + x.device.clone(), + output_shape, + output_buffer, + ); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); - let info_handle = build_info(&x, &out_grad); + let info_handle = build_info(&x, &out_grad); - x.client.execute( - Box::new(kernel), - &[&out_grad.handle, &output.handle, &info_handle], - ); + x.client.execute( + Box::new(kernel), + &[&out_grad.handle, &output.handle, &info_handle], + ); - output + output } fn build_info(x: &WgpuTensor, output: &WgpuTensor) -> WgpuHandle { - let mut info: [u32; 16] = [0; 16]; - info[0] = x.strides[0] as u32; - info[1] = x.strides[1] as u32; - info[2] = x.strides[2] as u32; - info[3] = x.strides[3] as u32; - info[4] = x.shape.dims[0] as u32; - info[5] = x.shape.dims[1] as u32; - info[6] = x.shape.dims[2] as u32; - info[7] = x.shape.dims[3] as u32; + let mut info: [u32; 16] = [0; 16]; + info[0] = x.strides[0] as u32; + info[1] = x.strides[1] as u32; + info[2] = x.strides[2] as u32; + info[3] = x.strides[3] as u32; + info[4] = x.shape.dims[0] as u32; + info[5] = x.shape.dims[1] as u32; + info[6] = x.shape.dims[2] as u32; + info[7] = x.shape.dims[3] as u32; - info[8] = output.strides[0] as u32; - info[9] = output.strides[1] as u32; - info[10] = output.strides[2] as u32; - info[11] = output.strides[3] as u32; - info[12] = output.shape.dims[0] as u32; - info[13] = output.shape.dims[1] as u32; - info[14] = output.shape.dims[2] as u32; - info[15] = output.shape.dims[3] as u32; + info[8] = output.strides[0] as u32; + info[9] = output.strides[1] as u32; + info[10] = output.strides[2] as u32; + info[11] = output.strides[3] as u32; + info[12] = output.shape.dims[0] as u32; + info[13] = output.shape.dims[1] as u32; + info[14] = output.shape.dims[2] as u32; + info[15] = output.shape.dims[3] as u32; - output.client.create(bytemuck::cast_slice(&info)) + output.client.create(bytemuck::cast_slice(&info)) } diff --git a/burn-wgpu/src/kernel/pool/avg_pool2d.rs b/burn-wgpu/src/kernel/pool/avg_pool2d.rs index 85f2e44f38..05a5d840d3 100644 --- a/burn-wgpu/src/kernel/pool/avg_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/avg_pool2d.rs @@ -1,154 +1,169 @@ use crate::{ - compute::{Kernel, StaticKernel}, - element::WgpuElement, - kernel::{ - self, elemwise_workgroup, - pool::{build_output_and_info_pool2d, build_pool2d_info}, - KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, - }, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::{Kernel, StaticKernel}, + element::WgpuElement, + kernel::{ + self, elemwise_workgroup, + pool::{build_output_and_info_pool2d, build_pool2d_info}, + KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT, + }, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(AvgPool2dRaw, "../../template/pool/avg_pool2d.wgsl"); kernel_wgsl!( - AvgPool2dBackwardRaw, - "../../template/pool/avg_pool2d_backward.wgsl" + AvgPool2dBackwardRaw, + "../../template/pool/avg_pool2d_backward.wgsl" ); struct AvgPool2dBackward; struct AvgPool2d; impl StaticKernelSource for AvgPool2dBackward { - fn source() -> kernel::SourceTemplate { - AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) - } + fn source() -> kernel::SourceTemplate { + AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) + } } impl StaticKernelSource for AvgPool2d { - fn source() -> kernel::SourceTemplate { - AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) - } + fn source() -> kernel::SourceTemplate { + AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) + } } pub(crate) fn avg_pool2d( - x: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> WgpuTensor { - let (info_handle, output) = - build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]); - - let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); - let kernel: Box = match count_include_pad { - true => Box::new(StaticKernel::< - KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(workgroup)), - false => Box::new(StaticKernel::< - KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(workgroup)), - }; - - x.client - .execute(kernel, &[&x.handle, &output.handle, &info_handle]); - - output + let (info_handle, output) = + build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]); + + let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); + let kernel: Box = match count_include_pad { + true => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, + >::new(workgroup)), + false => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, + >::new(workgroup)), + }; + + x.client + .execute(kernel, &[&x.handle, &output.handle, &info_handle]); + + output } pub(crate) fn avg_pool2d_backward( - x: WgpuTensor, - grad: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, + x: WgpuTensor, + grad: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, ) -> WgpuTensor { - let grad = kernel::into_contiguous(grad); - let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone()); - let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]); - let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); - - let kernel: Box = match count_include_pad { - true => Box::new(StaticKernel::< - KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(workgroup)), - false => Box::new(StaticKernel::< - KernelSettings, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(workgroup)), - }; - - x.client - .execute(kernel, &[&grad.handle, &output.handle, &info_handle]); - - output + let grad = kernel::into_contiguous(grad); + let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone()); + let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]); + let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT); + + let kernel: Box = match count_include_pad { + true => Box::new(StaticKernel::< + KernelSettings< + AvgPool2dBackward, + E, + i32, + WORKGROUP_DEFAULT, + WORKGROUP_DEFAULT, + 1, + >, + >::new(workgroup)), + false => Box::new(StaticKernel::< + KernelSettings< + AvgPool2dBackward, + E, + i32, + WORKGROUP_DEFAULT, + WORKGROUP_DEFAULT, + 1, + >, + >::new(workgroup)), + }; + + x.client + .execute(kernel, &[&grad.handle, &output.handle, &info_handle]); + + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{backend::Backend, module, ops::ModuleOps, Distribution, Tensor}; - - #[test] - fn avg_pool2d_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 4]; - let stride = [1, 2]; - let padding = [1, 2]; - let count_include_pad = true; - - let pooled = module::avg_pool2d(tensor, kernel_size, stride, padding, count_include_pad); - let pooled_ref = - module::avg_pool2d(tensor_ref, kernel_size, stride, padding, count_include_pad); - - pooled - .into_data() - .assert_approx_eq(&pooled_ref.into_data(), 3); - } - - #[test] - fn avg_pool2d_backward_should_work_with_multiple_invocations() { - TestBackend::seed(0); - ReferenceBackend::seed(0); - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 3]; - let stride = [1, 1]; - let padding = [1, 1]; - let count_include_pad = true; - - let shape_out = module::avg_pool2d( - tensor.clone(), - kernel_size, - stride, - padding, - count_include_pad, - ) - .shape(); - let grad_output = Tensor::::random(shape_out, Distribution::Default); - let grad_output_ref = Tensor::::from_data(grad_output.to_data()); - - let grad: Tensor = Tensor::from_primitive(TestBackend::avg_pool2d_backward( - tensor.into_primitive(), - grad_output.into_primitive(), - kernel_size, - stride, - padding, - count_include_pad, - )); - let grad_ref: Tensor = - Tensor::from_primitive(ReferenceBackend::avg_pool2d_backward( - tensor_ref.into_primitive(), - grad_output_ref.into_primitive(), - kernel_size, - stride, - padding, - count_include_pad, - )); - - grad.into_data().assert_approx_eq(&grad_ref.into_data(), 3); - } + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{backend::Backend, module, ops::ModuleOps, Distribution, Tensor}; + + #[test] + fn avg_pool2d_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 4]; + let stride = [1, 2]; + let padding = [1, 2]; + let count_include_pad = true; + + let pooled = module::avg_pool2d(tensor, kernel_size, stride, padding, count_include_pad); + let pooled_ref = + module::avg_pool2d(tensor_ref, kernel_size, stride, padding, count_include_pad); + + pooled + .into_data() + .assert_approx_eq(&pooled_ref.into_data(), 3); + } + + #[test] + fn avg_pool2d_backward_should_work_with_multiple_invocations() { + TestBackend::seed(0); + ReferenceBackend::seed(0); + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 3]; + let stride = [1, 1]; + let padding = [1, 1]; + let count_include_pad = true; + + let shape_out = module::avg_pool2d( + tensor.clone(), + kernel_size, + stride, + padding, + count_include_pad, + ) + .shape(); + let grad_output = Tensor::::random(shape_out, Distribution::Default); + let grad_output_ref = Tensor::::from_data(grad_output.to_data()); + + let grad: Tensor = + Tensor::from_primitive(TestBackend::avg_pool2d_backward( + tensor.into_primitive(), + grad_output.into_primitive(), + kernel_size, + stride, + padding, + count_include_pad, + )); + let grad_ref: Tensor = + Tensor::from_primitive(ReferenceBackend::avg_pool2d_backward( + tensor_ref.into_primitive(), + grad_output_ref.into_primitive(), + kernel_size, + stride, + padding, + count_include_pad, + )); + + grad.into_data().assert_approx_eq(&grad_ref.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/pool/base.rs b/burn-wgpu/src/kernel/pool/base.rs index 9e48b21326..13a16d6ec6 100644 --- a/burn-wgpu/src/kernel/pool/base.rs +++ b/burn-wgpu/src/kernel/pool/base.rs @@ -1,72 +1,73 @@ use crate::{ - compute::WgpuHandle, element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor, + compute::WgpuHandle, element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor, }; use burn_tensor::Shape; /// Build basic info to launch pool 2d kernels. pub fn build_output_and_info_pool2d( - x: &WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: &WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (WgpuHandle, WgpuTensor) { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape.dims; + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape.dims; - let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) - / stride_height) - + 1; - let out_width = - ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; - let shape_out = Shape::new([batch_size, channels, out_height, out_width]); - let output = empty_device(x.client.clone(), x.device.clone(), shape_out); + let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) + / stride_height) + + 1; + let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) + / stride_width) + + 1; + let shape_out = Shape::new([batch_size, channels, out_height, out_width]); + let output = empty_device(x.client.clone(), x.device.clone(), shape_out); - let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation); + let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation); - (info_buffer, output) + (info_buffer, output) } pub fn build_pool2d_info( - input: &WgpuTensor, - output: &WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + input: &WgpuTensor, + output: &WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> WgpuHandle { - let mut info: [u32; 24] = [0; 24]; - info[0] = input.strides[0] as u32; - info[1] = input.strides[1] as u32; - info[2] = input.strides[2] as u32; - info[3] = input.strides[3] as u32; - info[4] = input.shape.dims[0] as u32; - info[5] = input.shape.dims[1] as u32; - info[6] = input.shape.dims[2] as u32; - info[7] = input.shape.dims[3] as u32; + let mut info: [u32; 24] = [0; 24]; + info[0] = input.strides[0] as u32; + info[1] = input.strides[1] as u32; + info[2] = input.strides[2] as u32; + info[3] = input.strides[3] as u32; + info[4] = input.shape.dims[0] as u32; + info[5] = input.shape.dims[1] as u32; + info[6] = input.shape.dims[2] as u32; + info[7] = input.shape.dims[3] as u32; - info[8] = output.strides[0] as u32; - info[9] = output.strides[1] as u32; - info[10] = output.strides[2] as u32; - info[11] = output.strides[3] as u32; - info[12] = output.shape.dims[0] as u32; - info[13] = output.shape.dims[1] as u32; - info[14] = output.shape.dims[2] as u32; - info[15] = output.shape.dims[3] as u32; + info[8] = output.strides[0] as u32; + info[9] = output.strides[1] as u32; + info[10] = output.strides[2] as u32; + info[11] = output.strides[3] as u32; + info[12] = output.shape.dims[0] as u32; + info[13] = output.shape.dims[1] as u32; + info[14] = output.shape.dims[2] as u32; + info[15] = output.shape.dims[3] as u32; - info[16] = kernel_size[0] as u32; - info[17] = kernel_size[1] as u32; - info[18] = stride[0] as u32; - info[19] = stride[1] as u32; - info[20] = padding[0] as u32; - info[21] = padding[1] as u32; - info[22] = dilation[0] as u32; - info[23] = dilation[1] as u32; + info[16] = kernel_size[0] as u32; + info[17] = kernel_size[1] as u32; + info[18] = stride[0] as u32; + info[19] = stride[1] as u32; + info[20] = padding[0] as u32; + info[21] = padding[1] as u32; + info[22] = dilation[0] as u32; + info[23] = dilation[1] as u32; - let info_buffer = input.client.create(bytemuck::cast_slice(&info)); + let info_buffer = input.client.create(bytemuck::cast_slice(&info)); - info_buffer + info_buffer } diff --git a/burn-wgpu/src/kernel/pool/max_pool2d.rs b/burn-wgpu/src/kernel/pool/max_pool2d.rs index e06588755b..77ce5d998b 100644 --- a/burn-wgpu/src/kernel/pool/max_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/max_pool2d.rs @@ -1,189 +1,194 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{ - self, elemwise_workgroup, - pool::{build_output_and_info_pool2d, build_pool2d_info}, - KernelSettings, WORKGROUP_DEFAULT, - }, - kernel_wgsl, - ops::numeric::empty_device, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{ + self, elemwise_workgroup, + pool::{build_output_and_info_pool2d, build_pool2d_info}, + KernelSettings, WORKGROUP_DEFAULT, + }, + kernel_wgsl, + ops::numeric::empty_device, + tensor::WgpuTensor, }; kernel_wgsl!(MaxPool2d, "../../template/pool/max_pool2d.wgsl"); kernel_wgsl!( - MaxPool2dWithIndicesBackward, - "../../template/pool/max_pool2d_with_indices_backward.wgsl" + MaxPool2dWithIndicesBackward, + "../../template/pool/max_pool2d_with_indices_backward.wgsl" ); kernel_wgsl!( - MaxPool2dWithIndices, - "../../template/pool/max_pool2d_with_indices.wgsl" + MaxPool2dWithIndices, + "../../template/pool/max_pool2d_with_indices.wgsl" ); pub(crate) fn max_pool2d( - x: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> WgpuTensor { - let (info_handle, output) = - build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - x.client - .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); - - output + let (info_handle, output) = + build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + + x.client + .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); + + output } pub(crate) fn max_pool2d_with_indices( - x: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> (WgpuTensor, WgpuTensor) { - let (info_handle, output) = - build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); - let indices = empty_device(x.client.clone(), x.device, output.shape.clone()); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - x.client.execute( - Box::new(kernel), - &[&x.handle, &output.handle, &indices.handle, &info_handle], - ); - - (output, indices) + let (info_handle, output) = + build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); + let indices = empty_device(x.client.clone(), x.device, output.shape.clone()); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + + x.client.execute( + Box::new(kernel), + &[&x.handle, &output.handle, &indices.handle, &info_handle], + ); + + (output, indices) } pub(crate) fn max_pool2d_with_indices_backward( - x: WgpuTensor, - grad: WgpuTensor, - indices: WgpuTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], + x: WgpuTensor, + grad: WgpuTensor, + indices: WgpuTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], ) -> WgpuTensor { - let grad = kernel::into_contiguous(grad); - let indices = kernel::into_contiguous(indices); - - let num_elems = x.shape.num_elements(); - let buffer = x.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer); - - let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation); - - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - )); - - x.client.execute( - Box::new(kernel), - &[&indices.handle, &grad.handle, &output.handle, &info_handle], - ); - output + let grad = kernel::into_contiguous(grad); + let indices = kernel::into_contiguous(indices); + + let num_elems = x.shape.num_elements(); + let buffer = x.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer); + + let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup( + output.shape.num_elements(), + WORKGROUP_DEFAULT, + )); + + x.client.execute( + Box::new(kernel), + &[&indices.handle, &grad.handle, &output.handle, &info_handle], + ); + output } #[cfg(test)] mod tests { - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor}; - - #[test] - pub fn max_pool2d_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 3]; - let stride = [2, 2]; - let padding = [1, 1]; - let dilation = [1, 1]; - - let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation); - let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation); - - pooled - .into_data() - .assert_approx_eq(&pooled_ref.into_data(), 3); - } - - #[test] - pub fn max_pool2d_with_indices_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let kernel_size = [3, 3]; - let stride = [2, 2]; - let padding = [1, 1]; - let dilation = [1, 1]; - - let (pooled, indices) = - module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation); - let (pooled_ref, indices_ref) = - module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation); - - pooled - .into_data() - .assert_approx_eq(&pooled_ref.into_data(), 3); - assert_eq!(indices.into_data(), indices_ref.into_data().convert()); - } - - #[test] - pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); - let grad_output = Tensor::::random([32, 32, 16, 16], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let grad_output_ref = Tensor::::from_data(grad_output.to_data()); - let kernel_size = [3, 3]; - let stride = [2, 2]; - let padding = [1, 1]; - let dilation = [1, 1]; - - let (_, indices) = - module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation); - let (_, indices_ref) = - module::max_pool2d_with_indices(tensor_ref.clone(), kernel_size, stride, padding, dilation); - let grad = TestBackend::max_pool2d_with_indices_backward( - tensor.into_primitive(), - kernel_size, - stride, - padding, - dilation, - grad_output.into_primitive(), - indices.into_primitive(), - ) - .x_grad; - let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward( - tensor_ref.into_primitive(), - kernel_size, - stride, - padding, - dilation, - grad_output_ref.into_primitive(), - indices_ref.into_primitive(), - ) - .x_grad; - - Tensor::::from_primitive(grad) - .into_data() - .assert_approx_eq( - &Tensor::::from_primitive(grad_ref).into_data(), - 3, - ); - } + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor}; + + #[test] + pub fn max_pool2d_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 3]; + let stride = [2, 2]; + let padding = [1, 1]; + let dilation = [1, 1]; + + let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation); + let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation); + + pooled + .into_data() + .assert_approx_eq(&pooled_ref.into_data(), 3); + } + + #[test] + pub fn max_pool2d_with_indices_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let kernel_size = [3, 3]; + let stride = [2, 2]; + let padding = [1, 1]; + let dilation = [1, 1]; + + let (pooled, indices) = + module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation); + let (pooled_ref, indices_ref) = + module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation); + + pooled + .into_data() + .assert_approx_eq(&pooled_ref.into_data(), 3); + assert_eq!(indices.into_data(), indices_ref.into_data().convert()); + } + + #[test] + pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default); + let grad_output = Tensor::::random([32, 32, 16, 16], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let grad_output_ref = Tensor::::from_data(grad_output.to_data()); + let kernel_size = [3, 3]; + let stride = [2, 2]; + let padding = [1, 1]; + let dilation = [1, 1]; + + let (_, indices) = + module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation); + let (_, indices_ref) = module::max_pool2d_with_indices( + tensor_ref.clone(), + kernel_size, + stride, + padding, + dilation, + ); + let grad = TestBackend::max_pool2d_with_indices_backward( + tensor.into_primitive(), + kernel_size, + stride, + padding, + dilation, + grad_output.into_primitive(), + indices.into_primitive(), + ) + .x_grad; + let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward( + tensor_ref.into_primitive(), + kernel_size, + stride, + padding, + dilation, + grad_output_ref.into_primitive(), + indices_ref.into_primitive(), + ) + .x_grad; + + Tensor::::from_primitive(grad) + .into_data() + .assert_approx_eq( + &Tensor::::from_primitive(grad_ref).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/kernel/prng/base.rs b/burn-wgpu/src/kernel/prng/base.rs index fd5338bdc4..c0478db0e5 100644 --- a/burn-wgpu/src/kernel/prng/base.rs +++ b/burn-wgpu/src/kernel/prng/base.rs @@ -1,7 +1,7 @@ use crate::{ - compute::{WgpuComputeClient, WgpuHandle}, - element::WgpuElement, - kernel_wgsl, SEED, + compute::{WgpuComputeClient, WgpuHandle}, + element::WgpuElement, + kernel_wgsl, SEED, }; use burn_common::rand::get_seeded_rng; use rand::Rng; @@ -9,86 +9,86 @@ use rand::Rng; kernel_wgsl!(Prng, "../../template/prng/prng.wgsl"); pub(crate) fn get_seeds() -> Vec { - let mut seed = SEED.lock().unwrap(); - let mut rng = match seed.as_ref() { - Some(rng_seeded) => rng_seeded.clone(), - None => get_seeded_rng(), - }; - let mut seeds: Vec = Vec::with_capacity(4); - for _ in 0..4 { - seeds.push(rng.gen()); - } - *seed = Some(rng); - seeds + let mut seed = SEED.lock().unwrap(); + let mut rng = match seed.as_ref() { + Some(rng_seeded) => rng_seeded.clone(), + None => get_seeded_rng(), + }; + let mut seeds: Vec = Vec::with_capacity(4); + for _ in 0..4 { + seeds.push(rng.gen()); + } + *seed = Some(rng); + seeds } pub(crate) fn make_info_buffer( - client: WgpuComputeClient, - n_values_per_thread: usize, + client: WgpuComputeClient, + n_values_per_thread: usize, ) -> WgpuHandle { - let mut info = get_seeds(); - info.insert(0, n_values_per_thread as u32); - client.create(bytemuck::cast_slice(&info)) + let mut info = get_seeds(); + info.insert(0, n_values_per_thread as u32); + client.create(bytemuck::cast_slice(&info)) } pub(crate) fn make_args_buffer( - client: WgpuComputeClient, - args: &[E], + client: WgpuComputeClient, + args: &[E], ) -> WgpuHandle { - client.create(E::as_bytes(args)) + client.create(E::as_bytes(args)) } #[cfg(test)] pub mod tests { - use burn_tensor::Element; + use burn_tensor::Element; - #[derive(Default, Copy, Clone)] - pub struct BinStats { - pub count: usize, - pub n_runs: usize, // Number of sequences of same bin - } + #[derive(Default, Copy, Clone)] + pub struct BinStats { + pub count: usize, + pub n_runs: usize, // Number of sequences of same bin + } - pub fn calculate_bin_stats( - numbers: Vec, - number_of_bins: usize, - low: f32, - high: f32, - ) -> Vec { - let range = (high - low) / number_of_bins as f32; - let mut output: Vec = (0..number_of_bins).map(|_| Default::default()).collect(); - let mut initialized = false; - let mut current_runs = number_of_bins; // impossible value for starting point - for number in numbers { - let num = number.elem::(); - if num < low || num > high { - continue; - } - let index = f32::floor((num - low) / range) as usize; - output[index].count += 1; - if initialized && index != current_runs { + pub fn calculate_bin_stats( + numbers: Vec, + number_of_bins: usize, + low: f32, + high: f32, + ) -> Vec { + let range = (high - low) / number_of_bins as f32; + let mut output: Vec = (0..number_of_bins).map(|_| Default::default()).collect(); + let mut initialized = false; + let mut current_runs = number_of_bins; // impossible value for starting point + for number in numbers { + let num = number.elem::(); + if num < low || num > high { + continue; + } + let index = f32::floor((num - low) / range) as usize; + output[index].count += 1; + if initialized && index != current_runs { + output[current_runs].n_runs += 1; + } + initialized = true; + current_runs = index; + } output[current_runs].n_runs += 1; - } - initialized = true; - current_runs = index; + output } - output[current_runs].n_runs += 1; - output - } - #[test] - fn test_count_bins() { - let numbers = vec![0., 1., 1.5, 2., 2.5, 3., 2.5, 1.5, 3.5]; - let number_of_bins = 4; - let low = 0.; - let high = 4.; - let stats = calculate_bin_stats(numbers, number_of_bins, low, high); - assert_eq!(stats[0].count, 1); - assert_eq!(stats[0].n_runs, 1); - assert_eq!(stats[1].count, 3); - assert_eq!(stats[1].n_runs, 2); - assert_eq!(stats[2].count, 3); - assert_eq!(stats[2].n_runs, 2); - assert_eq!(stats[3].count, 2); - assert_eq!(stats[3].n_runs, 2); - } + #[test] + fn test_count_bins() { + let numbers = vec![0., 1., 1.5, 2., 2.5, 3., 2.5, 1.5, 3.5]; + let number_of_bins = 4; + let low = 0.; + let high = 4.; + let stats = calculate_bin_stats(numbers, number_of_bins, low, high); + assert_eq!(stats[0].count, 1); + assert_eq!(stats[0].n_runs, 1); + assert_eq!(stats[1].count, 3); + assert_eq!(stats[1].n_runs, 2); + assert_eq!(stats[2].count, 3); + assert_eq!(stats[2].n_runs, 2); + assert_eq!(stats[3].count, 2); + assert_eq!(stats[3].n_runs, 2); + } } diff --git a/burn-wgpu/src/kernel/prng/bernoulli.rs b/burn-wgpu/src/kernel/prng/bernoulli.rs index 1a2324420c..69c4cbf21e 100644 --- a/burn-wgpu/src/kernel/prng/bernoulli.rs +++ b/burn-wgpu/src/kernel/prng/bernoulli.rs @@ -1,13 +1,13 @@ use crate::{ - compute::{compute_client, StaticKernel}, - element::WgpuElement, - kernel::{ - prng::base::{make_args_buffer, make_info_buffer}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, - }, - ops::numeric::empty_device, - tensor::WgpuTensor, - GraphicsApi, WgpuDevice, + compute::{compute_client, StaticKernel}, + element::WgpuElement, + kernel::{ + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, + }, + ops::numeric::empty_device, + tensor::WgpuTensor, + GraphicsApi, WgpuDevice, }; use burn_tensor::Shape; @@ -16,115 +16,122 @@ use super::base::Prng; struct BernoulliPrng; impl StaticKernelSource for BernoulliPrng { - fn source() -> SourceTemplate { - Prng::source() - .register("num_args", "1") - .register( - "prng_loop", - include_str!("../../template/prng/bernoulli_inner_loop.wgsl"), - ) - .add_template("fn cast_elem(e: bool) -> {{ elem }} {return {{elem}}(e);}") - } + fn source() -> SourceTemplate { + Prng::source() + .register("num_args", "1") + .register( + "prng_loop", + include_str!("../../template/prng/bernoulli_inner_loop.wgsl"), + ) + .add_template("fn cast_elem(e: bool) -> {{ elem }} {return {{elem}}(e);}") + } } /// Pseudo-random generator for bernoulli pub fn random_bernoulli( - shape: Shape, - device: &WgpuDevice, - prob: E, + shape: Shape, + device: &WgpuDevice, + prob: E, ) -> WgpuTensor { - const N_VALUES_PER_THREAD: usize = 128; - - let client = compute_client::(device); - let output = empty_device(client.clone(), device.clone(), shape.clone()); - let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); - let args_handle = make_args_buffer(client.clone(), &[prob]); - let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); - - client.execute( - Box::new(kernel), - &[&output.handle, &info_handle, &args_handle], - ); - - output + const N_VALUES_PER_THREAD: usize = 128; + + let client = compute_client::(device); + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[prob]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); + + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], + ); + + output } #[cfg(test)] mod tests { - use core::f32; - - use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; - use serial_test::serial; - - use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; - - #[test] - #[serial] - fn subsequent_calls_give_different_tensors() { - TestBackend::seed(0); - let shape: Shape<2> = [40, 40].into(); - let device = WgpuDevice::default(); - - let tensor_1 = - Tensor::::random_device(shape.clone(), Distribution::Bernoulli(0.5), &device); - let tensor_2 = - Tensor::::random_device(shape.clone(), Distribution::Bernoulli(0.5), &device); - let mut diff_exists = false; - for i in 0..shape.num_elements() { - if tensor_1.to_data().value[i] != tensor_2.to_data().value[i] { - diff_exists = true; - break; - } + use core::f32; + + use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; + use serial_test::serial; + + use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; + + #[test] + #[serial] + fn subsequent_calls_give_different_tensors() { + TestBackend::seed(0); + let shape: Shape<2> = [40, 40].into(); + let device = WgpuDevice::default(); + + let tensor_1 = Tensor::::random_device( + shape.clone(), + Distribution::Bernoulli(0.5), + &device, + ); + let tensor_2 = Tensor::::random_device( + shape.clone(), + Distribution::Bernoulli(0.5), + &device, + ); + let mut diff_exists = false; + for i in 0..shape.num_elements() { + if tensor_1.to_data().value[i] != tensor_2.to_data().value[i] { + diff_exists = true; + break; + } + } + assert!(diff_exists); } - assert!(diff_exists); - } - - #[test] - #[serial] - fn number_of_1_proportional_to_prob() { - TestBackend::seed(0); - let shape: Shape<2> = [40, 40].into(); - let device = WgpuDevice::default(); - let prob = 0.7; - - let tensor_1 = Tensor::::random_device( - shape.clone(), - Distribution::Bernoulli(prob), - &device, - ); - // High bound slightly over 1 so 1.0 is included in second bin - let bin_stats = calculate_bin_stats(tensor_1.into_data().value, 2, 0., 1.1); - assert!( - f32::abs((bin_stats[1].count as f32 / shape.num_elements() as f32) - prob as f32) < 0.05 - ); - } - - #[test] - #[serial] - fn runs_test() { - TestBackend::seed(0); - let shape = Shape::new([512, 512]); - let device = WgpuDevice::default(); - let tensor = - Tensor::::random_device(shape, Distribution::Bernoulli(0.5), &device); - - let numbers = tensor.into_data().value; - let stats = calculate_bin_stats(numbers, 2, 0., 1.1); - let n_0 = stats[0].count as f32; - let n_1 = stats[1].count as f32; - let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; - - let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; - let variance = - ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); - let z = (n_runs - expectation) / variance.sqrt(); - - // below 2 means we can have good confidence in the randomness - // we put 2.5 to make sure it passes even when very unlucky - assert!(z.abs() < 2.5); - } + #[test] + #[serial] + fn number_of_1_proportional_to_prob() { + TestBackend::seed(0); + let shape: Shape<2> = [40, 40].into(); + let device = WgpuDevice::default(); + let prob = 0.7; + + let tensor_1 = Tensor::::random_device( + shape.clone(), + Distribution::Bernoulli(prob), + &device, + ); + + // High bound slightly over 1 so 1.0 is included in second bin + let bin_stats = calculate_bin_stats(tensor_1.into_data().value, 2, 0., 1.1); + assert!( + f32::abs((bin_stats[1].count as f32 / shape.num_elements() as f32) - prob as f32) + < 0.05 + ); + } + + #[test] + #[serial] + fn runs_test() { + TestBackend::seed(0); + let shape = Shape::new([512, 512]); + let device = WgpuDevice::default(); + let tensor = + Tensor::::random_device(shape, Distribution::Bernoulli(0.5), &device); + + let numbers = tensor.into_data().value; + let stats = calculate_bin_stats(numbers, 2, 0., 1.1); + let n_0 = stats[0].count as f32; + let n_1 = stats[1].count as f32; + let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; + + let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; + let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) + / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); + let z = (n_runs - expectation) / variance.sqrt(); + + // below 2 means we can have good confidence in the randomness + // we put 2.5 to make sure it passes even when very unlucky + assert!(z.abs() < 2.5); + } } diff --git a/burn-wgpu/src/kernel/prng/normal.rs b/burn-wgpu/src/kernel/prng/normal.rs index dd88e5f5fe..fc80f3f1e7 100644 --- a/burn-wgpu/src/kernel/prng/normal.rs +++ b/burn-wgpu/src/kernel/prng/normal.rs @@ -1,15 +1,15 @@ use burn_tensor::Shape; use crate::{ - compute::{compute_client, StaticKernel}, - element::WgpuElement, - kernel::{ - prng::base::{make_args_buffer, make_info_buffer}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, - }, - ops::numeric::empty_device, - tensor::WgpuTensor, - GraphicsApi, WgpuDevice, + compute::{compute_client, StaticKernel}, + element::WgpuElement, + kernel::{ + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, + }, + ops::numeric::empty_device, + tensor::WgpuTensor, + GraphicsApi, WgpuDevice, }; use super::base::Prng; @@ -17,107 +17,110 @@ use super::base::Prng; struct NormalPrng; impl StaticKernelSource for NormalPrng { - fn source() -> SourceTemplate { - Prng::source() - .register("num_args", "2") - .register( - "prng_loop", - include_str!("../../template/prng/normal_inner_loop.wgsl"), - ) - .add_template(include_str!( - "../../template/prng/box_muller_transform.wgsl" - )) - } + fn source() -> SourceTemplate { + Prng::source() + .register("num_args", "2") + .register( + "prng_loop", + include_str!("../../template/prng/normal_inner_loop.wgsl"), + ) + .add_template(include_str!( + "../../template/prng/box_muller_transform.wgsl" + )) + } } /// Pseudo-random generator for normal distribution pub fn random_normal( - shape: Shape, - device: &WgpuDevice, - mean: E, - std: E, + shape: Shape, + device: &WgpuDevice, + mean: E, + std: E, ) -> WgpuTensor { - const N_VALUES_PER_THREAD: usize = 128; // must be even + const N_VALUES_PER_THREAD: usize = 128; // must be even - let client = compute_client::(device); - let output = empty_device(client.clone(), device.clone(), shape.clone()); - let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); - let args_handle = make_args_buffer(client.clone(), &[mean, std]); - let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); + let client = compute_client::(device); + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[mean, std]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); - client.execute( - Box::new(kernel), - &[&output.handle, &info_handle, &args_handle], - ); + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], + ); - output + output } #[cfg(test)] mod tests { - use burn_tensor::{backend::Backend, Data, Distribution, Shape, Tensor}; - use serial_test::serial; + use burn_tensor::{backend::Backend, Data, Distribution, Shape, Tensor}; + use serial_test::serial; - use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; + use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; - #[test] - #[serial] - fn subsequent_calls_give_different_tensors() { - TestBackend::seed(0); - let shape = [4, 5]; - let device = WgpuDevice::default(); + #[test] + #[serial] + fn subsequent_calls_give_different_tensors() { + TestBackend::seed(0); + let shape = [4, 5]; + let device = WgpuDevice::default(); - let tensor_1 = - Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); - let tensor_2 = - Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); - for i in 0..20 { - assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); + let tensor_1 = + Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); + let tensor_2 = + Tensor::::random_device(shape, Distribution::Normal(0., 1.), &device); + for i in 0..20 { + assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); + } } - } - #[test] - #[serial] - fn empirical_mean_close_to_expectation() { - TestBackend::seed(0); - let shape = [128, 128]; - let device = WgpuDevice::default(); - let mean = 10.; - let tensor = - Tensor::::random_device(shape, Distribution::Normal(mean, 2.), &device); - let empirical_mean = tensor.mean().into_data(); - empirical_mean.assert_approx_eq(&Data::from([mean as f32]), 1); - } + #[test] + #[serial] + fn empirical_mean_close_to_expectation() { + TestBackend::seed(0); + let shape = [128, 128]; + let device = WgpuDevice::default(); + let mean = 10.; + let tensor = + Tensor::::random_device(shape, Distribution::Normal(mean, 2.), &device); + let empirical_mean = tensor.mean().into_data(); + empirical_mean.assert_approx_eq(&Data::from([mean as f32]), 1); + } - #[test] - #[serial] - fn normal_respects_68_95_99_rule() { - // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule - let shape: Shape<2> = [1000, 1000].into(); - let device = WgpuDevice::default(); - let mu = 0.; - let s = 1.; - let tensor = - Tensor::::random_device(shape.clone(), Distribution::Normal(mu, s), &device); - let stats = calculate_bin_stats( - tensor.into_data().value, - 6, - (mu - 3. * s) as f32, - (mu + 3. * s) as f32, - ); - let assert_approx_eq = |count, percent| { - let expected = percent * shape.num_elements() as f32 / 100.; - assert!(f32::abs(count as f32 - expected) < 2000.); - }; - assert_approx_eq(stats[0].count, 2.1); - assert_approx_eq(stats[1].count, 13.6); - assert_approx_eq(stats[2].count, 34.1); - assert_approx_eq(stats[3].count, 34.1); - assert_approx_eq(stats[4].count, 13.6); - assert_approx_eq(stats[5].count, 2.1); - } + #[test] + #[serial] + fn normal_respects_68_95_99_rule() { + // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule + let shape: Shape<2> = [1000, 1000].into(); + let device = WgpuDevice::default(); + let mu = 0.; + let s = 1.; + let tensor = Tensor::::random_device( + shape.clone(), + Distribution::Normal(mu, s), + &device, + ); + let stats = calculate_bin_stats( + tensor.into_data().value, + 6, + (mu - 3. * s) as f32, + (mu + 3. * s) as f32, + ); + let assert_approx_eq = |count, percent| { + let expected = percent * shape.num_elements() as f32 / 100.; + assert!(f32::abs(count as f32 - expected) < 2000.); + }; + assert_approx_eq(stats[0].count, 2.1); + assert_approx_eq(stats[1].count, 13.6); + assert_approx_eq(stats[2].count, 34.1); + assert_approx_eq(stats[3].count, 34.1); + assert_approx_eq(stats[4].count, 13.6); + assert_approx_eq(stats[5].count, 2.1); + } } diff --git a/burn-wgpu/src/kernel/prng/uniform.rs b/burn-wgpu/src/kernel/prng/uniform.rs index bf9880ba35..ec9f8e00a7 100644 --- a/burn-wgpu/src/kernel/prng/uniform.rs +++ b/burn-wgpu/src/kernel/prng/uniform.rs @@ -1,15 +1,15 @@ use burn_tensor::Shape; use crate::{ - compute::{compute_client, StaticKernel, WgpuComputeClient}, - element::WgpuElement, - kernel::{ - prng::base::{make_args_buffer, make_info_buffer}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, - }, - ops::numeric::empty_device, - tensor::WgpuTensor, - GraphicsApi, WgpuDevice, + compute::{compute_client, StaticKernel, WgpuComputeClient}, + element::WgpuElement, + kernel::{ + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, + }, + ops::numeric::empty_device, + tensor::WgpuTensor, + GraphicsApi, WgpuDevice, }; use super::base::Prng; @@ -17,149 +17,154 @@ use super::base::Prng; struct UniformPrng; impl StaticKernelSource for UniformPrng { - fn source() -> SourceTemplate { - Prng::source().register("num_args", "2").register( - "prng_loop", - include_str!("../../template/prng/uniform_inner_loop.wgsl"), - ) - } + fn source() -> SourceTemplate { + Prng::source().register("num_args", "2").register( + "prng_loop", + include_str!("../../template/prng/uniform_inner_loop.wgsl"), + ) + } } /// Pseudo-random generator for uniform distribution pub fn random_uniform( - shape: Shape, - device: &WgpuDevice, - low: E, - high: E, + shape: Shape, + device: &WgpuDevice, + low: E, + high: E, ) -> WgpuTensor { - let client = compute_client::(device); - uniform_kernel(client, device, &shape, low, high) + let client = compute_client::(device); + uniform_kernel(client, device, &shape, low, high) } /// Pseudo-random generator for uniform distribution, based on /// another tensor's client, device and shape pub fn random_like_uniform( - tensor: &WgpuTensor, - low: E, - high: E, + tensor: &WgpuTensor, + low: E, + high: E, ) -> WgpuTensor { - uniform_kernel( - tensor.client.clone(), - &tensor.device, - &tensor.shape, - low, - high, - ) + uniform_kernel( + tensor.client.clone(), + &tensor.device, + &tensor.shape, + low, + high, + ) } fn uniform_kernel( - client: WgpuComputeClient, - device: &WgpuDevice, - shape: &Shape, - low: E, - high: E, + client: WgpuComputeClient, + device: &WgpuDevice, + shape: &Shape, + low: E, + high: E, ) -> WgpuTensor { - const N_VALUES_PER_THREAD: usize = 128; - - let output = empty_device(client.clone(), device.clone(), shape.clone()); - let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); - let args_handle = make_args_buffer(client.clone(), &[low, high]); - let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); - - client.execute( - Box::new(kernel), - &[&output.handle, &info_handle, &args_handle], - ); - - output + const N_VALUES_PER_THREAD: usize = 128; + + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[low, high]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); + + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], + ); + + output } #[cfg(test)] mod tests { - use core::f32; + use core::f32; + + use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; + use serial_test::serial; + + use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; + + #[test] + #[serial] + fn subsequent_calls_give_different_tensors() { + TestBackend::seed(0); + let shape = [4, 5]; + let device = WgpuDevice::default(); + + let tensor_1 = + Tensor::::random_device(shape, Distribution::Default, &device); + let tensor_2 = + Tensor::::random_device(shape, Distribution::Default, &device); + for i in 0..20 { + assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); + } + } - use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; - use serial_test::serial; + #[test] + #[serial] + fn values_all_within_interval_default() { + TestBackend::seed(0); + let shape = [24, 24]; + let device = WgpuDevice::default(); - use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice}; + let tensor = Tensor::::random_device(shape, Distribution::Default, &device); + tensor.to_data().assert_within_range(0..1); + } - #[test] - #[serial] - fn subsequent_calls_give_different_tensors() { - TestBackend::seed(0); - let shape = [4, 5]; - let device = WgpuDevice::default(); + #[test] + #[serial] + fn values_all_within_interval_uniform() { + TestBackend::seed(0); + let shape = [24, 24]; + let device = WgpuDevice::default(); + + let tensor = + Tensor::::random_device(shape, Distribution::Uniform(5., 17.), &device); + tensor.to_data().assert_within_range(5..17); + } + + #[test] + #[serial] + fn at_least_one_value_per_bin_uniform() { + TestBackend::seed(0); + let shape = [64, 64]; + let device = WgpuDevice::default(); + + let tensor = Tensor::::random_device( + shape, + Distribution::Uniform(-5., 10.), + &device, + ); + let numbers = tensor.into_data().value; + let stats = calculate_bin_stats(numbers, 3, -5., 10.); + assert!(stats[0].count >= 1); + assert!(stats[1].count >= 1); + assert!(stats[2].count >= 1); + } - let tensor_1 = Tensor::::random_device(shape, Distribution::Default, &device); - let tensor_2 = Tensor::::random_device(shape, Distribution::Default, &device); - for i in 0..20 { - assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]); + #[test] + #[serial] + fn runs_test() { + TestBackend::seed(0); + let shape = Shape::new([512, 512]); + let device = WgpuDevice::default(); + let tensor = Tensor::::random_device(shape, Distribution::Default, &device); + + let numbers = tensor.into_data().value; + let stats = calculate_bin_stats(numbers, 2, 0., 1.); + let n_0 = stats[0].count as f32; + let n_1 = stats[1].count as f32; + let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; + + let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; + let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) + / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); + let z = (n_runs - expectation) / variance.sqrt(); + + // below 2 means we can have good confidence in the randomness + // we put 2.5 to make sure it passes even when very unlucky + assert!(z.abs() < 2.5); } - } - - #[test] - #[serial] - fn values_all_within_interval_default() { - TestBackend::seed(0); - let shape = [24, 24]; - let device = WgpuDevice::default(); - - let tensor = Tensor::::random_device(shape, Distribution::Default, &device); - tensor.to_data().assert_within_range(0..1); - } - - #[test] - #[serial] - fn values_all_within_interval_uniform() { - TestBackend::seed(0); - let shape = [24, 24]; - let device = WgpuDevice::default(); - - let tensor = - Tensor::::random_device(shape, Distribution::Uniform(5., 17.), &device); - tensor.to_data().assert_within_range(5..17); - } - - #[test] - #[serial] - fn at_least_one_value_per_bin_uniform() { - TestBackend::seed(0); - let shape = [64, 64]; - let device = WgpuDevice::default(); - - let tensor = - Tensor::::random_device(shape, Distribution::Uniform(-5., 10.), &device); - let numbers = tensor.into_data().value; - let stats = calculate_bin_stats(numbers, 3, -5., 10.); - assert!(stats[0].count >= 1); - assert!(stats[1].count >= 1); - assert!(stats[2].count >= 1); - } - - #[test] - #[serial] - fn runs_test() { - TestBackend::seed(0); - let shape = Shape::new([512, 512]); - let device = WgpuDevice::default(); - let tensor = Tensor::::random_device(shape, Distribution::Default, &device); - - let numbers = tensor.into_data().value; - let stats = calculate_bin_stats(numbers, 2, 0., 1.); - let n_0 = stats[0].count as f32; - let n_1 = stats[1].count as f32; - let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32; - - let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0; - let variance = - ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1)) / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.)); - let z = (n_runs - expectation) / variance.sqrt(); - - // below 2 means we can have good confidence in the randomness - // we put 2.5 to make sure it passes even when very unlucky - assert!(z.abs() < 2.5); - } } diff --git a/burn-wgpu/src/kernel/reduce/base.rs b/burn-wgpu/src/kernel/reduce/base.rs index bf50288116..0f58369607 100644 --- a/burn-wgpu/src/kernel/reduce/base.rs +++ b/burn-wgpu/src/kernel/reduce/base.rs @@ -2,21 +2,21 @@ use crate::{element::WgpuElement, tensor::WgpuTensor}; /// Creates an empty output tensor with reduce output shape pub fn init_reduce_output( - input: &WgpuTensor, - reduce_dim: usize, + input: &WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[reduce_dim] = 1; + let mut shape_out = input.shape.clone(); + shape_out.dims[reduce_dim] = 1; - // Create output handle - let num_elems_output = shape_out.num_elements(); - let handle = input - .client - .empty(num_elems_output * core::mem::size_of::()); - WgpuTensor::new( - input.client.clone(), - input.device.clone(), - shape_out.clone(), - handle, - ) + // Create output handle + let num_elems_output = shape_out.num_elements(); + let handle = input + .client + .empty(num_elems_output * core::mem::size_of::()); + WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + handle, + ) } diff --git a/burn-wgpu/src/kernel/reduce/reduction.rs b/burn-wgpu/src/kernel/reduce/reduction.rs index aa8bb3f9e4..432f678827 100644 --- a/burn-wgpu/src/kernel/reduce/reduction.rs +++ b/burn-wgpu/src/kernel/reduce/reduction.rs @@ -1,18 +1,18 @@ use crate::{ - compute::StaticKernel, - element::WgpuElement, - kernel::{ - build_info, elemwise_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, - WORKGROUP_DEFAULT, - }, - kernel_wgsl, - tensor::WgpuTensor, + compute::StaticKernel, + element::WgpuElement, + kernel::{ + build_info, elemwise_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, + WORKGROUP_DEFAULT, + }, + kernel_wgsl, + tensor::WgpuTensor, }; use burn_tensor::Shape; kernel_wgsl!( - RecursiveSumRaw, - "../../template/reduction/recursive_sum.wgsl" + RecursiveSumRaw, + "../../template/reduction/recursive_sum.wgsl" ); kernel_wgsl!(ReductionDimRaw, "../../template/reduction/reduce_dim.wgsl"); kernel_wgsl!(ReductionArgsRaw, "../../template/reduction/args.wgsl"); @@ -23,199 +23,199 @@ pub(crate) struct SumDim; pub(crate) struct MeanDim; impl StaticKernelSource for SumDim { - fn source() -> SourceTemplate { - ReductionDimRaw::source().register("assign", "output[id] = sum;") - } + fn source() -> SourceTemplate { + ReductionDimRaw::source().register("assign", "output[id] = sum;") + } } impl StaticKernelSource for MeanDim { - fn source() -> SourceTemplate { - ReductionDimRaw::source() - .add_template( - "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { + fn source() -> SourceTemplate { + ReductionDimRaw::source() + .add_template( + "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { return sum / {{ elem }}(dim); }", - ) - .register("assign", "output[id] = mean_dim(sum, shape_dim);") - } + ) + .register("assign", "output[id] = mean_dim(sum, shape_dim);") + } } impl StaticKernelSource for ArgsMax { - fn source() -> SourceTemplate { - ReductionArgsRaw::source() - .register("cmp", ">") - .register("initial", (-32767).to_string()) - } + fn source() -> SourceTemplate { + ReductionArgsRaw::source() + .register("cmp", ">") + .register("initial", (-32767).to_string()) + } } impl StaticKernelSource for ArgsMin { - fn source() -> SourceTemplate { - ReductionArgsRaw::source() - .register("cmp", "<") - .register("initial", 32767.to_string()) - } + fn source() -> SourceTemplate { + ReductionArgsRaw::source() + .register("cmp", "<") + .register("initial", 32767.to_string()) + } } /// Sum all elements in the input buffer. pub fn sum(input: WgpuTensor) -> WgpuTensor { - let mut input_handle = input.handle; - let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT); + let mut input_handle = input.handle; + let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT); - loop { - let num_invocations = workgroup.num_invocations(); - let handle = input - .client - .empty(core::mem::size_of::() * num_invocations); + loop { + let num_invocations = workgroup.num_invocations(); + let handle = input + .client + .empty(core::mem::size_of::() * num_invocations); - let kernel = StaticKernel::< - KernelSettings, - >::new(workgroup); + let kernel = StaticKernel::< + KernelSettings, + >::new(workgroup); - input - .client - .execute(Box::new(kernel), &[&input_handle, &handle]); + input + .client + .execute(Box::new(kernel), &[&input_handle, &handle]); - if num_invocations <= 1 { - return WgpuTensor::new(input.client, input.device, Shape::new([1]), handle); - } + if num_invocations <= 1 { + return WgpuTensor::new(input.client, input.device, Shape::new([1]), handle); + } - input_handle = handle; - workgroup = elemwise_workgroup(num_invocations, WORKGROUP_DEFAULT); - } + input_handle = handle; + workgroup = elemwise_workgroup(num_invocations, WORKGROUP_DEFAULT); + } } /// Execute the sum dim kernel. pub fn sum_dim( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim::(input, output, dim) + reduction_dim::(input, output, dim) } /// Execute the mean dim kernel. pub fn mean_dim( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim::(input, output, dim) + reduction_dim::(input, output, dim) } fn reduction_dim( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - let kernel = - StaticKernel::>::new( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), + let kernel = + StaticKernel::>::new( + elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), + ); + + let mut info = build_info(&[&input, &output]); + info.push(dim as u32); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], ); - let mut info = build_info(&[&input, &output]); - info.push(dim as u32); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_handle], - ); - - output + output } /// Execute the argmax kernel. pub fn argmax( - input: WgpuTensor, - dim: usize, + input: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_args_dim::(input, dim) + reduction_args_dim::(input, dim) } /// Execute the argmin kernel. pub fn argmin( - input: WgpuTensor, - dim: usize, + input: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_args_dim::(input, dim) + reduction_args_dim::(input, dim) } fn reduction_args_dim( - input: WgpuTensor, - dim: usize, + input: WgpuTensor, + dim: usize, ) -> WgpuTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[dim] = 1; - let num_elems = shape_out.num_elements(); - let buffer = input.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - input.client.clone(), - input.device.clone(), - shape_out, - buffer, - ); - - let kernel = - StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), + let mut shape_out = input.shape.clone(); + shape_out.dims[dim] = 1; + let num_elems = shape_out.num_elements(); + let buffer = input.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape_out, + buffer, ); - let mut info = build_info(&[&input, &output]); - info.push(dim as u32); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_handle], - ); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), + ); + let mut info = build_info(&[&input, &output]); + info.push(dim as u32); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); - WgpuTensor::new(output.client, output.device, output.shape, output.handle) + WgpuTensor::new(output.client, output.device, output.shape, output.handle) } #[cfg(test)] mod tests { - use super::*; - use crate::{ - kernel::reduce::init_reduce_output, - tests::{ReferenceBackend, TestBackend}, - }; - use burn_tensor::{Distribution, Int, Tensor}; - - #[test] - fn reduction_sum_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let val = Tensor::::from_primitive(sum(tensor.into_primitive())); - let val_ref = tensor_ref.sum(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_sum_dim_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 1024], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 1; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(reduction_dim::( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_args_dim_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 1024], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let val = Tensor::::from_primitive(argmax(tensor.into_primitive(), 1)); - let val_ref = tensor_ref.argmax(1); - - assert_eq!(val_ref.into_data().convert(), val.into_data()); - } + use super::*; + use crate::{ + kernel::reduce::init_reduce_output, + tests::{ReferenceBackend, TestBackend}, + }; + use burn_tensor::{Distribution, Int, Tensor}; + + #[test] + fn reduction_sum_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let val = Tensor::::from_primitive(sum(tensor.into_primitive())); + let val_ref = tensor_ref.sum(); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 1; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(reduction_dim::( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(1); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_args_dim_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let val = Tensor::::from_primitive(argmax(tensor.into_primitive(), 1)); + let val_ref = tensor_ref.argmax(1); + + assert_eq!(val_ref.into_data().convert(), val.into_data()); + } } diff --git a/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs b/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs index cd6d7dfa91..4d4fb43e3a 100644 --- a/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs +++ b/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs @@ -1,168 +1,170 @@ use crate::{ - compute::{StaticKernel, WorkGroup}, - element::WgpuElement, - kernel::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}, - kernel_wgsl, - tensor::WgpuTensor, + compute::{StaticKernel, WorkGroup}, + element::WgpuElement, + kernel::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, }; kernel_wgsl!( - ReductionDimSharedMemoryRaw, - "../../template/reduction/reduce_dim_shared_memory.wgsl" + ReductionDimSharedMemoryRaw, + "../../template/reduction/reduce_dim_shared_memory.wgsl" ); pub(crate) struct SumDimSharedMemory; pub(crate) struct MeanDimSharedMemory; impl StaticKernelSource for SumDimSharedMemory { - fn source() -> SourceTemplate { - ReductionDimSharedMemoryRaw::source() - .register( - "shared_size", - (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), - ) - .register("initial", 0.0.to_string()) - .register("update", "shared_memory[local_id] += value; ") - .register("assign", "output[output_position] = final_value; ") - } + fn source() -> SourceTemplate { + ReductionDimSharedMemoryRaw::source() + .register( + "shared_size", + (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), + ) + .register("initial", 0.0.to_string()) + .register("update", "shared_memory[local_id] += value; ") + .register("assign", "output[output_position] = final_value; ") + } } impl StaticKernelSource for MeanDimSharedMemory { - fn source() -> SourceTemplate { - ReductionDimSharedMemoryRaw::source() - .register( - "shared_size", - (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), - ) - .register("initial", 0.0.to_string()) - .register("update", "shared_memory[local_id] += value; ") - .add_template( - "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { + fn source() -> SourceTemplate { + ReductionDimSharedMemoryRaw::source() + .register( + "shared_size", + (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), + ) + .register("initial", 0.0.to_string()) + .register("update", "shared_memory[local_id] += value; ") + .add_template( + "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { return sum / {{ elem }}(dim); }", - ) - .register( - "assign", - "output[output_position] = mean_dim(final_value, shape_input_dim_reduce);", - ) - } + ) + .register( + "assign", + "output[output_position] = mean_dim(final_value, shape_input_dim_reduce);", + ) + } } /// Execute the sum dim kernel leveraging shared memory /// Probably more efficient on tensors where the dimension to reduced /// is much larger than the others pub fn sum_dim_shared_memory( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim_shared_memory::(input, output, dim) + reduction_dim_shared_memory::(input, output, dim) } /// Execute the mean dim kernel leveraging shared memory /// Probably more efficient on tensors where the dimension to reduced /// is much larger than the others pub fn mean_dim_shared_memory( - input: WgpuTensor, - output: WgpuTensor, - dim: usize, + input: WgpuTensor, + output: WgpuTensor, + dim: usize, ) -> WgpuTensor { - reduction_dim_shared_memory::(input, output, dim) + reduction_dim_shared_memory::(input, output, dim) } fn reduction_dim_shared_memory( - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let num_elems_output = output.shape.num_elements(); - let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x); - let grid = WorkGroup::new(n_workgroups_x as u32, n_workgroups_y as u32, 1); + let num_elems_output = output.shape.num_elements(); + let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32)); + let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x); + let grid = WorkGroup::new(n_workgroups_x as u32, n_workgroups_y as u32, 1); - let kernel = - StaticKernel::>::new(grid); + let kernel = + StaticKernel::>::new( + grid, + ); - // Build info - let mut info = build_info(&[&input, &output]); + // Build info + let mut info = build_info(&[&input, &output]); - // Reduce groups are elements that are aligned along the reduce dim - let reduce_group_size = input.shape.dims[reduce_dim]; - let n_invocation_per_workgroup = WORKGROUP_DEFAULT * WORKGROUP_DEFAULT; - let n_reduce_elements_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32; + // Reduce groups are elements that are aligned along the reduce dim + let reduce_group_size = input.shape.dims[reduce_dim]; + let n_invocation_per_workgroup = WORKGROUP_DEFAULT * WORKGROUP_DEFAULT; + let n_reduce_elements_per_thread = + f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32; - // Add dimension of reduction and how many reduce elements are treated per thread - info.push(reduce_dim as u32); - info.push(n_reduce_elements_per_thread); + // Add dimension of reduction and how many reduce elements are treated per thread + info.push(reduce_dim as u32); + info.push(n_reduce_elements_per_thread); - let info_handle = input.client.create(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &info_handle], - ); + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); - output + output } #[cfg(test)] mod tests { - use super::*; - use crate::{ - kernel::reduce::init_reduce_output, - tests::{ReferenceBackend, TestBackend}, - }; - use burn_tensor::{Distribution, Tensor}; - - #[test] - fn reduction_sum_dim_shared_memory_small() { - let tensor = Tensor::::random([700], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 0; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(sum_dim_shared_memory( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_sum_dim_shared_memory_medium() { - let tensor = Tensor::::random([6, 1024], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 1; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(sum_dim_shared_memory( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } - - #[test] - fn reduction_sum_dim_shared_memory_large() { - let tensor = Tensor::::random([4, 1024, 50], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - let reduce_dim = 2; - let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); - - let val = Tensor::::from_primitive(sum_dim_shared_memory( - tensor.into_primitive(), - output, - reduce_dim, - )); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 3); - } + use super::*; + use crate::{ + kernel::reduce::init_reduce_output, + tests::{ReferenceBackend, TestBackend}, + }; + use burn_tensor::{Distribution, Tensor}; + + #[test] + fn reduction_sum_dim_shared_memory_small() { + let tensor = Tensor::::random([700], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 0; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_shared_memory_medium() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 1; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_shared_memory_large() { + let tensor = Tensor::::random([4, 1024, 50], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 2; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } } diff --git a/burn-wgpu/src/kernel/reduce/tune/base.rs b/burn-wgpu/src/kernel/reduce/tune/base.rs index 02d1f1b4de..d52bf37dcb 100644 --- a/burn-wgpu/src/kernel/reduce/tune/base.rs +++ b/burn-wgpu/src/kernel/reduce/tune/base.rs @@ -1,27 +1,27 @@ #[macro_export] /// Generate an autotune operation for a reduce kernel macro_rules! reduce_tune_ops { - ($name:ident, $func:expr) => { - #[derive(new)] - pub(crate) struct $name { - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, - } + ($name:ident, $func:expr) => { + #[derive(new)] + pub(crate) struct $name { + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, + } - impl AutotuneOperation for $name { - fn execute(self: Box) { - #[allow(clippy::redundant_closure_call)] - $func(self.input, self.output, self.reduce_dim); - } + impl AutotuneOperation for $name { + fn execute(self: Box) { + #[allow(clippy::redundant_closure_call)] + $func(self.input, self.output, self.reduce_dim); + } - fn clone(&self) -> Box { - Box::new(Self { - input: self.input.clone(), - output: self.output.clone(), - reduce_dim: self.reduce_dim.clone(), - }) - } - } - }; + fn clone(&self) -> Box { + Box::new(Self { + input: self.input.clone(), + output: self.output.clone(), + reduce_dim: self.reduce_dim.clone(), + }) + } + } + }; } diff --git a/burn-wgpu/src/kernel/reduce/tune/key.rs b/burn-wgpu/src/kernel/reduce/tune/key.rs index 44f61647b5..db5e4b21bf 100644 --- a/burn-wgpu/src/kernel/reduce/tune/key.rs +++ b/burn-wgpu/src/kernel/reduce/tune/key.rs @@ -5,45 +5,45 @@ use burn_tensor::Shape; #[derive(Hash, Eq, PartialEq, Debug, Clone)] /// Autotune key representative of reduce versions pub struct ReduceAutotuneKey { - reduce_dim_length: usize, - others_product: usize, + reduce_dim_length: usize, + others_product: usize, } impl Display for ReduceAutotuneKey { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - format!( - "Reduce - reduce_dim_length: {:?} others_product: {:?}", - self.reduce_dim_length, self.others_product - ) - .as_str(), - ) - } + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + format!( + "Reduce - reduce_dim_length: {:?} others_product: {:?}", + self.reduce_dim_length, self.others_product + ) + .as_str(), + ) + } } impl ReduceAutotuneKey { - /// Create a reduce autotune key from the input shape and reduce dim - pub fn new(shape: &Shape, reduce_dim: usize) -> Self { - let reduce_dim_length = shape.dims[reduce_dim]; - let mut others_product = 1; - for d in 0..D { - if d != reduce_dim { - others_product *= shape.dims[d] - } - } - Self { - reduce_dim_length: anchor(reduce_dim_length, None), - others_product: anchor(others_product, None), + /// Create a reduce autotune key from the input shape and reduce dim + pub fn new(shape: &Shape, reduce_dim: usize) -> Self { + let reduce_dim_length = shape.dims[reduce_dim]; + let mut others_product = 1; + for d in 0..D { + if d != reduce_dim { + others_product *= shape.dims[d] + } + } + Self { + reduce_dim_length: anchor(reduce_dim_length, None), + others_product: anchor(others_product, None), + } } - } } fn anchor(x: usize, max: Option) -> usize { - let exp = f32::ceil(f32::log2(x as f32)) as u32; - let power_of_2 = 2_u32.pow(exp) as usize; - if let Some(max) = max { - min(power_of_2, max) - } else { - power_of_2 - } + let exp = f32::ceil(f32::log2(x as f32)) as u32; + let power_of_2 = 2_u32.pow(exp) as usize; + if let Some(max) = max { + min(power_of_2, max) + } else { + power_of_2 + } } diff --git a/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs index 8d272529b8..a19fc7cf34 100644 --- a/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs +++ b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs @@ -2,15 +2,15 @@ use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; use crate::{ - compute::WgpuAutotuneKey, - element::WgpuElement, - kernel::{ - prng::random_like_uniform, - reduce::{init_reduce_output, mean_dim, mean_dim_shared_memory}, - }, - ops::numeric::empty_device, - reduce_tune_ops, - tensor::WgpuTensor, + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{ + prng::random_like_uniform, + reduce::{init_reduce_output, mean_dim, mean_dim_shared_memory}, + }, + ops::numeric::empty_device, + reduce_tune_ops, + tensor::WgpuTensor, }; use super::ReduceAutotuneKey; @@ -19,90 +19,90 @@ use super::ReduceAutotuneKey; /// Autotune key is given by concatenating the closest upper power of 2 of /// dim to reduce, and product of others pub struct MeanDimAutotuneOperationSet { - key: WgpuAutotuneKey, - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, + key: WgpuAutotuneKey, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, } impl MeanDimAutotuneOperationSet { - fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { - Self { - key: WgpuAutotuneKey::MeanDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), - input, - output, - reduce_dim, + fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { + Self { + key: WgpuAutotuneKey::MeanDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), + input, + output, + reduce_dim, + } } - } } impl AutotuneOperationSet - for MeanDimAutotuneOperationSet + for MeanDimAutotuneOperationSet { - fn key(&self) -> WgpuAutotuneKey { - self.key.clone() - } + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } - fn autotunables(&self) -> Vec> { - let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); - let output = empty_device( - self.output.client.clone(), - self.output.device.clone(), - self.output.shape.clone(), - ); + let output = empty_device( + self.output.client.clone(), + self.output.device.clone(), + self.output.shape.clone(), + ); - vec![ - Box::new(MeanDimAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - Box::new(MeanDimSharedMemoryAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - ] - } + vec![ + Box::new(MeanDimAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + Box::new(MeanDimSharedMemoryAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + ] + } - fn fastest(self: Box, fastest_index: usize) -> Box { - // Warning: since AutotuneOperationSet shares his key with SumDimAutotuneOperationSet - // we must make sure the order here is correlated with SumDim - match fastest_index { - 0 => Box::new(MeanDimAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - 1 => Box::new(MeanDimSharedMemoryAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - _ => panic!("Fastest index is out of bound"), + fn fastest(self: Box, fastest_index: usize) -> Box { + // Warning: since AutotuneOperationSet shares his key with SumDimAutotuneOperationSet + // we must make sure the order here is correlated with SumDim + match fastest_index { + 0 => Box::new(MeanDimAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + 1 => Box::new(MeanDimSharedMemoryAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + _ => panic!("Fastest index is out of bound"), + } } - } } /// Executes autotune on mean_dim operation pub fn mean_dim_autotune( - input: WgpuTensor, - reduce_dim: usize, + input: WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let client = input.client.clone(); + let client = input.client.clone(); - let output = init_reduce_output(&input, reduce_dim); + let output = init_reduce_output(&input, reduce_dim); - let operation_set = Box::new(MeanDimAutotuneOperationSet::::new( - input, - output.clone(), - reduce_dim, - )); + let operation_set = Box::new(MeanDimAutotuneOperationSet::::new( + input, + output.clone(), + reduce_dim, + )); - client.execute_autotune(operation_set); + client.execute_autotune(operation_set); - output + output } // Probably better on balanced tensor shapes diff --git a/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs index 541a73d6db..a5831d7016 100644 --- a/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs +++ b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs @@ -2,15 +2,15 @@ use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; use crate::{ - compute::WgpuAutotuneKey, - element::WgpuElement, - kernel::{ - prng::random_like_uniform, - reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory}, - }, - ops::numeric::empty_device, - reduce_tune_ops, - tensor::WgpuTensor, + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{ + prng::random_like_uniform, + reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory}, + }, + ops::numeric::empty_device, + reduce_tune_ops, + tensor::WgpuTensor, }; use super::ReduceAutotuneKey; @@ -19,90 +19,90 @@ use super::ReduceAutotuneKey; /// Autotune key is given by concatenating the closest upper power of 2 of /// dim to reduce, and product of others pub struct SumDimAutotuneOperationSet { - key: WgpuAutotuneKey, - input: WgpuTensor, - output: WgpuTensor, - reduce_dim: usize, + key: WgpuAutotuneKey, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, } impl SumDimAutotuneOperationSet { - fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { - Self { - key: WgpuAutotuneKey::SumDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), - input, - output, - reduce_dim, + fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { + Self { + key: WgpuAutotuneKey::SumDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), + input, + output, + reduce_dim, + } } - } } impl AutotuneOperationSet - for SumDimAutotuneOperationSet + for SumDimAutotuneOperationSet { - fn key(&self) -> WgpuAutotuneKey { - self.key.clone() - } + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } - fn autotunables(&self) -> Vec> { - let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); - let output = empty_device( - self.output.client.clone(), - self.output.device.clone(), - self.output.shape.clone(), - ); + let output = empty_device( + self.output.client.clone(), + self.output.device.clone(), + self.output.shape.clone(), + ); - vec![ - Box::new(SumDimAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - Box::new(SumDimSharedMemoryAutotune::::new( - input.clone(), - output.clone(), - self.reduce_dim, - )), - ] - } + vec![ + Box::new(SumDimAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + Box::new(SumDimSharedMemoryAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + ] + } - fn fastest(self: Box, fastest_index: usize) -> Box { - // Warning: since AutotuneOperationSet shares his key with MeanDimAutotuneOperationSet - // we must make sure the order here is correlated with MeanDim - match fastest_index { - 0 => Box::new(SumDimAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - 1 => Box::new(SumDimSharedMemoryAutotune::::new( - self.input, - self.output, - self.reduce_dim, - )), - _ => panic!("Fastest index is out of bound"), + fn fastest(self: Box, fastest_index: usize) -> Box { + // Warning: since AutotuneOperationSet shares his key with MeanDimAutotuneOperationSet + // we must make sure the order here is correlated with MeanDim + match fastest_index { + 0 => Box::new(SumDimAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + 1 => Box::new(SumDimSharedMemoryAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + _ => panic!("Fastest index is out of bound"), + } } - } } /// Executes autotune on sum_dim operation pub fn sum_dim_autotune( - input: WgpuTensor, - reduce_dim: usize, + input: WgpuTensor, + reduce_dim: usize, ) -> WgpuTensor { - let client = input.client.clone(); + let client = input.client.clone(); - let output = init_reduce_output(&input, reduce_dim); + let output = init_reduce_output(&input, reduce_dim); - let operation_set = Box::new(SumDimAutotuneOperationSet::::new( - input, - output.clone(), - reduce_dim, - )); + let operation_set = Box::new(SumDimAutotuneOperationSet::::new( + input, + output.clone(), + reduce_dim, + )); - client.execute_autotune(operation_set); + client.execute_autotune(operation_set); - output + output } // Probably better on balanced tensor shapes diff --git a/burn-wgpu/src/kernel/source.rs b/burn-wgpu/src/kernel/source.rs index 6bbc0b9752..b13c2f6a50 100644 --- a/burn-wgpu/src/kernel/source.rs +++ b/burn-wgpu/src/kernel/source.rs @@ -6,64 +6,64 @@ use std::collections::HashMap; /// They will be updated with their proper value when `generate` is called. #[derive(Debug)] pub struct SourceTemplate { - items: HashMap, - templates: Vec, + items: HashMap, + templates: Vec, } impl SourceTemplate { - /// Create a new source template. - pub fn new(template: S) -> Self - where - S: Into, - { - Self { - items: HashMap::new(), - templates: vec![template.into()], + /// Create a new source template. + pub fn new(template: S) -> Self + where + S: Into, + { + Self { + items: HashMap::new(), + templates: vec![template.into()], + } } - } - /// Register the value for a placeholder item. - /// - /// # Notes - /// - /// The value can't have placeholders, since it would require recursive templating with - /// possibly circular dependencies. If you want to add a value that has some - /// placeholders, consider adding a new template to the source using - /// [add_template](SourceTemplate::add_template). The added template can be a function, and you can - /// register the function call instead. - pub fn register(mut self, name: Name, value: Value) -> Self - where - Name: Into, - Value: Into, - { - self.items.insert(name.into(), value.into()); - self - } + /// Register the value for a placeholder item. + /// + /// # Notes + /// + /// The value can't have placeholders, since it would require recursive templating with + /// possibly circular dependencies. If you want to add a value that has some + /// placeholders, consider adding a new template to the source using + /// [add_template](SourceTemplate::add_template). The added template can be a function, and you can + /// register the function call instead. + pub fn register(mut self, name: Name, value: Value) -> Self + where + Name: Into, + Value: Into, + { + self.items.insert(name.into(), value.into()); + self + } - /// Add a new template. - pub fn add_template(mut self, template: S) -> Self - where - S: Into, - { - self.templates.push(template.into()); - self - } + /// Add a new template. + pub fn add_template(mut self, template: S) -> Self + where + S: Into, + { + self.templates.push(template.into()); + self + } - /// Complete the template and returns the source code. - pub fn complete(mut self) -> String { - let mut source = self.templates.remove(0); + /// Complete the template and returns the source code. + pub fn complete(mut self) -> String { + let mut source = self.templates.remove(0); - for s in self.templates.into_iter() { - source.push_str(&s); - } + for s in self.templates.into_iter() { + source.push_str(&s); + } - let template = text_placeholder::Template::new(&source); - let mut context = HashMap::new(); + let template = text_placeholder::Template::new(&source); + let mut context = HashMap::new(); - for (key, value) in self.items.iter() { - context.insert(key.as_str(), value.as_str()); - } + for (key, value) in self.items.iter() { + context.insert(key.as_str(), value.as_str()); + } - template.fill_with_hashmap(&context) - } + template.fill_with_hashmap(&context) + } } diff --git a/burn-wgpu/src/kernel/unary.rs b/burn-wgpu/src/kernel/unary.rs index 7584be7b1f..e1f28dac12 100644 --- a/burn-wgpu/src/kernel/unary.rs +++ b/burn-wgpu/src/kernel/unary.rs @@ -7,202 +7,202 @@ kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl"); /// Creates a unary kernel. #[macro_export] macro_rules! unary { - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - let source = $crate::kernel::UnaryRaw::source(); - source.register("body", format!("output[id] = {}(input[id]);", $func)) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + let source = $crate::kernel::UnaryRaw::source(); + source.register("body", format!("output[id] = {}(input[id]);", $func)) + } + } + }; + ( $struct:ident, body $body:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source().register("body", $body) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryRaw::source().register("body", $body) + } + } + }; + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source() - .register("body", format!("output[id] = {}(input[id]);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryRaw::source() + .register("body", format!("output[id] = {}(input[id]);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Creates a unary inplace kernel. #[macro_export] macro_rules! unary_inplace { - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source() - .register("body", format!("input[id] = {}(input[id]);", $func)) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source() + .register("body", format!("input[id] = {}(input[id]);", $func)) + } + } + }; + ( $struct:ident, body $body:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source().register("body", $body) - } - } - }; - ( + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source().register("body", $body) + } + } + }; + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source() - .register("body", format!("input[id] = {}(input[id]);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source() + .register("body", format!("input[id] = {}(input[id]);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Execute a unary kernel using the default settings. pub fn unary_default( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - unary::(input) + unary::(input) } /// Execute a unary inplace kernel using the default settings. pub fn unary_inplace_default( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - unary_inplace::(input) + unary_inplace::(input) } /// Execute a unary inplace kernel using the provided WORKGROUP. pub fn unary_inplace< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); + let num_elems = input.shape.num_elements(); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); - input.client.execute(Box::new(kernel), &[&input.handle]); + input.client.execute(Box::new(kernel), &[&input.handle]); - input + input } /// Execute a unary kernel using the provided WORKGROUP. pub fn unary( - input: WgpuTensor, + input: WgpuTensor, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let buffer = input.client.empty(num_elems * core::mem::size_of::()); - let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer); - // Since we don't handle the stride inside the kernel, the output tensor have the same strides - // as the input tensor. It might not be in the default format. - output.strides = input.strides; - - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - input - .client - .execute(Box::new(kernel), &[&input.handle, &output.handle]); - - output + let num_elems = input.shape.num_elements(); + let buffer = input.client.empty(num_elems * core::mem::size_of::()); + let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer); + // Since we don't handle the stride inside the kernel, the output tensor have the same strides + // as the input tensor. It might not be in the default format. + output.strides = input.strides; + + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + input + .client + .execute(Box::new(kernel), &[&input.handle, &output.handle]); + + output } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; - unary!(TestKernel, func "log"); - unary_inplace!(TestKernelInplace, func "log"); + unary!(TestKernel, func "log"); + unary_inplace!(TestKernelInplace, func "log"); - #[test] - fn unary_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary::(tensor.into_primitive()); - let expected = tensor_ref.log(); + let actual = unary::(tensor.into_primitive()); + let expected = tensor_ref.log(); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } - #[test] - fn unary_inplace_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_inplace_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary_inplace::(tensor.into_primitive()); - let expected = tensor_ref.log(); + let actual = unary_inplace::(tensor.into_primitive()); + let expected = tensor_ref.log(); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } - - #[test] - fn tanh_should_not_have_numerical_bugs_on_macos() { - fn tanh_one_value(input: f32) -> f32 { - let tensor = Tensor::::ones([1]) * input; - let output = tensor.tanh().into_primitive(); - Tensor::::from_primitive(output) - .into_data() - .value[0] + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); } - let ok = tanh_one_value(43.0); // metal tanh gives 1.0 which is the right answer - let zero = tanh_one_value(44.0); // metal tanh gives zero when within 43.67..44.36 - let nan = tanh_one_value(45.0); // metal tanh gives nan when over 44.36 - let neg = tanh_one_value(-45.0); // metal works correctly here - - assert!(!ok.is_nan() && ok == 1.0); - assert!(!zero.is_nan() && zero == 1.0); - assert!(!nan.is_nan() && nan == 1.0); - assert!(!neg.is_nan() && neg == -1.0); - } + #[test] + fn tanh_should_not_have_numerical_bugs_on_macos() { + fn tanh_one_value(input: f32) -> f32 { + let tensor = Tensor::::ones([1]) * input; + let output = tensor.tanh().into_primitive(); + Tensor::::from_primitive(output) + .into_data() + .value[0] + } + + let ok = tanh_one_value(43.0); // metal tanh gives 1.0 which is the right answer + let zero = tanh_one_value(44.0); // metal tanh gives zero when within 43.67..44.36 + let nan = tanh_one_value(45.0); // metal tanh gives nan when over 44.36 + let neg = tanh_one_value(-45.0); // metal works correctly here + + assert!(!ok.is_nan() && ok == 1.0); + assert!(!zero.is_nan() && zero == 1.0); + assert!(!nan.is_nan() && nan == 1.0); + assert!(!neg.is_nan() && neg == -1.0); + } } diff --git a/burn-wgpu/src/kernel/unary_scalar.rs b/burn-wgpu/src/kernel/unary_scalar.rs index fabaa468a9..dc68443df5 100644 --- a/burn-wgpu/src/kernel/unary_scalar.rs +++ b/burn-wgpu/src/kernel/unary_scalar.rs @@ -3,218 +3,218 @@ use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::Wg kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl"); kernel_wgsl!( - UnaryScalarInplaceRaw, - "../template/unary_scalar_inplace.wgsl" + UnaryScalarInplaceRaw, + "../template/unary_scalar_inplace.wgsl" ); /// Creates a unary scalar kernel. #[macro_export] macro_rules! unary_scalar { - ( + ( $struct:ident, ops $ops:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = lhs[id] {} rhs;", $ops)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() + .register("body", format!("output[id] = lhs[id] {} rhs;", $ops)) + } + } + }; - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() + .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) + } + } + }; - ( + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() + .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Creates a unary scalar inplace kernel. #[macro_export] macro_rules! unary_scalar_inplace { - ( + ( $struct:ident, ops $ops:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() + .register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops)) + } + } + }; - ( + ( $struct:ident, body $body:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body) + } + } + }; - ( + ( $struct:ident, func $func:expr ) => { - pub struct $struct; + pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) - } - } - }; + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() + .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) + } + } + }; - ( + ( $struct:ident, func $func:expr, include $file:expr ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) - .add_template(include_str!($file)) - } - } - }; + pub struct $struct; + + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() + .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) + .add_template(include_str!($file)) + } + } + }; } /// Execute a unary scalar kernel using the default settings. pub fn unary_scalar_default( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - unary_scalar::(lhs, scalar) + unary_scalar::(lhs, scalar) } /// Execute a unary scalar kernel using the provided WORKGROUP. pub fn unary_scalar< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); - let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs_handle, &output.handle], - ); - - output + let num_elems = lhs.shape.num_elements(); + let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs_handle, &output.handle], + ); + + output } /// Execute a unary scalar inplace kernel using the default settings. pub fn unary_scalar_inplace_default( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - unary_scalar_inplace::(lhs, scalar) + unary_scalar_inplace::(lhs, scalar) } /// Execute a unary scalar inplace kernel using the provided WORKGROUP. pub fn unary_scalar_inplace< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, >( - lhs: WgpuTensor, - scalar: E, + lhs: WgpuTensor, + scalar: E, ) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); + let num_elems = lhs.shape.num_elements(); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); - lhs - .client - .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); - lhs + lhs } #[cfg(test)] mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Distribution, Tensor}; - unary_scalar!(TestKernel, ops "*"); - unary_scalar_inplace!(TestKernelInplace, ops "*"); + unary_scalar!(TestKernel, ops "*"); + unary_scalar_inplace!(TestKernelInplace, ops "*"); - #[test] - fn unary_scalar_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_scalar_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary_scalar::(tensor.into_primitive(), 5.0); - let expected = tensor_ref.mul_scalar(5.0); + let actual = unary_scalar::(tensor.into_primitive(), 5.0); + let expected = tensor_ref.mul_scalar(5.0); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } - #[test] - fn unary_scalar_inplace_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); + #[test] + fn unary_scalar_inplace_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary_scalar_inplace::(tensor.into_primitive(), 5.0); - let expected = tensor_ref.mul_scalar(5.0); + let actual = + unary_scalar_inplace::(tensor.into_primitive(), 5.0); + let expected = tensor_ref.mul_scalar(5.0); - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } } diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index 52ab6703b7..d04b282eda 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -32,15 +32,15 @@ mod fusion; #[cfg(test)] mod tests { - use super::*; + use super::*; - pub type TestBackend = Wgpu; - pub type ReferenceBackend = burn_ndarray::NdArray; + pub type TestBackend = Wgpu; + pub type ReferenceBackend = burn_ndarray::NdArray; - pub type TestTensor = burn_tensor::Tensor; - pub type ReferenceTensor = burn_tensor::Tensor; - pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensor = burn_tensor::Tensor; + pub type ReferenceTensor = burn_tensor::Tensor; + pub type TestTensorInt = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!(); + burn_autodiff::testgen_all!(); } diff --git a/burn-wgpu/src/ops/activation_ops.rs b/burn-wgpu/src/ops/activation_ops.rs index ff6dea19a0..2256628602 100644 --- a/burn-wgpu/src/ops/activation_ops.rs +++ b/burn-wgpu/src/ops/activation_ops.rs @@ -1,25 +1,25 @@ use burn_tensor::ops::{ActivationOps, FloatTensor}; use crate::{ - element::{FloatElement, IntElement}, - kernel::{unary_default, unary_inplace_default}, - unary, unary_inplace, GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel::{unary_default, unary_inplace_default}, + unary, unary_inplace, GraphicsApi, Wgpu, }; impl ActivationOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn relu(tensor: FloatTensor) -> FloatTensor { - unary!(Relu, body "output[id] = max(input[id], 0.0);"); - unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);"); + fn relu(tensor: FloatTensor) -> FloatTensor { + unary!(Relu, body "output[id] = max(input[id], 0.0);"); + unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);"); - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } - unary_default::(tensor) - } + unary_default::(tensor) + } } diff --git a/burn-wgpu/src/ops/base.rs b/burn-wgpu/src/ops/base.rs index bf06dcb824..93ea576a73 100644 --- a/burn-wgpu/src/ops/base.rs +++ b/burn-wgpu/src/ops/base.rs @@ -1,78 +1,78 @@ use crate::{ - compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi, - WgpuDevice, + compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi, + WgpuDevice, }; use burn_tensor::{Data, Reader, Shape}; pub fn from_data( - data: Data, - device: &WgpuDevice, + data: Data, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); - let buffer = client.create(E::as_bytes(&data.value)); + let client = compute_client::(device); + let buffer = client.create(E::as_bytes(&data.value)); - WgpuTensor::new(client, device.clone(), data.shape, buffer) + WgpuTensor::new(client, device.clone(), data.shape, buffer) } pub fn into_data(tensor: WgpuTensor) -> Reader> { - let tensor = kernel::into_contiguous(tensor); + let tensor = kernel::into_contiguous(tensor); - tensor - .client - .read(&tensor.handle) - .map(|bytes| Data::new(E::from_bytes(&bytes).to_vec(), tensor.shape)) + tensor + .client + .read(&tensor.handle) + .map(|bytes| Data::new(E::from_bytes(&bytes).to_vec(), tensor.shape)) } pub fn bool_into_data(tensor: WgpuTensor) -> Reader> { - let tensor = kernel::into_contiguous(tensor); + let tensor = kernel::into_contiguous(tensor); - tensor.client.read(&tensor.handle).map(|bytes| { - Data::new( - u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), - tensor.shape, - ) - }) + tensor.client.read(&tensor.handle).map(|bytes| { + Data::new( + u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), + tensor.shape, + ) + }) } pub fn to_device( - tensor: WgpuTensor, - device: &WgpuDevice, + tensor: WgpuTensor, + device: &WgpuDevice, ) -> WgpuTensor { - if &tensor.device == device { - return tensor; - } + if &tensor.device == device { + return tensor; + } - let client = compute_client::(device); - tensor.to_client(client, device.clone()) + let client = compute_client::(device); + tensor.to_client(client, device.clone()) } pub fn empty( - shape: Shape, - device: &WgpuDevice, + shape: Shape, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); - let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); + let client = compute_client::(device); + let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - WgpuTensor::new(client, device.clone(), shape, buffer) + WgpuTensor::new(client, device.clone(), shape, buffer) } pub fn swap_dims( - mut tensor: WgpuTensor, - dim1: usize, - dim2: usize, + mut tensor: WgpuTensor, + dim1: usize, + dim2: usize, ) -> WgpuTensor { - tensor.strides.swap(dim1, dim2); - tensor.shape.dims.swap(dim1, dim2); + tensor.strides.swap(dim1, dim2); + tensor.shape.dims.swap(dim1, dim2); - tensor + tensor } pub fn reshape( - tensor: WgpuTensor, - shape: Shape, + tensor: WgpuTensor, + shape: Shape, ) -> WgpuTensor { - // TODO: Not force standard layout all the time (improve performance). - let tensor = kernel::into_contiguous(tensor); + // TODO: Not force standard layout all the time (improve performance). + let tensor = kernel::into_contiguous(tensor); - WgpuTensor::new(tensor.client, tensor.device, shape, tensor.handle) + WgpuTensor::new(tensor.client, tensor.device, shape, tensor.handle) } diff --git a/burn-wgpu/src/ops/bool_ops.rs b/burn-wgpu/src/ops/bool_ops.rs index 938054b116..82470694f2 100644 --- a/burn-wgpu/src/ops/bool_ops.rs +++ b/burn-wgpu/src/ops/bool_ops.rs @@ -1,8 +1,8 @@ use crate::{ - element::{FloatElement, IntElement}, - kernel, - tensor::WgpuTensor, - GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel, + tensor::WgpuTensor, + GraphicsApi, Wgpu, }; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; use burn_tensor::{ops::BoolTensorOps, Data, Shape}; @@ -11,117 +11,116 @@ use std::ops::Range; impl BoolTensorOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - super::empty::(shape, device) - } - - fn bool_shape(tensor: &BoolTensor) -> Shape { - tensor.shape.clone() - } - - fn bool_into_data(tensor: BoolTensor) -> Reader> { - super::bool_into_data(tensor) - } - - fn bool_from_data( - data: Data, - device: &Device, - ) -> BoolTensor { - let data: Data = Data::new( - data - .value - .into_iter() - .map(|c| match c { - true => 1, - false => 0, - }) - .collect(), - data.shape, - ); - super::from_data::(data, device) - } - - fn bool_into_int(tensor: BoolTensor) -> IntTensor { - if std::mem::size_of::() == std::mem::size_of::() { - return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + super::empty::(shape, device) } - let device = Self::bool_device(&tensor); - let data = Self::bool_into_data(tensor) - .read_sync() - .expect("Can't convert bool to int with a different type size async") - .convert::(); - - Self::int_from_data(data, &device) - } - - fn bool_device(tensor: &BoolTensor) -> Device { - tensor.device.clone() - } - - fn bool_to_device( - tensor: BoolTensor, - device: &Device, - ) -> BoolTensor { - super::to_device::(tensor, device) - } - - fn bool_reshape( - tensor: BoolTensor, - shape: Shape, - ) -> BoolTensor { - super::reshape(tensor, shape) - } - - fn bool_slice( - tensor: BoolTensor, - ranges: [Range; D2], - ) -> BoolTensor { - kernel::slice(tensor, ranges) - } - - fn bool_slice_assign( - tensor: BoolTensor, - ranges: [Range; D2], - value: BoolTensor, - ) -> BoolTensor { - kernel::slice_assign(tensor, ranges, value) - } - - fn bool_cat( - tensors: Vec>, - dim: usize, - ) -> BoolTensor { - kernel::cat(tensors, dim) - } - - fn bool_equal( - lhs: BoolTensor, - rhs: BoolTensor, - ) -> BoolTensor { - kernel::equal(lhs, rhs) - } - - fn bool_not(tensor: BoolTensor) -> BoolTensor { - kernel::equal_elem(tensor, 0) - } - - fn bool_into_float(tensor: BoolTensor) -> FloatTensor { - kernel::cast(tensor) - } - - fn bool_swap_dims( - mut tensor: BoolTensor, - dim1: usize, - dim2: usize, - ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { - tensor.strides.swap(dim1, dim2); - tensor.shape.dims.swap(dim1, dim2); - - tensor - } + fn bool_shape(tensor: &BoolTensor) -> Shape { + tensor.shape.clone() + } + + fn bool_into_data(tensor: BoolTensor) -> Reader> { + super::bool_into_data(tensor) + } + + fn bool_from_data( + data: Data, + device: &Device, + ) -> BoolTensor { + let data: Data = Data::new( + data.value + .into_iter() + .map(|c| match c { + true => 1, + false => 0, + }) + .collect(), + data.shape, + ); + super::from_data::(data, device) + } + + fn bool_into_int(tensor: BoolTensor) -> IntTensor { + if std::mem::size_of::() == std::mem::size_of::() { + return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); + } + + let device = Self::bool_device(&tensor); + let data = Self::bool_into_data(tensor) + .read_sync() + .expect("Can't convert bool to int with a different type size async") + .convert::(); + + Self::int_from_data(data, &device) + } + + fn bool_device(tensor: &BoolTensor) -> Device { + tensor.device.clone() + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + super::to_device::(tensor, device) + } + + fn bool_reshape( + tensor: BoolTensor, + shape: Shape, + ) -> BoolTensor { + super::reshape(tensor, shape) + } + + fn bool_slice( + tensor: BoolTensor, + ranges: [Range; D2], + ) -> BoolTensor { + kernel::slice(tensor, ranges) + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: [Range; D2], + value: BoolTensor, + ) -> BoolTensor { + kernel::slice_assign(tensor, ranges, value) + } + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> BoolTensor { + kernel::cat(tensors, dim) + } + + fn bool_equal( + lhs: BoolTensor, + rhs: BoolTensor, + ) -> BoolTensor { + kernel::equal(lhs, rhs) + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + kernel::equal_elem(tensor, 0) + } + + fn bool_into_float(tensor: BoolTensor) -> FloatTensor { + kernel::cast(tensor) + } + + fn bool_swap_dims( + mut tensor: BoolTensor, + dim1: usize, + dim2: usize, + ) -> as burn_tensor::backend::Backend>::BoolTensorPrimitive { + tensor.strides.swap(dim1, dim2); + tensor.shape.dims.swap(dim1, dim2); + + tensor + } } diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index 900bddbef6..35dcbe12f3 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -9,13 +9,13 @@ use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; #[cfg(not(feature = "autotune"))] use crate::kernel::reduce::init_reduce_output; use crate::kernel::{ - self, reduce, unary_default, unary_inplace_default, unary_scalar_default, - unary_scalar_inplace_default, + self, reduce, unary_default, unary_inplace_default, unary_scalar_default, + unary_scalar_inplace_default, }; use crate::{unary, unary_inplace, unary_scalar, FloatElement, GraphicsApi, IntElement, Wgpu}; use crate::{unary_scalar_inplace, WgpuDevice}; use burn_tensor::ops::{ - BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, + BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, }; use burn_tensor::{ops::TensorOps, Data, Distribution, Shape}; use burn_tensor::{ElementConversion, Reader}; @@ -24,499 +24,503 @@ use std::ops::Range; impl TensorOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn from_data( - data: Data, D>, - device: &Device, - ) -> FloatTensor { - super::from_data::(data, device) - } - - fn random( - shape: Shape, - distribution: Distribution>, - device: &Device, - ) -> FloatTensor { - match distribution { - Distribution::Default => random_uniform::(shape, device, 0.elem(), 1.elem()), - Distribution::Uniform(low, high) => random_uniform::(shape, device, low, high), - Distribution::Bernoulli(prob) => random_bernoulli::(shape, device, prob.elem()), - Distribution::Normal(mean, std) => { - random_normal::(shape, device, mean.elem(), std.elem()) - } - } - } - - fn shape(tensor: &FloatTensor) -> Shape { - tensor.shape.clone() - } - - fn into_data(tensor: FloatTensor) -> Reader, D>> { - super::into_data(tensor) - } - - fn device(tensor: &FloatTensor) -> Device { - tensor.device.clone() - } - - fn to_device( - tensor: FloatTensor, - device: &Device, - ) -> FloatTensor { - super::to_device::(tensor, device) - } - - fn empty(shape: Shape, device: &Device) -> FloatTensor { - super::empty::(shape, device) - } - - fn add( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::add(lhs, rhs) - } - - fn add_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::add_scalar(lhs, rhs) - } - - fn zeros(shape: Shape, device: &Device) -> FloatTensor { - numeric::zeros::(shape, device) - } - - fn full( - shape: Shape, - fill_value: FloatElem, - device: &WgpuDevice, - ) -> FloatTensor { - numeric::full::(shape, device, fill_value) - } - - fn ones(shape: Shape, device: &Device) -> FloatTensor { - numeric::ones::(shape, device) - } - - fn sub( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::sub(lhs, rhs) - } - - fn sub_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::sub_scalar(lhs, rhs) - } - - fn mul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::mul(lhs, rhs) - } - - fn mul_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::mul_scalar(lhs, rhs) - } - - fn div( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - numeric::div(lhs, rhs) - } - - fn div_scalar( - lhs: FloatTensor, - rhs: FloatElem, - ) -> FloatTensor { - numeric::div_scalar(lhs, rhs) - } - - fn matmul( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> FloatTensor { - #[cfg(feature = "autotune")] - { - matmul_autotune(lhs, rhs) - } - - #[cfg(not(feature = "autotune"))] - { - let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_vec4(lhs, rhs, out) - } - } - - fn swap_dims( - tensor: FloatTensor, - dim1: usize, - dim2: usize, - ) -> FloatTensor { - super::swap_dims(tensor, dim1, dim2) - } - - fn reshape( - tensor: FloatTensor, - shape: Shape, - ) -> FloatTensor { - super::reshape(tensor, shape) - } - - fn gather( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - ) -> FloatTensor { - kernel::gather(dim, tensor, indices) - } - - fn scatter( - dim: usize, - tensor: FloatTensor, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - kernel::scatter(dim, tensor, indices, value) - } - - fn select( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - ) -> FloatTensor { - kernel::select(tensor, dim, indices) - } - - fn select_assign( - tensor: FloatTensor, - dim: usize, - indices: IntTensor, - value: FloatTensor, - ) -> FloatTensor { - kernel::select_assign(tensor, dim, indices, value) - } - - fn slice( - tensor: FloatTensor, - ranges: [Range; D2], - ) -> FloatTensor { - kernel::slice(tensor, ranges) - } - - fn slice_assign( - tensor: FloatTensor, - ranges: [Range; D2], - value: FloatTensor, - ) -> FloatTensor { - kernel::slice_assign(tensor, ranges, value) - } - - fn mask_where( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatTensor, - ) -> FloatTensor { - kernel::mask_where(tensor, mask, value) - } - - fn mask_fill( - tensor: FloatTensor, - mask: BoolTensor, - value: FloatElem, - ) -> FloatTensor { - kernel::mask_fill(tensor, mask, value) - } - - fn equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::equal(lhs, rhs) - } - - fn equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::equal_elem(lhs, rhs) - } - - fn greater( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::greater(lhs, rhs) - } - - fn greater_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::greater_elem(lhs, rhs) - } - - fn greater_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::greater_equal(lhs, rhs) - } - - fn greater_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::greater_equal_elem(lhs, rhs) - } - - fn lower( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::lower(lhs, rhs) - } - - fn lower_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::lower_elem(lhs, rhs) - } - - fn lower_equal( - lhs: FloatTensor, - rhs: FloatTensor, - ) -> BoolTensor { - kernel::lower_equal(lhs, rhs) - } - - fn lower_equal_elem( - lhs: FloatTensor, - rhs: FloatElem, - ) -> BoolTensor { - kernel::lower_equal_elem(lhs, rhs) - } - - fn sum(tensor: FloatTensor) -> FloatTensor { - reduce::sum(tensor) - } - - fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[cfg(feature = "autotune")] - { - reduce::sum_dim_autotune(tensor, dim) - } - - #[cfg(not(feature = "autotune"))] - { - let output = init_reduce_output(&tensor, dim); - reduce::sum_dim(tensor, output, dim) - } - } + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor { + super::from_data::(data, device) + } + + fn random( + shape: Shape, + distribution: Distribution>, + device: &Device, + ) -> FloatTensor { + match distribution { + Distribution::Default => random_uniform::(shape, device, 0.elem(), 1.elem()), + Distribution::Uniform(low, high) => random_uniform::(shape, device, low, high), + Distribution::Bernoulli(prob) => { + random_bernoulli::(shape, device, prob.elem()) + } + Distribution::Normal(mean, std) => { + random_normal::(shape, device, mean.elem(), std.elem()) + } + } + } + + fn shape(tensor: &FloatTensor) -> Shape { + tensor.shape.clone() + } + + fn into_data(tensor: FloatTensor) -> Reader, D>> { + super::into_data(tensor) + } + + fn device(tensor: &FloatTensor) -> Device { + tensor.device.clone() + } + + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor { + super::to_device::(tensor, device) + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + super::empty::(shape, device) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::add(lhs, rhs) + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::add_scalar(lhs, rhs) + } + + fn zeros(shape: Shape, device: &Device) -> FloatTensor { + numeric::zeros::(shape, device) + } + + fn full( + shape: Shape, + fill_value: FloatElem, + device: &WgpuDevice, + ) -> FloatTensor { + numeric::full::(shape, device, fill_value) + } + + fn ones(shape: Shape, device: &Device) -> FloatTensor { + numeric::ones::(shape, device) + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::sub(lhs, rhs) + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::sub_scalar(lhs, rhs) + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::mul(lhs, rhs) + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::mul_scalar(lhs, rhs) + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + numeric::div(lhs, rhs) + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + numeric::div_scalar(lhs, rhs) + } + + fn matmul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + #[cfg(feature = "autotune")] + { + matmul_autotune(lhs, rhs) + } + + #[cfg(not(feature = "autotune"))] + { + let out = init_matmul_output(&lhs, &rhs); + matmul_tiling_2d_vec4(lhs, rhs, out) + } + } + + fn swap_dims( + tensor: FloatTensor, + dim1: usize, + dim2: usize, + ) -> FloatTensor { + super::swap_dims(tensor, dim1, dim2) + } + + fn reshape( + tensor: FloatTensor, + shape: Shape, + ) -> FloatTensor { + super::reshape(tensor, shape) + } + + fn gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + kernel::gather(dim, tensor, indices) + } + + fn scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + kernel::scatter(dim, tensor, indices, value) + } + + fn select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + kernel::select(tensor, dim, indices) + } + + fn select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + kernel::select_assign(tensor, dim, indices, value) + } + + fn slice( + tensor: FloatTensor, + ranges: [Range; D2], + ) -> FloatTensor { + kernel::slice(tensor, ranges) + } + + fn slice_assign( + tensor: FloatTensor, + ranges: [Range; D2], + value: FloatTensor, + ) -> FloatTensor { + kernel::slice_assign(tensor, ranges, value) + } + + fn mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor { + kernel::mask_where(tensor, mask, value) + } + + fn mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + kernel::mask_fill(tensor, mask, value) + } + + fn equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::equal(lhs, rhs) + } + + fn equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::equal_elem(lhs, rhs) + } + + fn greater( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::greater(lhs, rhs) + } + + fn greater_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::greater_elem(lhs, rhs) + } + + fn greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::greater_equal(lhs, rhs) + } - fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - #[cfg(feature = "autotune")] - { - reduce::mean_dim_autotune(tensor, dim) + fn greater_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::greater_equal_elem(lhs, rhs) } - #[cfg(not(feature = "autotune"))] - { - let output = init_reduce_output(&tensor, dim); - reduce::mean_dim(tensor, output, dim) + fn lower( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::lower(lhs, rhs) } - } - fn to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - kernel::cast(tensor.clone()) - } + fn lower_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::lower_elem(lhs, rhs) + } - fn from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - kernel::cast(tensor) - } + fn lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> BoolTensor { + kernel::lower_equal(lhs, rhs) + } - fn exp(lhs: FloatTensor) -> FloatTensor { - unary!(Exp, func "exp"); - unary_inplace!(ExpInplace, func "exp"); + fn lower_equal_elem( + lhs: FloatTensor, + rhs: FloatElem, + ) -> BoolTensor { + kernel::lower_equal_elem(lhs, rhs) + } - if lhs.can_mut() { - return unary_inplace_default::(lhs); + fn sum(tensor: FloatTensor) -> FloatTensor { + reduce::sum(tensor) } - unary_default::(lhs) - } + fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[cfg(feature = "autotune")] + { + reduce::sum_dim_autotune(tensor, dim) + } + + #[cfg(not(feature = "autotune"))] + { + let output = init_reduce_output(&tensor, dim); + reduce::sum_dim(tensor, output, dim) + } + } - fn log(tensor: FloatTensor) -> FloatTensor { - unary!(Log, func "log"); - unary_inplace!(LogInplace, func "log"); + fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[cfg(feature = "autotune")] + { + reduce::mean_dim_autotune(tensor, dim) + } + + #[cfg(not(feature = "autotune"))] + { + let output = init_reduce_output(&tensor, dim); + reduce::mean_dim(tensor, output, dim) + } + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + fn to_full_precision( + tensor: &FloatTensor, + ) -> FloatTensor, D> { + kernel::cast(tensor.clone()) } - unary_default::(tensor) - } + fn from_full_precision( + tensor: FloatTensor, D>, + ) -> FloatTensor { + kernel::cast(tensor) + } + + fn exp(lhs: FloatTensor) -> FloatTensor { + unary!(Exp, func "exp"); + unary_inplace!(ExpInplace, func "exp"); - fn log1p(tensor: FloatTensor) -> FloatTensor { - unary!(Log1p, body "output[id] = log(1.0 + input[id]);"); - unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);"); + if lhs.can_mut() { + return unary_inplace_default::(lhs); + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + unary_default::(lhs) } - unary_default::(tensor) - } + fn log(tensor: FloatTensor) -> FloatTensor { + unary!(Log, func "log"); + unary_inplace!(LogInplace, func "log"); - fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { - unary_scalar!(Powf, func "powf", include "../template/powf.wgsl"); - unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl"); + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs.elem()); + unary_default::(tensor) } - unary_scalar_default::(lhs, rhs.elem()) - } + fn log1p(tensor: FloatTensor) -> FloatTensor { + unary!(Log1p, body "output[id] = log(1.0 + input[id]);"); + unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);"); - fn sqrt(tensor: FloatTensor) -> FloatTensor { - unary!(Sqrt, func "sqrt"); - unary_inplace!(SqrtInplace, func "sqrt"); + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + unary_default::(tensor) } - unary_default::(tensor) - } + fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { + unary_scalar!(Powf, func "powf", include "../template/powf.wgsl"); + unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl"); - fn abs(tensor: FloatTensor) -> FloatTensor { - unary!(Abs, func "abs"); - unary_inplace!(AbsInplace, func "abs"); + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs.elem()); + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + unary_scalar_default::(lhs, rhs.elem()) } - unary_default::(tensor) - } + fn sqrt(tensor: FloatTensor) -> FloatTensor { + unary!(Sqrt, func "sqrt"); + unary_inplace!(SqrtInplace, func "sqrt"); - fn cos(tensor: FloatTensor) -> FloatTensor { - unary!(Cos, func "cos"); - unary_inplace!(CosInplace, func "cos"); + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + unary_default::(tensor) } - unary_default::(tensor) - } + fn abs(tensor: FloatTensor) -> FloatTensor { + unary!(Abs, func "abs"); + unary_inplace!(AbsInplace, func "abs"); - fn sin(tensor: FloatTensor) -> FloatTensor { - unary!(Sin, func "sin"); - unary_inplace!(SinInplace, func "sin"); + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + unary_default::(tensor) } - unary_default::(tensor) - } + fn cos(tensor: FloatTensor) -> FloatTensor { + unary!(Cos, func "cos"); + unary_inplace!(CosInplace, func "cos"); + + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } + + unary_default::(tensor) + } - fn tanh(tensor: FloatTensor) -> FloatTensor { - // Metal has a weird numerical behaviour with tanh which require a new function - #[cfg(target_os = "macos")] - unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl"); - #[cfg(target_os = "macos")] - unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl"); + fn sin(tensor: FloatTensor) -> FloatTensor { + unary!(Sin, func "sin"); + unary_inplace!(SinInplace, func "sin"); - #[cfg(not(target_os = "macos"))] - unary!(Tanh, func "tanh"); - #[cfg(not(target_os = "macos"))] - unary_inplace!(TanhInplace, func "tanh"); + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + unary_default::(tensor) } - unary_default::(tensor) - } + fn tanh(tensor: FloatTensor) -> FloatTensor { + // Metal has a weird numerical behaviour with tanh which require a new function + #[cfg(target_os = "macos")] + unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl"); + #[cfg(target_os = "macos")] + unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl"); - fn erf(tensor: FloatTensor) -> FloatTensor { - unary!(Erf, func "erf", include "../template/erf.wgsl"); - unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl"); + #[cfg(not(target_os = "macos"))] + unary!(Tanh, func "tanh"); + #[cfg(not(target_os = "macos"))] + unary_inplace!(TanhInplace, func "tanh"); - if tensor.can_mut() { - return unary_inplace_default::(tensor); + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } + + unary_default::(tensor) } - unary_default::(tensor) - } + fn erf(tensor: FloatTensor) -> FloatTensor { + unary!(Erf, func "erf", include "../template/erf.wgsl"); + unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl"); - fn cat(tensors: Vec>, dim: usize) -> FloatTensor { - kernel::cat(tensors, dim) - } + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } - fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - reduce::argmax(tensor, dim) - } + unary_default::(tensor) + } - fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - reduce::argmin(tensor, dim) - } + fn cat(tensors: Vec>, dim: usize) -> FloatTensor { + kernel::cat(tensors, dim) + } - fn into_int(tensor: FloatTensor) -> IntTensor { - kernel::cast(tensor) - } + fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + reduce::argmax(tensor, dim) + } - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - kernel::clamp_min(tensor, min) - } + fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + reduce::argmin(tensor, dim) + } - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - kernel::clamp_max(tensor, max) - } + fn into_int(tensor: FloatTensor) -> IntTensor { + kernel::cast(tensor) + } - fn clamp( - tensor: FloatTensor, - min: FloatElem, - max: FloatElem, - ) -> FloatTensor { - kernel::clamp(tensor, min, max) - } + fn clamp_min( + tensor: FloatTensor, + min: FloatElem, + ) -> FloatTensor { + kernel::clamp_min(tensor, min) + } - fn recip(tensor: FloatTensor, D>) -> FloatTensor, D> { - unary!(Recip, func "1.0 /"); - unary_inplace!(RecipInplace, func "1.0 /"); + fn clamp_max( + tensor: FloatTensor, + max: FloatElem, + ) -> FloatTensor { + kernel::clamp_max(tensor, max) + } - if tensor.can_mut() { - return unary_inplace_default::(tensor); + fn clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + kernel::clamp(tensor, min, max) } - unary_default::(tensor) - } + fn recip( + tensor: FloatTensor, D>, + ) -> FloatTensor, D> { + unary!(Recip, func "1.0 /"); + unary_inplace!(RecipInplace, func "1.0 /"); + + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } + + unary_default::(tensor) + } } diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs index c79c601fcd..bbef2dd6a6 100644 --- a/burn-wgpu/src/ops/int_ops.rs +++ b/burn-wgpu/src/ops/int_ops.rs @@ -3,8 +3,8 @@ use super::numeric; use crate::kernel::reduce::{self, init_reduce_output}; use crate::kernel::{unary_default, unary_inplace_default}; use crate::{ - element::{FloatElement, IntElement}, - kernel, unary, unary_inplace, GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel, unary, unary_inplace, GraphicsApi, Wgpu, }; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; @@ -14,314 +14,317 @@ use std::ops::Range; impl IntTensorOps> for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn int_empty(shape: Shape, device: &Device) -> IntTensor { - super::empty::(shape, device) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - tensor.shape.clone() - } - - fn int_into_data(tensor: IntTensor) -> Reader> { - super::into_data(tensor) - } - - fn int_from_data(data: Data, device: &Device) -> IntTensor { - super::from_data::(data, device) - } - - fn int_device(tensor: &IntTensor) -> Device { - tensor.device.clone() - } - - fn int_to_device( - tensor: IntTensor, - device: &Device, - ) -> IntTensor { - super::to_device::(tensor, device) - } - - fn int_reshape( - tensor: IntTensor, - shape: Shape, - ) -> IntTensor { - super::reshape(tensor, shape) - } - - fn int_slice( - tensor: IntTensor, - ranges: [Range; D2], - ) -> IntTensor { - kernel::slice(tensor, ranges) - } - - fn int_slice_assign( - tensor: IntTensor, - ranges: [Range; D2], - value: IntTensor, - ) -> IntTensor { - kernel::slice_assign(tensor, ranges, value) - } - - fn int_mask_where( - tensor: IntTensor, - mask: BoolTensor, - value: IntTensor, - ) -> IntTensor { - kernel::mask_where(tensor, mask, value) - } - - fn int_mask_fill( - tensor: IntTensor, - mask: BoolTensor, - value: IntElem, - ) -> IntTensor { - kernel::mask_fill(tensor, mask, value) - } - - fn int_gather( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - ) -> IntTensor { - kernel::gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - kernel::scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - ) -> IntTensor { - kernel::select(tensor, dim, indices) - } - - fn int_select_assign( - tensor: IntTensor, - dim: usize, - indices: IntTensor, - value: IntTensor, - ) -> IntTensor { - kernel::select_assign(tensor, dim, indices, value) - } - - fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { - kernel::cat(tensors, dim) - } - - fn int_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::equal::(lhs, rhs) - } - - fn int_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::equal_elem::(lhs, rhs) - } - - fn int_greater( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::greater::(lhs, rhs) - } - - fn int_greater_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::greater_elem::(lhs, rhs) - } - - fn int_greater_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::greater_equal::(lhs, rhs) - } - - fn int_greater_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::greater_equal_elem::(lhs, rhs) - } - - fn int_lower( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::lower::(lhs, rhs) - } - - fn int_lower_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::lower_elem::(lhs, rhs) - } - - fn int_lower_equal( - lhs: IntTensor, - rhs: IntTensor, - ) -> BoolTensor { - kernel::lower_equal::(lhs, rhs) - } - - fn int_lower_equal_elem( - lhs: IntTensor, - rhs: IntElem, - ) -> BoolTensor { - kernel::lower_equal_elem::(lhs, rhs) - } - - fn int_add( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::add::(lhs, rhs) - } - - fn int_add_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::add_scalar(lhs, rhs) - } - - fn int_sub( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::sub(lhs, rhs) - } - - fn int_sub_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::sub_scalar(lhs, rhs) - } - - fn int_mul( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::mul(lhs, rhs) - } - - fn int_mul_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::mul_scalar(lhs, rhs) - } - - fn int_div( - lhs: IntTensor, - rhs: IntTensor, - ) -> IntTensor { - numeric::div(lhs, rhs) - } - - fn int_div_scalar( - lhs: IntTensor, - rhs: IntElem, - ) -> IntTensor { - numeric::div_scalar(lhs, rhs) - } - - fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - numeric::zeros::(shape, device) - } - - fn int_ones(shape: Shape, device: &Device) -> IntTensor { - numeric::ones::(shape, device) - } - - fn int_sum(tensor: IntTensor) -> IntTensor { - kernel::reduce::sum(tensor) - } - - fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let output = init_reduce_output(&tensor, dim); - reduce::sum_dim(tensor, output, dim) - } - - fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - let output = init_reduce_output(&tensor, dim); - reduce::mean_dim(tensor, output, dim) - } - - fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax(tensor, dim) - } - - fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin(tensor, dim) - } - - fn int_clamp_min( - tensor: IntTensor, - min: IntElem, - ) -> IntTensor { - kernel::clamp_min(tensor, min) - } - - fn int_clamp_max( - tensor: IntTensor, - max: IntElem, - ) -> IntTensor { - kernel::clamp_max(tensor, max) - } - - fn int_clamp( - tensor: IntTensor, - min: IntElem, - max: IntElem, - ) -> IntTensor { - kernel::clamp(tensor, min, max) - } - - fn int_abs(tensor: IntTensor) -> IntTensor { - unary!(IntAbs, func "abs"); - unary_inplace!(IntAbsInplace, func "abs"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) - } - - fn int_into_float(tensor: IntTensor) -> FloatTensor { - kernel::cast(tensor) - } - - fn int_swap_dims( - mut tensor: IntTensor, - dim1: usize, - dim2: usize, - ) -> IntTensor { - tensor.strides.swap(dim1, dim2); - tensor.shape.dims.swap(dim1, dim2); - - tensor - } + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + super::empty::(shape, device) + } + + fn int_shape(tensor: &IntTensor) -> Shape { + tensor.shape.clone() + } + + fn int_into_data(tensor: IntTensor) -> Reader> { + super::into_data(tensor) + } + + fn int_from_data( + data: Data, + device: &Device, + ) -> IntTensor { + super::from_data::(data, device) + } + + fn int_device(tensor: &IntTensor) -> Device { + tensor.device.clone() + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + super::to_device::(tensor, device) + } + + fn int_reshape( + tensor: IntTensor, + shape: Shape, + ) -> IntTensor { + super::reshape(tensor, shape) + } + + fn int_slice( + tensor: IntTensor, + ranges: [Range; D2], + ) -> IntTensor { + kernel::slice(tensor, ranges) + } + + fn int_slice_assign( + tensor: IntTensor, + ranges: [Range; D2], + value: IntTensor, + ) -> IntTensor { + kernel::slice_assign(tensor, ranges, value) + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + value: IntTensor, + ) -> IntTensor { + kernel::mask_where(tensor, mask, value) + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor { + kernel::mask_fill(tensor, mask, value) + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + kernel::gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + kernel::scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + kernel::select(tensor, dim, indices) + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + kernel::select_assign(tensor, dim, indices, value) + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + kernel::cat(tensors, dim) + } + + fn int_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::equal::(lhs, rhs) + } + + fn int_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::equal_elem::(lhs, rhs) + } + + fn int_greater( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::greater::(lhs, rhs) + } + + fn int_greater_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::greater_elem::(lhs, rhs) + } + + fn int_greater_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::greater_equal::(lhs, rhs) + } + + fn int_greater_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::greater_equal_elem::(lhs, rhs) + } + + fn int_lower( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::lower::(lhs, rhs) + } + + fn int_lower_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::lower_elem::(lhs, rhs) + } + + fn int_lower_equal( + lhs: IntTensor, + rhs: IntTensor, + ) -> BoolTensor { + kernel::lower_equal::(lhs, rhs) + } + + fn int_lower_equal_elem( + lhs: IntTensor, + rhs: IntElem, + ) -> BoolTensor { + kernel::lower_equal_elem::(lhs, rhs) + } + + fn int_add( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::add::(lhs, rhs) + } + + fn int_add_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::add_scalar(lhs, rhs) + } + + fn int_sub( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::sub(lhs, rhs) + } + + fn int_sub_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::sub_scalar(lhs, rhs) + } + + fn int_mul( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::mul(lhs, rhs) + } + + fn int_mul_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::mul_scalar(lhs, rhs) + } + + fn int_div( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + numeric::div(lhs, rhs) + } + + fn int_div_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + numeric::div_scalar(lhs, rhs) + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + numeric::zeros::(shape, device) + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + numeric::ones::(shape, device) + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + kernel::reduce::sum(tensor) + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let output = init_reduce_output(&tensor, dim); + reduce::sum_dim(tensor, output, dim) + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let output = init_reduce_output(&tensor, dim); + reduce::mean_dim(tensor, output, dim) + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + kernel::reduce::argmax(tensor, dim) + } + + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + kernel::reduce::argmin(tensor, dim) + } + + fn int_clamp_min( + tensor: IntTensor, + min: IntElem, + ) -> IntTensor { + kernel::clamp_min(tensor, min) + } + + fn int_clamp_max( + tensor: IntTensor, + max: IntElem, + ) -> IntTensor { + kernel::clamp_max(tensor, max) + } + + fn int_clamp( + tensor: IntTensor, + min: IntElem, + max: IntElem, + ) -> IntTensor { + kernel::clamp(tensor, min, max) + } + + fn int_abs(tensor: IntTensor) -> IntTensor { + unary!(IntAbs, func "abs"); + unary_inplace!(IntAbsInplace, func "abs"); + + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } + + unary_default::(tensor) + } + + fn int_into_float(tensor: IntTensor) -> FloatTensor { + kernel::cast(tensor) + } + + fn int_swap_dims( + mut tensor: IntTensor, + dim1: usize, + dim2: usize, + ) -> IntTensor { + tensor.strides.swap(dim1, dim2); + tensor.shape.dims.swap(dim1, dim2); + + tensor + } } diff --git a/burn-wgpu/src/ops/module_ops.rs b/burn-wgpu/src/ops/module_ops.rs index 0d6b171f31..e523c581fd 100644 --- a/burn-wgpu/src/ops/module_ops.rs +++ b/burn-wgpu/src/ops/module_ops.rs @@ -1,110 +1,113 @@ use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; use crate::{ - element::{FloatElement, IntElement}, - kernel, GraphicsApi, Wgpu, + element::{FloatElement, IntElement}, + kernel, GraphicsApi, Wgpu, }; use burn_tensor::ops::{FloatTensor, IntTensor}; impl ModuleOps for Wgpu where - G: GraphicsApi + 'static, - F: FloatElement, - I: IntElement, + G: GraphicsApi + 'static, + F: FloatElement, + I: IntElement, { - fn conv2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<2>, - ) -> FloatTensor { - kernel::conv::conv2d(x, weight, bias, options) - } + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor { + kernel::conv::conv2d(x, weight, bias, options) + } - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor { - kernel::conv::conv_transpose2d(x, weight, bias, options) - } + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + kernel::conv::conv_transpose2d(x, weight, bias, options) + } - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - kernel::pool::avg_pool2d(x, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + kernel::pool::avg_pool2d(x, kernel_size, stride, padding, count_include_pad) + } - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor { - kernel::pool::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) - } + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + kernel::pool::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) + } - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor { - kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation) - } + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor { + kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation) + } - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - let (output, indices) = - kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + let (output, indices) = + kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); - MaxPool2dWithIndices::new(output, indices) - } + MaxPool2dWithIndices::new(output, indices) + } - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, - indices: IntTensor, - ) -> MaxPool2dBackward> { - MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward( - x, - output_grad, - indices, - kernel_size, - stride, - padding, - dilation, - )) - } + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward> { + MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward( + x, + output_grad, + indices, + kernel_size, + stride, + padding, + dilation, + )) + } - fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { - kernel::pool::adaptive_avg_pool2d(x, output_size) - } + fn adaptive_avg_pool2d( + x: FloatTensor, + output_size: [usize; 2], + ) -> FloatTensor { + kernel::pool::adaptive_avg_pool2d(x, output_size) + } - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - kernel::pool::adaptive_avg_pool2d_backward(x, grad) - } + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + kernel::pool::adaptive_avg_pool2d_backward(x, grad) + } } diff --git a/burn-wgpu/src/ops/numeric.rs b/burn-wgpu/src/ops/numeric.rs index 0ff3fb890a..f4b57e83d8 100644 --- a/burn-wgpu/src/ops/numeric.rs +++ b/burn-wgpu/src/ops/numeric.rs @@ -1,197 +1,197 @@ use crate::compute::{compute_client, WgpuComputeClient}; use crate::kernel::{ - binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default, - unary_scalar_inplace_default, + binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default, + unary_scalar_inplace_default, }; use crate::{ - binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary_scalar, - unary_scalar_inplace, + binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, + unary_scalar, unary_scalar_inplace, }; use crate::{GraphicsApi, WgpuDevice}; use burn_tensor::{Element, ElementConversion, Shape}; pub fn full( - shape: Shape, - device: &WgpuDevice, - value: E, + shape: Shape, + device: &WgpuDevice, + value: E, ) -> WgpuTensor { - let client = compute_client::(device); + let client = compute_client::(device); - full_device(client, shape, device.clone(), value) + full_device(client, shape, device.clone(), value) } pub fn full_device( - client: WgpuComputeClient, - shape: Shape, - device: WgpuDevice, - value: E, + client: WgpuComputeClient, + shape: Shape, + device: WgpuDevice, + value: E, ) -> WgpuTensor { - let empty = empty_device(client, device, shape); + let empty = empty_device(client, device, shape); - unary_scalar_inplace!(Full, body "lhs[id] = rhs;"); - unary_scalar_inplace_default::(empty, value) + unary_scalar_inplace!(Full, body "lhs[id] = rhs;"); + unary_scalar_inplace_default::(empty, value) } pub fn zeros( - shape: Shape, - device: &WgpuDevice, + shape: Shape, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); + let client = compute_client::(device); - zeros_device(client, device.clone(), shape) + zeros_device(client, device.clone(), shape) } pub fn zeros_device( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, ) -> WgpuTensor { - full_device::(client, shape, device, 0.elem()) + full_device::(client, shape, device, 0.elem()) } pub fn ones( - shape: Shape, - device: &WgpuDevice, + shape: Shape, + device: &WgpuDevice, ) -> WgpuTensor { - let client = compute_client::(device); + let client = compute_client::(device); - ones_device(client, device.clone(), shape) + ones_device(client, device.clone(), shape) } pub fn ones_device( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, ) -> WgpuTensor { - full_device::(client, shape, device, 1.elem()) + full_device::(client, shape, device, 1.elem()) } pub fn empty_device( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, ) -> WgpuTensor { - let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); + let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - WgpuTensor::new(client, device, shape, buffer) + WgpuTensor::new(client, device, shape, buffer) } pub fn add( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Add, "+"); - binary_elemwise_inplace!(AddInplace, "+"); + binary_elemwise!(Add, "+"); + binary_elemwise_inplace!(AddInplace, "+"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - if rhs.can_mut_broadcast(&lhs) { - return binary_elemwise_inplace_default::(rhs, lhs); - } + if rhs.can_mut_broadcast(&lhs) { + return binary_elemwise_inplace_default::(rhs, lhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn add_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(AddScalar, ops "+"); - unary_scalar_inplace!(AddScalarInplace, ops "+"); + unary_scalar!(AddScalar, ops "+"); + unary_scalar_inplace!(AddScalarInplace, ops "+"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } pub fn sub( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Sub, "-"); - binary_elemwise_inplace!(SubInplace, "-"); + binary_elemwise!(Sub, "-"); + binary_elemwise_inplace!(SubInplace, "-"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn sub_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(SubScalar, ops "-"); - unary_scalar_inplace!(SubScalarInplace, ops "-"); + unary_scalar!(SubScalar, ops "-"); + unary_scalar_inplace!(SubScalarInplace, ops "-"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } pub fn mul( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Mul, "*"); - binary_elemwise_inplace!(MulInplace, "*"); + binary_elemwise!(Mul, "*"); + binary_elemwise_inplace!(MulInplace, "*"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - if rhs.can_mut_broadcast(&lhs) { - return binary_elemwise_inplace_default::(rhs, lhs); - } + if rhs.can_mut_broadcast(&lhs) { + return binary_elemwise_inplace_default::(rhs, lhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn mul_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(MulScalar, ops "*"); - unary_scalar_inplace!(MulScalarInplace, ops "*"); + unary_scalar!(MulScalar, ops "*"); + unary_scalar_inplace!(MulScalarInplace, ops "*"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } pub fn div( - lhs: WgpuTensor, - rhs: WgpuTensor, + lhs: WgpuTensor, + rhs: WgpuTensor, ) -> WgpuTensor { - binary_elemwise!(Div, "/"); - binary_elemwise_inplace!(DivInplace, "/"); + binary_elemwise!(Div, "/"); + binary_elemwise_inplace!(DivInplace, "/"); - if lhs.can_mut_broadcast(&rhs) { - return binary_elemwise_inplace_default::(lhs, rhs); - } + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace_default::(lhs, rhs); + } - binary_elemwise_default::(lhs, rhs) + binary_elemwise_default::(lhs, rhs) } pub fn div_scalar( - lhs: WgpuTensor, - rhs: E, + lhs: WgpuTensor, + rhs: E, ) -> WgpuTensor { - unary_scalar!(DivScalar, ops "/"); - unary_scalar_inplace!(DivScalarInplace, ops "/"); + unary_scalar!(DivScalar, ops "/"); + unary_scalar_inplace!(DivScalarInplace, ops "/"); - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } + if lhs.can_mut() { + return unary_scalar_inplace_default::(lhs, rhs); + } - unary_scalar_default::(lhs, rhs) + unary_scalar_default::(lhs, rhs) } diff --git a/burn-wgpu/src/tensor/base.rs b/burn-wgpu/src/tensor/base.rs index 9388f2f3fd..f939a0fcdf 100644 --- a/burn-wgpu/src/tensor/base.rs +++ b/burn-wgpu/src/tensor/base.rs @@ -1,6 +1,6 @@ use crate::{ - compute::{WgpuComputeClient, WgpuHandle}, - unary, WgpuDevice, + compute::{WgpuComputeClient, WgpuHandle}, + unary, WgpuDevice, }; use crate::{element::WgpuElement, kernel::unary_default}; use burn_tensor::Shape; @@ -9,135 +9,135 @@ use std::marker::PhantomData; /// The basic tensor primitive struct. #[derive(Debug, Clone)] pub struct WgpuTensor { - /// Compute client for wgpu. - pub client: WgpuComputeClient, - /// The buffer where the data are stored. - pub handle: WgpuHandle, - /// The shape of the current tensor. - pub shape: Shape, - /// The device of the current tensor. - pub device: WgpuDevice, - /// The strides of the current tensor. - pub strides: [usize; D], - pub(crate) elem: PhantomData, + /// Compute client for wgpu. + pub client: WgpuComputeClient, + /// The buffer where the data are stored. + pub handle: WgpuHandle, + /// The shape of the current tensor. + pub shape: Shape, + /// The device of the current tensor. + pub device: WgpuDevice, + /// The strides of the current tensor. + pub strides: [usize; D], + pub(crate) elem: PhantomData, } impl WgpuTensor { - /// Create a new tensor. - pub fn new( - client: WgpuComputeClient, - device: WgpuDevice, - shape: Shape, - handle: WgpuHandle, - ) -> Self { - let mut strides = [0; D]; - - let mut current = 1; - shape - .dims - .iter() - .enumerate() - .rev() - .for_each(|(index, val)| { - strides[index] = current; - current *= val; - }); - - Self { - client, - handle, - shape, - strides, - device, - elem: PhantomData, + /// Create a new tensor. + pub fn new( + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, + handle: WgpuHandle, + ) -> Self { + let mut strides = [0; D]; + + let mut current = 1; + shape + .dims + .iter() + .enumerate() + .rev() + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); + + Self { + client, + handle, + shape, + strides, + device, + elem: PhantomData, + } } - } - - /// Change the context of the current tensor and return the newly transferred tensor. - pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self { - let bytes = self - .client - .read(&self.handle) - .read_sync() - .expect("Can only change client synchronously"); - let handle = client.create(&bytes); - - Self { - client, - handle, - shape: self.shape.clone(), - strides: self.strides, - device, - elem: PhantomData, + + /// Change the context of the current tensor and return the newly transferred tensor. + pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self { + let bytes = self + .client + .read(&self.handle) + .read_sync() + .expect("Can only change client synchronously"); + let handle = client.create(&bytes); + + Self { + client, + handle, + shape: self.shape.clone(), + strides: self.strides, + device, + elem: PhantomData, + } } - } - pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor) -> bool { - if !self.handle.can_mut() { - return false; + pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor) -> bool { + if !self.handle.can_mut() { + return false; + } + + for i in 0..D { + // Output tensor will be different from the mutable tensor. + if self.shape.dims[i] < tensor_other.shape.dims[i] { + return false; + } + } + + true } - for i in 0..D { - // Output tensor will be different from the mutable tensor. - if self.shape.dims[i] < tensor_other.shape.dims[i] { - return false; - } + /// Copy the current tensor. + pub fn copy(&self) -> Self { + // Seems like using the copy buffer from the `wgpu` API leads to race condition when they + // are used inplace afterward. + // + // To avoid them we need to execute the whole pipeline, which leads to significant + // slowdowns. + // + // The solution is just to use a simple unary compute shader. + unary!(CopyBuffer, body "output[id] = input[id];"); + unary_default::(self.clone()) } - true - } - - /// Copy the current tensor. - pub fn copy(&self) -> Self { - // Seems like using the copy buffer from the `wgpu` API leads to race condition when they - // are used inplace afterward. - // - // To avoid them we need to execute the whole pipeline, which leads to significant - // slowdowns. - // - // The solution is just to use a simple unary compute shader. - unary!(CopyBuffer, body "output[id] = input[id];"); - unary_default::(self.clone()) - } - - /// Check if the tensor is safe to mutate. - pub fn can_mut(&self) -> bool { - self.handle.can_mut() - } - - /// Assert that both tensors are on the same device. - pub fn assert_is_on_same_device(&self, other: &Self) { - if self.device != other.device { - panic!( - "Both tensors should be on the same device {:?} != {:?}", - self.device, other.device - ); + /// Check if the tensor is safe to mutate. + pub fn can_mut(&self) -> bool { + self.handle.can_mut() } - } - /// Check if the current tensor is contiguous. - pub fn is_contiguous(&self) -> bool { - let mut current_stride = 0; - for d in 0..D { - let stride = self.strides[D - 1 - d]; + /// Assert that both tensors are on the same device. + pub fn assert_is_on_same_device(&self, other: &Self) { + if self.device != other.device { + panic!( + "Both tensors should be on the same device {:?} != {:?}", + self.device, other.device + ); + } + } - if stride < current_stride { - return false; - } + /// Check if the current tensor is contiguous. + pub fn is_contiguous(&self) -> bool { + let mut current_stride = 0; + for d in 0..D { + let stride = self.strides[D - 1 - d]; - current_stride = stride; - } + if stride < current_stride { + return false; + } + + current_stride = stride; + } - true - } + true + } - pub(crate) fn batch_swapped_with_row_col(&self) -> bool { - for d in 0..D - 2 { - let stride = self.strides[d]; - if stride < self.strides[D - 2] || stride < self.strides[D - 1] { - return true; - } + pub(crate) fn batch_swapped_with_row_col(&self) -> bool { + for d in 0..D - 2 { + let stride = self.strides[d]; + if stride < self.strides[D - 2] || stride < self.strides[D - 1] { + return true; + } + } + false } - false - } } diff --git a/burn/src/lib.rs b/burn/src/lib.rs index 616d7755d8..10028b6674 100644 --- a/burn/src/lib.rs +++ b/burn/src/lib.rs @@ -12,5 +12,5 @@ pub use burn_core::*; /// Train module #[cfg(any(feature = "train", feature = "train-minimal"))] pub mod train { - pub use burn_train::*; + pub use burn_train::*; } diff --git a/examples/custom-renderer/examples/custom-renderer.rs b/examples/custom-renderer/examples/custom-renderer.rs index a2de145f19..94ce8e3e6f 100644 --- a/examples/custom-renderer/examples/custom-renderer.rs +++ b/examples/custom-renderer/examples/custom-renderer.rs @@ -2,5 +2,5 @@ use burn::backend::wgpu::WgpuDevice; use burn::backend::{Autodiff, Wgpu}; fn main() { - custom_renderer::run::>(WgpuDevice::default()); + custom_renderer::run::>(WgpuDevice::default()); } diff --git a/examples/custom-renderer/src/lib.rs b/examples/custom-renderer/src/lib.rs index c08d5088e3..f1e066aff4 100644 --- a/examples/custom-renderer/src/lib.rs +++ b/examples/custom-renderer/src/lib.rs @@ -2,82 +2,82 @@ use burn::data::dataset::source::huggingface::MNISTDataset; use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress}; use burn::train::LearnerBuilder; use burn::{ - config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig, - tensor::backend::AutodiffBackend, + config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig, + tensor::backend::AutodiffBackend, }; use guide::{data::MNISTBatcher, model::ModelConfig}; #[derive(Config)] pub struct MnistTrainingConfig { - #[config(default = 10)] - pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, - #[config(default = 1e-4)] - pub lr: f64, - pub model: ModelConfig, - pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1e-4)] + pub lr: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, } struct CustomRenderer {} impl MetricsRenderer for CustomRenderer { - fn update_train(&mut self, _state: MetricState) {} + fn update_train(&mut self, _state: MetricState) {} - fn update_valid(&mut self, _state: MetricState) {} + fn update_valid(&mut self, _state: MetricState) {} - fn render_train(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_train(&mut self, item: TrainingProgress) { + dbg!(item); + } - fn render_valid(&mut self, item: TrainingProgress) { - dbg!(item); - } + fn render_valid(&mut self, item: TrainingProgress) { + dbg!(item); + } } pub fn run(device: B::Device) { - // Create the configuration. - let config_model = ModelConfig::new(10, 1024); - let config_optimizer = AdamConfig::new(); - let config = MnistTrainingConfig::new(config_model, config_optimizer); + // Create the configuration. + let config_model = ModelConfig::new(10, 1024); + let config_optimizer = AdamConfig::new(); + let config = MnistTrainingConfig::new(config_model, config_optimizer); - B::seed(config.seed); + B::seed(config.seed); - // Create the model and optimizer. - let model = config.model.init(); - let optim = config.optimizer.init(); + // Create the model and optimizer. + let model = config.model.init(); + let optim = config.optimizer.init(); - // Create the batcher. - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); + // Create the batcher. + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); - // Create the dataloaders. - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); + // Create the dataloaders. + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); - // artifact dir does not need to be provided when log_to_file is false - let builder = LearnerBuilder::new("") - .devices(vec![device]) - .num_epochs(config.num_epochs) - .renderer(CustomRenderer {}) - .log_to_file(false); - // can be used to interrupt training - let _interrupter = builder.interrupter(); + // artifact dir does not need to be provided when log_to_file is false + let builder = LearnerBuilder::new("") + .devices(vec![device]) + .num_epochs(config.num_epochs) + .renderer(CustomRenderer {}) + .log_to_file(false); + // can be used to interrupt training + let _interrupter = builder.interrupter(); - let learner = builder.build(model, optim, config.lr); + let learner = builder.build(model, optim, config.lr); - let _model_trained = learner.fit(dataloader_train, dataloader_test); + let _model_trained = learner.fit(dataloader_train, dataloader_test); } diff --git a/examples/custom-training-loop/examples/custom-training-loop.rs b/examples/custom-training-loop/examples/custom-training-loop.rs index 9e0e5ade36..1b264c527c 100644 --- a/examples/custom-training-loop/examples/custom-training-loop.rs +++ b/examples/custom-training-loop/examples/custom-training-loop.rs @@ -2,5 +2,5 @@ use burn::backend::wgpu::WgpuDevice; use burn::backend::{Autodiff, Wgpu}; fn main() { - custom_training_loop::run::>(WgpuDevice::default()); + custom_training_loop::run::>(WgpuDevice::default()); } diff --git a/examples/custom-training-loop/src/lib.rs b/examples/custom-training-loop/src/lib.rs index 80a72a262a..9eb8b154cb 100644 --- a/examples/custom-training-loop/src/lib.rs +++ b/examples/custom-training-loop/src/lib.rs @@ -2,171 +2,171 @@ use std::marker::PhantomData; use burn::data::dataset::source::huggingface::MNISTDataset; use burn::{ - config::Config, - data::dataloader::DataLoaderBuilder, - module::AutodiffModule, - nn::loss::CrossEntropyLoss, - optim::{AdamConfig, GradientsParams, Optimizer}, - tensor::{ - backend::{AutodiffBackend, Backend}, - ElementConversion, Int, Tensor, - }, + config::Config, + data::dataloader::DataLoaderBuilder, + module::AutodiffModule, + nn::loss::CrossEntropyLoss, + optim::{AdamConfig, GradientsParams, Optimizer}, + tensor::{ + backend::{AutodiffBackend, Backend}, + ElementConversion, Int, Tensor, + }, }; use guide::{ - data::{MNISTBatch, MNISTBatcher}, - model::{Model, ModelConfig}, + data::{MNISTBatch, MNISTBatcher}, + model::{Model, ModelConfig}, }; #[derive(Config)] pub struct MnistTrainingConfig { - #[config(default = 10)] - pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, - #[config(default = 1e-4)] - pub lr: f64, - pub model: ModelConfig, - pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1e-4)] + pub lr: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, } pub fn run(device: B::Device) { - // Create the configuration. - let config_model = ModelConfig::new(10, 1024); - let config_optimizer = AdamConfig::new(); - let config = MnistTrainingConfig::new(config_model, config_optimizer); - - B::seed(config.seed); - - // Create the model and optimizer. - let mut model = config.model.init(); - let mut optim = config.optimizer.init(); - - // Create the batcher. - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); - - // Create the dataloaders. - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); - - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); - - // Iterate over our training and validation loop for X epochs. - for epoch in 1..config.num_epochs + 1 { - // Implement our training loop. - for (iteration, batch) in dataloader_train.iter().enumerate() { - let output = model.forward(batch.images); - let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); - let accuracy = accuracy(output, batch.targets); - - println!( - "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", - epoch, - iteration, - loss.clone().into_scalar(), - accuracy, - ); - - // Gradients for the current backward pass - let grads = loss.backward(); - // Gradients linked to each parameter of the model. - let grads = GradientsParams::from_grads(grads, &model); - // Update the model using the optimizer. - model = optim.step(config.lr, model, grads); + // Create the configuration. + let config_model = ModelConfig::new(10, 1024); + let config_optimizer = AdamConfig::new(); + let config = MnistTrainingConfig::new(config_model, config_optimizer); + + B::seed(config.seed); + + // Create the model and optimizer. + let mut model = config.model.init(); + let mut optim = config.optimizer.init(); + + // Create the batcher. + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); + + // Create the dataloaders. + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); + + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); + + // Iterate over our training and validation loop for X epochs. + for epoch in 1..config.num_epochs + 1 { + // Implement our training loop. + for (iteration, batch) in dataloader_train.iter().enumerate() { + let output = model.forward(batch.images); + let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); + let accuracy = accuracy(output, batch.targets); + + println!( + "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", + epoch, + iteration, + loss.clone().into_scalar(), + accuracy, + ); + + // Gradients for the current backward pass + let grads = loss.backward(); + // Gradients linked to each parameter of the model. + let grads = GradientsParams::from_grads(grads, &model); + // Update the model using the optimizer. + model = optim.step(config.lr, model, grads); + } + + // Get the model without autodiff. + let model_valid = model.valid(); + + // Implement our validation loop. + for (iteration, batch) in dataloader_test.iter().enumerate() { + let output = model_valid.forward(batch.images); + let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); + let accuracy = accuracy(output, batch.targets); + + println!( + "[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}", + iteration, + epoch, + loss.clone().into_scalar(), + accuracy, + ); + } } - - // Get the model without autodiff. - let model_valid = model.valid(); - - // Implement our validation loop. - for (iteration, batch) in dataloader_test.iter().enumerate() { - let output = model_valid.forward(batch.images); - let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone()); - let accuracy = accuracy(output, batch.targets); - - println!( - "[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}", - iteration, - epoch, - loss.clone().into_scalar(), - accuracy, - ); - } - } } /// Create out own accuracy metric calculation. fn accuracy(output: Tensor, targets: Tensor) -> f32 { - let predictions = output.argmax(1).squeeze(1); - let num_predictions: usize = targets.dims().iter().product(); - let num_corrects = predictions.equal(targets).int().sum().into_scalar(); + let predictions = output.argmax(1).squeeze(1); + let num_predictions: usize = targets.dims().iter().product(); + let num_corrects = predictions.equal(targets).int().sum().into_scalar(); - num_corrects.elem::() / num_predictions as f32 * 100.0 + num_corrects.elem::() / num_predictions as f32 * 100.0 } #[allow(dead_code)] struct Learner1 where - B: AutodiffBackend, + B: AutodiffBackend, { - model: Model, - optim: O, + model: Model, + optim: O, } #[allow(dead_code)] struct Learner2 { - model: M, - optim: O, + model: M, + optim: O, } #[allow(dead_code)] struct Learner3 { - model: M, - optim: O, - _b: PhantomData, + model: M, + optim: O, + _b: PhantomData, } #[allow(dead_code)] impl Learner1 where - B: AutodiffBackend, - O: Optimizer, B>, + B: AutodiffBackend, + O: Optimizer, B>, { - pub fn step1(&mut self, _batch: MNISTBatch) { - // - } + pub fn step1(&mut self, _batch: MNISTBatch) { + // + } } #[allow(dead_code)] impl Learner2, O> where - B: AutodiffBackend, - O: Optimizer, B>, + B: AutodiffBackend, + O: Optimizer, B>, { - pub fn step2(&mut self, _batch: MNISTBatch) { - // - } + pub fn step2(&mut self, _batch: MNISTBatch) { + // + } } #[allow(dead_code)] impl Learner2 { - pub fn step3(&mut self, _batch: MNISTBatch) - where - B: AutodiffBackend, - M: AutodiffModule, - O: Optimizer, - { - // - } + pub fn step3(&mut self, _batch: MNISTBatch) + where + B: AutodiffBackend, + M: AutodiffModule, + O: Optimizer, + { + // + } } diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index 67912aa1a6..9a2e9d96fc 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -1,76 +1,76 @@ use burn::tensor::{Distribution, Tensor}; use custom_wgpu_kernel::{ - matmul_add_relu_custom, matmul_add_relu_reference, AutodiffBackend, Backend, + matmul_add_relu_custom, matmul_add_relu_reference, AutodiffBackend, Backend, }; fn inference() { - let lhs = Tensor::::random([1, 32, 32], Distribution::Default); - let rhs = Tensor::random([32, 32, 32], Distribution::Default); - let bias = Tensor::random([32, 32, 32], Distribution::Default); + let lhs = Tensor::::random([1, 32, 32], Distribution::Default); + let rhs = Tensor::random([32, 32, 32], Distribution::Default); + let bias = Tensor::random([32, 32, 32], Distribution::Default); - let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()) - .into_data() - .convert::(); - let custom = matmul_add_relu_custom(lhs, rhs, bias) - .into_data() - .convert::(); + let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()) + .into_data() + .convert::(); + let custom = matmul_add_relu_custom(lhs, rhs, bias) + .into_data() + .convert::(); - reference.assert_approx_eq(&custom, 3); + reference.assert_approx_eq(&custom, 3); - println!("Both reference and the custom fused kernel have the same output"); + println!("Both reference and the custom fused kernel have the same output"); } fn autodiff() { - let lhs = Tensor::::random([1, 32, 32], Distribution::Default).require_grad(); - let rhs = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); - let bias = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); + let lhs = Tensor::::random([1, 32, 32], Distribution::Default).require_grad(); + let rhs = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); + let bias = Tensor::random([32, 32, 32], Distribution::Default).require_grad(); - let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()); + let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone()); - let mut gradients = reference.backward(); + let mut gradients = reference.backward(); - let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap(); - let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap(); - let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap(); + let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap(); + let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap(); + let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap(); - let lhs = lhs.detach(); - let rhs = rhs.detach(); - let bias = bias.detach(); + let lhs = lhs.detach(); + let rhs = rhs.detach(); + let bias = bias.detach(); - let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone()); + let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone()); - let mut gradients = custom.backward(); + let mut gradients = custom.backward(); - let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap(); - let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap(); - let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap(); + let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap(); + let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap(); + let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap(); - lhs_grad_ref - .into_data() - .convert::() - .assert_approx_eq(&lhs_grad_custom.into_data().convert(), 3); + lhs_grad_ref + .into_data() + .convert::() + .assert_approx_eq(&lhs_grad_custom.into_data().convert(), 3); - println!("Both reference and the custom fused kernel have the same lhs gradient"); + println!("Both reference and the custom fused kernel have the same lhs gradient"); - rhs_grad_ref - .into_data() - .convert::() - .assert_approx_eq(&rhs_grad_custom.into_data().convert(), 3); + rhs_grad_ref + .into_data() + .convert::() + .assert_approx_eq(&rhs_grad_custom.into_data().convert(), 3); - println!("Both reference and the custom fused kernel have the same rhs gradient"); + println!("Both reference and the custom fused kernel have the same rhs gradient"); - bias_grad_ref - .into_data() - .convert::() - .assert_approx_eq(&bias_grad_custom.into_data().convert(), 3); + bias_grad_ref + .into_data() + .convert::() + .assert_approx_eq(&bias_grad_custom.into_data().convert(), 3); - println!("Both reference and the custom fused kernel have the same bias gradient"); + println!("Both reference and the custom fused kernel have the same bias gradient"); } fn main() { - type MyBackend = burn::backend::Wgpu; - type MyAutodiffBackend = burn::backend::Autodiff; + type MyBackend = burn::backend::Wgpu; + type MyAutodiffBackend = burn::backend::Autodiff; - inference::(); - autodiff::(); + inference::(); + autodiff::(); } diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index 6cc755f55d..fd50df6079 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -2,9 +2,9 @@ use crate::FloatTensor; use super::{AutodiffBackend, Backend}; use burn::backend::autodiff::{ - grads::Gradients, - ops::{broadcast_shape, Backward, Ops, OpsKind}, - Autodiff, + grads::Gradients, + ops::{broadcast_shape, Backward, Ops, OpsKind}, + Autodiff, }; use burn::backend::wgpu::{FloatElement, GraphicsApi, IntElement, Wgpu}; use burn::tensor::Shape; @@ -17,103 +17,106 @@ impl AutodiffBackend for Autodif // also implements our own API. This would allow us to call any function only implemented for Wgpu // and potentially call a custom kernel crafted only for this task. impl Backend for Autodiff { - fn fused_matmul_add_relu( - lhs: FloatTensor, - rhs: FloatTensor, - bias: FloatTensor, - ) -> FloatTensor { - // Create our zero-sized type that will implement the Backward trait. - #[derive(Debug)] - struct FusedMatmulAddReluBackward; + fn fused_matmul_add_relu( + lhs: FloatTensor, + rhs: FloatTensor, + bias: FloatTensor, + ) -> FloatTensor { + // Create our zero-sized type that will implement the Backward trait. + #[derive(Debug)] + struct FusedMatmulAddReluBackward; - // Implement the backward trait for the given backend B, the node gradient being of rank D - // with three other gradients to calculate (lhs, rhs, and bias). - impl Backward for FusedMatmulAddReluBackward { - // Our state that we must build during the forward pass to compute the backward pass. - // - // Note that we could improve the performance further by only keeping the state of - // tensors that are tracked, improving memory management, but for simplicity, we avoid - // that part. - type State = ( - FloatTensor, - FloatTensor, - FloatTensor, - Shape, - ); + // Implement the backward trait for the given backend B, the node gradient being of rank D + // with three other gradients to calculate (lhs, rhs, and bias). + impl Backward for FusedMatmulAddReluBackward { + // Our state that we must build during the forward pass to compute the backward pass. + // + // Note that we could improve the performance further by only keeping the state of + // tensors that are tracked, improving memory management, but for simplicity, we avoid + // that part. + type State = ( + FloatTensor, + FloatTensor, + FloatTensor, + Shape, + ); - fn backward(self, ops: Ops, grads: &mut Gradients) { - // Get the nodes of each variable. - let [node_lhs, node_rhs, node_bias] = ops.parents; - // Fetch the gradient for the current node. - let grad = grads.consume::(&ops.node); + fn backward(self, ops: Ops, grads: &mut Gradients) { + // Get the nodes of each variable. + let [node_lhs, node_rhs, node_bias] = ops.parents; + // Fetch the gradient for the current node. + let grad = grads.consume::(&ops.node); - // Set our state. - let (lhs, rhs, output, shape_bias) = ops.state; + // Set our state. + let (lhs, rhs, output, shape_bias) = ops.state; - // Fetch shapes of our tensor to support broadcasting. - let shape_lhs = B::shape(&lhs); - let shape_rhs = B::shape(&rhs); + // Fetch shapes of our tensor to support broadcasting. + let shape_lhs = B::shape(&lhs); + let shape_rhs = B::shape(&rhs); - // Compute the gradient of the output using the already existing `relu_backward` - // function in the basic Burn backend trait. - let grad_output = B::relu_backward(output, grad); + // Compute the gradient of the output using the already existing `relu_backward` + // function in the basic Burn backend trait. + let grad_output = B::relu_backward(output, grad); - // Compute the lhs gradient, which is the derivative of matmul with support for - // broadcasting. - let grad_lhs = broadcast_shape::( - B::matmul(grad_output.clone(), B::transpose(rhs)), - &shape_lhs, - ); - // Compute the rhs gradient, which is the derivative of matmul with support for - // broadcasting. - let grad_rhs = broadcast_shape::( - B::matmul(B::transpose(lhs), grad_output.clone()), - &shape_rhs, - ); - // The add derivative is only 1, so we just need to support broadcasting to - // compute the bias gradient. - let grad_bias = broadcast_shape::(grad_output, &shape_bias); + // Compute the lhs gradient, which is the derivative of matmul with support for + // broadcasting. + let grad_lhs = broadcast_shape::( + B::matmul(grad_output.clone(), B::transpose(rhs)), + &shape_lhs, + ); + // Compute the rhs gradient, which is the derivative of matmul with support for + // broadcasting. + let grad_rhs = broadcast_shape::( + B::matmul(B::transpose(lhs), grad_output.clone()), + &shape_rhs, + ); + // The add derivative is only 1, so we just need to support broadcasting to + // compute the bias gradient. + let grad_bias = broadcast_shape::(grad_output, &shape_bias); - // Register the gradient for each variable based on whether they are marked as - // `tracked`. - if let Some(node) = node_bias { - grads.register::(node, grad_bias); + // Register the gradient for each variable based on whether they are marked as + // `tracked`. + if let Some(node) = node_bias { + grads.register::(node, grad_bias); + } + if let Some(node) = node_lhs { + grads.register::(node, grad_lhs); + } + if let Some(node) = node_rhs { + grads.register::(node, grad_rhs); + } + } } - if let Some(node) = node_lhs { - grads.register::(node, grad_lhs); - } - if let Some(node) = node_rhs { - grads.register::(node, grad_rhs); - } - } - } - // Prepare a stateful operation with each variable node and corresponding graph. - // - // Each node can be fetched with `ops.parents` in the same order as defined here. - match FusedMatmulAddReluBackward - .prepare( - [lhs.node, rhs.node, bias.node], - [lhs.graph, rhs.graph, bias.graph], - ) - .stateful() - { - OpsKind::Tracked(prep) => { - // When at least one node is tracked, we should register our backward step. - // We compute the output and the state before finishing the preparation. - let bias_shape = B::shape(&bias.primitive); - let output = - B::fused_matmul_add_relu(lhs.primitive.clone(), rhs.primitive.clone(), bias.primitive); + // Prepare a stateful operation with each variable node and corresponding graph. + // + // Each node can be fetched with `ops.parents` in the same order as defined here. + match FusedMatmulAddReluBackward + .prepare( + [lhs.node, rhs.node, bias.node], + [lhs.graph, rhs.graph, bias.graph], + ) + .stateful() + { + OpsKind::Tracked(prep) => { + // When at least one node is tracked, we should register our backward step. + // We compute the output and the state before finishing the preparation. + let bias_shape = B::shape(&bias.primitive); + let output = B::fused_matmul_add_relu( + lhs.primitive.clone(), + rhs.primitive.clone(), + bias.primitive, + ); - let state = (lhs.primitive, rhs.primitive, output.clone(), bias_shape); - prep.finish(state, output) - } - OpsKind::UnTracked(prep) => { - // When no node is tracked, we can just compute the original operation without - // keeping any state. - let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive); - prep.finish(output) - } + let state = (lhs.primitive, rhs.primitive, output.clone(), bias_shape); + prep.finish(state, output) + } + OpsKind::UnTracked(prep) => { + // When no node is tracked, we can just compute the original operation without + // keeping any state. + let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive); + prep.finish(output) + } + } } - } } diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index d92dede8f3..01e2e7db83 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -2,11 +2,13 @@ use crate::FloatTensor; use super::Backend; use burn::backend::wgpu::{ - compute::{DynamicKernel, WorkGroup}, - kernel::{build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, - kernel_wgsl, - tensor::WgpuTensor, - FloatElement, GraphicsApi, IntElement, Wgpu, + compute::{DynamicKernel, WorkGroup}, + kernel::{ + build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, + }, + kernel_wgsl, + tensor::WgpuTensor, + FloatElement, GraphicsApi, IntElement, Wgpu, }; use burn::tensor::Shape; use derive_new::new; @@ -18,95 +20,95 @@ kernel_wgsl!(FusedMatmulAddReluRaw, "./kernel.wgsl"); // Define our kernel type with workgroup information. #[derive(new, Debug)] struct FusedMatmulAddRelu { - workgroup_size_x: usize, - workgroup_size_y: usize, - _elem: PhantomData, + workgroup_size_x: usize, + workgroup_size_y: usize, + _elem: PhantomData, } // Implement the dynamic kernel trait for our kernel type. impl DynamicKernelSource for FusedMatmulAddRelu { - fn source(&self) -> SourceTemplate { - // Extend our raw kernel with workgroup size information using the - // `SourceTemplate` trait. - FusedMatmulAddReluRaw::source() - .register("workgroup_size_x", self.workgroup_size_x.to_string()) - .register("workgroup_size_y", self.workgroup_size_y.to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - format!("{:?}", self) - } + fn source(&self) -> SourceTemplate { + // Extend our raw kernel with workgroup size information using the + // `SourceTemplate` trait. + FusedMatmulAddReluRaw::source() + .register("workgroup_size_x", self.workgroup_size_x.to_string()) + .register("workgroup_size_y", self.workgroup_size_y.to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + format!("{:?}", self) + } } /// Implement our custom backend trait for the existing backend `WgpuBackend`. impl Backend for Wgpu { - fn fused_matmul_add_relu( - lhs: FloatTensor, - rhs: FloatTensor, - bias: FloatTensor, - ) -> WgpuTensor { - // Define workgroup size, hardcoded for simplicity. - let workgroup_size_x = 16; - let workgroup_size_y = 16; - - lhs.assert_is_on_same_device(&rhs); - lhs.assert_is_on_same_device(&bias); - - // For simplicity, make sure each tensor is continuous. - let lhs = into_contiguous(lhs); - let rhs = into_contiguous(rhs); - let bias = into_contiguous(bias); - - // Get the matmul relevant shapes. - let num_rows = lhs.shape.dims[D - 2]; - let num_cols = rhs.shape.dims[D - 1]; - - // Compute shape of output, while tracking number of batches. - let mut num_batches = 1; - let mut shape_out = [0; D]; - for i in shape_out.into_iter().take(D - 2) { - shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]); - num_batches *= shape_out[i]; + fn fused_matmul_add_relu( + lhs: FloatTensor, + rhs: FloatTensor, + bias: FloatTensor, + ) -> WgpuTensor { + // Define workgroup size, hardcoded for simplicity. + let workgroup_size_x = 16; + let workgroup_size_y = 16; + + lhs.assert_is_on_same_device(&rhs); + lhs.assert_is_on_same_device(&bias); + + // For simplicity, make sure each tensor is continuous. + let lhs = into_contiguous(lhs); + let rhs = into_contiguous(rhs); + let bias = into_contiguous(bias); + + // Get the matmul relevant shapes. + let num_rows = lhs.shape.dims[D - 2]; + let num_cols = rhs.shape.dims[D - 1]; + + // Compute shape of output, while tracking number of batches. + let mut num_batches = 1; + let mut shape_out = [0; D]; + for i in shape_out.into_iter().take(D - 2) { + shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]); + num_batches *= shape_out[i]; + } + shape_out[D - 2] = num_rows; + shape_out[D - 1] = num_cols; + let shape_out = Shape::new(shape_out); + + // Create a buffer for the output tensor. + let buffer = lhs + .client + .empty(shape_out.num_elements() * core::mem::size_of::()); + + // Create the output tensor primitive. + let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); + + // Create the kernel. + let kernel = FusedMatmulAddRelu::::new(workgroup_size_x, workgroup_size_y); + + // Build info buffer with tensor information needed by the kernel, such as shapes and strides. + let info = build_info(&[&lhs, &rhs, &output]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + // Declare the wgsl workgroup with the number of blocks in x, y and z. + let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; + let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_batches as u32); + + // Execute lazily the kernel with the launch information and the given buffers. + lhs.client.execute( + Box::new(DynamicKernel::new(kernel, workgroup)), + &[ + &lhs.handle, + &rhs.handle, + &bias.handle, + &output.handle, + &info_handle, + ], + ); + + // Return the output tensor. + output } - shape_out[D - 2] = num_rows; - shape_out[D - 1] = num_cols; - let shape_out = Shape::new(shape_out); - - // Create a buffer for the output tensor. - let buffer = lhs - .client - .empty(shape_out.num_elements() * core::mem::size_of::()); - - // Create the output tensor primitive. - let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); - - // Create the kernel. - let kernel = FusedMatmulAddRelu::::new(workgroup_size_x, workgroup_size_y); - - // Build info buffer with tensor information needed by the kernel, such as shapes and strides. - let info = build_info(&[&lhs, &rhs, &output]); - let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - - // Declare the wgsl workgroup with the number of blocks in x, y and z. - let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_batches as u32); - - // Execute lazily the kernel with the launch information and the given buffers. - lhs.client.execute( - Box::new(DynamicKernel::new(kernel, workgroup)), - &[ - &lhs.handle, - &rhs.handle, - &bias.handle, - &output.handle, - &info_handle, - ], - ); - - // Return the output tensor. - output - } } diff --git a/examples/custom-wgpu-kernel/src/lib.rs b/examples/custom-wgpu-kernel/src/lib.rs index ac1f6ecc58..eb8cfd1570 100644 --- a/examples/custom-wgpu-kernel/src/lib.rs +++ b/examples/custom-wgpu-kernel/src/lib.rs @@ -8,11 +8,11 @@ pub type FloatTensor = : /// We create our own Backend trait that extends the Burn backend trait. pub trait Backend: burn::tensor::backend::Backend { - fn fused_matmul_add_relu( - lhs: FloatTensor, - rhs: FloatTensor, - bias: FloatTensor, - ) -> FloatTensor; + fn fused_matmul_add_relu( + lhs: FloatTensor, + rhs: FloatTensor, + bias: FloatTensor, + ) -> FloatTensor; } /// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait. @@ -20,26 +20,26 @@ pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {} /// We define our custom implementation using the added function on our custom backend. pub fn matmul_add_relu_custom( - lhs: Tensor, - rhs: Tensor, - bias: Tensor, + lhs: Tensor, + rhs: Tensor, + bias: Tensor, ) -> Tensor { - let output = B::fused_matmul_add_relu( - lhs.into_primitive(), - rhs.into_primitive(), - bias.into_primitive(), - ); + let output = B::fused_matmul_add_relu( + lhs.into_primitive(), + rhs.into_primitive(), + bias.into_primitive(), + ); - Tensor::from_primitive(output) + Tensor::from_primitive(output) } /// We define a reference implementation using basic tensor operations. pub fn matmul_add_relu_reference( - lhs: Tensor, - rhs: Tensor, - bias: Tensor, + lhs: Tensor, + rhs: Tensor, + bias: Tensor, ) -> Tensor { - let x = lhs.matmul(rhs) + bias; + let x = lhs.matmul(rhs) + bias; - activation::relu(x) + activation::relu(x) } diff --git a/examples/guide/examples/guide.rs b/examples/guide/examples/guide.rs index 682502987b..274e511cfe 100644 --- a/examples/guide/examples/guide.rs +++ b/examples/guide/examples/guide.rs @@ -5,21 +5,21 @@ use burn::optim::AdamConfig; use guide::{model::ModelConfig, training::TrainingConfig}; fn main() { - type MyBackend = Wgpu; - type MyAutodiffBackend = Autodiff; + type MyBackend = Wgpu; + type MyAutodiffBackend = Autodiff; - let device = burn::backend::wgpu::WgpuDevice::default(); - let artifact_dir = "/tmp/guide"; - guide::training::train::( - artifact_dir, - TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()), - device.clone(), - ); - guide::inference::infer::( - artifact_dir, - device, - burn::data::dataset::source::huggingface::MNISTDataset::test() - .get(42) - .unwrap(), - ); + let device = burn::backend::wgpu::WgpuDevice::default(); + let artifact_dir = "/tmp/guide"; + guide::training::train::( + artifact_dir, + TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()), + device.clone(), + ); + guide::inference::infer::( + artifact_dir, + device, + burn::data::dataset::source::huggingface::MNISTDataset::test() + .get(42) + .unwrap(), + ); } diff --git a/examples/guide/src/data.rs b/examples/guide/src/data.rs index d089bb7226..fb43b4a08e 100644 --- a/examples/guide/src/data.rs +++ b/examples/guide/src/data.rs @@ -1,45 +1,45 @@ use burn::{ - data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, - tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, + data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, + tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, }; pub struct MNISTBatcher { - device: B::Device, + device: B::Device, } impl MNISTBatcher { - pub fn new(device: B::Device) -> Self { - Self { device } - } + pub fn new(device: B::Device) -> Self { + Self { device } + } } #[derive(Clone, Debug)] pub struct MNISTBatch { - pub images: Tensor, - pub targets: Tensor, + pub images: Tensor, + pub targets: Tensor, } impl Batcher> for MNISTBatcher { - fn batch(&self, items: Vec) -> MNISTBatch { - let images = items - .iter() - .map(|item| Data::::from(item.image)) - .map(|data| Tensor::::from_data(data.convert())) - .map(|tensor| tensor.reshape([1, 28, 28])) - // normalize: make between [0,1] and make the mean = 0 and std = 1 - // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example - // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 - .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) - .collect(); + fn batch(&self, items: Vec) -> MNISTBatch { + let images = items + .iter() + .map(|item| Data::::from(item.image)) + .map(|data| Tensor::::from_data(data.convert())) + .map(|tensor| tensor.reshape([1, 28, 28])) + // normalize: make between [0,1] and make the mean = 0 and std = 1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) + .collect(); - let targets = items - .iter() - .map(|item| Tensor::::from_data([(item.label as i64).elem()])) - .collect(); + let targets = items + .iter() + .map(|item| Tensor::::from_data([(item.label as i64).elem()])) + .collect(); - let images = Tensor::cat(images, 0).to_device(&self.device); - let targets = Tensor::cat(targets, 0).to_device(&self.device); + let images = Tensor::cat(images, 0).to_device(&self.device); + let targets = Tensor::cat(targets, 0).to_device(&self.device); - MNISTBatch { images, targets } - } + MNISTBatch { images, targets } + } } diff --git a/examples/guide/src/inference.rs b/examples/guide/src/inference.rs index 9d665dcfd7..c2f0fbb494 100644 --- a/examples/guide/src/inference.rs +++ b/examples/guide/src/inference.rs @@ -1,27 +1,27 @@ use crate::{data::MNISTBatcher, training::TrainingConfig}; use burn::data::dataset::source::huggingface::MNISTItem; use burn::{ - config::Config, - data::dataloader::batcher::Batcher, - module::Module, - record::{CompactRecorder, Recorder}, - tensor::backend::Backend, + config::Config, + data::dataloader::batcher::Batcher, + module::Module, + record::{CompactRecorder, Recorder}, + tensor::backend::Backend, }; pub fn infer(artifact_dir: &str, device: B::Device, item: MNISTItem) { - let config = - TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists"); - let record = CompactRecorder::new() - .load(format!("{artifact_dir}/model").into()) - .expect("Failed to load trained model"); + let config = + TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into()) + .expect("Failed to load trained model"); - let model = config.model.init_with::(record).to_device(&device); + let model = config.model.init_with::(record).to_device(&device); - let label = item.label; - let batcher = MNISTBatcher::new(device); - let batch = batcher.batch(vec![item]); - let output = model.forward(batch.images); - let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); + let label = item.label; + let batcher = MNISTBatcher::new(device); + let batch = batcher.batch(vec![item]); + let output = model.forward(batch.images); + let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); - println!("Predicted {} Expected {}", predicted, label); + println!("Predicted {} Expected {}", predicted, label); } diff --git a/examples/guide/src/model.rs b/examples/guide/src/model.rs index 8d842fbb65..665612b50b 100644 --- a/examples/guide/src/model.rs +++ b/examples/guide/src/model.rs @@ -1,82 +1,83 @@ use burn::{ - config::Config, - module::Module, - nn::{ - conv::{Conv2d, Conv2dConfig}, - pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, - Dropout, DropoutConfig, Linear, LinearConfig, ReLU, - }, - tensor::{backend::Backend, Tensor}, + config::Config, + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, + Dropout, DropoutConfig, Linear, LinearConfig, ReLU, + }, + tensor::{backend::Backend, Tensor}, }; #[derive(Module, Debug)] pub struct Model { - conv1: Conv2d, - conv2: Conv2d, - pool: AdaptiveAvgPool2d, - dropout: Dropout, - linear1: Linear, - linear2: Linear, - activation: ReLU, + conv1: Conv2d, + conv2: Conv2d, + pool: AdaptiveAvgPool2d, + dropout: Dropout, + linear1: Linear, + linear2: Linear, + activation: ReLU, } #[derive(Config, Debug)] pub struct ModelConfig { - num_classes: usize, - hidden_size: usize, - #[config(default = "0.5")] - dropout: f64, + num_classes: usize, + hidden_size: usize, + #[config(default = "0.5")] + dropout: f64, } impl ModelConfig { - /// Returns the initialized model. - pub fn init(&self) -> Model { - Model { - conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), - conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), - pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), - activation: ReLU::new(), - linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), - linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), - dropout: DropoutConfig::new(self.dropout).init(), + /// Returns the initialized model. + pub fn init(&self) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: ReLU::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), + dropout: DropoutConfig::new(self.dropout).init(), + } } - } - /// Returns the initialized model using the recorded weights. - pub fn init_with(&self, record: ModelRecord) -> Model { - Model { - conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), - conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), - pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), - activation: ReLU::new(), - linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), - linear2: LinearConfig::new(self.hidden_size, self.num_classes).init_with(record.linear2), - dropout: DropoutConfig::new(self.dropout).init(), + /// Returns the initialized model using the recorded weights. + pub fn init_with(&self, record: ModelRecord) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: ReLU::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), + linear2: LinearConfig::new(self.hidden_size, self.num_classes) + .init_with(record.linear2), + dropout: DropoutConfig::new(self.dropout).init(), + } } - } } impl Model { - /// # Shapes - /// - Images [batch_size, height, width] - /// - Output [batch_size, class_prob] - pub fn forward(&self, images: Tensor) -> Tensor { - let [batch_size, height, width] = images.dims(); + /// # Shapes + /// - Images [batch_size, height, width] + /// - Output [batch_size, class_prob] + pub fn forward(&self, images: Tensor) -> Tensor { + let [batch_size, height, width] = images.dims(); - // Create a channel. - let x = images.reshape([batch_size, 1, height, width]); + // Create a channel. + let x = images.reshape([batch_size, 1, height, width]); - let x = self.conv1.forward(x); // [batch_size, 8, _, _] - let x = self.dropout.forward(x); - let x = self.conv2.forward(x); // [batch_size, 16, _, _] - let x = self.dropout.forward(x); - let x = self.activation.forward(x); + let x = self.conv1.forward(x); // [batch_size, 8, _, _] + let x = self.dropout.forward(x); + let x = self.conv2.forward(x); // [batch_size, 16, _, _] + let x = self.dropout.forward(x); + let x = self.activation.forward(x); - let x = self.pool.forward(x); // [batch_size, 16, 8, 8] - let x = x.reshape([batch_size, 16 * 8 * 8]); - let x = self.linear1.forward(x); - let x = self.dropout.forward(x); - let x = self.activation.forward(x); + let x = self.pool.forward(x); // [batch_size, 16, 8, 8] + let x = x.reshape([batch_size, 16 * 8 * 8]); + let x = self.linear1.forward(x); + let x = self.dropout.forward(x); + let x = self.activation.forward(x); - self.linear2.forward(x) // [batch_size, num_classes] - } + self.linear2.forward(x) // [batch_size, num_classes] + } } diff --git a/examples/guide/src/training.rs b/examples/guide/src/training.rs index 7582b459ea..f04d132fd8 100644 --- a/examples/guide/src/training.rs +++ b/examples/guide/src/training.rs @@ -1,109 +1,109 @@ use crate::{ - data::{MNISTBatch, MNISTBatcher}, - model::{Model, ModelConfig}, + data::{MNISTBatch, MNISTBatcher}, + model::{Model, ModelConfig}, }; use burn::data::dataset::source::huggingface::MNISTDataset; use burn::train::{ - metric::{AccuracyMetric, LossMetric}, - ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, + metric::{AccuracyMetric, LossMetric}, + ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, }; use burn::{ - self, - config::Config, - data::dataloader::DataLoaderBuilder, - module::Module, - nn::loss::CrossEntropyLoss, - optim::AdamConfig, - record::CompactRecorder, - tensor::{ - backend::{AutodiffBackend, Backend}, - Int, Tensor, - }, + self, + config::Config, + data::dataloader::DataLoaderBuilder, + module::Module, + nn::loss::CrossEntropyLoss, + optim::AdamConfig, + record::CompactRecorder, + tensor::{ + backend::{AutodiffBackend, Backend}, + Int, Tensor, + }, }; impl Model { - pub fn forward_classification( - &self, - images: Tensor, - targets: Tensor, - ) -> ClassificationOutput { - let output = self.forward(images); - let loss = CrossEntropyLoss::default().forward(output.clone(), targets.clone()); + pub fn forward_classification( + &self, + images: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + let output = self.forward(images); + let loss = CrossEntropyLoss::default().forward(output.clone(), targets.clone()); - ClassificationOutput::new(loss, output, targets) - } + ClassificationOutput::new(loss, output, targets) + } } impl TrainStep, ClassificationOutput> for Model { - fn step(&self, batch: MNISTBatch) -> TrainOutput> { - let item = self.forward_classification(batch.images, batch.targets); + fn step(&self, batch: MNISTBatch) -> TrainOutput> { + let item = self.forward_classification(batch.images, batch.targets); - TrainOutput::new(self, item.loss.backward(), item) - } + TrainOutput::new(self, item.loss.backward(), item) + } } impl ValidStep, ClassificationOutput> for Model { - fn step(&self, batch: MNISTBatch) -> ClassificationOutput { - self.forward_classification(batch.images, batch.targets) - } + fn step(&self, batch: MNISTBatch) -> ClassificationOutput { + self.forward_classification(batch.images, batch.targets) + } } #[derive(Config)] pub struct TrainingConfig { - pub model: ModelConfig, - pub optimizer: AdamConfig, - #[config(default = 10)] - pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, - #[config(default = 1.0e-4)] - pub learning_rate: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1.0e-4)] + pub learning_rate: f64, } pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { - std::fs::create_dir_all(artifact_dir).ok(); - config - .save(format!("{artifact_dir}/config.json")) - .expect("Save without error"); + std::fs::create_dir_all(artifact_dir).ok(); + config + .save(format!("{artifact_dir}/config.json")) + .expect("Save without error"); - B::seed(config.seed); + B::seed(config.seed); - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); - let learner = LearnerBuilder::new(artifact_dir) - .metric_train_numeric(AccuracyMetric::new()) - .metric_valid_numeric(AccuracyMetric::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device]) - .num_epochs(config.num_epochs) - .build( - config.model.init::(), - config.optimizer.init(), - config.learning_rate, - ); + let learner = LearnerBuilder::new(artifact_dir) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device]) + .num_epochs(config.num_epochs) + .build( + config.model.init::(), + config.optimizer.init(), + config.learning_rate, + ); - let model_trained = learner.fit(dataloader_train, dataloader_test); + let model_trained = learner.fit(dataloader_train, dataloader_test); - model_trained - .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) - .expect("Failed to save trained model"); + model_trained + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Failed to save trained model"); } diff --git a/examples/image-classification-web/build.rs b/examples/image-classification-web/build.rs index 627a3ea84c..e9c4ad2fa1 100644 --- a/examples/image-classification-web/build.rs +++ b/examples/image-classification-web/build.rs @@ -13,45 +13,45 @@ const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx"; const OUT_DIR: &str = "model/"; fn main() { - // Re-run the build script if model files change. - println!("cargo:rerun-if-changed=src/model"); + // Re-run the build script if model files change. + println!("cargo:rerun-if-changed=src/model"); - // Check if half precision is enabled. - let half_precision = cfg!(feature = "half_precision"); + // Check if half precision is enabled. + let half_precision = cfg!(feature = "half_precision"); - // Generate the model code from the ONNX file. - ModelGen::new() - .input(INPUT_ONNX_FILE) - .out_dir(OUT_DIR) - .record_type(RecordType::Bincode) - .embed_states(true) - .half_precision(half_precision) - .run_from_script(); + // Generate the model code from the ONNX file. + ModelGen::new() + .input(INPUT_ONNX_FILE) + .out_dir(OUT_DIR) + .record_type(RecordType::Bincode) + .embed_states(true) + .half_precision(half_precision) + .run_from_script(); - // Generate the labels from the synset.txt file. - generate_labels_from_txt_file().unwrap(); + // Generate the labels from the synset.txt file. + generate_labels_from_txt_file().unwrap(); } /// Read labels from synset.txt and store them in a vector of strings in a Rust file. fn generate_labels_from_txt_file() -> std::io::Result<()> { - let out_dir = env::var("OUT_DIR").unwrap(); - let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE); - let mut f = File::create(dest_path)?; + let out_dir = env::var("OUT_DIR").unwrap(); + let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE); + let mut f = File::create(dest_path)?; - let file = File::open(LABEL_SOURCE_FILE)?; - let reader = BufReader::new(file); + let file = File::open(LABEL_SOURCE_FILE)?; + let reader = BufReader::new(file); - writeln!(f, "pub static LABELS: &[&str] = &[")?; - for line in reader.lines() { - writeln!( - f, - " \"{}\",", - extract_simple_label(line.unwrap()).unwrap() - )?; - } - writeln!(f, "];")?; + writeln!(f, "pub static LABELS: &[&str] = &[")?; + for line in reader.lines() { + writeln!( + f, + " \"{}\",", + extract_simple_label(line.unwrap()).unwrap() + )?; + } + writeln!(f, "];")?; - Ok(()) + Ok(()) } /// Extract the simple label from the full label. @@ -59,17 +59,17 @@ fn generate_labels_from_txt_file() -> std::io::Result<()> { /// The full label is of the form: "n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea" /// The simple label is of the form: "indigo bunting" fn extract_simple_label(input: String) -> Option { - // Split the string based on the space character. - let mut parts = input.split(' '); + // Split the string based on the space character. + let mut parts = input.split(' '); - // Skip the first part (the alphanumeric code). - parts.next()?; + // Skip the first part (the alphanumeric code). + parts.next()?; - // Get the remaining string. - let remaining = parts.collect::>().join(" "); + // Get the remaining string. + let remaining = parts.collect::>().join(" "); - // Find the first comma, if it exists, and take the substring before it. - let end_index = remaining.find(',').unwrap_or(remaining.len()); + // Find the first comma, if it exists, and take the substring before it. + let end_index = remaining.find(',').unwrap_or(remaining.len()); - Some(remaining[0..end_index].to_string()) + Some(remaining[0..end_index].to_string()) } diff --git a/examples/image-classification-web/src/model/normalizer.rs b/examples/image-classification-web/src/model/normalizer.rs index d55cbd100a..7e1e6019ce 100644 --- a/examples/image-classification-web/src/model/normalizer.rs +++ b/examples/image-classification-web/src/model/normalizer.rs @@ -7,32 +7,32 @@ const STD: [f32; 3] = [0.229, 0.224, 0.225]; /// Normalizer for the imagenet dataset. pub struct Normalizer { - pub mean: Tensor, - pub std: Tensor, + pub mean: Tensor, + pub std: Tensor, } impl Normalizer { - /// Creates a new normalizer. - pub fn new() -> Self { - let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]); - let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]); - Self { mean, std } - } + /// Creates a new normalizer. + pub fn new() -> Self { + let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]); + let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]); + Self { mean, std } + } - /// Normalizes the input image according to the imagenet dataset. - /// - /// The input image should be in the range [0, 1]. - /// The output image will be in the range [-1, 1]. - /// - /// The normalization is done according to the following formula: - /// `input = (input - mean) / std` - pub fn normalize(&self, input: Tensor) -> Tensor { - (input - self.mean.clone()) / self.std.clone() - } + /// Normalizes the input image according to the imagenet dataset. + /// + /// The input image should be in the range [0, 1]. + /// The output image will be in the range [-1, 1]. + /// + /// The normalization is done according to the following formula: + /// `input = (input - mean) / std` + pub fn normalize(&self, input: Tensor) -> Tensor { + (input - self.mean.clone()) / self.std.clone() + } } impl Default for Normalizer { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } diff --git a/examples/image-classification-web/src/model/squeezenet.rs b/examples/image-classification-web/src/model/squeezenet.rs index c729aa0177..d796ae629f 100644 --- a/examples/image-classification-web/src/model/squeezenet.rs +++ b/examples/image-classification-web/src/model/squeezenet.rs @@ -1,6 +1,6 @@ // Generated model from squeezenet1.onnx mod internal_model { - include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs")); + include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs")); } pub use internal_model::*; diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index df093a0ce4..52e50c5653 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -1,19 +1,19 @@ #![allow(clippy::new_without_default)] use alloc::{ - string::{String, ToString}, - vec::Vec, + string::{String, ToString}, + vec::Vec, }; use core::convert::Into; use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel}; use burn::{ - backend::{ - wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice}, - NdArray, - }, - tensor::{activation::softmax, backend::Backend, Tensor}, + backend::{ + wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice}, + NdArray, + }, + tensor::{activation::softmax, backend::Backend, Tensor}, }; use burn_candle::Candle; @@ -24,14 +24,14 @@ use wasm_timer::Instant; #[allow(clippy::large_enum_variant)] /// The model is loaded to a specific backend pub enum ModelType { - /// The model is loaded to the Candle backend - WithCandleBackend(Model>), + /// The model is loaded to the Candle backend + WithCandleBackend(Model>), - /// The model is loaded to the NdArray backend - WithNdArrayBackend(Model>), + /// The model is loaded to the NdArray backend + WithNdArrayBackend(Model>), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(Model>), + /// The model is loaded to the Wgpu backend + WithWgpuBackend(Model>), } /// The image is 224x224 pixels with 3 channels (RGB) @@ -42,150 +42,150 @@ const CHANNELS: usize = 3; /// The image classifier #[wasm_bindgen] pub struct ImageClassifier { - model: ModelType, + model: ModelType, } #[wasm_bindgen] impl ImageClassifier { - /// Constructor called by JavaScripts with the new keyword. - #[wasm_bindgen(constructor)] - pub fn new() -> Self { - // Initialize the logger so that the logs are printed to the console - wasm_logger::init(wasm_logger::Config::default()); + /// Constructor called by JavaScripts with the new keyword. + #[wasm_bindgen(constructor)] + pub fn new() -> Self { + // Initialize the logger so that the logs are printed to the console + wasm_logger::init(wasm_logger::Config::default()); - log::info!("Initializing the image classifier"); + log::info!("Initializing the image classifier"); - Self { - model: ModelType::WithNdArrayBackend(Model::new()), + Self { + model: ModelType::WithNdArrayBackend(Model::new()), + } + } + + /// Runs inference on the image + pub async fn inference(&self, input: &[f32]) -> Result { + log::info!("Running inference on the image"); + + let start = Instant::now(); + + let result = match self.model { + ModelType::WithCandleBackend(ref model) => model.forward(input).await, + ModelType::WithNdArrayBackend(ref model) => model.forward(input).await, + ModelType::WithWgpuBackend(ref model) => model.forward(input).await, + }; + + let duration = start.elapsed(); + + log::debug!("Inference is completed in {:?}", duration); + + top_5_classes(result) + } + + /// Sets the backend to Candle + pub async fn set_backend_candle(&mut self) -> Result<(), JsValue> { + log::info!("Loading the model to the Candle backend"); + let start = Instant::now(); + self.model = ModelType::WithCandleBackend(Model::new()); + let duration = start.elapsed(); + log::debug!("Model is loaded to the Candle backend in {:?}", duration); + Ok(()) + } + + /// Sets the backend to NdArray + pub async fn set_backend_ndarray(&mut self) -> Result<(), JsValue> { + log::info!("Loading the model to the NdArray backend"); + let start = Instant::now(); + self.model = ModelType::WithNdArrayBackend(Model::new()); + let duration = start.elapsed(); + log::debug!("Model is loaded to the NdArray backend in {:?}", duration); + Ok(()) + } + + /// Sets the backend to Wgpu + pub async fn set_backend_wgpu(&mut self) -> Result<(), JsValue> { + log::info!("Loading the model to the Wgpu backend"); + let start = Instant::now(); + init_async::(&WgpuDevice::default()).await; + self.model = ModelType::WithWgpuBackend(Model::new()); + let duration = start.elapsed(); + log::debug!("Model is loaded to the Wgpu backend in {:?}", duration); + + log::debug!("Warming up the model"); + let start = Instant::now(); + let _ = self.inference(&[0.0; HEIGHT * WIDTH * CHANNELS]).await; + let duration = start.elapsed(); + log::debug!("Warming up is completed in {:?}", duration); + Ok(()) } - } - - /// Runs inference on the image - pub async fn inference(&self, input: &[f32]) -> Result { - log::info!("Running inference on the image"); - - let start = Instant::now(); - - let result = match self.model { - ModelType::WithCandleBackend(ref model) => model.forward(input).await, - ModelType::WithNdArrayBackend(ref model) => model.forward(input).await, - ModelType::WithWgpuBackend(ref model) => model.forward(input).await, - }; - - let duration = start.elapsed(); - - log::debug!("Inference is completed in {:?}", duration); - - top_5_classes(result) - } - - /// Sets the backend to Candle - pub async fn set_backend_candle(&mut self) -> Result<(), JsValue> { - log::info!("Loading the model to the Candle backend"); - let start = Instant::now(); - self.model = ModelType::WithCandleBackend(Model::new()); - let duration = start.elapsed(); - log::debug!("Model is loaded to the Candle backend in {:?}", duration); - Ok(()) - } - - /// Sets the backend to NdArray - pub async fn set_backend_ndarray(&mut self) -> Result<(), JsValue> { - log::info!("Loading the model to the NdArray backend"); - let start = Instant::now(); - self.model = ModelType::WithNdArrayBackend(Model::new()); - let duration = start.elapsed(); - log::debug!("Model is loaded to the NdArray backend in {:?}", duration); - Ok(()) - } - - /// Sets the backend to Wgpu - pub async fn set_backend_wgpu(&mut self) -> Result<(), JsValue> { - log::info!("Loading the model to the Wgpu backend"); - let start = Instant::now(); - init_async::(&WgpuDevice::default()).await; - self.model = ModelType::WithWgpuBackend(Model::new()); - let duration = start.elapsed(); - log::debug!("Model is loaded to the Wgpu backend in {:?}", duration); - - log::debug!("Warming up the model"); - let start = Instant::now(); - let _ = self.inference(&[0.0; HEIGHT * WIDTH * CHANNELS]).await; - let duration = start.elapsed(); - log::debug!("Warming up is completed in {:?}", duration); - Ok(()) - } } /// The image classifier model pub struct Model { - model: SqueezenetModel, - normalizer: Normalizer, + model: SqueezenetModel, + normalizer: Normalizer, } impl Model { - /// Constructor - pub fn new() -> Self { - Self { - model: SqueezenetModel::from_embedded(), - normalizer: Normalizer::new(), + /// Constructor + pub fn new() -> Self { + Self { + model: SqueezenetModel::from_embedded(), + normalizer: Normalizer::new(), + } } - } - /// Normalizes input and runs inference on the image - pub async fn forward(&self, input: &[f32]) -> Vec { - // Reshape from the 1D array to 3d tensor [ width, height, channels] - let input: Tensor = Tensor::from_floats(input).reshape([1, CHANNELS, HEIGHT, WIDTH]); + /// Normalizes input and runs inference on the image + pub async fn forward(&self, input: &[f32]) -> Vec { + // Reshape from the 1D array to 3d tensor [ width, height, channels] + let input: Tensor = Tensor::from_floats(input).reshape([1, CHANNELS, HEIGHT, WIDTH]); - // Normalize input: make between [-1,1] and make the mean=0 and std=1 - let input = self.normalizer.normalize(input); + // Normalize input: make between [-1,1] and make the mean=0 and std=1 + let input = self.normalizer.normalize(input); - // Run the tensor input through the model - let output = self.model.forward(input); + // Run the tensor input through the model + let output = self.model.forward(input); - // Convert the model output into probability distribution using softmax formula - let probabilies = softmax(output, 1); + // Convert the model output into probability distribution using softmax formula + let probabilies = softmax(output, 1); - #[cfg(not(target_family = "wasm"))] - let result = probabilies.into_data().convert::().value; + #[cfg(not(target_family = "wasm"))] + let result = probabilies.into_data().convert::().value; - // Forces the result to be computed - #[cfg(target_family = "wasm")] - let result = probabilies.into_data().await.convert::().value; + // Forces the result to be computed + #[cfg(target_family = "wasm")] + let result = probabilies.into_data().await.convert::().value; - result - } + result + } } #[wasm_bindgen] #[derive(Serialize)] pub struct InferenceResult { - index: usize, - probability: f32, - label: String, + index: usize, + probability: f32, + label: String, } /// Returns the top 5 classes and convert them into a JsValue fn top_5_classes(probabilies: Vec) -> Result { - // Convert the probabilities into a vector of (index, probability) - let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect(); - - // Sort the probabilities in descending order - probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); - - // Take the top 5 probabilities - probabilies.truncate(5); - - // Convert the probabilities into InferenceResult - let result: Vec = probabilies - .into_iter() - .map(|(index, probability)| InferenceResult { - index, - probability: *probability, - label: LABELS[index].to_string(), - }) - .collect(); - - // Convert the InferenceResult into a JsValue - Ok(serde_wasm_bindgen::to_value(&result)?) + // Convert the probabilities into a vector of (index, probability) + let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect(); + + // Sort the probabilities in descending order + probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + + // Take the top 5 probabilities + probabilies.truncate(5); + + // Convert the probabilities into InferenceResult + let result: Vec = probabilies + .into_iter() + .map(|(index, probability)| InferenceResult { + index, + probability: *probability, + label: LABELS[index].to_string(), + }) + .collect(); + + // Convert the InferenceResult into a JsValue + Ok(serde_wasm_bindgen::to_value(&result)?) } diff --git a/examples/mnist-inference-web/src/model.rs b/examples/mnist-inference-web/src/model.rs index 10cc05ae50..66054100a3 100644 --- a/examples/mnist-inference-web/src/model.rs +++ b/examples/mnist-inference-web/src/model.rs @@ -3,94 +3,94 @@ // Originally copied from the burn/examples/mnist package use burn::{ - module::Module, - nn::{self, BatchNorm, PaddingConfig2d}, - tensor::{backend::Backend, Tensor}, + module::Module, + nn::{self, BatchNorm, PaddingConfig2d}, + tensor::{backend::Backend, Tensor}, }; #[derive(Module, Debug)] pub struct Model { - conv1: ConvBlock, - conv2: ConvBlock, - conv3: ConvBlock, - dropout: nn::Dropout, - fc1: nn::Linear, - fc2: nn::Linear, - activation: nn::GELU, + conv1: ConvBlock, + conv2: ConvBlock, + conv3: ConvBlock, + dropout: nn::Dropout, + fc1: nn::Linear, + fc2: nn::Linear, + activation: nn::GELU, } const NUM_CLASSES: usize = 10; impl Model { - pub fn new() -> Self { - let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] - let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] - let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] - let hidden_size = 24 * 22 * 22; - let fc1 = nn::LinearConfig::new(hidden_size, 32) - .with_bias(false) - .init(); - let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) - .with_bias(false) - .init(); - - let dropout = nn::DropoutConfig::new(0.5).init(); - - Self { - conv1, - conv2, - conv3, - fc1, - fc2, - dropout, - activation: nn::GELU::new(), + pub fn new() -> Self { + let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] + let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] + let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] + let hidden_size = 24 * 22 * 22; + let fc1 = nn::LinearConfig::new(hidden_size, 32) + .with_bias(false) + .init(); + let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) + .with_bias(false) + .init(); + + let dropout = nn::DropoutConfig::new(0.5).init(); + + Self { + conv1, + conv2, + conv3, + fc1, + fc2, + dropout, + activation: nn::GELU::new(), + } } - } - pub fn forward(&self, input: Tensor) -> Tensor { - let [batch_size, height, width] = input.dims(); + pub fn forward(&self, input: Tensor) -> Tensor { + let [batch_size, height, width] = input.dims(); - let x = input.reshape([batch_size, 1, height, width]).detach(); - let x = self.conv1.forward(x); - let x = self.conv2.forward(x); - let x = self.conv3.forward(x); + let x = input.reshape([batch_size, 1, height, width]).detach(); + let x = self.conv1.forward(x); + let x = self.conv2.forward(x); + let x = self.conv3.forward(x); - let [batch_size, channels, height, width] = x.dims(); - let x = x.reshape([batch_size, channels * height * width]); + let [batch_size, channels, height, width] = x.dims(); + let x = x.reshape([batch_size, channels * height * width]); - let x = self.dropout.forward(x); - let x = self.fc1.forward(x); - let x = self.activation.forward(x); + let x = self.dropout.forward(x); + let x = self.fc1.forward(x); + let x = self.activation.forward(x); - self.fc2.forward(x) - } + self.fc2.forward(x) + } } #[derive(Module, Debug)] pub struct ConvBlock { - conv: nn::conv::Conv2d, - norm: BatchNorm, - activation: nn::GELU, + conv: nn::conv::Conv2d, + norm: BatchNorm, + activation: nn::GELU, } impl ConvBlock { - pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { - let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) - .with_padding(PaddingConfig2d::Valid) - .init(); - let norm = nn::BatchNormConfig::new(channels[1]).init(); - - Self { - conv, - norm, - activation: nn::GELU::new(), + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) + .with_padding(PaddingConfig2d::Valid) + .init(); + let norm = nn::BatchNormConfig::new(channels[1]).init(); + + Self { + conv, + norm, + activation: nn::GELU::new(), + } } - } - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.conv.forward(input); - let x = self.norm.forward(x); + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.conv.forward(input); + let x = self.norm.forward(x); - self.activation.forward(x) - } + self.activation.forward(x) + } } diff --git a/examples/mnist-inference-web/src/state.rs b/examples/mnist-inference-web/src/state.rs index d1b1311cae..5fabc868e4 100644 --- a/examples/mnist-inference-web/src/state.rs +++ b/examples/mnist-inference-web/src/state.rs @@ -17,13 +17,13 @@ static STATE_ENCODED: &[u8] = include_bytes!("../model.bin"); /// Builds and loads trained parameters into the model. pub async fn build_and_load_model() -> Model { - #[cfg(feature = "wgpu")] - init_async::(&WgpuDevice::default()).await; + #[cfg(feature = "wgpu")] + init_async::(&WgpuDevice::default()).await; - let model: Model = Model::new(); - let record = BinBytesRecorder::::default() - .load(STATE_ENCODED.to_vec()) - .expect("Failed to decode state"); + let model: Model = Model::new(); + let record = BinBytesRecorder::::default() + .load(STATE_ENCODED.to_vec()) + .expect("Failed to decode state"); - model.load_record(record) + model.load_record(record) } diff --git a/examples/mnist-inference-web/src/web.rs b/examples/mnist-inference-web/src/web.rs index 7c69308d3e..d15aa4a257 100644 --- a/examples/mnist-inference-web/src/web.rs +++ b/examples/mnist-inference-web/src/web.rs @@ -15,63 +15,63 @@ use burn::tensor::Tensor; /// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html) #[cfg_attr(target_family = "wasm", wasm_bindgen)] pub struct Mnist { - model: Option>, + model: Option>, } #[cfg_attr(target_family = "wasm", wasm_bindgen)] impl Mnist { - /// Constructor called by JavaScripts with the new keyword. - #[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))] - pub fn new() -> Self { - Self { model: None } - } - - /// Returns the inference results. - /// - /// This method is called from JavaScript via generated wrapper code by wasm-bindgen. - /// - /// # Arguments - /// - /// * `input` - A f32 slice of input 28x28 image - /// - /// See bindgen support types for passing and returning arrays: - /// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html) - /// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html) - /// - pub async fn inference(&mut self, input: &[f32]) -> Result { - if self.model.is_none() { - self.model = Some(build_and_load_model().await); + /// Constructor called by JavaScripts with the new keyword. + #[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))] + pub fn new() -> Self { + Self { model: None } } - let model = self.model.as_ref().unwrap(); - - // Reshape from the 1D array to 3d tensor [batch, height, width] - let input: Tensor = Tensor::from_floats(input).reshape([1, 28, 28]); - - // Normalize input: make between [0,1] and make the mean=0 and std=1 - // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example - // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 - - let input = ((input / 255) - 0.1307) / 0.3081; - - // Run the tensor input through the model - let output: Tensor = model.forward(input); - - // Convert the model output into probability distribution using softmax formula - let output = burn::tensor::activation::softmax(output, 1); - - // Flatten output tensor with [1, 10] shape into boxed slice of [f32] - #[cfg(not(target_family = "wasm"))] - let output = output.into_data().convert::().value; - - #[cfg(target_family = "wasm")] - let output = output.into_data().await.convert::().value; - - let array = Array::new(); - for value in output { - array.push(&value.into()); + /// Returns the inference results. + /// + /// This method is called from JavaScript via generated wrapper code by wasm-bindgen. + /// + /// # Arguments + /// + /// * `input` - A f32 slice of input 28x28 image + /// + /// See bindgen support types for passing and returning arrays: + /// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html) + /// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html) + /// + pub async fn inference(&mut self, input: &[f32]) -> Result { + if self.model.is_none() { + self.model = Some(build_and_load_model().await); + } + + let model = self.model.as_ref().unwrap(); + + // Reshape from the 1D array to 3d tensor [batch, height, width] + let input: Tensor = Tensor::from_floats(input).reshape([1, 28, 28]); + + // Normalize input: make between [0,1] and make the mean=0 and std=1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + + let input = ((input / 255) - 0.1307) / 0.3081; + + // Run the tensor input through the model + let output: Tensor = model.forward(input); + + // Convert the model output into probability distribution using softmax formula + let output = burn::tensor::activation::softmax(output, 1); + + // Flatten output tensor with [1, 10] shape into boxed slice of [f32] + #[cfg(not(target_family = "wasm"))] + let output = output.into_data().convert::().value; + + #[cfg(target_family = "wasm")] + let output = output.into_data().await.convert::().value; + + let array = Array::new(); + for value in output { + array.push(&value.into()); + } + + Ok(array) } - - Ok(array) - } } diff --git a/examples/mnist/examples/mnist.rs b/examples/mnist/examples/mnist.rs index 3f0d16d63c..e22a209a0e 100644 --- a/examples/mnist/examples/mnist.rs +++ b/examples/mnist/examples/mnist.rs @@ -1,72 +1,72 @@ #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - let device = NdArrayDevice::Cpu; - training::run::>(device); - } + pub fn run() { + let device = NdArrayDevice::Cpu; + training::run::>(device); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - training::run::>(device); - } + training::run::>(device); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{Wgpu, WgpuDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::wgpu::{Wgpu, WgpuDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - let device = WgpuDevice::default(); - training::run::>(device); - } + pub fn run() { + let device = WgpuDevice::default(); + training::run::>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; - use mnist::training; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; + use mnist::training; - pub fn run() { - let device = LibTorchDevice::Cpu; - training::run::>(device); - } + pub fn run() { + let device = LibTorchDevice::Cpu; + training::run::>(device); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/mnist/src/data.rs b/examples/mnist/src/data.rs index 0f253e948d..d424b9e038 100644 --- a/examples/mnist/src/data.rs +++ b/examples/mnist/src/data.rs @@ -1,45 +1,45 @@ use burn::{ - data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, - tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, + data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem}, + tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, }; pub struct MNISTBatcher { - device: B::Device, + device: B::Device, } #[derive(Clone, Debug)] pub struct MNISTBatch { - pub images: Tensor, - pub targets: Tensor, + pub images: Tensor, + pub targets: Tensor, } impl MNISTBatcher { - pub fn new(device: B::Device) -> Self { - Self { device } - } + pub fn new(device: B::Device) -> Self { + Self { device } + } } impl Batcher> for MNISTBatcher { - fn batch(&self, items: Vec) -> MNISTBatch { - let images = items - .iter() - .map(|item| Data::::from(item.image)) - .map(|data| Tensor::::from_data(data.convert())) - .map(|tensor| tensor.reshape([1, 28, 28])) - // normalize: make between [0,1] and make the mean = 0 and std = 1 - // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example - // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 - .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) - .collect(); + fn batch(&self, items: Vec) -> MNISTBatch { + let images = items + .iter() + .map(|item| Data::::from(item.image)) + .map(|data| Tensor::::from_data(data.convert())) + .map(|tensor| tensor.reshape([1, 28, 28])) + // normalize: make between [0,1] and make the mean = 0 and std = 1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) + .collect(); - let targets = items - .iter() - .map(|item| Tensor::::from_data(Data::from([(item.label as i64).elem()]))) - .collect(); + let targets = items + .iter() + .map(|item| Tensor::::from_data(Data::from([(item.label as i64).elem()]))) + .collect(); - let images = Tensor::cat(images, 0).to_device(&self.device); - let targets = Tensor::cat(targets, 0).to_device(&self.device); + let images = Tensor::cat(images, 0).to_device(&self.device); + let targets = Tensor::cat(targets, 0).to_device(&self.device); - MNISTBatch { images, targets } - } + MNISTBatch { images, targets } + } } diff --git a/examples/mnist/src/model.rs b/examples/mnist/src/model.rs index 02efcf2906..eca5be9b14 100644 --- a/examples/mnist/src/model.rs +++ b/examples/mnist/src/model.rs @@ -1,130 +1,130 @@ use crate::data::MNISTBatch; use burn::{ - module::Module, - nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d}, - tensor::{ - backend::{AutodiffBackend, Backend}, - Tensor, - }, - train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, + module::Module, + nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d}, + tensor::{ + backend::{AutodiffBackend, Backend}, + Tensor, + }, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, }; #[derive(Module, Debug)] pub struct Model { - conv1: ConvBlock, - conv2: ConvBlock, - conv3: ConvBlock, - dropout: nn::Dropout, - fc1: nn::Linear, - fc2: nn::Linear, - activation: nn::GELU, + conv1: ConvBlock, + conv2: ConvBlock, + conv3: ConvBlock, + dropout: nn::Dropout, + fc1: nn::Linear, + fc2: nn::Linear, + activation: nn::GELU, } impl Default for Model { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } const NUM_CLASSES: usize = 10; impl Model { - pub fn new() -> Self { - let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] - let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] - let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] - let hidden_size = 24 * 22 * 22; - let fc1 = nn::LinearConfig::new(hidden_size, 32) - .with_bias(false) - .init(); - let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) - .with_bias(false) - .init(); - - let dropout = nn::DropoutConfig::new(0.5).init(); - - Self { - conv1, - conv2, - conv3, - dropout, - fc1, - fc2, - activation: nn::GELU::new(), + pub fn new() -> Self { + let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] + let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] + let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] + let hidden_size = 24 * 22 * 22; + let fc1 = nn::LinearConfig::new(hidden_size, 32) + .with_bias(false) + .init(); + let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) + .with_bias(false) + .init(); + + let dropout = nn::DropoutConfig::new(0.5).init(); + + Self { + conv1, + conv2, + conv3, + dropout, + fc1, + fc2, + activation: nn::GELU::new(), + } } - } - - pub fn forward(&self, input: Tensor) -> Tensor { - let [batch_size, height, width] = input.dims(); - let x = input.reshape([batch_size, 1, height, width]).detach(); - let x = self.conv1.forward(x); - let x = self.conv2.forward(x); - let x = self.conv3.forward(x); + pub fn forward(&self, input: Tensor) -> Tensor { + let [batch_size, height, width] = input.dims(); - let [batch_size, channels, height, width] = x.dims(); - let x = x.reshape([batch_size, channels * height * width]); + let x = input.reshape([batch_size, 1, height, width]).detach(); + let x = self.conv1.forward(x); + let x = self.conv2.forward(x); + let x = self.conv3.forward(x); - let x = self.dropout.forward(x); - let x = self.fc1.forward(x); - let x = self.activation.forward(x); + let [batch_size, channels, height, width] = x.dims(); + let x = x.reshape([batch_size, channels * height * width]); - self.fc2.forward(x) - } + let x = self.dropout.forward(x); + let x = self.fc1.forward(x); + let x = self.activation.forward(x); - pub fn forward_classification(&self, item: MNISTBatch) -> ClassificationOutput { - let targets = item.targets; - let output = self.forward(item.images); - let loss = CrossEntropyLoss::default(); - let loss = loss.forward(output.clone(), targets.clone()); + self.fc2.forward(x) + } - ClassificationOutput { - loss, - output, - targets, + pub fn forward_classification(&self, item: MNISTBatch) -> ClassificationOutput { + let targets = item.targets; + let output = self.forward(item.images); + let loss = CrossEntropyLoss::default(); + let loss = loss.forward(output.clone(), targets.clone()); + + ClassificationOutput { + loss, + output, + targets, + } } - } } #[derive(Module, Debug)] pub struct ConvBlock { - conv: nn::conv::Conv2d, - norm: BatchNorm, - activation: nn::GELU, + conv: nn::conv::Conv2d, + norm: BatchNorm, + activation: nn::GELU, } impl ConvBlock { - pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { - let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) - .with_padding(PaddingConfig2d::Valid) - .init(); - let norm = nn::BatchNormConfig::new(channels[1]).init(); - - Self { - conv, - norm, - activation: nn::GELU::new(), + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) + .with_padding(PaddingConfig2d::Valid) + .init(); + let norm = nn::BatchNormConfig::new(channels[1]).init(); + + Self { + conv, + norm, + activation: nn::GELU::new(), + } } - } - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.conv.forward(input); - let x = self.norm.forward(x); + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.conv.forward(input); + let x = self.norm.forward(x); - self.activation.forward(x) - } + self.activation.forward(x) + } } impl TrainStep, ClassificationOutput> for Model { - fn step(&self, item: MNISTBatch) -> TrainOutput> { - let item = self.forward_classification(item); + fn step(&self, item: MNISTBatch) -> TrainOutput> { + let item = self.forward_classification(item); - TrainOutput::new(self, item.loss.backward(), item) - } + TrainOutput::new(self, item.loss.backward(), item) + } } impl ValidStep, ClassificationOutput> for Model { - fn step(&self, item: MNISTBatch) -> ClassificationOutput { - self.forward_classification(item) - } + fn step(&self, item: MNISTBatch) -> ClassificationOutput { + self.forward_classification(item) + } } diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index a1f2687401..2f7a5bed94 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -9,88 +9,88 @@ use burn::train::metric::store::{Aggregate, Direction, Split}; use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse}; use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition}; use burn::{ - config::Config, - data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset}, - tensor::backend::AutodiffBackend, - train::{ - metric::{AccuracyMetric, LossMetric}, - LearnerBuilder, - }, + config::Config, + data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset}, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, LossMetric}, + LearnerBuilder, + }, }; static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist"; #[derive(Config)] pub struct MnistTrainingConfig { - #[config(default = 10)] - pub num_epochs: usize, + #[config(default = 10)] + pub num_epochs: usize, - #[config(default = 64)] - pub batch_size: usize, + #[config(default = 64)] + pub batch_size: usize, - #[config(default = 4)] - pub num_workers: usize, + #[config(default = 4)] + pub num_workers: usize, - #[config(default = 42)] - pub seed: u64, + #[config(default = 42)] + pub seed: u64, - pub optimizer: AdamConfig, + pub optimizer: AdamConfig, } pub fn run(device: B::Device) { - // Config - let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))); - let config = MnistTrainingConfig::new(config_optimizer); - B::seed(config.seed); + // Config + let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))); + let config = MnistTrainingConfig::new(config_optimizer); + B::seed(config.seed); - // Data - let batcher_train = MNISTBatcher::::new(device.clone()); - let batcher_valid = MNISTBatcher::::new(device.clone()); + // Data + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::train()); - let dataloader_test = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(MNISTDataset::test()); + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::train()); + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::test()); - // Model - let learner = LearnerBuilder::new(ARTIFACT_DIR) - .metric_train_numeric(AccuracyMetric::new()) - .metric_valid_numeric(AccuracyMetric::new()) - .metric_train_numeric(CpuUse::new()) - .metric_valid_numeric(CpuUse::new()) - .metric_train_numeric(CpuMemory::new()) - .metric_valid_numeric(CpuMemory::new()) - .metric_train_numeric(CpuTemperature::new()) - .metric_valid_numeric(CpuTemperature::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .early_stopping(MetricEarlyStoppingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Valid, - StoppingCondition::NoImprovementSince { n_epochs: 1 }, - )) - .devices(vec![device]) - .num_epochs(config.num_epochs) - .build(Model::new(), config.optimizer.init(), 1e-4); + // Model + let learner = LearnerBuilder::new(ARTIFACT_DIR) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(CpuUse::new()) + .metric_valid_numeric(CpuUse::new()) + .metric_train_numeric(CpuMemory::new()) + .metric_valid_numeric(CpuMemory::new()) + .metric_train_numeric(CpuTemperature::new()) + .metric_valid_numeric(CpuTemperature::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .early_stopping(MetricEarlyStoppingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + StoppingCondition::NoImprovementSince { n_epochs: 1 }, + )) + .devices(vec![device]) + .num_epochs(config.num_epochs) + .build(Model::new(), config.optimizer.init(), 1e-4); - let model_trained = learner.fit(dataloader_train, dataloader_test); + let model_trained = learner.fit(dataloader_train, dataloader_test); - config - .save(format!("{ARTIFACT_DIR}/config.json").as_str()) - .unwrap(); + config + .save(format!("{ARTIFACT_DIR}/config.json").as_str()) + .unwrap(); - model_trained - .save_file( - format!("{ARTIFACT_DIR}/model"), - &NoStdTrainingRecorder::new(), - ) - .expect("Failed to save trained model"); + model_trained + .save_file( + format!("{ARTIFACT_DIR}/model"), + &NoStdTrainingRecorder::new(), + ) + .expect("Failed to save trained model"); } diff --git a/examples/named-tensor/examples/named-tensor.rs b/examples/named-tensor/examples/named-tensor.rs index 967f75712f..7ea0dd159d 100644 --- a/examples/named-tensor/examples/named-tensor.rs +++ b/examples/named-tensor/examples/named-tensor.rs @@ -1,3 +1,3 @@ fn main() { - named_tensor::run::>(); + named_tensor::run::>(); } diff --git a/examples/named-tensor/src/lib.rs b/examples/named-tensor/src/lib.rs index f2f7ee0e16..3aa637a587 100644 --- a/examples/named-tensor/src/lib.rs +++ b/examples/named-tensor/src/lib.rs @@ -6,40 +6,42 @@ NamedDim!(SeqLength); NamedDim!(DModel); pub fn run() { - let batch_size = 32; - let seq_length = 48; - let d_model = 24; - - let weights = - NamedTensor::::random([1, d_model, d_model], Distribution::Default); - - let input = NamedTensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - ); - - // Doesn't compile - // - // mismatched types - // expected reference `&NamedTensor` - // found reference `&NamedTensor` - // let output = weights.matmul(&input); - - let output = input.clone().matmul(weights.clone()); - - // Doesn't compile - // - // mismatched types - // expected reference `&NamedTensor` - // found reference `&NamedTensor` - // let output = output.mul(&weights); - - let output = output.mul(input.clone()); - - let permut = output.clone().swap_dims::<_, 1, 2>(); - - println!("Weights => {weights}"); - println!("Input => {input}"); - println!("Output => {output}"); - println!("Permut => {permut}"); + let batch_size = 32; + let seq_length = 48; + let d_model = 24; + + let weights = NamedTensor::::random( + [1, d_model, d_model], + Distribution::Default, + ); + + let input = NamedTensor::::random( + [batch_size, seq_length, d_model], + Distribution::Default, + ); + + // Doesn't compile + // + // mismatched types + // expected reference `&NamedTensor` + // found reference `&NamedTensor` + // let output = weights.matmul(&input); + + let output = input.clone().matmul(weights.clone()); + + // Doesn't compile + // + // mismatched types + // expected reference `&NamedTensor` + // found reference `&NamedTensor` + // let output = output.mul(&weights); + + let output = output.mul(input.clone()); + + let permut = output.clone().swap_dims::<_, 1, 2>(); + + println!("Weights => {weights}"); + println!("Input => {input}"); + println!("Output => {output}"); + println!("Permut => {permut}"); } diff --git a/examples/onnx-inference/build.rs b/examples/onnx-inference/build.rs index 5425853542..174d3e517b 100644 --- a/examples/onnx-inference/build.rs +++ b/examples/onnx-inference/build.rs @@ -1,21 +1,21 @@ use burn_import::onnx::{ModelGen, RecordType}; fn main() { - // Generate the model code from the ONNX file. + // Generate the model code from the ONNX file. - if cfg!(feature = "embedded-model") { - // If the embedded-model, then model is bundled into the binary. - ModelGen::new() - .input("src/model/mnist.onnx") - .out_dir("model/") - .record_type(RecordType::Bincode) - .embed_states(true) - .run_from_script(); - } else { - // If not embedded-model, then model is loaded from the file system (default). - ModelGen::new() - .input("src/model/mnist.onnx") - .out_dir("model/") - .run_from_script(); - } + if cfg!(feature = "embedded-model") { + // If the embedded-model, then model is bundled into the binary. + ModelGen::new() + .input("src/model/mnist.onnx") + .out_dir("model/") + .record_type(RecordType::Bincode) + .embed_states(true) + .run_from_script(); + } else { + // If not embedded-model, then model is loaded from the file system (default). + ModelGen::new() + .input("src/model/mnist.onnx") + .out_dir("model/") + .run_from_script(); + } } diff --git a/examples/onnx-inference/src/bin/mnist_inference.rs b/examples/onnx-inference/src/bin/mnist_inference.rs index ca74984b8c..82913f4693 100644 --- a/examples/onnx-inference/src/bin/mnist_inference.rs +++ b/examples/onnx-inference/src/bin/mnist_inference.rs @@ -11,53 +11,53 @@ use onnx_inference::mnist::Model; const IMAGE_INX: usize = 42; // <- Change this to test a different image fn main() { - // Get image index argument (first) from command line + // Get image index argument (first) from command line - let image_index = if let Some(image_index) = args().nth(1) { - println!("Image index: {}", image_index); - image_index - .parse::() - .expect("Failed to parse image index") - } else { - println!("No image index provided; Using default image index: {IMAGE_INX}"); - IMAGE_INX - }; + let image_index = if let Some(image_index) = args().nth(1) { + println!("Image index: {}", image_index); + image_index + .parse::() + .expect("Failed to parse image index") + } else { + println!("No image index provided; Using default image index: {IMAGE_INX}"); + IMAGE_INX + }; - assert!(image_index < 10000, "Image index must be less than 10000"); + assert!(image_index < 10000, "Image index must be less than 10000"); - type Backend = NdArray; + type Backend = NdArray; - // Create a new model and load the state - let model: Model = Model::default(); + // Create a new model and load the state + let model: Model = Model::default(); - // Load the MNIST dataset and get an item - let dataset = MNISTDataset::test(); - let item = dataset.get(image_index).unwrap(); + // Load the MNIST dataset and get an item + let dataset = MNISTDataset::test(); + let item = dataset.get(image_index).unwrap(); - // Create a tensor from the image data - let image_data = item.image.iter().copied().flatten().collect::>(); - let mut input: Tensor = - Tensor::from_floats(image_data.as_slice()).reshape([1, 1, 28, 28]); + // Create a tensor from the image data + let image_data = item.image.iter().copied().flatten().collect::>(); + let mut input: Tensor = + Tensor::from_floats(image_data.as_slice()).reshape([1, 1, 28, 28]); - // Normalize the input - input = ((input / 255) - 0.1307) / 0.3081; + // Normalize the input + input = ((input / 255) - 0.1307) / 0.3081; - // Run the model on the input - let output = model.forward(input); + // Run the model on the input + let output = model.forward(input); - // Get the index of the maximum value - let arg_max = output.argmax(1).into_scalar() as usize; + // Get the index of the maximum value + let arg_max = output.argmax(1).into_scalar() as usize; - // Check if the index matches the label - assert!(arg_max == item.label); + // Check if the index matches the label + assert!(arg_max == item.label); - println!("Success!"); - println!("Predicted: {}", arg_max); - println!("Actual: {}", item.label); + println!("Success!"); + println!("Predicted: {}", arg_max); + println!("Actual: {}", item.label); - // Print the image URL if the image index is less than 100 (the online dataset only has 100 images) - if image_index < 100 { - println!("See the image online, click the link below:"); - println!("https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/{image_index}/image/image.jpg"); - } + // Print the image URL if the image index is less than 100 (the online dataset only has 100 images) + if image_index < 100 { + println!("See the image online, click the link below:"); + println!("https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/{image_index}/image/image.jpg"); + } } diff --git a/examples/onnx-inference/src/model/mod.rs b/examples/onnx-inference/src/model/mod.rs index adf789cdd1..4c821cafd4 100644 --- a/examples/onnx-inference/src/model/mod.rs +++ b/examples/onnx-inference/src/model/mod.rs @@ -1,3 +1,3 @@ pub mod mnist { - include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); + include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); } diff --git a/examples/text-classification/examples/ag-news-infer.rs b/examples/text-classification/examples/ag-news-infer.rs index abde7a1750..a2bfa0ce1a 100644 --- a/examples/text-classification/examples/ag-news-infer.rs +++ b/examples/text-classification/examples/ag-news-infer.rs @@ -9,7 +9,7 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - text_classification::inference::infer::( + text_classification::inference::infer::( device, "/tmp/text-classification-ag-news", // Samples from the test dataset, but you are free to test with your own text. @@ -22,75 +22,75 @@ pub fn launch(device: B::Device) { } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Autodiff; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index f9c8ec0685..4b336c2700 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -12,89 +12,89 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - let config = ExperimentConfig::new( - TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), - AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), - ); - - text_classification::training::train::( - device, - AgNewsDataset::train(), - AgNewsDataset::test(), - config, - "/tmp/text-classification-ag-news", - ); + let config = ExperimentConfig::new( + TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), + AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), + ); + + text_classification::training::train::( + device, + AgNewsDataset::train(), + AgNewsDataset::test(), + config, + "/tmp/text-classification-ag-news", + ); } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use crate::{launch, ElemType}; - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::{Autodiff, Fusion}; + use crate::{launch, ElemType}; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::{Autodiff, Fusion}; - pub fn run() { - launch::>>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/examples/db-pedia-infer.rs b/examples/text-classification/examples/db-pedia-infer.rs index f178e5b6e4..be8a2d5dd0 100644 --- a/examples/text-classification/examples/db-pedia-infer.rs +++ b/examples/text-classification/examples/db-pedia-infer.rs @@ -9,7 +9,7 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - text_classification::inference::infer::( + text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", // Samples from the test dataset, but you are free to test with your own text. @@ -22,75 +22,75 @@ pub fn launch(device: B::Device) { } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::tch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::tch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Autodiff; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/examples/db-pedia-train.rs b/examples/text-classification/examples/db-pedia-train.rs index e5478088e7..81319c32cb 100644 --- a/examples/text-classification/examples/db-pedia-train.rs +++ b/examples/text-classification/examples/db-pedia-train.rs @@ -12,89 +12,89 @@ type ElemType = f32; type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { - let config = ExperimentConfig::new( - TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), - AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), - ); - - text_classification::training::train::( - device, - DbPediaDataset::train(), - DbPediaDataset::test(), - config, - "/tmp/text-classification-db-pedia", - ); + let config = ExperimentConfig::new( + TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), + AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), + ); + + text_classification::training::train::( + device, + DbPediaDataset::train(), + DbPediaDataset::test(), + config, + "/tmp/text-classification-db-pedia", + ); } #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", ))] mod ndarray { - use crate::{launch, ElemType}; - use burn::backend::ndarray::{NdArray, NdArrayDevice}; - use burn::backend::Autodiff; + use crate::{launch, ElemType}; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + use burn::backend::Autodiff; - pub fn run() { - launch::>>(NdArrayDevice::Cpu); - } + pub fn run() { + launch::>>(NdArrayDevice::Cpu); + } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - #[cfg(not(target_os = "macos"))] - let device = LibTorchDevice::Cuda(0); - #[cfg(target_os = "macos")] - let device = LibTorchDevice::Mps; + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; - launch::>>(device); - } + launch::>>(device); + } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::libtorch::{LibTorch, LibTorchDevice}; - use burn::backend::Autodiff; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(LibTorchDevice::Cpu); - } + pub fn run() { + launch::>>(LibTorchDevice::Cpu); + } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::Autodiff; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Autodiff; - use crate::{launch, ElemType}; + use crate::{launch, ElemType}; - pub fn run() { - launch::>>(WgpuDevice::default()); - } + pub fn run() { + launch::>>(WgpuDevice::default()); + } } fn main() { - #[cfg(any( - feature = "ndarray", - feature = "ndarray-blas-netlib", - feature = "ndarray-blas-openblas", - feature = "ndarray-blas-accelerate", - ))] - ndarray::run(); - #[cfg(feature = "tch-gpu")] - tch_gpu::run(); - #[cfg(feature = "tch-cpu")] - tch_cpu::run(); - #[cfg(feature = "wgpu")] - wgpu::run(); + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); } diff --git a/examples/text-classification/src/data/batcher.rs b/examples/text-classification/src/data/batcher.rs index 8b8edc9f0a..75ef1080b0 100644 --- a/examples/text-classification/src/data/batcher.rs +++ b/examples/text-classification/src/data/batcher.rs @@ -12,92 +12,92 @@ use super::{dataset::TextClassificationItem, tokenizer::Tokenizer}; use burn::{ - data::dataloader::batcher::Batcher, - nn::attention::generate_padding_mask, - tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Tensor}, + data::dataloader::batcher::Batcher, + nn::attention::generate_padding_mask, + tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Tensor}, }; use std::sync::Arc; /// Struct for batching text classification items #[derive(new)] pub struct TextClassificationBatcher { - tokenizer: Arc, // Tokenizer for converting text to token IDs - device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) - max_seq_length: usize, // Maximum sequence length for tokenized text + tokenizer: Arc, // Tokenizer for converting text to token IDs + device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) + max_seq_length: usize, // Maximum sequence length for tokenized text } /// Struct for training batch in text classification task #[derive(Debug, Clone, new)] pub struct TextClassificationTrainingBatch { - pub tokens: Tensor, // Tokenized text - pub labels: Tensor, // Labels of the text - pub mask_pad: Tensor, // Padding mask for the tokenized text + pub tokens: Tensor, // Tokenized text + pub labels: Tensor, // Labels of the text + pub mask_pad: Tensor, // Padding mask for the tokenized text } /// Struct for inference batch in text classification task #[derive(Debug, Clone, new)] pub struct TextClassificationInferenceBatch { - pub tokens: Tensor, // Tokenized text - pub mask_pad: Tensor, // Padding mask for the tokenized text + pub tokens: Tensor, // Tokenized text + pub mask_pad: Tensor, // Padding mask for the tokenized text } /// Implement Batcher trait for TextClassificationBatcher struct for training impl Batcher> - for TextClassificationBatcher + for TextClassificationBatcher { - /// Batches a vector of text classification items into a training batch - fn batch(&self, items: Vec) -> TextClassificationTrainingBatch { - let mut tokens_list = Vec::with_capacity(items.len()); - let mut labels_list = Vec::with_capacity(items.len()); + /// Batches a vector of text classification items into a training batch + fn batch(&self, items: Vec) -> TextClassificationTrainingBatch { + let mut tokens_list = Vec::with_capacity(items.len()); + let mut labels_list = Vec::with_capacity(items.len()); - // Tokenize text and create label tensor for each item - for item in items { - tokens_list.push(self.tokenizer.encode(&item.text)); - labels_list.push(Tensor::from_data(Data::from([(item.label as i64).elem()]))); - } + // Tokenize text and create label tensor for each item + for item in items { + tokens_list.push(self.tokenizer.encode(&item.text)); + labels_list.push(Tensor::from_data(Data::from([(item.label as i64).elem()]))); + } - // Generate padding mask for tokenized text - let mask = generate_padding_mask( - self.tokenizer.pad_token(), - tokens_list, - Some(self.max_seq_length), - &B::Device::default(), - ); + // Generate padding mask for tokenized text + let mask = generate_padding_mask( + self.tokenizer.pad_token(), + tokens_list, + Some(self.max_seq_length), + &B::Device::default(), + ); - // Create and return training batch - TextClassificationTrainingBatch { - tokens: mask.tensor.to_device(&self.device), - labels: Tensor::cat(labels_list, 0).to_device(&self.device), - mask_pad: mask.mask.to_device(&self.device), + // Create and return training batch + TextClassificationTrainingBatch { + tokens: mask.tensor.to_device(&self.device), + labels: Tensor::cat(labels_list, 0).to_device(&self.device), + mask_pad: mask.mask.to_device(&self.device), + } } - } } /// Implement Batcher trait for TextClassificationBatcher struct for inference impl Batcher> - for TextClassificationBatcher + for TextClassificationBatcher { - /// Batches a vector of strings into an inference batch - fn batch(&self, items: Vec) -> TextClassificationInferenceBatch { - let mut tokens_list = Vec::with_capacity(items.len()); + /// Batches a vector of strings into an inference batch + fn batch(&self, items: Vec) -> TextClassificationInferenceBatch { + let mut tokens_list = Vec::with_capacity(items.len()); - // Tokenize each string - for item in items { - tokens_list.push(self.tokenizer.encode(&item)); - } + // Tokenize each string + for item in items { + tokens_list.push(self.tokenizer.encode(&item)); + } - // Generate padding mask for tokenized text - let mask = generate_padding_mask( - self.tokenizer.pad_token(), - tokens_list, - Some(self.max_seq_length), - &B::Device::default(), - ); + // Generate padding mask for tokenized text + let mask = generate_padding_mask( + self.tokenizer.pad_token(), + tokens_list, + Some(self.max_seq_length), + &B::Device::default(), + ); - // Create and return inference batch - TextClassificationInferenceBatch { - tokens: mask.tensor.to_device(&self.device), - mask_pad: mask.mask.to_device(&self.device), + // Create and return inference batch + TextClassificationInferenceBatch { + tokens: mask.tensor.to_device(&self.device), + mask_pad: mask.mask.to_device(&self.device), + } } - } } diff --git a/examples/text-classification/src/data/dataset.rs b/examples/text-classification/src/data/dataset.rs index 43e868879f..28ae0f2240 100644 --- a/examples/text-classification/src/data/dataset.rs +++ b/examples/text-classification/src/data/dataset.rs @@ -10,163 +10,162 @@ use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset // Define a struct for text classification items #[derive(new, Clone, Debug)] pub struct TextClassificationItem { - pub text: String, // The text for classification - pub label: usize, // The label of the text (classification category) + pub text: String, // The text for classification + pub label: usize, // The label of the text (classification category) } // Trait for text classification datasets pub trait TextClassificationDataset: Dataset { - fn num_classes() -> usize; // Returns the number of unique classes in the dataset - fn class_name(label: usize) -> String; // Returns the name of the class given its label + fn num_classes() -> usize; // Returns the number of unique classes in the dataset + fn class_name(label: usize) -> String; // Returns the name of the class given its label } // Struct for items in the AG News dataset #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct AgNewsItem { - pub text: String, // The text for classification - pub label: usize, // The label of the text (classification category) + pub text: String, // The text for classification + pub label: usize, // The label of the text (classification category) } // Struct for the AG News dataset pub struct AgNewsDataset { - dataset: SqliteDataset, // Underlying SQLite dataset + dataset: SqliteDataset, // Underlying SQLite dataset } // Implement the Dataset trait for the AG News dataset impl Dataset for AgNewsDataset { - /// Returns a specific item from the dataset - fn get(&self, index: usize) -> Option { - self - .dataset - .get(index) - .map(|item| TextClassificationItem::new(item.text, item.label)) // Map AgNewsItems to TextClassificationItems - } - - /// Returns the length of the dataset - fn len(&self) -> usize { - self.dataset.len() - } + /// Returns a specific item from the dataset + fn get(&self, index: usize) -> Option { + self.dataset + .get(index) + .map(|item| TextClassificationItem::new(item.text, item.label)) // Map AgNewsItems to TextClassificationItems + } + + /// Returns the length of the dataset + fn len(&self) -> usize { + self.dataset.len() + } } // Implement methods for constructing the AG News dataset impl AgNewsDataset { - /// Returns the training portion of the dataset - pub fn train() -> Self { - Self::new("train") - } - - /// Returns the testing portion of the dataset - pub fn test() -> Self { - Self::new("test") - } - - /// Constructs the dataset from a split (either "train" or "test") - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("ag_news") - .dataset(split) - .unwrap(); - Self { dataset } - } + /// Returns the training portion of the dataset + pub fn train() -> Self { + Self::new("train") + } + + /// Returns the testing portion of the dataset + pub fn test() -> Self { + Self::new("test") + } + + /// Constructs the dataset from a split (either "train" or "test") + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("ag_news") + .dataset(split) + .unwrap(); + Self { dataset } + } } /// Implements the TextClassificationDataset trait for the AG News dataset impl TextClassificationDataset for AgNewsDataset { - /// Returns the number of unique classes in the dataset - fn num_classes() -> usize { - 4 - } - - /// Returns the name of a class given its label - fn class_name(label: usize) -> String { - match label { - 0 => "World", - 1 => "Sports", - 2 => "Business", - 3 => "Technology", - _ => panic!("invalid class"), + /// Returns the number of unique classes in the dataset + fn num_classes() -> usize { + 4 + } + + /// Returns the name of a class given its label + fn class_name(label: usize) -> String { + match label { + 0 => "World", + 1 => "Sports", + 2 => "Business", + 3 => "Technology", + _ => panic!("invalid class"), + } + .to_string() } - .to_string() - } } /// Struct for items in the DbPedia dataset #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { - pub title: String, // The title of the item - pub content: String, // The content of the item - pub label: usize, // The label of the item (classification category) + pub title: String, // The title of the item + pub content: String, // The content of the item + pub label: usize, // The label of the item (classification category) } /// Struct for the DbPedia dataset pub struct DbPediaDataset { - dataset: SqliteDataset, // Underlying SQLite dataset + dataset: SqliteDataset, // Underlying SQLite dataset } /// Implements the Dataset trait for the DbPedia dataset impl Dataset for DbPediaDataset { - /// Returns a specific item from the dataset - fn get(&self, index: usize) -> Option { - self.dataset.get(index).map(|item| { - TextClassificationItem::new( - format!("Title: {} - Content: {}", item.title, item.content), - item.label, - ) - }) - } - - /// Returns the length of the dataset - fn len(&self) -> usize { - self.dataset.len() - } + /// Returns a specific item from the dataset + fn get(&self, index: usize) -> Option { + self.dataset.get(index).map(|item| { + TextClassificationItem::new( + format!("Title: {} - Content: {}", item.title, item.content), + item.label, + ) + }) + } + + /// Returns the length of the dataset + fn len(&self) -> usize { + self.dataset.len() + } } /// Implement methods for constructing the DbPedia dataset impl DbPediaDataset { - /// Returns the training portion of the dataset - pub fn train() -> Self { - Self::new("train") - } - - /// Returns the testing portion of the dataset - pub fn test() -> Self { - Self::new("test") - } - - /// Constructs the dataset from a split (either "train" or "test") - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") - .dataset(split) - .unwrap(); - Self { dataset } - } + /// Returns the training portion of the dataset + pub fn train() -> Self { + Self::new("train") + } + + /// Returns the testing portion of the dataset + pub fn test() -> Self { + Self::new("test") + } + + /// Constructs the dataset from a split (either "train" or "test") + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") + .dataset(split) + .unwrap(); + Self { dataset } + } } /// Implement the TextClassificationDataset trait for the DbPedia dataset impl TextClassificationDataset for DbPediaDataset { - /// Returns the number of unique classes in the dataset - fn num_classes() -> usize { - 14 - } - - /// Returns the name of a class given its label - fn class_name(label: usize) -> String { - match label { - 0 => "Company", - 1 => "EducationalInstitution", - 2 => "Artist", - 3 => "Athlete", - 4 => "OfficeHolder", - 5 => "MeanOfTransportation", - 6 => "Building", - 7 => "NaturalPlace", - 8 => "Village", - 9 => "Animal", - 10 => "Plant", - 11 => "Album", - 12 => "Film", - 13 => "WrittenWork", - _ => panic!("invalid class"), + /// Returns the number of unique classes in the dataset + fn num_classes() -> usize { + 14 + } + + /// Returns the name of a class given its label + fn class_name(label: usize) -> String { + match label { + 0 => "Company", + 1 => "EducationalInstitution", + 2 => "Artist", + 3 => "Athlete", + 4 => "OfficeHolder", + 5 => "MeanOfTransportation", + 6 => "Building", + 7 => "NaturalPlace", + 8 => "Village", + 9 => "Animal", + 10 => "Plant", + 11 => "Album", + 12 => "Film", + 13 => "WrittenWork", + _ => panic!("invalid class"), + } + .to_string() } - .to_string() - } } diff --git a/examples/text-classification/src/data/tokenizer.rs b/examples/text-classification/src/data/tokenizer.rs index 1fbe07d3f6..3d1044b365 100644 --- a/examples/text-classification/src/data/tokenizer.rs +++ b/examples/text-classification/src/data/tokenizer.rs @@ -6,62 +6,62 @@ // The `Send + Sync` bounds are necessary for allowing these operations // to work across thread boundaries. pub trait Tokenizer: Send + Sync { - /// Converts a text string into a sequence of tokens. - fn encode(&self, value: &str) -> Vec; + /// Converts a text string into a sequence of tokens. + fn encode(&self, value: &str) -> Vec; - /// Converts a sequence of tokens back into a text string. - fn decode(&self, tokens: &[usize]) -> String; + /// Converts a sequence of tokens back into a text string. + fn decode(&self, tokens: &[usize]) -> String; - /// Gets the size of the tokenizer's vocabulary. - fn vocab_size(&self) -> usize; + /// Gets the size of the tokenizer's vocabulary. + fn vocab_size(&self) -> usize; - /// Gets the token used for padding sequences to a consistent length. - fn pad_token(&self) -> usize; + /// Gets the token used for padding sequences to a consistent length. + fn pad_token(&self) -> usize; - /// Gets the string representation of the padding token. - /// The default implementation uses `decode` on the padding token. - fn pad_token_value(&self) -> String { - self.decode(&[self.pad_token()]) - } + /// Gets the string representation of the padding token. + /// The default implementation uses `decode` on the padding token. + fn pad_token_value(&self) -> String { + self.decode(&[self.pad_token()]) + } } /// Struct represents a specific tokenizer using the BERT cased tokenization strategy. pub struct BertCasedTokenizer { - // The underlying tokenizer from the `tokenizers` library. - tokenizer: tokenizers::Tokenizer, + // The underlying tokenizer from the `tokenizers` library. + tokenizer: tokenizers::Tokenizer, } // Default implementation for creating a new BertCasedTokenizer. // This uses a pretrained BERT cased tokenizer model. impl Default for BertCasedTokenizer { - fn default() -> Self { - Self { - tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), + fn default() -> Self { + Self { + tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), + } } - } } // Implementation of the Tokenizer trait for BertCasedTokenizer. impl Tokenizer for BertCasedTokenizer { - // Convert a text string into a sequence of tokens using the BERT cased tokenization strategy. - fn encode(&self, value: &str) -> Vec { - let tokens = self.tokenizer.encode(value, true).unwrap(); - tokens.get_ids().iter().map(|t| *t as usize).collect() - } + // Convert a text string into a sequence of tokens using the BERT cased tokenization strategy. + fn encode(&self, value: &str) -> Vec { + let tokens = self.tokenizer.encode(value, true).unwrap(); + tokens.get_ids().iter().map(|t| *t as usize).collect() + } - /// Converts a sequence of tokens back into a text string. - fn decode(&self, tokens: &[usize]) -> String { - let tokens = tokens.iter().map(|t| *t as u32).collect::>(); - self.tokenizer.decode(&tokens, false).unwrap() - } + /// Converts a sequence of tokens back into a text string. + fn decode(&self, tokens: &[usize]) -> String { + let tokens = tokens.iter().map(|t| *t as u32).collect::>(); + self.tokenizer.decode(&tokens, false).unwrap() + } - /// Gets the size of the BERT cased tokenizer's vocabulary. - fn vocab_size(&self) -> usize { - self.tokenizer.get_vocab_size(true) - } + /// Gets the size of the BERT cased tokenizer's vocabulary. + fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } - /// Gets the token used for padding sequences to a consistent length. - fn pad_token(&self) -> usize { - self.tokenizer.token_to_id("[PAD]").unwrap() as usize - } + /// Gets the token used for padding sequences to a consistent length. + fn pad_token(&self) -> usize { + self.tokenizer.token_to_id("[PAD]").unwrap() as usize + } } diff --git a/examples/text-classification/src/inference.rs b/examples/text-classification/src/inference.rs index 8360a5a06d..02ed961809 100644 --- a/examples/text-classification/src/inference.rs +++ b/examples/text-classification/src/inference.rs @@ -5,73 +5,73 @@ // Import required modules and types use crate::{ - data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, - model::TextClassificationModelConfig, - training::ExperimentConfig, + data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, + model::TextClassificationModelConfig, + training::ExperimentConfig, }; use burn::{ - config::Config, - data::dataloader::batcher::Batcher, - module::Module, - record::{CompactRecorder, Recorder}, - tensor::backend::Backend, + config::Config, + data::dataloader::batcher::Batcher, + module::Module, + record::{CompactRecorder, Recorder}, + tensor::backend::Backend, }; use std::sync::Arc; // Define inference function pub fn infer( - device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) - artifact_dir: &str, // Directory containing model and config files - samples: Vec, // Text samples for inference + device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) + artifact_dir: &str, // Directory containing model and config files + samples: Vec, // Text samples for inference ) { - // Load experiment configuration - let config = ExperimentConfig::load(format!("{artifact_dir}/config.json").as_str()) - .expect("Config file present"); + // Load experiment configuration + let config = ExperimentConfig::load(format!("{artifact_dir}/config.json").as_str()) + .expect("Config file present"); - // Initialize tokenizer - let tokenizer = Arc::new(BertCasedTokenizer::default()); + // Initialize tokenizer + let tokenizer = Arc::new(BertCasedTokenizer::default()); - // Get number of classes from dataset - let n_classes = D::num_classes(); + // Get number of classes from dataset + let n_classes = D::num_classes(); - // Initialize batcher for batching samples - let batcher = Arc::new(TextClassificationBatcher::::new( - tokenizer.clone(), - device.clone(), - config.max_seq_length, - )); + // Initialize batcher for batching samples + let batcher = Arc::new(TextClassificationBatcher::::new( + tokenizer.clone(), + device.clone(), + config.max_seq_length, + )); - // Load pre-trained model weights - println!("Loading weights ..."); - let record = CompactRecorder::new() - .load(format!("{artifact_dir}/model").into()) - .expect("Trained model weights"); + // Load pre-trained model weights + println!("Loading weights ..."); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into()) + .expect("Trained model weights"); - // Create model using loaded weights - println!("Creating model ..."); - let model = TextClassificationModelConfig::new( - config.transformer, - n_classes, - tokenizer.vocab_size(), - config.max_seq_length, - ) - .init_with::(record) // Initialize model with loaded weights - .to_device(&device); // Move model to computation device + // Create model using loaded weights + println!("Creating model ..."); + let model = TextClassificationModelConfig::new( + config.transformer, + n_classes, + tokenizer.vocab_size(), + config.max_seq_length, + ) + .init_with::(record) // Initialize model with loaded weights + .to_device(&device); // Move model to computation device - // Run inference on the given text samples - println!("Running inference ..."); - let item = batcher.batch(samples.clone()); // Batch samples using the batcher - let predictions = model.infer(item); // Get model predictions + // Run inference on the given text samples + println!("Running inference ..."); + let item = batcher.batch(samples.clone()); // Batch samples using the batcher + let predictions = model.infer(item); // Get model predictions - // Print out predictions for each sample - for (i, text) in samples.into_iter().enumerate() { - #[allow(clippy::single_range_in_vec_init)] - let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample - let logits = prediction.to_data(); // Convert prediction tensor to data - let class_index = prediction.argmax(1).into_data().convert::().value[0]; // Get class index with the highest value - let class = D::class_name(class_index as usize); // Get class name + // Print out predictions for each sample + for (i, text) in samples.into_iter().enumerate() { + #[allow(clippy::single_range_in_vec_init)] + let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample + let logits = prediction.to_data(); // Convert prediction tensor to data + let class_index = prediction.argmax(1).into_data().convert::().value[0]; // Get class index with the highest value + let class = D::class_name(class_index as usize); // Get class name - // Print sample text, predicted logits and predicted class - println!("\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: {class}\n================"); - } + // Print sample text, predicted logits and predicted class + println!("\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: {class}\n================"); + } } diff --git a/examples/text-classification/src/model.rs b/examples/text-classification/src/model.rs index 96fd825367..914b14576a 100644 --- a/examples/text-classification/src/model.rs +++ b/examples/text-classification/src/model.rs @@ -5,173 +5,178 @@ use crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch}; use burn::{ - config::Config, - module::Module, - nn::{ - loss::CrossEntropyLoss, - transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, - Embedding, EmbeddingConfig, Linear, LinearConfig, - }, - tensor::backend::{AutodiffBackend, Backend}, - tensor::{activation::softmax, Tensor}, - train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, + config::Config, + module::Module, + nn::{ + loss::CrossEntropyLoss, + transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, + Embedding, EmbeddingConfig, Linear, LinearConfig, + }, + tensor::backend::{AutodiffBackend, Backend}, + tensor::{activation::softmax, Tensor}, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, }; // Define the model configuration #[derive(Config)] pub struct TextClassificationModelConfig { - transformer: TransformerEncoderConfig, - n_classes: usize, - vocab_size: usize, - max_seq_length: usize, + transformer: TransformerEncoderConfig, + n_classes: usize, + vocab_size: usize, + max_seq_length: usize, } // Define the model structure #[derive(Module, Debug)] pub struct TextClassificationModel { - transformer: TransformerEncoder, - embedding_token: Embedding, - embedding_pos: Embedding, - output: Linear, - n_classes: usize, - max_seq_length: usize, + transformer: TransformerEncoder, + embedding_token: Embedding, + embedding_pos: Embedding, + output: Linear, + n_classes: usize, + max_seq_length: usize, } // Define functions for model initialization impl TextClassificationModelConfig { - /// Initializes a model with default weights - pub fn init(&self) -> TextClassificationModel { - let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(); - let transformer = self.transformer.init(); - let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); - let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); - - TextClassificationModel { - transformer, - embedding_token, - embedding_pos, - output, - n_classes: self.n_classes, - max_seq_length: self.max_seq_length, + /// Initializes a model with default weights + pub fn init(&self) -> TextClassificationModel { + let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(); + let transformer = self.transformer.init(); + let embedding_token = + EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); + let embedding_pos = + EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); + + TextClassificationModel { + transformer, + embedding_token, + embedding_pos, + output, + n_classes: self.n_classes, + max_seq_length: self.max_seq_length, + } } - } - - /// Initializes a model with provided weights - pub fn init_with( - &self, - record: TextClassificationModelRecord, - ) -> TextClassificationModel { - let output = - LinearConfig::new(self.transformer.d_model, self.n_classes).init_with(record.output); - let transformer = self.transformer.init_with(record.transformer); - let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model) - .init_with(record.embedding_token); - let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model) - .init_with(record.embedding_pos); - - TextClassificationModel { - transformer, - embedding_token, - embedding_pos, - output, - n_classes: self.n_classes, - max_seq_length: self.max_seq_length, + + /// Initializes a model with provided weights + pub fn init_with( + &self, + record: TextClassificationModelRecord, + ) -> TextClassificationModel { + let output = + LinearConfig::new(self.transformer.d_model, self.n_classes).init_with(record.output); + let transformer = self.transformer.init_with(record.transformer); + let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model) + .init_with(record.embedding_token); + let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model) + .init_with(record.embedding_pos); + + TextClassificationModel { + transformer, + embedding_token, + embedding_pos, + output, + n_classes: self.n_classes, + max_seq_length: self.max_seq_length, + } } - } } /// Define model behavior impl TextClassificationModel { - // Defines forward pass for training - pub fn forward(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { - // Get batch and sequence length, and the device - let [batch_size, seq_length] = item.tokens.dims(); - let device = &self.embedding_token.devices()[0]; - - // Move tensors to the correct device - let tokens = item.tokens.to_device(device); - let labels = item.labels.to_device(device); - let mask_pad = item.mask_pad.to_device(device); - - // Calculate token and position embeddings, and combine them - let index_positions = Tensor::arange_device(0..seq_length, device) - .reshape([1, seq_length]) - .repeat(0, batch_size); - let embedding_positions = self.embedding_pos.forward(index_positions); - let embedding_tokens = self.embedding_token.forward(tokens); - let embedding = (embedding_positions + embedding_tokens) / 2; - - // Perform transformer encoding, calculate output and loss - let encoded = self - .transformer - .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); - let output = self.output.forward(encoded); - - let output_classification = output - .slice([0..batch_size, 0..1]) - .reshape([batch_size, self.n_classes]); - - let loss = CrossEntropyLoss::default(); - let loss = loss.forward(output_classification.clone(), labels.clone()); - - // Return the output and loss - ClassificationOutput { - loss, - output: output_classification, - targets: labels, + // Defines forward pass for training + pub fn forward(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { + // Get batch and sequence length, and the device + let [batch_size, seq_length] = item.tokens.dims(); + let device = &self.embedding_token.devices()[0]; + + // Move tensors to the correct device + let tokens = item.tokens.to_device(device); + let labels = item.labels.to_device(device); + let mask_pad = item.mask_pad.to_device(device); + + // Calculate token and position embeddings, and combine them + let index_positions = Tensor::arange_device(0..seq_length, device) + .reshape([1, seq_length]) + .repeat(0, batch_size); + let embedding_positions = self.embedding_pos.forward(index_positions); + let embedding_tokens = self.embedding_token.forward(tokens); + let embedding = (embedding_positions + embedding_tokens) / 2; + + // Perform transformer encoding, calculate output and loss + let encoded = self + .transformer + .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); + let output = self.output.forward(encoded); + + let output_classification = output + .slice([0..batch_size, 0..1]) + .reshape([batch_size, self.n_classes]); + + let loss = CrossEntropyLoss::default(); + let loss = loss.forward(output_classification.clone(), labels.clone()); + + // Return the output and loss + ClassificationOutput { + loss, + output: output_classification, + targets: labels, + } + } + + /// Defines forward pass for inference + pub fn infer(&self, item: TextClassificationInferenceBatch) -> Tensor { + // Get batch and sequence length, and the device + let [batch_size, seq_length] = item.tokens.dims(); + let device = &self.embedding_token.devices()[0]; + + // Move tensors to the correct device + let tokens = item.tokens.to_device(device); + let mask_pad = item.mask_pad.to_device(device); + + // Calculate token and position embeddings, and combine them + let index_positions = Tensor::arange_device(0..seq_length, device) + .reshape([1, seq_length]) + .repeat(0, batch_size); + let embedding_positions = self.embedding_pos.forward(index_positions); + let embedding_tokens = self.embedding_token.forward(tokens); + let embedding = (embedding_positions + embedding_tokens) / 2; + + // Perform transformer encoding, calculate output and apply softmax for prediction + let encoded = self + .transformer + .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); + let output = self.output.forward(encoded); + let output = output + .slice([0..batch_size, 0..1]) + .reshape([batch_size, self.n_classes]); + + softmax(output, 1) } - } - - /// Defines forward pass for inference - pub fn infer(&self, item: TextClassificationInferenceBatch) -> Tensor { - // Get batch and sequence length, and the device - let [batch_size, seq_length] = item.tokens.dims(); - let device = &self.embedding_token.devices()[0]; - - // Move tensors to the correct device - let tokens = item.tokens.to_device(device); - let mask_pad = item.mask_pad.to_device(device); - - // Calculate token and position embeddings, and combine them - let index_positions = Tensor::arange_device(0..seq_length, device) - .reshape([1, seq_length]) - .repeat(0, batch_size); - let embedding_positions = self.embedding_pos.forward(index_positions); - let embedding_tokens = self.embedding_token.forward(tokens); - let embedding = (embedding_positions + embedding_tokens) / 2; - - // Perform transformer encoding, calculate output and apply softmax for prediction - let encoded = self - .transformer - .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); - let output = self.output.forward(encoded); - let output = output - .slice([0..batch_size, 0..1]) - .reshape([batch_size, self.n_classes]); - - softmax(output, 1) - } } /// Define training step impl TrainStep, ClassificationOutput> - for TextClassificationModel + for TextClassificationModel { - fn step(&self, item: TextClassificationTrainingBatch) -> TrainOutput> { - // Run forward pass, calculate gradients and return them along with the output - let item = self.forward(item); - let grads = item.loss.backward(); - - TrainOutput::new(self, grads, item) - } + fn step( + &self, + item: TextClassificationTrainingBatch, + ) -> TrainOutput> { + // Run forward pass, calculate gradients and return them along with the output + let item = self.forward(item); + let grads = item.loss.backward(); + + TrainOutput::new(self, grads, item) + } } /// Define validation step impl ValidStep, ClassificationOutput> - for TextClassificationModel + for TextClassificationModel { - fn step(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { - // Run forward pass and return the output - self.forward(item) - } + fn step(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { + // Run forward pass and return the output + self.forward(item) + } } diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index 241db1abda..eed1ddc4d0 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -6,109 +6,112 @@ // then saved to the specified directory. use crate::{ - data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, - model::TextClassificationModelConfig, + data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, + model::TextClassificationModelConfig, }; use burn::{ - config::Config, - data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset}, - lr_scheduler::noam::NoamLrSchedulerConfig, - module::Module, - nn::transformer::TransformerEncoderConfig, - optim::AdamConfig, - record::{CompactRecorder, Recorder}, - tensor::backend::AutodiffBackend, - train::{ - metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, - LearnerBuilder, - }, + config::Config, + data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset}, + lr_scheduler::noam::NoamLrSchedulerConfig, + module::Module, + nn::transformer::TransformerEncoderConfig, + optim::AdamConfig, + record::{CompactRecorder, Recorder}, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, + LearnerBuilder, + }, }; use std::sync::Arc; // Define configuration struct for the experiment #[derive(Config)] pub struct ExperimentConfig { - pub transformer: TransformerEncoderConfig, - pub optimizer: AdamConfig, - #[config(default = 256)] - pub max_seq_length: usize, - #[config(default = 32)] - pub batch_size: usize, - #[config(default = 5)] - pub num_epochs: usize, + pub transformer: TransformerEncoderConfig, + pub optimizer: AdamConfig, + #[config(default = 256)] + pub max_seq_length: usize, + #[config(default = 32)] + pub batch_size: usize, + #[config(default = 5)] + pub num_epochs: usize, } // Define train function pub fn train( - device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) - dataset_train: D, // Training dataset - dataset_test: D, // Testing dataset - config: ExperimentConfig, // Experiment configuration - artifact_dir: &str, // Directory to save model and config files + device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) + dataset_train: D, // Training dataset + dataset_test: D, // Testing dataset + config: ExperimentConfig, // Experiment configuration + artifact_dir: &str, // Directory to save model and config files ) { - // Initialize tokenizer - let tokenizer = Arc::new(BertCasedTokenizer::default()); + // Initialize tokenizer + let tokenizer = Arc::new(BertCasedTokenizer::default()); - // Initialize batchers for training and testing data - let batcher_train = - TextClassificationBatcher::::new(tokenizer.clone(), device.clone(), config.max_seq_length); - let batcher_test = TextClassificationBatcher::::new( - tokenizer.clone(), - device.clone(), - config.max_seq_length, - ); + // Initialize batchers for training and testing data + let batcher_train = TextClassificationBatcher::::new( + tokenizer.clone(), + device.clone(), + config.max_seq_length, + ); + let batcher_test = TextClassificationBatcher::::new( + tokenizer.clone(), + device.clone(), + config.max_seq_length, + ); - // Initialize model - let model = TextClassificationModelConfig::new( - config.transformer.clone(), - D::num_classes(), - tokenizer.vocab_size(), - config.max_seq_length, - ) - .init(); + // Initialize model + let model = TextClassificationModelConfig::new( + config.transformer.clone(), + D::num_classes(), + tokenizer.vocab_size(), + config.max_seq_length, + ) + .init(); - // Initialize data loaders for training and testing data - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_train, 50_000)); - let dataloader_test = DataLoaderBuilder::new(batcher_test) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_test, 5_000)); + // Initialize data loaders for training and testing data + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_train, 50_000)); + let dataloader_test = DataLoaderBuilder::new(batcher_test) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_test, 5_000)); - // Initialize optimizer - let optim = config.optimizer.init(); + // Initialize optimizer + let optim = config.optimizer.init(); - // Initialize learning rate scheduler - let lr_scheduler = NoamLrSchedulerConfig::new(0.25) - .with_warmup_steps(1000) - .with_model_size(config.transformer.d_model) - .init(); + // Initialize learning rate scheduler + let lr_scheduler = NoamLrSchedulerConfig::new(0.25) + .with_warmup_steps(1000) + .with_model_size(config.transformer.d_model) + .init(); - // Initialize learner - let learner = LearnerBuilder::new(artifact_dir) - .metric_train(CUDAMetric::new()) - .metric_valid(CUDAMetric::new()) - .metric_train(AccuracyMetric::new()) - .metric_valid(AccuracyMetric::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .metric_train_numeric(LearningRateMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device]) - .num_epochs(config.num_epochs) - .build(model, optim, lr_scheduler); + // Initialize learner + let learner = LearnerBuilder::new(artifact_dir) + .metric_train(CUDAMetric::new()) + .metric_valid(CUDAMetric::new()) + .metric_train(AccuracyMetric::new()) + .metric_valid(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .metric_train_numeric(LearningRateMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device]) + .num_epochs(config.num_epochs) + .build(model, optim, lr_scheduler); - // Train the model - let model_trained = learner.fit(dataloader_train, dataloader_test); + // Train the model + let model_trained = learner.fit(dataloader_train, dataloader_test); - // Save the configuration and the trained model - config.save(format!("{artifact_dir}/config.json")).unwrap(); - CompactRecorder::new() - .record( - model_trained.into_record(), - format!("{artifact_dir}/model").into(), - ) - .unwrap(); + // Save the configuration and the trained model + config.save(format!("{artifact_dir}/config.json")).unwrap(); + CompactRecorder::new() + .record( + model_trained.into_record(), + format!("{artifact_dir}/model").into(), + ) + .unwrap(); } diff --git a/examples/text-generation/examples/text-generation.rs b/examples/text-generation/examples/text-generation.rs index 3b58b793a3..6b6823f5ff 100644 --- a/examples/text-generation/examples/text-generation.rs +++ b/examples/text-generation/examples/text-generation.rs @@ -9,20 +9,21 @@ type Elem = f32; type Backend = burn::backend::Autodiff>; fn main() { - let config = ExperimentConfig::new( - burn::nn::transformer::TransformerEncoderConfig::new(384, 1536, 12, 6).with_norm_first(true), - burn::optim::AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1.0e-6))), - ); + let config = ExperimentConfig::new( + burn::nn::transformer::TransformerEncoderConfig::new(384, 1536, 12, 6) + .with_norm_first(true), + burn::optim::AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1.0e-6))), + ); - text_generation::training::train::( - if cfg!(target_os = "macos") { - burn::tensor::Device::::Mps - } else { - burn::tensor::Device::::Cuda(0) - }, - DbPediaDataset::train(), - DbPediaDataset::test(), - config, - "/tmp/text-generation", - ); + text_generation::training::train::( + if cfg!(target_os = "macos") { + burn::tensor::Device::::Mps + } else { + burn::tensor::Device::::Cuda(0) + }, + DbPediaDataset::train(), + DbPediaDataset::test(), + config, + "/tmp/text-generation", + ); } diff --git a/examples/text-generation/src/data/batcher.rs b/examples/text-generation/src/data/batcher.rs index acaff5e33a..598676eafa 100644 --- a/examples/text-generation/src/data/batcher.rs +++ b/examples/text-generation/src/data/batcher.rs @@ -1,66 +1,66 @@ use super::{dataset::TextGenerationItem, tokenizer::Tokenizer}; use burn::{ - data::dataloader::batcher::Batcher, - nn::attention::generate_padding_mask, - tensor::{backend::Backend, Bool, Int, Tensor}, + data::dataloader::batcher::Batcher, + nn::attention::generate_padding_mask, + tensor::{backend::Backend, Bool, Int, Tensor}, }; use std::sync::Arc; #[derive(new)] pub struct TextGenerationBatcher { - tokenizer: Arc, - max_seq_length: usize, + tokenizer: Arc, + max_seq_length: usize, } #[derive(Debug, Clone, new)] pub struct TextGenerationBatch { - pub tokens: Tensor, - pub mask_pad: Tensor, + pub tokens: Tensor, + pub mask_pad: Tensor, } #[derive(Debug, Clone, new)] pub struct TrainingTextGenerationBatch { - pub tokens_inputs: Tensor, - pub targets: Tensor, - pub mask_pad: Tensor, + pub tokens_inputs: Tensor, + pub targets: Tensor, + pub mask_pad: Tensor, } impl Batcher> for TextGenerationBatcher { - fn batch(&self, items: Vec) -> TextGenerationBatch { - let mut tokens_list = Vec::with_capacity(items.len()); + fn batch(&self, items: Vec) -> TextGenerationBatch { + let mut tokens_list = Vec::with_capacity(items.len()); - for item in items { - tokens_list.push(self.tokenizer.encode(&item.text, true)); - } + for item in items { + tokens_list.push(self.tokenizer.encode(&item.text, true)); + } - let mask = generate_padding_mask( - self.tokenizer.pad_token(), - tokens_list, - Some(self.max_seq_length), - &B::Device::default(), - ); + let mask = generate_padding_mask( + self.tokenizer.pad_token(), + tokens_list, + Some(self.max_seq_length), + &B::Device::default(), + ); - TextGenerationBatch { - tokens: mask.tensor, - mask_pad: mask.mask, + TextGenerationBatch { + tokens: mask.tensor, + mask_pad: mask.mask, + } } - } } impl Batcher> - for TextGenerationBatcher + for TextGenerationBatcher { - fn batch(&self, items: Vec) -> TrainingTextGenerationBatch { - let item: TextGenerationBatch = self.batch(items); - let [batch_size, seq_length] = item.tokens.dims(); + fn batch(&self, items: Vec) -> TrainingTextGenerationBatch { + let item: TextGenerationBatch = self.batch(items); + let [batch_size, seq_length] = item.tokens.dims(); - let inputs = item - .tokens - .clone() - .slice([0..batch_size, 0..seq_length - 1]); - let targets = item.tokens.slice([0..batch_size, 1..seq_length]); - let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]); + let inputs = item + .tokens + .clone() + .slice([0..batch_size, 0..seq_length - 1]); + let targets = item.tokens.slice([0..batch_size, 1..seq_length]); + let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]); - TrainingTextGenerationBatch::new(inputs, targets, mask_pad) - } + TrainingTextGenerationBatch::new(inputs, targets, mask_pad) + } } diff --git a/examples/text-generation/src/data/dataset.rs b/examples/text-generation/src/data/dataset.rs index b22b3f6598..f198143582 100644 --- a/examples/text-generation/src/data/dataset.rs +++ b/examples/text-generation/src/data/dataset.rs @@ -2,43 +2,42 @@ use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset #[derive(new, Clone, Debug)] pub struct TextGenerationItem { - pub text: String, + pub text: String, } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { - pub content: String, + pub content: String, } pub struct DbPediaDataset { - dataset: SqliteDataset, + dataset: SqliteDataset, } impl Dataset for DbPediaDataset { - fn get(&self, index: usize) -> Option { - self - .dataset - .get(index) - .map(|item| TextGenerationItem::new(item.content)) - } + fn get(&self, index: usize) -> Option { + self.dataset + .get(index) + .map(|item| TextGenerationItem::new(item.content)) + } - fn len(&self) -> usize { - self.dataset.len() - } + fn len(&self) -> usize { + self.dataset.len() + } } impl DbPediaDataset { - pub fn train() -> Self { - Self::new("train") - } + pub fn train() -> Self { + Self::new("train") + } - pub fn test() -> Self { - Self::new("test") - } - pub fn new(split: &str) -> Self { - let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") - .dataset(split) - .unwrap(); - Self { dataset } - } + pub fn test() -> Self { + Self::new("test") + } + pub fn new(split: &str) -> Self { + let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") + .dataset(split) + .unwrap(); + Self { dataset } + } } diff --git a/examples/text-generation/src/data/tokenizer.rs b/examples/text-generation/src/data/tokenizer.rs index 25296cb893..cf6fc81bae 100644 --- a/examples/text-generation/src/data/tokenizer.rs +++ b/examples/text-generation/src/data/tokenizer.rs @@ -1,93 +1,93 @@ pub trait Tokenizer: Send + Sync { - fn encode(&self, value: &str, special_tokens: bool) -> Vec; - fn decode(&self, tokens: &[usize]) -> String; - fn vocab_size(&self) -> usize; - fn pad_token(&self) -> usize; - fn start_token(&self) -> usize; - fn end_token(&self) -> usize; - fn pad_token_value(&self) -> String { - self.decode(&[self.pad_token()]) - } - fn start_token_value(&self) -> String { - self.decode(&[self.start_token()]) - } - fn end_token_value(&self) -> String { - self.decode(&[self.end_token()]) - } + fn encode(&self, value: &str, special_tokens: bool) -> Vec; + fn decode(&self, tokens: &[usize]) -> String; + fn vocab_size(&self) -> usize; + fn pad_token(&self) -> usize; + fn start_token(&self) -> usize; + fn end_token(&self) -> usize; + fn pad_token_value(&self) -> String { + self.decode(&[self.pad_token()]) + } + fn start_token_value(&self) -> String { + self.decode(&[self.start_token()]) + } + fn end_token_value(&self) -> String { + self.decode(&[self.end_token()]) + } } pub struct Gpt2Tokenizer { - tokenizer: tokenizers::Tokenizer, + tokenizer: tokenizers::Tokenizer, } impl Default for Gpt2Tokenizer { - fn default() -> Self { - let mut tokenizer = tokenizers::Tokenizer::from_pretrained("gpt2", None).unwrap(); - tokenizer.add_special_tokens(&[ - tokenizers::AddedToken::from("[START]", true), - tokenizers::AddedToken::from("[END]", true), - tokenizers::AddedToken::from("[PAD]", true), - ]); - - Self { tokenizer } - } + fn default() -> Self { + let mut tokenizer = tokenizers::Tokenizer::from_pretrained("gpt2", None).unwrap(); + tokenizer.add_special_tokens(&[ + tokenizers::AddedToken::from("[START]", true), + tokenizers::AddedToken::from("[END]", true), + tokenizers::AddedToken::from("[PAD]", true), + ]); + + Self { tokenizer } + } } impl Tokenizer for Gpt2Tokenizer { - fn encode(&self, value: &str, special_tokens: bool) -> Vec { - let text = match special_tokens { - true => "[START]".to_owned() + value + "[END]", - false => value.to_string(), - }; - let tokens = self.tokenizer.encode(text, true).unwrap(); - tokens.get_ids().iter().map(|t| *t as usize).collect() - } - - fn decode(&self, tokens: &[usize]) -> String { - let tokens = tokens.iter().map(|t| *t as u32).collect::>(); - self.tokenizer.decode(&tokens, false).unwrap() - } - - fn vocab_size(&self) -> usize { - self.tokenizer.get_vocab_size(true) - } - - fn pad_token(&self) -> usize { - self.tokenizer.token_to_id("[PAD]").unwrap() as usize - } - - fn start_token(&self) -> usize { - self.tokenizer.token_to_id("[START]").unwrap() as usize - } - - fn end_token(&self) -> usize { - self.tokenizer.token_to_id("[END]").unwrap() as usize - } + fn encode(&self, value: &str, special_tokens: bool) -> Vec { + let text = match special_tokens { + true => "[START]".to_owned() + value + "[END]", + false => value.to_string(), + }; + let tokens = self.tokenizer.encode(text, true).unwrap(); + tokens.get_ids().iter().map(|t| *t as usize).collect() + } + + fn decode(&self, tokens: &[usize]) -> String { + let tokens = tokens.iter().map(|t| *t as u32).collect::>(); + self.tokenizer.decode(&tokens, false).unwrap() + } + + fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } + + fn pad_token(&self) -> usize { + self.tokenizer.token_to_id("[PAD]").unwrap() as usize + } + + fn start_token(&self) -> usize { + self.tokenizer.token_to_id("[START]").unwrap() as usize + } + + fn end_token(&self) -> usize { + self.tokenizer.token_to_id("[END]").unwrap() as usize + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_encode_decode() { - let tokenizer = Gpt2Tokenizer::default(); - let text = "A sentence"; + #[test] + fn test_encode_decode() { + let tokenizer = Gpt2Tokenizer::default(); + let text = "A sentence"; - let tokens = tokenizer.encode(text, false); - let decoded = tokenizer.decode(&tokens); + let tokens = tokenizer.encode(text, false); + let decoded = tokenizer.decode(&tokens); - assert_eq!(decoded, text); - } + assert_eq!(decoded, text); + } - #[test] - fn test_add_start_end_token() { - let tokenizer = Gpt2Tokenizer::default(); - let text = "A sentence"; + #[test] + fn test_add_start_end_token() { + let tokenizer = Gpt2Tokenizer::default(); + let text = "A sentence"; - let tokens_without = tokenizer.encode(text, false); - let tokens_with = tokenizer.encode(text, true); + let tokens_without = tokenizer.encode(text, false); + let tokens_with = tokenizer.encode(text, true); - assert_eq!(tokens_with.len() - 2, tokens_without.len()); - } + assert_eq!(tokens_with.len() - 2, tokens_without.len()); + } } diff --git a/examples/text-generation/src/model.rs b/examples/text-generation/src/model.rs index 099811a2e4..6e23121424 100644 --- a/examples/text-generation/src/model.rs +++ b/examples/text-generation/src/model.rs @@ -1,111 +1,116 @@ use crate::data::TrainingTextGenerationBatch; use burn::{ - config::Config, - module::Module, - nn::{ - attention::generate_autoregressive_mask, - loss::CrossEntropyLossConfig, - transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, - Embedding, EmbeddingConfig, Linear, LinearConfig, - }, - tensor::backend::{AutodiffBackend, Backend}, - tensor::Tensor, - train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, + config::Config, + module::Module, + nn::{ + attention::generate_autoregressive_mask, + loss::CrossEntropyLossConfig, + transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, + Embedding, EmbeddingConfig, Linear, LinearConfig, + }, + tensor::backend::{AutodiffBackend, Backend}, + tensor::Tensor, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, }; #[derive(Config)] pub struct TextGenerationModelConfig { - transformer: TransformerEncoderConfig, - vocab_size: usize, - pad_token: usize, - max_seq_length: usize, + transformer: TransformerEncoderConfig, + vocab_size: usize, + pad_token: usize, + max_seq_length: usize, } #[derive(Module, Debug)] pub struct TextGenerationModel { - transformer: TransformerEncoder, - embedding_token: Embedding, - embedding_pos: Embedding, - output: Linear, - vocab_size: usize, - pad_token: usize, - max_seq_length: usize, + transformer: TransformerEncoder, + embedding_token: Embedding, + embedding_pos: Embedding, + output: Linear, + vocab_size: usize, + pad_token: usize, + max_seq_length: usize, } impl TextGenerationModelConfig { - pub fn init(&self) -> TextGenerationModel { - let output = LinearConfig::new(self.transformer.d_model, self.vocab_size).init(); - let transformer = self.transformer.init(); - let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); - let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); + pub fn init(&self) -> TextGenerationModel { + let output = LinearConfig::new(self.transformer.d_model, self.vocab_size).init(); + let transformer = self.transformer.init(); + let embedding_token = + EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(); + let embedding_pos = + EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(); - TextGenerationModel { - transformer, - embedding_token, - embedding_pos, - output, - vocab_size: self.vocab_size, - pad_token: self.pad_token, - max_seq_length: self.max_seq_length, + TextGenerationModel { + transformer, + embedding_token, + embedding_pos, + output, + vocab_size: self.vocab_size, + pad_token: self.pad_token, + max_seq_length: self.max_seq_length, + } } - } } impl TextGenerationModel { - pub fn forward_training(&self, item: TrainingTextGenerationBatch) -> ClassificationOutput { - let [batch_size, seq_length] = item.tokens_inputs.dims(); - let device = &self.devices()[0]; + pub fn forward_training( + &self, + item: TrainingTextGenerationBatch, + ) -> ClassificationOutput { + let [batch_size, seq_length] = item.tokens_inputs.dims(); + let device = &self.devices()[0]; - let inputs = item.tokens_inputs.to_device(device); - let targets = item.targets.to_device(device); - let mask_pad = item.mask_pad.to_device(device); + let inputs = item.tokens_inputs.to_device(device); + let targets = item.targets.to_device(device); + let mask_pad = item.mask_pad.to_device(device); - let index_positions = Tensor::arange_device(0..seq_length, device) - .reshape([1, seq_length]) - .repeat(0, batch_size); + let index_positions = Tensor::arange_device(0..seq_length, device) + .reshape([1, seq_length]) + .repeat(0, batch_size); - let embedding_positions = self.embedding_pos.forward(index_positions); - let embedding_tokens = self.embedding_token.forward(inputs); - let embedding = (embedding_positions + embedding_tokens) / 2; + let embedding_positions = self.embedding_pos.forward(index_positions); + let embedding_tokens = self.embedding_token.forward(inputs); + let embedding = (embedding_positions + embedding_tokens) / 2; - let mask_attn = generate_autoregressive_mask::(batch_size, seq_length, device); - let encoded = self.transformer.forward( - TransformerEncoderInput::new(embedding) - .mask_pad(mask_pad) - .mask_attn(mask_attn), - ); + let mask_attn = generate_autoregressive_mask::(batch_size, seq_length, device); + let encoded = self.transformer.forward( + TransformerEncoderInput::new(embedding) + .mask_pad(mask_pad) + .mask_attn(mask_attn), + ); - let output = self.output.forward(encoded); - let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]); - let targets_flatten = targets.reshape([batch_size * seq_length]); + let output = self.output.forward(encoded); + let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]); + let targets_flatten = targets.reshape([batch_size * seq_length]); - let loss = CrossEntropyLossConfig::new() - .with_pad_tokens(Some(vec![self.pad_token])) - .init(); - let loss = loss.forward(output_flatten.clone(), targets_flatten.clone()); + let loss = CrossEntropyLossConfig::new() + .with_pad_tokens(Some(vec![self.pad_token])) + .init(); + let loss = loss.forward(output_flatten.clone(), targets_flatten.clone()); - ClassificationOutput { - loss, - output: output_flatten, - targets: targets_flatten, + ClassificationOutput { + loss, + output: output_flatten, + targets: targets_flatten, + } } - } } impl TrainStep, ClassificationOutput> - for TextGenerationModel + for TextGenerationModel { - fn step(&self, item: TrainingTextGenerationBatch) -> TrainOutput> { - let item = self.forward_training(item); - let grads = item.loss.backward(); + fn step(&self, item: TrainingTextGenerationBatch) -> TrainOutput> { + let item = self.forward_training(item); + let grads = item.loss.backward(); - TrainOutput::new(self, grads, item) - } + TrainOutput::new(self, grads, item) + } } impl ValidStep, ClassificationOutput> - for TextGenerationModel + for TextGenerationModel { - fn step(&self, item: TrainingTextGenerationBatch) -> ClassificationOutput { - self.forward_training(item) - } + fn step(&self, item: TrainingTextGenerationBatch) -> ClassificationOutput { + self.forward_training(item) + } } diff --git a/examples/text-generation/src/training.rs b/examples/text-generation/src/training.rs index e59475e952..782012b8ba 100644 --- a/examples/text-generation/src/training.rs +++ b/examples/text-generation/src/training.rs @@ -1,94 +1,94 @@ use crate::{ - data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer}, - model::TextGenerationModelConfig, + data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer}, + model::TextGenerationModelConfig, }; use burn::data::dataset::transform::SamplerDataset; use burn::{ - config::Config, - data::{dataloader::DataLoaderBuilder, dataset::Dataset}, - lr_scheduler::noam::NoamLrSchedulerConfig, - module::Module, - nn::transformer::TransformerEncoderConfig, - optim::AdamConfig, - record::{CompactRecorder, DefaultRecorder, Recorder}, - tensor::backend::AutodiffBackend, - train::{ - metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, - LearnerBuilder, - }, + config::Config, + data::{dataloader::DataLoaderBuilder, dataset::Dataset}, + lr_scheduler::noam::NoamLrSchedulerConfig, + module::Module, + nn::transformer::TransformerEncoderConfig, + optim::AdamConfig, + record::{CompactRecorder, DefaultRecorder, Recorder}, + tensor::backend::AutodiffBackend, + train::{ + metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, + LearnerBuilder, + }, }; use std::sync::Arc; #[derive(Config)] pub struct ExperimentConfig { - transformer: TransformerEncoderConfig, - optimizer: AdamConfig, - #[config(default = 512)] - max_seq_length: usize, - #[config(default = 6)] - batch_size: usize, - #[config(default = 50)] - num_epochs: usize, + transformer: TransformerEncoderConfig, + optimizer: AdamConfig, + #[config(default = 512)] + max_seq_length: usize, + #[config(default = 6)] + batch_size: usize, + #[config(default = 50)] + num_epochs: usize, } pub fn train + 'static>( - device: B::Device, - dataset_train: D, - dataset_test: D, - config: ExperimentConfig, - artifact_dir: &str, + device: B::Device, + dataset_train: D, + dataset_test: D, + config: ExperimentConfig, + artifact_dir: &str, ) { - let tokenizer = Arc::new(Gpt2Tokenizer::default()); - let batcher_train = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); - let batcher_test = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); + let tokenizer = Arc::new(Gpt2Tokenizer::default()); + let batcher_train = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); + let batcher_test = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); - let model = TextGenerationModelConfig::new( - config.transformer.clone(), - tokenizer.vocab_size(), - tokenizer.pad_token(), - config.max_seq_length, - ) - .init::(); + let model = TextGenerationModelConfig::new( + config.transformer.clone(), + tokenizer.vocab_size(), + tokenizer.pad_token(), + config.max_seq_length, + ) + .init::(); - let dataloader_train = DataLoaderBuilder::new(batcher_train) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_train, 10_000)); + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_train, 10_000)); - let dataloader_test = DataLoaderBuilder::new(batcher_test) - .batch_size(config.batch_size) - .num_workers(4) - .build(SamplerDataset::new(dataset_test, 1000)); + let dataloader_test = DataLoaderBuilder::new(batcher_test) + .batch_size(config.batch_size) + .num_workers(4) + .build(SamplerDataset::new(dataset_test, 1000)); - let accum = 6; // Effective batch size = 6 * 6 = 32. - let optim = config.optimizer.init(); - let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64) - .with_warmup_steps(6000) - .with_model_size(config.transformer.d_model) - .init(); + let accum = 6; // Effective batch size = 6 * 6 = 32. + let optim = config.optimizer.init(); + let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64) + .with_warmup_steps(6000) + .with_model_size(config.transformer.d_model) + .init(); - let learner = LearnerBuilder::new(artifact_dir) - .metric_train(CUDAMetric::new()) - .metric_valid(CUDAMetric::new()) - .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) - .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) - .metric_train(LossMetric::new()) - .metric_valid(LossMetric::new()) - .metric_train_numeric(LearningRateMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device]) - .grads_accumulation(accum) - .num_epochs(config.num_epochs) - .build(model, optim, lr_scheduler); + let learner = LearnerBuilder::new(artifact_dir) + .metric_train(CUDAMetric::new()) + .metric_valid(CUDAMetric::new()) + .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) + .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) + .metric_train(LossMetric::new()) + .metric_valid(LossMetric::new()) + .metric_train_numeric(LearningRateMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device]) + .grads_accumulation(accum) + .num_epochs(config.num_epochs) + .build(model, optim, lr_scheduler); - let model_trained = learner.fit(dataloader_train, dataloader_test); + let model_trained = learner.fit(dataloader_train, dataloader_test); - config.save(format!("{artifact_dir}/config.json")).unwrap(); + config.save(format!("{artifact_dir}/config.json")).unwrap(); - DefaultRecorder::new() - .record( - model_trained.into_record(), - format!("{artifact_dir}/model").into(), - ) - .unwrap(); + DefaultRecorder::new() + .record( + model_trained.into_record(), + format!("{artifact_dir}/model").into(), + ) + .unwrap(); } diff --git a/xtask/src/main.rs b/xtask/src/main.rs index b33db64842..9658a66bb9 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -6,29 +6,29 @@ mod runchecks; #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Args { - #[command(subcommand)] - command: Command, + #[command(subcommand)] + command: Command, } #[derive(Subcommand)] enum Command { - /// Publish a crate to crates.io - Publish { - /// The name of the crate to publish on crates.io - name: String, - }, - /// Run the specified `burn` tests and checks locally. - RunChecks { - /// The environment to run checks against - env: runchecks::CheckType, - }, + /// Publish a crate to crates.io + Publish { + /// The name of the crate to publish on crates.io + name: String, + }, + /// Run the specified `burn` tests and checks locally. + RunChecks { + /// The environment to run checks against + env: runchecks::CheckType, + }, } fn main() -> anyhow::Result<()> { - let args = Args::parse(); + let args = Args::parse(); - match args.command { - Command::RunChecks { env } => runchecks::run(env), - Command::Publish { name } => publish::run(name), - } + match args.command { + Command::RunChecks { env } => runchecks::run(env), + Command::Publish { name } => publish::run(name), + } } diff --git a/xtask/src/publish.rs b/xtask/src/publish.rs index 0d13a3bd29..bff5672166 100644 --- a/xtask/src/publish.rs +++ b/xtask/src/publish.rs @@ -13,116 +13,116 @@ const CRATES_IO_API_TOKEN: &str = "CRATES_IO_API_TOKEN"; // Obtain local crate version fn local_version(crate_name: &str) -> String { - // Obtain local crate version contained in cargo pkgid data - let cargo_pkgid_output = Command::new("cargo") - .args(["pkgid", "-p", crate_name]) - .output() - .expect("Failed to run cargo pkgid"); - - // Convert cargo pkgid output into a str - let cargo_pkgid_str = - str::from_utf8(&cargo_pkgid_output.stdout).expect("Failed to convert pkgid output into a str"); - - // Extract only the local crate version from str - let (_, local_version) = cargo_pkgid_str - .split_once('#') - .expect("Failed to get local crate version"); - - local_version.trim_end().to_string() + // Obtain local crate version contained in cargo pkgid data + let cargo_pkgid_output = Command::new("cargo") + .args(["pkgid", "-p", crate_name]) + .output() + .expect("Failed to run cargo pkgid"); + + // Convert cargo pkgid output into a str + let cargo_pkgid_str = str::from_utf8(&cargo_pkgid_output.stdout) + .expect("Failed to convert pkgid output into a str"); + + // Extract only the local crate version from str + let (_, local_version) = cargo_pkgid_str + .split_once('#') + .expect("Failed to get local crate version"); + + local_version.trim_end().to_string() } // Obtain remote crate version fn remote_version(crate_name: &str) -> Option { - // Obtain remote crate version contained in cargo search data - let cargo_search_output = Command::new("cargo") - .args(["search", crate_name, "--limit", "1"]) - .output() - .expect("Failed to run cargo search"); - - // Cargo search returns an empty string in case of a crate not present on - // crates.io - if cargo_search_output.stdout.is_empty() { - None - } else { - // Convert cargo search output into a str - let remote_version_str = str::from_utf8(&cargo_search_output.stdout) - .expect("Failed to convert cargo search output into a str"); - - // Extract only the remote crate version from str - remote_version_str - .split_once('=') - .and_then(|(_, second)| second.trim_start().split_once(' ')) - .map(|(s, _)| s.trim_matches('"').to_string()) - } + // Obtain remote crate version contained in cargo search data + let cargo_search_output = Command::new("cargo") + .args(["search", crate_name, "--limit", "1"]) + .output() + .expect("Failed to run cargo search"); + + // Cargo search returns an empty string in case of a crate not present on + // crates.io + if cargo_search_output.stdout.is_empty() { + None + } else { + // Convert cargo search output into a str + let remote_version_str = str::from_utf8(&cargo_search_output.stdout) + .expect("Failed to convert cargo search output into a str"); + + // Extract only the remote crate version from str + remote_version_str + .split_once('=') + .and_then(|(_, second)| second.trim_start().split_once(' ')) + .map(|(s, _)| s.trim_matches('"').to_string()) + } } // Run cargo publish fn cargo_publish(params: &[&str]) { - // Run cargo publish - let mut cargo_publish = Command::new("cargo") - .arg("publish") - .arg("--color=always") - .args(params) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to run cargo publish"); - - // Wait for cargo publish command to finish - let status = cargo_publish - .wait() - .expect("Failed to wait for cargo publish child process"); - - // If exit status is not a success, terminate the process with an error - if !status.success() { - // Use the exit code associated to a command to terminate the process, - // if any exit code had been found, use the default value 1 - std::process::exit(status.code().unwrap_or(1)); - } + // Run cargo publish + let mut cargo_publish = Command::new("cargo") + .arg("publish") + .arg("--color=always") + .args(params) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect("Failed to run cargo publish"); + + // Wait for cargo publish command to finish + let status = cargo_publish + .wait() + .expect("Failed to wait for cargo publish child process"); + + // If exit status is not a success, terminate the process with an error + if !status.success() { + // Use the exit code associated to a command to terminate the process, + // if any exit code had been found, use the default value 1 + std::process::exit(status.code().unwrap_or(1)); + } } // Publishes a crate fn publish(crate_name: String) { - // Run cargo publish --dry-run - cargo_publish(&["-p", &crate_name, "--dry-run"]); + // Run cargo publish --dry-run + cargo_publish(&["-p", &crate_name, "--dry-run"]); - let crates_io_token = - env::var(CRATES_IO_API_TOKEN).expect("Failed to retrieve the crates.io API token"); + let crates_io_token = + env::var(CRATES_IO_API_TOKEN).expect("Failed to retrieve the crates.io API token"); - // Publish crate - cargo_publish(&["-p", &crate_name, "--token", &crates_io_token]); + // Publish crate + cargo_publish(&["-p", &crate_name, "--token", &crates_io_token]); } pub fn run(crate_name: String) -> anyhow::Result<()> { - println!("Publishing {crate_name}...\n"); + println!("Publishing {crate_name}...\n"); - // Retrieve local version for crate - let local_version = local_version(&crate_name); + // Retrieve local version for crate + let local_version = local_version(&crate_name); - // Print local version for crate - println!("{crate_name} local version: {local_version}"); - - // Retrieve remote version for crate - // - // If remote version is None, the crate will be published for the first time - // on crates.io - if let Some(remote_version) = remote_version(&crate_name) { // Print local version for crate - println!("{crate_name} remote version: {remote_version}\n"); - - // If local and remote versions are equal, do not publish - if local_version == remote_version { - println!("Remote version {remote_version} is up to date, skipping deployment"); + println!("{crate_name} local version: {local_version}"); + + // Retrieve remote version for crate + // + // If remote version is None, the crate will be published for the first time + // on crates.io + if let Some(remote_version) = remote_version(&crate_name) { + // Print local version for crate + println!("{crate_name} remote version: {remote_version}\n"); + + // If local and remote versions are equal, do not publish + if local_version == remote_version { + println!("Remote version {remote_version} is up to date, skipping deployment"); + } else { + // Publish crate + publish(crate_name); + } } else { - // Publish crate - publish(crate_name); + // Print crate publishing message + println!("\nFirst time publishing {crate_name} on crates.io!\n"); + // Publish crate + publish(crate_name); } - } else { - // Print crate publishing message - println!("\nFirst time publishing {crate_name} on crates.io!\n"); - // Publish crate - publish(crate_name); - } - Ok(()) + Ok(()) } diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 062b7178dc..4d790aa3f7 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -15,169 +15,169 @@ const ARM_TARGET: &str = "thumbv7m-none-eabi"; // Handle child process fn handle_child_process(mut child: Child, error: &str) { - // Wait for the child process to finish - let status = child.wait().expect(error); - - // If exit status is not a success, terminate the process with an error - if !status.success() { - // Use the exit code associated to a command to terminate the process, - // if any exit code had been found, use the default value 1 - std::process::exit(status.code().unwrap_or(1)); - } + // Wait for the child process to finish + let status = child.wait().expect(error); + + // If exit status is not a success, terminate the process with an error + if !status.success() { + // Use the exit code associated to a command to terminate the process, + // if any exit code had been found, use the default value 1 + std::process::exit(status.code().unwrap_or(1)); + } } // Run a command fn run_command(command: &str, args: &[&str], command_error: &str, child_error: &str) { - // Format command - println!("{command} {}\n\n", args.join(" ")); - - // Run command as child process - let command = Command::new(command) - .args(args) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect(command_error); - - // Handle command child process - handle_child_process(command, child_error); + // Format command + println!("{command} {}\n\n", args.join(" ")); + + // Run command as child process + let command = Command::new(command) + .args(args) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect(command_error); + + // Handle command child process + handle_child_process(command, child_error); } // Define and run rustup command fn rustup(command: &str, target: &str) { - run_command( - "rustup", - &[command, "add", target], - "Failed to run rustup", - "Failed to wait for rustup child process", - ) + run_command( + "rustup", + &[command, "add", target], + "Failed to run rustup", + "Failed to wait for rustup child process", + ) } // Define and run a cargo command fn run_cargo(command: &str, params: Params, error: &str) { - // Print cargo command - println!("\ncargo {} {}\n", command, params); - - // Run cargo - let cargo = Command::new("cargo") - .env("CARGO_INCREMENTAL", "0") - .arg(command) - .args(params.params) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect(error); - - // Handle cargo child process - handle_child_process(cargo, "Failed to wait for cargo child process"); + // Print cargo command + println!("\ncargo {} {}\n", command, params); + + // Run cargo + let cargo = Command::new("cargo") + .env("CARGO_INCREMENTAL", "0") + .arg(command) + .args(params.params) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect(error); + + // Handle cargo child process + handle_child_process(cargo, "Failed to wait for cargo child process"); } // Run cargo build command fn cargo_build(params: Params) { - // Run cargo build - run_cargo( - "build", - params + "--color=always", - "Failed to run cargo build", - ); + // Run cargo build + run_cargo( + "build", + params + "--color=always", + "Failed to run cargo build", + ); } // Run cargo install command fn cargo_install(params: Params) { - // Run cargo install - run_cargo( - "install", - params + "--color=always", - "Failed to run cargo install", - ); + // Run cargo install + run_cargo( + "install", + params + "--color=always", + "Failed to run cargo install", + ); } // Run cargo test command fn cargo_test(params: Params) { - // Run cargo test - run_cargo( - "test", - params + "--color=always" + "--" + "--color=always", - "Failed to run cargo test", - ); + // Run cargo test + run_cargo( + "test", + params + "--color=always" + "--" + "--color=always", + "Failed to run cargo test", + ); } // Run cargo fmt command fn cargo_fmt() { - // Run cargo fmt - run_cargo( - "fmt", - ["--check", "--all", "--", "--color=always"].into(), - "Failed to run cargo fmt", - ); + // Run cargo fmt + run_cargo( + "fmt", + ["--check", "--all", "--", "--color=always"].into(), + "Failed to run cargo fmt", + ); } // Run cargo clippy command fn cargo_clippy() { - if std::env::var("CI_RUN").is_ok() { - return; - } - // Run cargo clippy - run_cargo( - "clippy", - ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), - "Failed to run cargo clippy", - ); + if std::env::var("CI_RUN").is_ok() { + return; + } + // Run cargo clippy + run_cargo( + "clippy", + ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), + "Failed to run cargo clippy", + ); } // Run cargo doc command fn cargo_doc(params: Params) { - // Run cargo doc - run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); + // Run cargo doc + run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); } // Build and test a crate in a no_std environment fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N]) { - println!("\nRun checks for `{}` crate", crate_name); - - // Run cargo build --no-default-features - cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); - - // Run cargo test --no-default-features - cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); - - // Run cargo build --no-default-features --target wasm32-unknown-unknowns - cargo_build( - Params::from([ - "-p", - crate_name, - "--no-default-features", - "--target", - WASM32_TARGET, - ]) + extra_args, - ); - - // Run cargo build --no-default-features --target thumbv7m-none-eabi - cargo_build( - Params::from([ - "-p", - crate_name, - "--no-default-features", - "--target", - ARM_TARGET, - ]) + extra_args, - ); + println!("\nRun checks for `{}` crate", crate_name); + + // Run cargo build --no-default-features + cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); + + // Run cargo test --no-default-features + cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); + + // Run cargo build --no-default-features --target wasm32-unknown-unknowns + cargo_build( + Params::from([ + "-p", + crate_name, + "--no-default-features", + "--target", + WASM32_TARGET, + ]) + extra_args, + ); + + // Run cargo build --no-default-features --target thumbv7m-none-eabi + cargo_build( + Params::from([ + "-p", + crate_name, + "--no-default-features", + "--target", + ARM_TARGET, + ]) + extra_args, + ); } // Setup code coverage fn setup_coverage() { - // Install llvm-tools-preview - rustup("component", "llvm-tools-preview"); + // Install llvm-tools-preview + rustup("component", "llvm-tools-preview"); - // Set coverage environment variables - env::set_var("RUSTFLAGS", "-Cinstrument-coverage"); - env::set_var("LLVM_PROFILE_FILE", "burn-%p-%m.profraw"); + // Set coverage environment variables + env::set_var("RUSTFLAGS", "-Cinstrument-coverage"); + env::set_var("LLVM_PROFILE_FILE", "burn-%p-%m.profraw"); } // Run grcov to produce lcov.info fn run_grcov() { - // grcov arguments - #[rustfmt::skip] + // grcov arguments + #[rustfmt::skip] let args = [ ".", "--binary-path", "./target/debug/", @@ -191,245 +191,245 @@ fn run_grcov() { "-o", "lcov.info", ]; - run_command( - "grcov", - &args, - "Failed to run grcov", - "Failed to wait for grcov child process", - ); + run_command( + "grcov", + &args, + "Failed to run grcov", + "Failed to wait for grcov child process", + ); } // Run no_std checks fn no_std_checks() { - println!("Checks for no_std environment...\n\n"); - - // Install wasm32 target - rustup("target", WASM32_TARGET); - - // Install ARM target - rustup("target", ARM_TARGET); - - // Run checks for the following crates - build_and_test_no_std("burn", []); - build_and_test_no_std("burn-core", []); - build_and_test_no_std( - "burn-compute", - ["--features", "channel-mutex storage-bytes"], - ); - build_and_test_no_std("burn-common", []); - build_and_test_no_std("burn-tensor", []); - build_and_test_no_std("burn-ndarray", []); - build_and_test_no_std("burn-no-std-tests", []); + println!("Checks for no_std environment...\n\n"); + + // Install wasm32 target + rustup("target", WASM32_TARGET); + + // Install ARM target + rustup("target", ARM_TARGET); + + // Run checks for the following crates + build_and_test_no_std("burn", []); + build_and_test_no_std("burn-core", []); + build_and_test_no_std( + "burn-compute", + ["--features", "channel-mutex storage-bytes"], + ); + build_and_test_no_std("burn-common", []); + build_and_test_no_std("burn-tensor", []); + build_and_test_no_std("burn-ndarray", []); + build_and_test_no_std("burn-no-std-tests", []); } // Test burn-core with tch and wgpu backend fn burn_core_std() { - println!("\n\nRun checks for burn-core crate with tch and wgpu backend"); + println!("\n\nRun checks for burn-core crate with tch and wgpu backend"); - // Run cargo test --features test-tch - cargo_test(["-p", "burn-core", "--features", "test-tch"].into()); + // Run cargo test --features test-tch + cargo_test(["-p", "burn-core", "--features", "test-tch"].into()); - // Run cargo test --features test-wgpu - cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into()); + // Run cargo test --features test-wgpu + cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into()); } // Test burn-dataset features fn burn_dataset_features_std() { - println!("\n\nRun checks for burn-dataset features"); + println!("\n\nRun checks for burn-dataset features"); - // Run cargo build --all-features - cargo_build(["-p", "burn-dataset", "--all-features"].into()); + // Run cargo build --all-features + cargo_build(["-p", "burn-dataset", "--all-features"].into()); - // Run cargo test --all-features - cargo_test(["-p", "burn-dataset", "--all-features"].into()); + // Run cargo test --all-features + cargo_test(["-p", "burn-dataset", "--all-features"].into()); - // Run cargo doc --all-features - cargo_doc(["-p", "burn-dataset", "--all-features"].into()); + // Run cargo doc --all-features + cargo_doc(["-p", "burn-dataset", "--all-features"].into()); } fn std_checks() { - // Set RUSTDOCFLAGS environment variable to treat warnings as errors - // for the documentation build - env::set_var("RUSTDOCFLAGS", "-D warnings"); + // Set RUSTDOCFLAGS environment variable to treat warnings as errors + // for the documentation build + env::set_var("RUSTDOCFLAGS", "-D warnings"); - // Check if COVERAGE environment variable is set - let is_coverage = std::env::var("COVERAGE").is_ok(); + // Check if COVERAGE environment variable is set + let is_coverage = std::env::var("COVERAGE").is_ok(); - println!("Running std checks"); + println!("Running std checks"); - // Check format - cargo_fmt(); + // Check format + cargo_fmt(); - // Check clippy lints - cargo_clippy(); + // Check clippy lints + cargo_clippy(); - // Build each workspace - cargo_build(["--workspace", "--exclude=xtask"].into()); + // Build each workspace + cargo_build(["--workspace", "--exclude=xtask"].into()); - // Produce documentation for each workspace - cargo_doc(["--workspace"].into()); + // Produce documentation for each workspace + cargo_doc(["--workspace"].into()); - // Setup code coverage - if is_coverage { - setup_coverage(); - } + // Setup code coverage + if is_coverage { + setup_coverage(); + } - // Test each workspace - cargo_test(["--workspace"].into()); + // Test each workspace + cargo_test(["--workspace"].into()); - // Test burn-dataset features - burn_dataset_features_std(); + // Test burn-dataset features + burn_dataset_features_std(); - // Test burn-core with tch and wgpu backend - burn_core_std(); + // Test burn-core with tch and wgpu backend + burn_core_std(); - // Run grcov and produce lcov.info - if is_coverage { - run_grcov(); - } + // Run grcov and produce lcov.info + if is_coverage { + run_grcov(); + } } fn check_typos() { - // This path defines where typos-cl is installed on different - // operating systems. - let typos_cli_path = std::env::var("CARGO_HOME") - .map(|v| std::path::Path::new(&v).join("bin/typos-cli")) - .unwrap(); - - // Do not run cargo install on CI to speed up the computation. - // Check whether the file has been installed on - if std::env::var("CI_RUN").is_err() && !typos_cli_path.exists() { - // Install typos-cli - cargo_install(["typos-cli", "--version", "1.16.5"].into()); - } - - println!("Running typos check \n\n"); - - // Run typos command as child process - let typos = Command::new("typos") - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to run typos"); - - // Handle typos child process - handle_child_process(typos, "Failed to wait for typos child process"); -} - -fn check_examples() { - println!("Checking examples compile \n\n"); - - std::fs::read_dir("examples").unwrap().for_each(|dir| { - let dir = dir.unwrap(); - let path = dir.path(); - // Skip if not a directory - if !path.is_dir() { - return; - } - if path.file_name().unwrap().to_str().unwrap() == "notebook" { - // not a crate - return; + // This path defines where typos-cl is installed on different + // operating systems. + let typos_cli_path = std::env::var("CARGO_HOME") + .map(|v| std::path::Path::new(&v).join("bin/typos-cli")) + .unwrap(); + + // Do not run cargo install on CI to speed up the computation. + // Check whether the file has been installed on + if std::env::var("CI_RUN").is_err() && !typos_cli_path.exists() { + // Install typos-cli + cargo_install(["typos-cli", "--version", "1.16.5"].into()); } - let path = path.to_str().unwrap(); - println!("Checking {path} \n\n"); - - let child = Command::new("cargo") - .arg("check") - .arg("--examples") - .current_dir(dir.path()) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to check examples"); + + println!("Running typos check \n\n"); + + // Run typos command as child process + let typos = Command::new("typos") + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect("Failed to run typos"); // Handle typos child process - handle_child_process(child, "Failed to wait for examples child process"); - }); + handle_child_process(typos, "Failed to wait for typos child process"); +} + +fn check_examples() { + println!("Checking examples compile \n\n"); + + std::fs::read_dir("examples").unwrap().for_each(|dir| { + let dir = dir.unwrap(); + let path = dir.path(); + // Skip if not a directory + if !path.is_dir() { + return; + } + if path.file_name().unwrap().to_str().unwrap() == "notebook" { + // not a crate + return; + } + let path = path.to_str().unwrap(); + println!("Checking {path} \n\n"); + + let child = Command::new("cargo") + .arg("check") + .arg("--examples") + .current_dir(dir.path()) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect("Failed to check examples"); + + // Handle typos child process + handle_child_process(child, "Failed to wait for examples child process"); + }); } #[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)] pub enum CheckType { - /// Run all checks. - #[default] - All, - /// Run `std` environment checks - Std, - /// Run `no-std` environment checks - NoStd, - /// Check for typos - Typos, - /// Test the examples - Examples, + /// Run all checks. + #[default] + All, + /// Run `std` environment checks + Std, + /// Run `no-std` environment checks + NoStd, + /// Check for typos + Typos, + /// Test the examples + Examples, } pub fn run(env: CheckType) -> anyhow::Result<()> { - // Start time measurement - let start = Instant::now(); - - // The environment can assume ONLY "std", "no_std", "typos", "examples" - // as values. - // - // Depending on the input argument, the respective environment checks - // are run. - // - // If no environment has been passed, run all checks. - match env { - CheckType::Std => std_checks(), - CheckType::NoStd => no_std_checks(), - CheckType::Typos => check_typos(), - CheckType::Examples => check_examples(), - CheckType::All => { - /* Run all checks */ - check_typos(); - std_checks(); - no_std_checks(); - check_examples(); + // Start time measurement + let start = Instant::now(); + + // The environment can assume ONLY "std", "no_std", "typos", "examples" + // as values. + // + // Depending on the input argument, the respective environment checks + // are run. + // + // If no environment has been passed, run all checks. + match env { + CheckType::Std => std_checks(), + CheckType::NoStd => no_std_checks(), + CheckType::Typos => check_typos(), + CheckType::Examples => check_examples(), + CheckType::All => { + /* Run all checks */ + check_typos(); + std_checks(); + no_std_checks(); + check_examples(); + } } - } - // Stop time measurement - // - // Compute runtime duration - let duration = start.elapsed(); + // Stop time measurement + // + // Compute runtime duration + let duration = start.elapsed(); - // Print duration - println!("Time elapsed for the current execution: {:?}", duration); + // Print duration + println!("Time elapsed for the current execution: {:?}", duration); - Ok(()) + Ok(()) } struct Params { - params: Vec, + params: Vec, } impl From<[&str; N]> for Params { - fn from(value: [&str; N]) -> Self { - Self { - params: value.iter().map(|v| v.to_string()).collect(), + fn from(value: [&str; N]) -> Self { + Self { + params: value.iter().map(|v| v.to_string()).collect(), + } } - } } impl From<&str> for Params { - fn from(value: &str) -> Self { - Self { - params: vec![value.to_string()], + fn from(value: &str) -> Self { + Self { + params: vec![value.to_string()], + } } - } } impl std::fmt::Display for Params { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.params.join(" ").as_str()) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.params.join(" ").as_str()) + } } impl> std::ops::Add for Params { - type Output = Params; + type Output = Params; - fn add(mut self, rhs: Rhs) -> Self::Output { - let rhs: Params = rhs.into(); - self.params.extend(rhs.params); - self - } + fn add(mut self, rhs: Rhs) -> Self::Output { + let rhs: Params = rhs.into(); + self.params.extend(rhs.params); + self + } }