diff --git a/android/.gitignore b/android/.gitignore
new file mode 100644
index 00000000000..b4d6d617645
--- /dev/null
+++ b/android/.gitignore
@@ -0,0 +1,8 @@
+local.properties
+**/*.iml
+.gradle
+gradlew*
+gradle/wrapper
+.idea/*
+.externalNativeBuild
+build
diff --git a/android/build.gradle b/android/build.gradle
new file mode 100644
index 00000000000..c3393e047e4
--- /dev/null
+++ b/android/build.gradle
@@ -0,0 +1,43 @@
+allprojects {
+ buildscript {
+ ext {
+ minSdkVersion = 21
+ targetSdkVersion = 28
+ compileSdkVersion = 28
+ buildToolsVersion = '28.0.3'
+
+ coreVersion = "1.2.0"
+ extJUnitVersion = "1.1.1"
+ runnerVersion = "1.2.0"
+ rulesVersion = "1.2.0"
+ junitVersion = "4.12"
+
+ androidSupportAppCompatV7Version = "28.0.0"
+ fbjniJavaOnlyVersion = "0.0.3"
+ soLoaderNativeLoaderVersion = "0.8.0"
+ }
+
+ repositories {
+ google()
+ mavenCentral()
+ jcenter()
+ }
+
+ dependencies {
+ classpath 'com.android.tools.build:gradle:3.3.2'
+ classpath "com.jfrog.bintray.gradle:gradle-bintray-plugin:${GRADLE_BINTRAY_PLUGIN_VERSION}"
+ classpath "com.github.dcendents:android-maven-gradle-plugin:${ANDROID_MAVEN_GRADLE_PLUGIN_VERSION}"
+ classpath "org.jfrog.buildinfo:build-info-extractor-gradle:4.9.8"
+ }
+ }
+
+ repositories {
+ google()
+ jcenter()
+ }
+}
+
+ext.deps = [
+ jsr305: 'com.google.code.findbugs:jsr305:3.0.1',
+]
+
diff --git a/android/gradle.properties b/android/gradle.properties
new file mode 100644
index 00000000000..bd5ed6bbd98
--- /dev/null
+++ b/android/gradle.properties
@@ -0,0 +1,28 @@
+ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
+
+VERSION_NAME=0.0.1-SNAPSHOT
+GROUP=org.pytorch
+MAVEN_GROUP=org.pytorch
+POM_URL=https://github.com/pytorch/vision/
+POM_SCM_URL=https://github.com/pytorch/vision.git
+POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/vision
+POM_SCM_DEV_CONNECTION=scm:git:git@github.com:pytorch/vision.git
+POM_LICENSE_NAME=BSD 3-Clause
+POM_LICENSE_URL=https://github.com/pytorch/vision/blob/master/LICENSE
+POM_ISSUES_URL=https://github.com/pytorch/vision/issues
+POM_LICENSE_DIST=repo
+POM_DEVELOPER_ID=pytorch
+POM_DEVELOPER_NAME=pytorch
+syncWithMavenCentral=true
+
+GRADLE_BINTRAY_PLUGIN_VERSION=1.8.0
+GRADLE_VERSIONS_PLUGIN_VERSION=0.15.0
+ANDROID_MAVEN_GRADLE_PLUGIN_VERSION=2.1
+
+# Gradle internals
+android.useAndroidX=true
+android.enableJetifier=true
+
+testAppAllVariantsEnabled=false
+
+org.gradle.jvmargs=-Xmx4096m
diff --git a/android/gradle_scripts/android_maven_install.gradle b/android/gradle_scripts/android_maven_install.gradle
new file mode 100644
index 00000000000..ce80472d79e
--- /dev/null
+++ b/android/gradle_scripts/android_maven_install.gradle
@@ -0,0 +1,38 @@
+apply plugin: 'com.github.dcendents.android-maven'
+
+version = VERSION_NAME
+group = GROUP
+project.archivesBaseName = POM_ARTIFACT_ID
+
+install {
+ repositories.mavenInstaller {
+ pom.project {
+ name POM_NAME
+ artifactId POM_ARTIFACT_ID
+ packaging POM_PACKAGING
+ description POM_DESCRIPTION
+ url projectUrl
+
+ scm {
+ url scmUrl
+ connection scmConnection
+ developerConnection scmDeveloperConnection
+ }
+
+ licenses {
+ license {
+ name = POM_LICENSE_NAME
+ url = POM_LICENSE_URL
+ distribution = POM_LICENSE_DIST
+ }
+ }
+
+ developers {
+ developer {
+ id developerId
+ name developerName
+ }
+ }
+ }
+ }
+}
diff --git a/android/gradle_scripts/android_tasks.gradle b/android/gradle_scripts/android_tasks.gradle
new file mode 100644
index 00000000000..ca188ac72d0
--- /dev/null
+++ b/android/gradle_scripts/android_tasks.gradle
@@ -0,0 +1,95 @@
+
+import java.nio.file.Files
+import java.nio.file.Paths
+import java.io.FileOutputStream
+import java.util.zip.ZipFile
+
+// Android tasks for Javadoc and sources.jar generation
+
+afterEvaluate { project ->
+ if (POM_PACKAGING == 'aar') {
+ task androidJavadoc(type: Javadoc, dependsOn: assembleDebug) {
+ source += files(android.sourceSets.main.java.srcDirs)
+ failOnError false
+ // This task will try to compile *everything* it finds in the above directory and
+ // will choke on text files it doesn't understand.
+ exclude '**/BUCK'
+ exclude '**/*.md'
+ }
+
+ task androidJavadocJar(type: Jar, dependsOn: androidJavadoc) {
+ classifier = 'javadoc'
+ from androidJavadoc.destinationDir
+ }
+
+ task androidSourcesJar(type: Jar) {
+ classifier = 'sources'
+ from android.sourceSets.main.java.srcDirs
+ }
+
+ android.libraryVariants.all { variant ->
+ def name = variant.name.capitalize()
+ task "jar${name}"(type: Jar, dependsOn: variant.javaCompileProvider) {
+ from variant.javaCompileProvider.get().destinationDir
+ }
+
+ androidJavadoc.doFirst {
+ classpath += files(android.bootClasspath)
+ classpath += files(variant.javaCompileProvider.get().classpath.files)
+ // This is generated by `assembleDebug` and holds the JARs generated by the APT.
+ classpath += fileTree(dir: "$buildDir/intermediates/bundles/debug/", include: '**/*.jar')
+
+ // Process AAR dependencies
+ def aarDependencies = classpath.filter { it.name.endsWith('.aar') }
+ classpath -= aarDependencies
+ aarDependencies.each { aar ->
+ // Extract classes.jar from the AAR dependency, and add it to the javadoc classpath
+ def outputPath = "$buildDir/tmp/aarJar/${aar.name.replace('.aar', '.jar')}"
+ classpath += files(outputPath)
+
+ // Use a task so the actual extraction only happens before the javadoc task is run
+ dependsOn task(name: "extract ${aar.name}").doLast {
+ extractEntry(aar, 'classes.jar', outputPath)
+ }
+ }
+ }
+ }
+
+ artifacts.add('archives', androidJavadocJar)
+ artifacts.add('archives', androidSourcesJar)
+ }
+
+ if (POM_PACKAGING == 'jar') {
+ task javadocJar(type: Jar, dependsOn: javadoc) {
+ classifier = 'javadoc'
+ from javadoc.destinationDir
+ }
+
+ task sourcesJar(type: Jar, dependsOn: classes) {
+ classifier = 'sources'
+ from sourceSets.main.allSource
+ }
+
+ artifacts.add('archives', javadocJar)
+ artifacts.add('archives', sourcesJar)
+ }
+}
+
+// Utility method to extract only one entry in a zip file
+private def extractEntry(archive, entryPath, outputPath) {
+ if (!archive.exists()) {
+ throw new GradleException("archive $archive not found")
+ }
+
+ def zip = new ZipFile(archive)
+ zip.entries().each {
+ if (it.name == entryPath) {
+ def path = Paths.get(outputPath)
+ if (!Files.exists(path)) {
+ Files.createDirectories(path.getParent())
+ Files.copy(zip.getInputStream(it), path)
+ }
+ }
+ }
+ zip.close()
+}
diff --git a/android/gradle_scripts/bintray.gradle b/android/gradle_scripts/bintray.gradle
new file mode 100644
index 00000000000..c20073964f7
--- /dev/null
+++ b/android/gradle_scripts/bintray.gradle
@@ -0,0 +1,64 @@
+apply plugin: 'com.jfrog.bintray'
+
+def getBintrayUsername() {
+ return project.hasProperty('bintrayUsername') ? property('bintrayUsername') : System.getenv('BINTRAY_USERNAME')
+}
+
+def getBintrayApiKey() {
+ return project.hasProperty('bintrayApiKey') ? property('bintrayApiKey') : System.getenv('BINTRAY_API_KEY')
+}
+
+def getBintrayGpgPassword() {
+ return project.hasProperty('bintrayGpgPassword') ? property('bintrayGpgPassword') : System.getenv('BINTRAY_GPG_PASSWORD')
+}
+
+def getMavenCentralUsername() {
+ return project.hasProperty('mavenCentralUsername') ? property('mavenCentralUsername') : System.getenv('MAVEN_CENTRAL_USERNAME')
+}
+
+def getMavenCentralPassword() {
+ return project.hasProperty('mavenCentralPassword') ? property('mavenCentralPassword') : System.getenv('MAVEN_CENTRAL_PASSWORD')
+}
+
+def shouldSyncWithMavenCentral() {
+ return project.hasProperty('syncWithMavenCentral') ? property('syncWithMavenCentral').toBoolean() : false
+}
+
+def dryRunOnly() {
+ return project.hasProperty('dryRun') ? property('dryRun').toBoolean() : false
+}
+
+bintray {
+ user = getBintrayUsername()
+ key = getBintrayApiKey()
+ override = false
+ configurations = ['archives']
+ pkg {
+ repo = bintrayRepo
+ userOrg = bintrayUserOrg
+ name = bintrayName
+ desc = bintrayDescription
+ websiteUrl = projectUrl
+ issueTrackerUrl = issuesUrl
+ vcsUrl = scmUrl
+ licenses = [ POM_LICENSE_NAME ]
+ dryRun = dryRunOnly()
+ override = false
+ publish = true
+ publicDownloadNumbers = true
+ version {
+ name = versionName
+ desc = bintrayDescription
+ gpg {
+ sign = true
+ passphrase = getBintrayGpgPassword()
+ }
+ mavenCentralSync {
+ sync = shouldSyncWithMavenCentral()
+ user = getMavenCentralUsername()
+ password = getMavenCentralPassword()
+ close = '1' // If set to 0, you have to manually click release
+ }
+ }
+ }
+}
diff --git a/android/gradle_scripts/gradle_maven_push.gradle b/android/gradle_scripts/gradle_maven_push.gradle
new file mode 100644
index 00000000000..5fdd8fbc6a0
--- /dev/null
+++ b/android/gradle_scripts/gradle_maven_push.gradle
@@ -0,0 +1,99 @@
+apply plugin: 'signing'
+
+version = VERSION_NAME
+group = MAVEN_GROUP
+
+def isReleaseBuild() {
+ return !VERSION_NAME.contains('SNAPSHOT')
+}
+
+def getReleaseRepositoryUrl() {
+ return hasProperty('RELEASE_REPOSITORY_URL') ? RELEASE_REPOSITORY_URL
+ : "https://oss.sonatype.org/service/local/staging/deploy/maven2/"
+}
+
+def getSnapshotRepositoryUrl() {
+ return hasProperty('SNAPSHOT_REPOSITORY_URL') ? SNAPSHOT_REPOSITORY_URL
+ : "https://oss.sonatype.org/content/repositories/snapshots/"
+}
+
+def getRepositoryUsername() {
+ return hasProperty('SONATYPE_NEXUS_USERNAME') ? SONATYPE_NEXUS_USERNAME : ""
+}
+
+def getRepositoryPassword() {
+ return hasProperty('SONATYPE_NEXUS_PASSWORD') ? SONATYPE_NEXUS_PASSWORD : ""
+}
+
+def getHttpProxyHost() {
+ return project.properties['systemProp.http.proxyHost']
+}
+
+def getHttpProxyPort() {
+ return project.properties['systemProp.http.proxyPort']
+}
+
+def needProxy() {
+ return (getHttpProxyHost() != null) && (getHttpProxyPort() != null)
+}
+
+afterEvaluate { project ->
+ uploadArchives {
+ repositories {
+ mavenDeployer {
+ beforeDeployment { MavenDeployment deployment -> signing.signPom(deployment) }
+
+ pom.groupId = MAVEN_GROUP
+ pom.artifactId = POM_ARTIFACT_ID
+ pom.version = VERSION_NAME
+
+ repository(url: getReleaseRepositoryUrl()) {
+ authentication(userName: getRepositoryUsername(), password: getRepositoryPassword())
+ if (needProxy()) {
+ proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http')
+ }
+ }
+ snapshotRepository(url: getSnapshotRepositoryUrl()) {
+ authentication(userName: getRepositoryUsername(), password: getRepositoryPassword())
+ if (needProxy()) {
+ proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http')
+ }
+ }
+
+ pom.project {
+ name POM_NAME
+ packaging POM_PACKAGING
+ description POM_DESCRIPTION
+ url POM_URL
+
+ scm {
+ url POM_SCM_URL
+ connection POM_SCM_CONNECTION
+ developerConnection POM_SCM_DEV_CONNECTION
+ }
+
+ licenses {
+ license {
+ name POM_LICENSE_NAME
+ url POM_LICENSE_URL
+ distribution POM_LICENSE_DIST
+ }
+ }
+
+ developers {
+ developer {
+ id POM_DEVELOPER_ID
+ name POM_DEVELOPER_NAME
+ }
+ }
+ }
+ }
+ }
+ }
+
+ signing {
+ required { isReleaseBuild() && gradle.taskGraph.hasTask('uploadArchives') }
+ sign configurations.archives
+ }
+
+}
diff --git a/android/gradle_scripts/release.gradle b/android/gradle_scripts/release.gradle
new file mode 100644
index 00000000000..d4e3aef4a22
--- /dev/null
+++ b/android/gradle_scripts/release.gradle
@@ -0,0 +1,5 @@
+apply from: rootProject.file('gradle_scripts/android_tasks.gradle')
+
+apply from: rootProject.file('gradle_scripts/release_bintray.gradle')
+
+apply from: rootProject.file('gradle_scripts/gradle_maven_push.gradle')
diff --git a/android/gradle_scripts/release_bintray.gradle b/android/gradle_scripts/release_bintray.gradle
new file mode 100644
index 00000000000..9b2af121a94
--- /dev/null
+++ b/android/gradle_scripts/release_bintray.gradle
@@ -0,0 +1,32 @@
+ext {
+ bintrayRepo = 'maven'
+ bintrayUserOrg = 'pytorch'
+ bintrayName = "${GROUP}:${POM_ARTIFACT_ID}"
+ bintrayDescription = POM_DESCRIPTION
+ projectUrl = POM_URL
+ issuesUrl = POM_ISSUES_URL
+ scmUrl = POM_SCM_URL
+ scmConnection = POM_SCM_CONNECTION
+ scmDeveloperConnection = POM_SCM_DEV_CONNECTION
+
+ publishedGroupId = GROUP
+ libraryName = 'torchvision'
+ artifact = 'torchvision'
+
+ developerId = POM_DEVELOPER_ID
+ developerName = POM_DEVELOPER_NAME
+
+ versionName = VERSION_NAME
+
+ projectLicenses = {
+ license = {
+ name = POM_LICENSE_NAME
+ url = POM_LICENSE_URL
+ distribution = POM_LICENSE_DIST
+ }
+ }
+}
+
+apply from: rootProject.file('gradle_scripts/android_maven_install.gradle')
+
+apply from: rootProject.file('gradle_scripts/bintray.gradle')
diff --git a/android/ops/CMakeLists.txt b/android/ops/CMakeLists.txt
new file mode 100644
index 00000000000..ad42adbfa71
--- /dev/null
+++ b/android/ops/CMakeLists.txt
@@ -0,0 +1,47 @@
+cmake_minimum_required(VERSION 3.4.1)
+set(TARGET torchvision_ops)
+project(${TARGET} CXX)
+set(CMAKE_CXX_STANDARD 14)
+
+string(APPEND CMAKE_CXX_FLAGS " -DMOBILE")
+
+set(build_DIR ${CMAKE_SOURCE_DIR}/build)
+set(root_DIR ${CMAKE_CURRENT_LIST_DIR}/..)
+
+file(GLOB VISION_SRCS
+ ../../torchvision/csrc/ops/cpu/*.h
+ ../../torchvision/csrc/ops/cpu/*.cpp
+ ../../torchvision/csrc/ops/*.h
+ ../../torchvision/csrc/ops/*.cpp)
+
+add_library(${TARGET} SHARED
+ ${VISION_SRCS}
+)
+
+file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
+file(GLOB PYTORCH_INCLUDE_DIRS_CSRC "${build_DIR}/pytorch_android*.aar/headers/torch/csrc/api/include")
+file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
+
+target_compile_options(${TARGET} PRIVATE
+ -fexceptions
+)
+
+set(BUILD_SUBDIR ${ANDROID_ABI})
+
+find_library(PYTORCH_LIBRARY pytorch_jni
+ PATHS ${PYTORCH_LINK_DIRS}
+ NO_CMAKE_FIND_ROOT_PATH)
+
+find_library(FBJNI_LIBRARY fbjni
+ PATHS ${PYTORCH_LINK_DIRS}
+ NO_CMAKE_FIND_ROOT_PATH)
+
+target_include_directories(${TARGET} PRIVATE
+ ${PYTORCH_INCLUDE_DIRS}
+ ${PYTORCH_INCLUDE_DIRS_CSRC}
+)
+
+target_link_libraries(${TARGET} PRIVATE
+ ${PYTORCH_LIBRARY}
+ ${FBJNI_LIBRARY}
+)
diff --git a/android/ops/build.gradle b/android/ops/build.gradle
new file mode 100644
index 00000000000..773e09fb280
--- /dev/null
+++ b/android/ops/build.gradle
@@ -0,0 +1,94 @@
+apply plugin: 'com.android.library'
+apply plugin: 'maven'
+
+repositories {
+ jcenter()
+ maven {
+ url "https://oss.sonatype.org/content/repositories/snapshots"
+ }
+ flatDir {
+ dirs 'aars'
+ }
+}
+
+android {
+ configurations {
+ extractForNativeBuild
+ }
+ compileSdkVersion rootProject.compileSdkVersion
+ buildToolsVersion rootProject.buildToolsVersion
+
+
+ defaultConfig {
+ minSdkVersion rootProject.minSdkVersion
+ targetSdkVersion rootProject.targetSdkVersion
+ versionCode 0
+ versionName "0.1"
+
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
+ ndk {
+ abiFilters ABI_FILTERS.split(",")
+ }
+ }
+
+ buildTypes {
+ debug {
+ minifyEnabled false
+ debuggable true
+ }
+ release {
+ minifyEnabled false
+ }
+ }
+
+ externalNativeBuild {
+ cmake {
+ path "CMakeLists.txt"
+ }
+ }
+
+ useLibrary 'android.test.runner'
+ useLibrary 'android.test.base'
+ useLibrary 'android.test.mock'
+}
+
+dependencies {
+ implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version
+
+ implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
+ extractForNativeBuild 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
+
+ // For testing: deps on local aar files
+ //implementation(name: 'pytorch_android-release', ext: 'aar')
+ //extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
+ //implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3'
+}
+
+task extractAARForNativeBuild {
+ doLast {
+ configurations.extractForNativeBuild.files.each {
+ def file = it.absoluteFile
+ copy {
+ from zipTree(file)
+ into "$buildDir/$file.name"
+ include "headers/**"
+ include "jni/**"
+ }
+ }
+ }
+}
+
+tasks.whenTaskAdded { task ->
+ if (task.name.contains('externalNativeBuild')) {
+ task.dependsOn(extractAARForNativeBuild)
+ }
+}
+
+apply from: rootProject.file('gradle_scripts/release.gradle')
+
+task sourcesJar(type: Jar) {
+ from android.sourceSets.main.java.srcDirs
+ classifier = 'sources'
+}
+
+artifacts.add('archives', sourcesJar)
diff --git a/android/ops/gradle.properties b/android/ops/gradle.properties
new file mode 100644
index 00000000000..5a4ea2f3aba
--- /dev/null
+++ b/android/ops/gradle.properties
@@ -0,0 +1,4 @@
+POM_NAME=torchvision ops
+POM_DESCRIPTION=torchvision ops
+POM_ARTIFACT_ID=torchvision_ops
+POM_PACKAGING=aar
diff --git a/android/ops/src/main/AndroidManifest.xml b/android/ops/src/main/AndroidManifest.xml
new file mode 100644
index 00000000000..8ca386493c4
--- /dev/null
+++ b/android/ops/src/main/AndroidManifest.xml
@@ -0,0 +1 @@
+
diff --git a/android/settings.gradle b/android/settings.gradle
new file mode 100644
index 00000000000..6d34eb8d51a
--- /dev/null
+++ b/android/settings.gradle
@@ -0,0 +1,4 @@
+include ':ops', ':test_app'
+
+project(':ops').projectDir = file('ops')
+project(':test_app').projectDir = file('test_app/app')
diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle
new file mode 100644
index 00000000000..e95078c401d
--- /dev/null
+++ b/android/test_app/app/build.gradle
@@ -0,0 +1,138 @@
+apply plugin: 'com.android.application'
+
+repositories {
+ jcenter()
+ maven {
+ url "https://oss.sonatype.org/content/repositories/snapshots"
+ }
+ flatDir {
+ dirs 'aars'
+ }
+}
+
+android {
+ configurations {
+ extractForNativeBuild
+ }
+ compileOptions {
+ sourceCompatibility 1.8
+ targetCompatibility 1.8
+ }
+ compileSdkVersion rootProject.compileSdkVersion
+ buildToolsVersion rootProject.buildToolsVersion
+ defaultConfig {
+ applicationId "org.pytorch.testapp"
+ minSdkVersion rootProject.minSdkVersion
+ targetSdkVersion rootProject.targetSdkVersion
+ versionCode 1
+ versionName "1.0"
+ ndk {
+ abiFilters ABI_FILTERS.split(",")
+ }
+ externalNativeBuild {
+ cmake {
+ abiFilters ABI_FILTERS.split(",")
+ arguments "-DANDROID_STL=c++_shared"
+ }
+ }
+ buildConfigField("String", "MODULE_ASSET_NAME", "\"frcnn_mnetv3.pt\"")
+ buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
+ buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{3, 96, 96}")
+ addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
+ }
+ buildTypes {
+ debug {
+ minifyEnabled false
+ debuggable true
+ }
+ release {
+ minifyEnabled false
+ }
+ }
+ flavorDimensions "model", "activity", "build"
+ productFlavors {
+ frcnnMnetv3 {
+ dimension "model"
+ applicationIdSuffix ".frcnnMnetv3"
+ buildConfigField("String", "MODULE_ASSET_NAME", "\"frcnn_mnetv3.pt\"")
+ addManifestPlaceholders([APP_NAME: "TV_FRCNN_MNETV3"])
+ buildConfigField("String", "LOGCAT_TAG", "\"pytorch-frcnn-mnetv3\"")
+ }
+ camera {
+ dimension "activity"
+ addManifestPlaceholders([APP_NAME: "TV_CAMERA_FRCNN"])
+ addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
+ }
+ base {
+ dimension "activity"
+ }
+ aar {
+ dimension "build"
+ }
+ nightly {
+ dimension "build"
+ }
+ local {
+ dimension "build"
+ }
+ }
+ packagingOptions {
+ doNotStrip '**.so'
+ }
+
+ // Filtering for CI
+ if (!testAppAllVariantsEnabled.toBoolean()) {
+ variantFilter { variant ->
+ def names = variant.flavors*.name
+ if (names.contains("aar")) {
+ setIgnore(true)
+ }
+ }
+ }
+}
+
+tasks.all { task ->
+ // Disable externalNativeBuild for all but nativeBuild variant
+ if (task.name.startsWith('externalNativeBuild')
+ && !task.name.contains('NativeBuild')) {
+ task.enabled = false
+ }
+}
+
+dependencies {
+ implementation 'com.android.support:appcompat-v7:28.0.0'
+ implementation 'com.facebook.soloader:nativeloader:0.8.0'
+ localImplementation project(':ops')
+
+ implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
+ implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
+ implementation 'org.pytorch:torchvision_ops:0.0.1-SNAPSHOT'
+
+ aarImplementation(name: 'pytorch_android-release', ext: 'aar')
+ aarImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
+
+ def camerax_version = "1.0.0-alpha05"
+ implementation "androidx.camera:camera-core:$camerax_version"
+ implementation "androidx.camera:camera-camera2:$camerax_version"
+ implementation 'com.google.android.material:material:1.0.0-beta01'
+}
+
+task extractAARForNativeBuild {
+ doLast {
+ configurations.extractForNativeBuild.files.each {
+ def file = it.absoluteFile
+ copy {
+ from zipTree(file)
+ into "$buildDir/$file.name"
+ include "headers/**"
+ include "jni/**"
+ }
+ }
+ }
+}
+
+tasks.whenTaskAdded { task ->
+ if (task.name.contains('externalNativeBuild')) {
+ task.dependsOn(extractAARForNativeBuild)
+ }
+}
diff --git a/android/test_app/app/src/main/AndroidManifest.xml b/android/test_app/app/src/main/AndroidManifest.xml
new file mode 100644
index 00000000000..a83bf223bda
--- /dev/null
+++ b/android/test_app/app/src/main/AndroidManifest.xml
@@ -0,0 +1,21 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/BBox.java b/android/test_app/app/src/main/java/org/pytorch/testapp/BBox.java
new file mode 100644
index 00000000000..6fd60791864
--- /dev/null
+++ b/android/test_app/app/src/main/java/org/pytorch/testapp/BBox.java
@@ -0,0 +1,22 @@
+package org.pytorch.testapp;
+
+class BBox {
+ public final float score;
+ public final float x0;
+ public final float y0;
+ public final float x1;
+ public final float y1;
+
+ public BBox(float score, float x0, float y0, float x1, float y1) {
+ this.score = score;
+ this.x0 = x0;
+ this.y0 = y0;
+ this.x1 = x1;
+ this.y1 = y1;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("Box{score=%f x0=%f y0=%f x1=%f y1=%f", score, x0, y0, x1, y1);
+ }
+}
diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java
new file mode 100644
index 00000000000..cfbbad4a8d2
--- /dev/null
+++ b/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java
@@ -0,0 +1,440 @@
+package org.pytorch.testapp;
+
+import android.Manifest;
+import android.content.Context;
+import android.content.pm.PackageManager;
+import android.graphics.Bitmap;
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Paint;
+import android.graphics.Rect;
+import android.os.Bundle;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.SystemClock;
+import android.util.DisplayMetrics;
+import android.util.Log;
+import android.util.Size;
+import android.view.TextureView;
+import android.view.ViewStub;
+import android.widget.ImageView;
+import android.widget.TextView;
+import android.widget.Toast;
+
+import androidx.annotation.Nullable;
+import androidx.annotation.UiThread;
+import androidx.annotation.WorkerThread;
+import androidx.appcompat.app.AppCompatActivity;
+import androidx.camera.core.CameraX;
+import androidx.camera.core.ImageAnalysis;
+import androidx.camera.core.ImageAnalysisConfig;
+import androidx.camera.core.ImageProxy;
+import androidx.camera.core.Preview;
+import androidx.camera.core.PreviewConfig;
+import androidx.core.app.ActivityCompat;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import com.facebook.soloader.nativeloader.NativeLoader;
+import com.facebook.soloader.nativeloader.SystemDelegate;
+
+import org.pytorch.IValue;
+import org.pytorch.MemoryFormat;
+import org.pytorch.Module;
+import org.pytorch.PyTorchAndroid;
+import org.pytorch.Tensor;
+
+public class CameraActivity extends AppCompatActivity {
+
+ private static final float BBOX_SCORE_DRAW_THRESHOLD = 0.5f;
+ private static final String TAG = BuildConfig.LOGCAT_TAG;
+ private static final int TEXT_TRIM_SIZE = 4096;
+ private static final int RGB_MAX_CHANNEL_VALUE = 262143;
+
+ private static final int REQUEST_CODE_CAMERA_PERMISSION = 200;
+ private static final String[] PERMISSIONS = {Manifest.permission.CAMERA};
+
+ static {
+ if (!NativeLoader.isInitialized()) {
+ NativeLoader.init(new SystemDelegate());
+ }
+ NativeLoader.loadLibrary("pytorch_jni");
+ NativeLoader.loadLibrary("torchvision_ops");
+ }
+
+ private Bitmap mInputTensorBitmap;
+ private Bitmap mBitmap;
+ private Canvas mCanvas;
+
+ private long mLastAnalysisResultTime;
+
+ protected HandlerThread mBackgroundThread;
+ protected Handler mBackgroundHandler;
+ protected Handler mUIHandler;
+
+ private TextView mTextView;
+ private ImageView mCameraOverlay;
+ private StringBuilder mTextViewStringBuilder = new StringBuilder();
+
+ private Paint mBboxPaint;
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_camera);
+ mTextView = findViewById(R.id.text);
+ mCameraOverlay = findViewById(R.id.camera_overlay);
+ mUIHandler = new Handler(getMainLooper());
+ startBackgroundThread();
+
+ if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA)
+ != PackageManager.PERMISSION_GRANTED) {
+ ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION);
+ } else {
+ setupCameraX();
+ }
+ mBboxPaint = new Paint();
+ mBboxPaint.setAntiAlias(true);
+ mBboxPaint.setDither(true);
+ mBboxPaint.setColor(Color.GREEN);
+ }
+
+ @Override
+ protected void onPostCreate(@Nullable Bundle savedInstanceState) {
+ super.onPostCreate(savedInstanceState);
+ startBackgroundThread();
+ }
+
+ protected void startBackgroundThread() {
+ mBackgroundThread = new HandlerThread("ModuleActivity");
+ mBackgroundThread.start();
+ mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
+ }
+
+ @Override
+ protected void onDestroy() {
+ stopBackgroundThread();
+ super.onDestroy();
+ }
+
+ protected void stopBackgroundThread() {
+ mBackgroundThread.quitSafely();
+ try {
+ mBackgroundThread.join();
+ mBackgroundThread = null;
+ mBackgroundHandler = null;
+ } catch (InterruptedException e) {
+ Log.e(TAG, "Error on stopping background thread", e);
+ }
+ }
+
+ @Override
+ public void onRequestPermissionsResult(
+ int requestCode, String[] permissions, int[] grantResults) {
+ if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) {
+ if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
+ Toast.makeText(
+ this,
+ "You can't use image classification example without granting CAMERA permission",
+ Toast.LENGTH_LONG)
+ .show();
+ finish();
+ } else {
+ setupCameraX();
+ }
+ }
+ }
+
+ private void setupCameraX() {
+ final TextureView textureView =
+ ((ViewStub) findViewById(R.id.camera_texture_view_stub))
+ .inflate()
+ .findViewById(R.id.texture_view);
+ final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
+ final Preview preview = new Preview(previewConfig);
+ preview.setOnPreviewOutputUpdateListener(
+ new Preview.OnPreviewOutputUpdateListener() {
+ @Override
+ public void onUpdated(Preview.PreviewOutput output) {
+ textureView.setSurfaceTexture(output.getSurfaceTexture());
+ }
+ });
+
+ final DisplayMetrics displayMetrics = new DisplayMetrics();
+ getWindowManager().getDefaultDisplay().getMetrics(displayMetrics);
+
+ final ImageAnalysisConfig imageAnalysisConfig =
+ new ImageAnalysisConfig.Builder()
+ .setTargetResolution(new Size(displayMetrics.widthPixels, displayMetrics.heightPixels))
+ .setCallbackHandler(mBackgroundHandler)
+ .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
+ .build();
+ final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
+ imageAnalysis.setAnalyzer(
+ new ImageAnalysis.Analyzer() {
+ @Override
+ public void analyze(ImageProxy image, int rotationDegrees) {
+ if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {
+ return;
+ }
+
+ final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees);
+
+ if (result != null) {
+ mLastAnalysisResultTime = SystemClock.elapsedRealtime();
+ CameraActivity.this.runOnUiThread(
+ new Runnable() {
+ @Override
+ public void run() {
+ CameraActivity.this.handleResult(result);
+ }
+ });
+ }
+ }
+ });
+
+ CameraX.bindToLifecycle(this, preview, imageAnalysis);
+ }
+
+ private Module mModule;
+ private FloatBuffer mInputTensorBuffer;
+ private Tensor mInputTensor;
+
+ private static int clamp0255(int x) {
+ if (x > 255) {
+ return 255;
+ }
+ return x < 0 ? 0 : x;
+ }
+
+ protected void fillInputTensorBuffer(
+ ImageProxy image,
+ int rotationDegrees,
+ FloatBuffer inputTensorBuffer) {
+
+ if (mInputTensorBitmap == null) {
+ final int tensorSize = Math.min(image.getWidth(), image.getHeight());
+ mInputTensorBitmap = Bitmap.createBitmap(tensorSize, tensorSize, Bitmap.Config.ARGB_8888);
+ }
+
+ ImageProxy.PlaneProxy[] planes = image.getPlanes();
+ ImageProxy.PlaneProxy Y = planes[0];
+ ImageProxy.PlaneProxy U = planes[1];
+ ImageProxy.PlaneProxy V = planes[2];
+ ByteBuffer yBuffer = Y.getBuffer();
+ ByteBuffer uBuffer = U.getBuffer();
+ ByteBuffer vBuffer = V.getBuffer();
+ final int imageWidth = image.getWidth();
+ final int imageHeight = image.getHeight();
+ final int tensorSize = Math.min(imageWidth, imageHeight);
+
+ int widthAfterRtn = imageWidth;
+ int heightAfterRtn = imageHeight;
+ boolean oddRotation = rotationDegrees == 90 || rotationDegrees == 270;
+ if (oddRotation) {
+ widthAfterRtn = imageHeight;
+ heightAfterRtn = imageWidth;
+ }
+
+ int minSizeAfterRtn = Math.min(heightAfterRtn, widthAfterRtn);
+ int cropWidthAfterRtn = minSizeAfterRtn;
+ int cropHeightAfterRtn = minSizeAfterRtn;
+
+ int cropWidthBeforeRtn = cropWidthAfterRtn;
+ int cropHeightBeforeRtn = cropHeightAfterRtn;
+ if (oddRotation) {
+ cropWidthBeforeRtn = cropHeightAfterRtn;
+ cropHeightBeforeRtn = cropWidthAfterRtn;
+ }
+
+ int offsetX = (int) ((imageWidth - cropWidthBeforeRtn) / 2.f);
+ int offsetY = (int) ((imageHeight - cropHeightBeforeRtn) / 2.f);
+
+ int yRowStride = Y.getRowStride();
+ int yPixelStride = Y.getPixelStride();
+ int uvRowStride = U.getRowStride();
+ int uvPixelStride = U.getPixelStride();
+
+ float scale = cropWidthAfterRtn / tensorSize;
+ int yIdx, uvIdx, yi, ui, vi;
+ final int channelSize = tensorSize * tensorSize;
+ for (int y = 0; y < tensorSize; y++) {
+ for (int x = 0; x < tensorSize; x++) {
+ final int centerCropX = (int) Math.floor(x * scale);
+ final int centerCropY = (int) Math.floor(y * scale);
+ int srcX = centerCropX + offsetX;
+ int srcY = centerCropY + offsetY;
+
+ if (rotationDegrees == 90) {
+ srcX = offsetX + centerCropY;
+ srcY = offsetY + (minSizeAfterRtn - 1) - centerCropX;
+ } else if (rotationDegrees == 180) {
+ srcX = offsetX + (minSizeAfterRtn - 1) - centerCropX;
+ srcY = offsetY + (minSizeAfterRtn - 1) - centerCropY;
+ } else if (rotationDegrees == 270) {
+ srcX = offsetX + (minSizeAfterRtn - 1) - centerCropY;
+ srcY = offsetY + centerCropX;
+ }
+
+ yIdx = srcY * yRowStride + srcX * yPixelStride;
+ uvIdx = (srcY >> 1) * uvRowStride + (srcX >> 1) * uvPixelStride;
+
+ yi = yBuffer.get(yIdx) & 0xff;
+ ui = uBuffer.get(uvIdx) & 0xff;
+ vi = vBuffer.get(uvIdx) & 0xff;
+
+ yi = (yi - 16) < 0 ? 0 : (yi - 16);
+ ui -= 128;
+ vi -= 128;
+
+ int a0 = 1192 * yi;
+ int ri = (a0 + 1634 * vi);
+ int gi = (a0 - 833 * vi - 400 * ui);
+ int bi = (a0 + 2066 * ui);
+
+ ri = ri > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (ri < 0 ? 0 : ri);
+ gi = gi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (gi < 0 ? 0 : gi);
+ bi = bi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (bi < 0 ? 0 : bi);
+
+ final int color = 0xff000000 | ((ri << 6) & 0xff0000) | ((gi >> 2) & 0xff00) | ((bi >> 10) & 0xff);
+ mInputTensorBitmap.setPixel(x, y, color);
+ inputTensorBuffer.put(0 * channelSize + y * tensorSize + x, clamp0255(ri >> 10) / 255.f);
+ inputTensorBuffer.put(1 * channelSize + y * tensorSize + x, clamp0255(gi >> 10) / 255.f);
+ inputTensorBuffer.put(2 * channelSize + y * tensorSize + x, clamp0255(bi >> 10) / 255.f);
+ }
+ }
+ }
+
+ public static String assetFilePath(Context context, String assetName) {
+ File file = new File(context.getFilesDir(), assetName);
+ if (file.exists() && file.length() > 0) {
+ return file.getAbsolutePath();
+ }
+
+ try (InputStream is = context.getAssets().open(assetName)) {
+ try (OutputStream os = new FileOutputStream(file)) {
+ byte[] buffer = new byte[4 * 1024];
+ int read;
+ while ((read = is.read(buffer)) != -1) {
+ os.write(buffer, 0, read);
+ }
+ os.flush();
+ }
+ return file.getAbsolutePath();
+ } catch (IOException e) {
+ Log.e(TAG, "Error process asset " + assetName + " to file path");
+ }
+ return null;
+ }
+
+ @WorkerThread
+ @Nullable
+ protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
+ Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees));
+ final int tensorSize = Math.min(image.getWidth(), image.getHeight());
+ if (mModule == null) {
+ Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
+ mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * tensorSize * tensorSize);
+ mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{3, tensorSize, tensorSize});
+ final String modelFileAbsoluteFilePath =
+ new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath();
+ mModule = Module.load(modelFileAbsoluteFilePath);
+ }
+
+ final long startTime = SystemClock.elapsedRealtime();
+ fillInputTensorBuffer(image, rotationDegrees, mInputTensorBuffer);
+
+ final long moduleForwardStartTime = SystemClock.elapsedRealtime();
+ final IValue outputTuple = mModule.forward(IValue.listFrom(mInputTensor));
+ final IValue out1 = outputTuple.toTuple()[1];
+ final Map map = out1.toList()[0].toDictStringKey();
+
+ float[] boxesData = new float[]{};
+ float[] scoresData = new float[]{};
+ final List bboxes = new ArrayList<>();
+ if (map.containsKey("boxes")) {
+ final Tensor boxesTensor = map.get("boxes").toTensor();
+ final Tensor scoresTensor = map.get("scores").toTensor();
+ boxesData = boxesTensor.getDataAsFloatArray();
+ scoresData = scoresTensor.getDataAsFloatArray();
+ final int n = scoresData.length;
+ for (int i = 0; i < n; i++) {
+ final BBox bbox = new BBox(
+ scoresData[i],
+ boxesData[4 * i + 0],
+ boxesData[4 * i + 1],
+ boxesData[4 * i + 2],
+ boxesData[4 * i + 3]
+ );
+ android.util.Log.i(TAG, String.format("Forward result %d: %s", i, bbox));
+ bboxes.add(bbox);
+ }
+ } else {
+ android.util.Log.i(TAG, "Forward result empty");
+ }
+
+ final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
+ final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
+ return new Result(tensorSize, bboxes, moduleForwardDuration, analysisDuration);
+ }
+
+ @UiThread
+ protected void handleResult(Result result) {
+ final int W = mCameraOverlay.getMeasuredWidth();
+ final int H = mCameraOverlay.getMeasuredHeight();
+
+ final int size = Math.min(W, H);
+ final int offsetX = (W - size) / 2;
+ final int offsetY = (H - size) / 2;
+
+ float scaleX = (float) size / result.tensorSize;
+ float scaleY = (float) size / result.tensorSize;
+ if (mBitmap == null) {
+ mBitmap = Bitmap.createBitmap(W, H, Bitmap.Config.ARGB_8888);
+ mCanvas = new Canvas(mBitmap);
+ }
+
+ mCanvas.drawBitmap(
+ mInputTensorBitmap,
+ new Rect(0, 0, result.tensorSize, result.tensorSize),
+ new Rect(offsetX, offsetY, offsetX + size, offsetY + size),
+ null
+ );
+
+ for (final BBox bbox : result.bboxes) {
+ if (bbox.score < BBOX_SCORE_DRAW_THRESHOLD) {
+ continue;
+ }
+
+ float c_x0 = offsetX + scaleX * bbox.x0;
+ float c_y0 = offsetY + scaleY * bbox.y0;
+
+ float c_x1 = offsetX + scaleX * bbox.x1;
+ float c_y1 = offsetY + scaleY * bbox.y1;
+
+ mCanvas.drawLine(c_x0, c_y0, c_x1, c_y0, mBboxPaint);
+ mCanvas.drawLine(c_x1, c_y0, c_x1, c_y1, mBboxPaint);
+ mCanvas.drawLine(c_x1, c_y1, c_x0, c_y1, mBboxPaint);
+ mCanvas.drawLine(c_x0, c_y1, c_x0, c_y0, mBboxPaint);
+ mCanvas.drawText(String.format("%.2f", bbox.score), c_x0, c_y0, mBboxPaint);
+ }
+ mCameraOverlay.setImageBitmap(mBitmap);
+
+ String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
+ Log.i(TAG, message);
+ mTextViewStringBuilder.insert(0, '\n').insert(0, message);
+ if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
+ mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
+ }
+ mTextView.setText(mTextViewStringBuilder.toString());
+ }
+}
diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java
new file mode 100644
index 00000000000..1ee16a87a0a
--- /dev/null
+++ b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java
@@ -0,0 +1,158 @@
+package org.pytorch.testapp;
+
+import android.os.Bundle;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.SystemClock;
+import android.util.Log;
+import android.widget.TextView;
+import androidx.annotation.Nullable;
+import androidx.annotation.UiThread;
+import androidx.annotation.WorkerThread;
+import androidx.appcompat.app.AppCompatActivity;
+import com.facebook.soloader.nativeloader.NativeLoader;
+import com.facebook.soloader.nativeloader.SystemDelegate;
+import org.pytorch.IValue;
+import org.pytorch.Module;
+import org.pytorch.PyTorchAndroid;
+import org.pytorch.Tensor;
+
+import java.nio.FloatBuffer;
+import java.util.Map;
+
+public class MainActivity extends AppCompatActivity {
+ static {
+ if (!NativeLoader.isInitialized()) {
+ NativeLoader.init(new SystemDelegate());
+ }
+ NativeLoader.loadLibrary("pytorch_jni");
+ NativeLoader.loadLibrary("torchvision_ops");
+ }
+
+ private static final String TAG = BuildConfig.LOGCAT_TAG;
+ private static final int TEXT_TRIM_SIZE = 4096;
+
+ private TextView mTextView;
+
+ protected HandlerThread mBackgroundThread;
+ protected Handler mBackgroundHandler;
+ private Module mModule;
+ private FloatBuffer mInputTensorBuffer;
+ private Tensor mInputTensor;
+ private StringBuilder mTextViewStringBuilder = new StringBuilder();
+
+ private final Runnable mModuleForwardRunnable =
+ new Runnable() {
+ @Override
+ public void run() {
+ final Result result = doModuleForward();
+ runOnUiThread(
+ () -> {
+ handleResult(result);
+ if (mBackgroundHandler != null) {
+ mBackgroundHandler.post(mModuleForwardRunnable);
+ }
+ });
+ }
+ };
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_main);
+ mTextView = findViewById(R.id.text);
+ startBackgroundThread();
+ mBackgroundHandler.post(mModuleForwardRunnable);
+ }
+
+ protected void startBackgroundThread() {
+ mBackgroundThread = new HandlerThread(TAG + "_bg");
+ mBackgroundThread.start();
+ mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
+ }
+
+ @Override
+ protected void onDestroy() {
+ stopBackgroundThread();
+ super.onDestroy();
+ }
+
+ protected void stopBackgroundThread() {
+ mBackgroundThread.quitSafely();
+ try {
+ mBackgroundThread.join();
+ mBackgroundThread = null;
+ mBackgroundHandler = null;
+ } catch (InterruptedException e) {
+ Log.e(TAG, "Error stopping background thread", e);
+ }
+ }
+
+ @WorkerThread
+ @Nullable
+ protected Result doModuleForward() {
+ if (mModule == null) {
+ final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
+ long numElements = 1;
+ for (int i = 0; i < shape.length; i++) {
+ numElements *= shape[i];
+ }
+ mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
+ mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
+ PyTorchAndroid.setNumThreads(1);
+ mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
+ }
+
+ final long startTime = SystemClock.elapsedRealtime();
+ final long moduleForwardStartTime = SystemClock.elapsedRealtime();
+ final IValue outputTuple = mModule.forward(IValue.listFrom(mInputTensor));
+ final IValue[] outputArray = outputTuple.toTuple();
+ final IValue out0 = outputArray[0];
+ final Map map = out0.toDictStringKey();
+ if (map.containsKey("boxes")) {
+ final Tensor boxes = map.get("boxes").toTensor();
+ final Tensor scores = map.get("scores").toTensor();
+ final float[] boxesData = boxes.getDataAsFloatArray();
+ final float[] scoresData = scores.getDataAsFloatArray();
+ final int n = scoresData.length;
+ for (int i = 0; i < n; i++) {
+ android.util.Log.i(TAG,
+ String.format("Forward result %d: score %f box:(%f, %f, %f, %f)",
+ scoresData[i],
+ boxesData[4 * i + 0],
+ boxesData[4 * i + 1],
+ boxesData[4 * i + 2],
+ boxesData[4 * i + 3]));
+ }
+ } else {
+ android.util.Log.i(TAG, "Forward result empty");
+ }
+
+ final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
+ final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
+ return new Result(new float[]{}, moduleForwardDuration, analysisDuration);
+ }
+
+ static class Result {
+
+ private final float[] scores;
+ private final long totalDuration;
+ private final long moduleForwardDuration;
+
+ public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
+ this.scores = scores;
+ this.moduleForwardDuration = moduleForwardDuration;
+ this.totalDuration = totalDuration;
+ }
+ }
+
+ @UiThread
+ protected void handleResult(Result result) {
+ String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
+ mTextViewStringBuilder.insert(0, '\n').insert(0, message);
+ if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
+ mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
+ }
+ mTextView.setText(mTextViewStringBuilder.toString());
+ }
+}
diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java b/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java
new file mode 100644
index 00000000000..ed7ebd006cd
--- /dev/null
+++ b/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java
@@ -0,0 +1,17 @@
+package org.pytorch.testapp;
+
+import java.util.List;
+
+class Result {
+ public final int tensorSize;
+ public final List bboxes;
+ public final long totalDuration;
+ public final long moduleForwardDuration;
+
+ public Result(int tensorSize, List bboxes, long moduleForwardDuration, long totalDuration) {
+ this.tensorSize = tensorSize;
+ this.bboxes = bboxes;
+ this.moduleForwardDuration = moduleForwardDuration;
+ this.totalDuration = totalDuration;
+ }
+}
diff --git a/android/test_app/app/src/main/res/layout/activity_camera.xml b/android/test_app/app/src/main/res/layout/activity_camera.xml
new file mode 100644
index 00000000000..7ba2e42b7c0
--- /dev/null
+++ b/android/test_app/app/src/main/res/layout/activity_camera.xml
@@ -0,0 +1,28 @@
+
+
+
+
+
+
+
+
+
diff --git a/android/test_app/app/src/main/res/layout/activity_main.xml b/android/test_app/app/src/main/res/layout/activity_main.xml
new file mode 100644
index 00000000000..c0939ebc0eb
--- /dev/null
+++ b/android/test_app/app/src/main/res/layout/activity_main.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/test_app/app/src/main/res/layout/texture_view.xml b/android/test_app/app/src/main/res/layout/texture_view.xml
new file mode 100644
index 00000000000..6518c6c84c6
--- /dev/null
+++ b/android/test_app/app/src/main/res/layout/texture_view.xml
@@ -0,0 +1,5 @@
+
+
diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png
new file mode 100644
index 00000000000..64ba76f75e9
Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png differ
diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png
new file mode 100644
index 00000000000..dae5e082342
Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png differ
diff --git a/android/test_app/app/src/main/res/values/colors.xml b/android/test_app/app/src/main/res/values/colors.xml
new file mode 100644
index 00000000000..69b22338c65
--- /dev/null
+++ b/android/test_app/app/src/main/res/values/colors.xml
@@ -0,0 +1,6 @@
+
+
+ #008577
+ #00574B
+ #D81B60
+
diff --git a/android/test_app/app/src/main/res/values/strings.xml b/android/test_app/app/src/main/res/values/strings.xml
new file mode 100644
index 00000000000..cafbaad1511
--- /dev/null
+++ b/android/test_app/app/src/main/res/values/strings.xml
@@ -0,0 +1,3 @@
+
+ TV_FRCNN
+
diff --git a/android/test_app/app/src/main/res/values/styles.xml b/android/test_app/app/src/main/res/values/styles.xml
new file mode 100644
index 00000000000..5885930df6d
--- /dev/null
+++ b/android/test_app/app/src/main/res/values/styles.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py
new file mode 100644
index 00000000000..7860c759a57
--- /dev/null
+++ b/android/test_app/make_assets.py
@@ -0,0 +1,17 @@
+import torch
+import torchvision
+from torch.utils.mobile_optimizer import optimize_for_mobile
+
+print(torch.__version__)
+
+model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
+ pretrained=True,
+ box_score_thresh=0.7,
+ rpn_post_nms_top_n_test=100,
+ rpn_score_thresh=0.4,
+ rpn_pre_nms_top_n_test=150)
+
+model.eval()
+script_model = torch.jit.script(model)
+opt_script_model = optimize_for_mobile(script_model)
+opt_script_model.save("app/src/main/assets/frcnn_mnetv3.pt")
diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp
index baad319e7c0..1b75de4a754 100644
--- a/torchvision/csrc/vision.cpp
+++ b/torchvision/csrc/vision.cpp
@@ -1,6 +1,8 @@
#include "vision.h"
+#ifndef MOBILE
#include
+#endif
#include
#ifdef WITH_CUDA