From 42bfcc9d131a91a1305399a73fadee637172203b Mon Sep 17 00:00:00 2001 From: hyperfraise Date: Thu, 21 Sep 2023 19:11:38 +0200 Subject: [PATCH] Add support for 3d convolution --- DxDispatch/CMakeLists.txt | 3 +- ...nvolution.json => dml_convolution_2d.json} | 0 DxDispatch/models/dml_convolution_3d.json | 60 +++++++++++++++++++ Libraries/DirectMLX.h | 16 +++-- 4 files changed, 72 insertions(+), 7 deletions(-) rename DxDispatch/models/{dml_convolution.json => dml_convolution_2d.json} (100%) create mode 100644 DxDispatch/models/dml_convolution_3d.json diff --git a/DxDispatch/CMakeLists.txt b/DxDispatch/CMakeLists.txt index e9b0595a..cfde811a 100644 --- a/DxDispatch/CMakeLists.txt +++ b/DxDispatch/CMakeLists.txt @@ -251,7 +251,8 @@ if(DXD_TESTS) set_tests_properties(test_${model_name} PROPERTIES PASS_REGULAR_EXPRESSION ${expected_output}) endfunction() - model_test(dml_convolution "Resource 'output': 6, 8, 12, 14") + model_test(dml_convolution_2d "Resource 'output': 6, 8, 12, 14") + model_test(dml_convolution_3d "Resource 'output': 4, 4, 4, 4, 4, 4, 4, 4") model_test(dml_cumulative_product "Resource 'Out': 2, 8, 64, 192") model_test(dml_element_wise_add "Resource 'Out': 6, 10, -2") model_test(dml_element_wise_add_npy "Resource 'Out': 2, 4, 6, 8, 10, 12") diff --git a/DxDispatch/models/dml_convolution.json b/DxDispatch/models/dml_convolution_2d.json similarity index 100% rename from DxDispatch/models/dml_convolution.json rename to DxDispatch/models/dml_convolution_2d.json diff --git a/DxDispatch/models/dml_convolution_3d.json b/DxDispatch/models/dml_convolution_3d.json new file mode 100644 index 00000000..83344e64 --- /dev/null +++ b/DxDispatch/models/dml_convolution_3d.json @@ -0,0 +1,60 @@ +{ + "$schema": "./_schema.json", + + "resources": + { + "input": + { + "initialValuesDataType": "FLOAT32", + "initialValues": { "valueCount": 27, "valueStart": 1, "valueDelta": 0 } + }, + "filter": + { + "initialValuesDataType": "FLOAT32", + "initialValues": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + }, + "output": + { + "initialValuesDataType": "FLOAT32", + "initialValues": { "valueCount": 4, "value": 0 } + } + }, + + "dispatchables": + { + "conv3d": + { + "type": "DML_OPERATOR_CONVOLUTION", + "desc": + { + "InputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,3,3,3] }, + "FilterTensor": { "DataType": "FLOAT32", "Sizes": [1,1,2,2,2] }, + "OutputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,2,2,2] }, + "Mode": "DML_CONVOLUTION_MODE_CROSS_CORRELATION", + "Direction": "DML_CONVOLUTION_DIRECTION_FORWARD", + "DimensionCount": 3, + "Strides": [1,1,1], + "Dilations": [1,1,1], + "StartPadding": [0,0,0], + "EndPadding": [0,0,0], + "OutputPadding": [0,0,0], + "GroupCount": 1 + } + } + }, + + "commands": + [ + { + "type": "dispatch", + "dispatchable": "conv3d", + "bindings": + { + "InputTensor": "input", + "FilterTensor": "filter", + "OutputTensor": "output" + } + }, + { "type": "print", "resource": "output" } + ] +} \ No newline at end of file diff --git a/Libraries/DirectMLX.h b/Libraries/DirectMLX.h index ff283322..ef4dc2cc 100644 --- a/Libraries/DirectMLX.h +++ b/Libraries/DirectMLX.h @@ -3547,11 +3547,13 @@ namespace dml uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); // todo: support 1d convolution? - assert(dimensionCount == 4); + assert(dimensionCount == 4 || dimensionCount == 5); uint32_t spatialDimensionCount = dimensionCount - 2; - const uint32_t defaultStridesAndDilations[2] = { 1, 1 }; - const uint32_t defaultPadding[2] = { 0, 0 }; + // If the spatial dimension count is 2, we'll just use the first two elements by setting + // DimensionCount = 2 in the desc + const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 }; + const uint32_t defaultPadding[3] = { 0, 0, 0 }; assert(strides.empty() || strides.size() == spatialDimensionCount); assert(dilations.empty() || dilations.size() == spatialDimensionCount); @@ -3659,11 +3661,13 @@ namespace dml uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); // todo: suppord 1d convolution? - assert(dimensionCount == 4); + assert(dimensionCount == 4 || dimensionCount == 5); const uint32_t spatialDimensionCount = dimensionCount - 2; - const uint32_t defaultStridesAndDilations[2] = { 1, 1 }; - const uint32_t defaultPadding[2] = { 1, 1 }; + // If the spatial dimension count is 2, we'll just use the first two elements by setting + // DimensionCount = 2 in the desc + const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 }; + const uint32_t defaultPadding[3] = { 0, 0, 0 }; assert(strides.empty() || strides.size() == spatialDimensionCount); assert(dilations.empty() || dilations.size() == spatialDimensionCount);