Skip to content

Commit

Permalink
Add iOS tests for cosine similarity and identity methods
Browse files Browse the repository at this point in the history
  • Loading branch information
erksch committed Jun 24, 2022
1 parent 413dcbd commit 85bfb4b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
9 changes: 9 additions & 0 deletions build_dummy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.cos_sim = nn.CosineSimilarity()

@torch.jit.export
def inference(self, x: torch.Tensor):
Expand All @@ -16,6 +17,14 @@ def inference(self, x: torch.Tensor):
def inference_dict(self, x: Dict[str, torch.Tensor]):
return self.linear(x["x"])

@torch.jit.export
def identity(self, x: torch.Tensor):
return x.float()

@torch.jit.export
def similarity(self, x: torch.Tensor, y: torch.Tensor):
return self.cos_sim(x, y)

def forward(self, x):
return self.linear(x)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,54 @@ class TorchModuleIOSTest {
assertEquals(10, output.data.size)
assertContentEquals(longArrayOf(1, 10), output.shape)
}

@Test
fun testIdentityLong() {
val module = TorchModule(localModulePath)
val data = longArrayOf(3L, 2L, 0L, 0L, 1L, 6L)
val shape = longArrayOf(2, 3)
val tensor = LongTensor(data, shape)
val output = module.runMethod(
"identity",
listOf(tensor)
)
assertEquals(data.toList().map { it.toFloat() }, output.data.toList())
assertEquals(shape.toList(), output.shape.toList())
}

@Test
fun testIdentity() {
val module = TorchModule(localModulePath)
val data = floatArrayOf(0.86F, 1.36F, 0.51F, 0.45F, 0.37F, 1.84F)
val shape = longArrayOf(2, 3)
val tensor = FloatTensor(data, shape)
val output = module.runMethod(
"identity",
listOf(tensor)
)
assertEquals(data.toList(), output.data.toList())
assertEquals(shape.toList(), output.shape.toList())
}

@Test
fun testSimilarity() {
val module = TorchModule(localModulePath)
val output = module.runMethod(
"similarity",
listOf(
FloatTensor(
floatArrayOf(0.86F, 1.36F, 0.51F, 0.45F, 0.37F, 1.84F),
longArrayOf(2, 3),
),
FloatTensor(
floatArrayOf(1.02F, 0.17F, 1.99F, 1.02F, 0.82F, 1.33F),
longArrayOf(2, 3),
)
)
)

assertEquals(listOf(2L), output.shape.toList())
assertEquals(0.56F, output.data[0], 0.01F)
assertEquals(0.89F, output.data[1], 0.01F)
}
}

0 comments on commit 85bfb4b

Please sign in to comment.