Skip to content

Commit

Permalink
extra axis added to the right
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed May 24, 2023
1 parent d7bfa9a commit 1663cc0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion nnef/src/ops/core/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ fn matmul_load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> T
let b: OutletId = invocation.named_arg_as(builder, "B")?;
let axes: TVec<usize> = invocation.named_arg_as(builder, "axes")?;
let fact = builder.model.outlet_fact(a)?;
builder.wire(EinSum::new(from_legacy_axes_spec(&axes, fact.rank())?, fact.datum_type), &[a, b])
let axes = from_legacy_axes_spec(&axes, fact.rank())?;
builder.wire(EinSum::new(axes, fact.datum_type), &[a, b])
}
2 changes: 1 addition & 1 deletion nnef/src/ops/nnef/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ pub fn matmul(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tr
let name = &*invocation.invocation.id.0;
if a_dt.is_quantized() || b_dt.is_quantized() {
for input in 0..7 {
axes = axes.with_extra_input(input)?;
axes = axes.with_extra_input(2 + input)?;
}
let accum_dt = DatumType::QI32(QParams::ZpScale {
scale: a_dt.zp_scale().1 * b_dt.zp_scale().1,
Expand Down

0 comments on commit 1663cc0

Please sign in to comment.