Skip to content

Commit

Permalink
vsmigx: allow fp16 input & output (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
abihf authored and WolframRhodium committed Mar 5, 2024
1 parent 3bebe61 commit d414e90
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions vsmigx/vs_migraphx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ static void setDimensions(
std::unique_ptr<VSVideoInfo> & vi,
const std::array<int, 4> & input_shape,
const std::array<int, 4> & output_shape,
int bitsPerSample,
VSCore * core,
const VSAPI * vsapi
) noexcept {
Expand All @@ -71,9 +72,9 @@ static void setDimensions(
vi->width *= output_shape[3] / input_shape[3];

if (output_shape[1] == 1) {
vi->format = vsapi->registerFormat(cmGray, stFloat, 32, 0, 0, core);
vi->format = vsapi->registerFormat(cmGray, stFloat, bitsPerSample, 0, 0, core);
} else if (output_shape[1] == 3) {
vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core);
vi->format = vsapi->registerFormat(cmRGB, stFloat, bitsPerSample, 0, 0, core);
}
}

Expand Down Expand Up @@ -723,8 +724,8 @@ static void VS_CC vsMIGXCreate(
}
migraphx_shape_datatype_t type;
checkError(migraphx_shape_type(&type, input_shape));
if (type != migraphx_shape_float_type) {
return set_error("input type must be float");
if (type != migraphx_shape_float_type && type != migraphx_shape_half_type) {
return set_error("input type must be float or half");
}
const size_t * lengths;
size_t ndim;
Expand Down Expand Up @@ -769,6 +770,7 @@ static void VS_CC vsMIGXCreate(

size_t output_size;
const_migraphx_shape_t output_shape;
int bitsPerSample;
{
migraphx_shapes_t output_shapes;
checkError(migraphx_program_get_output_shapes(&output_shapes, d->program));
Expand All @@ -786,9 +788,10 @@ static void VS_CC vsMIGXCreate(
}
migraphx_shape_datatype_t type;
checkError(migraphx_shape_type(&type, output_shape));
if (type != migraphx_shape_float_type) {
return set_error("output type must be float");
if (type != migraphx_shape_float_type && type != migraphx_shape_half_type) {
return set_error("output type must be float or half");
}
bitsPerSample = type == migraphx_shape_float_type ? 32 : 16;
const size_t * lengths;
size_t ndim;
checkError(migraphx_shape_lengths(&lengths, &ndim, output_shape));
Expand Down Expand Up @@ -838,7 +841,7 @@ static void VS_CC vsMIGXCreate(
return set_error("\"num_streams\" must be 1 for now");
}

setDimensions(d->out_vi, d->src_tile_shape, d->dst_tile_shape, core, vsapi);
setDimensions(d->out_vi, d->src_tile_shape, d->dst_tile_shape, bitsPerSample, core, vsapi);

// per-stream context
d->instances.reserve(num_streams);
Expand Down

0 comments on commit d414e90

Please sign in to comment.