diff --git a/android/.gitignore b/android/.gitignore new file mode 100644 index 00000000000..b4d6d617645 --- /dev/null +++ b/android/.gitignore @@ -0,0 +1,8 @@ +local.properties +**/*.iml +.gradle +gradlew* +gradle/wrapper +.idea/* +.externalNativeBuild +build diff --git a/android/build.gradle b/android/build.gradle new file mode 100644 index 00000000000..c3393e047e4 --- /dev/null +++ b/android/build.gradle @@ -0,0 +1,43 @@ +allprojects { + buildscript { + ext { + minSdkVersion = 21 + targetSdkVersion = 28 + compileSdkVersion = 28 + buildToolsVersion = '28.0.3' + + coreVersion = "1.2.0" + extJUnitVersion = "1.1.1" + runnerVersion = "1.2.0" + rulesVersion = "1.2.0" + junitVersion = "4.12" + + androidSupportAppCompatV7Version = "28.0.0" + fbjniJavaOnlyVersion = "0.0.3" + soLoaderNativeLoaderVersion = "0.8.0" + } + + repositories { + google() + mavenCentral() + jcenter() + } + + dependencies { + classpath 'com.android.tools.build:gradle:3.3.2' + classpath "com.jfrog.bintray.gradle:gradle-bintray-plugin:${GRADLE_BINTRAY_PLUGIN_VERSION}" + classpath "com.github.dcendents:android-maven-gradle-plugin:${ANDROID_MAVEN_GRADLE_PLUGIN_VERSION}" + classpath "org.jfrog.buildinfo:build-info-extractor-gradle:4.9.8" + } + } + + repositories { + google() + jcenter() + } +} + +ext.deps = [ + jsr305: 'com.google.code.findbugs:jsr305:3.0.1', +] + diff --git a/android/gradle.properties b/android/gradle.properties new file mode 100644 index 00000000000..bd5ed6bbd98 --- /dev/null +++ b/android/gradle.properties @@ -0,0 +1,28 @@ +ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64 + +VERSION_NAME=0.0.1-SNAPSHOT +GROUP=org.pytorch +MAVEN_GROUP=org.pytorch +POM_URL=https://github.com/pytorch/vision/ +POM_SCM_URL=https://github.com/pytorch/vision.git +POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/vision +POM_SCM_DEV_CONNECTION=scm:git:git@github.com:pytorch/vision.git +POM_LICENSE_NAME=BSD 3-Clause +POM_LICENSE_URL=https://github.com/pytorch/vision/blob/master/LICENSE +POM_ISSUES_URL=https://github.com/pytorch/vision/issues +POM_LICENSE_DIST=repo +POM_DEVELOPER_ID=pytorch +POM_DEVELOPER_NAME=pytorch +syncWithMavenCentral=true + +GRADLE_BINTRAY_PLUGIN_VERSION=1.8.0 +GRADLE_VERSIONS_PLUGIN_VERSION=0.15.0 +ANDROID_MAVEN_GRADLE_PLUGIN_VERSION=2.1 + +# Gradle internals +android.useAndroidX=true +android.enableJetifier=true + +testAppAllVariantsEnabled=false + +org.gradle.jvmargs=-Xmx4096m diff --git a/android/gradle_scripts/android_maven_install.gradle b/android/gradle_scripts/android_maven_install.gradle new file mode 100644 index 00000000000..ce80472d79e --- /dev/null +++ b/android/gradle_scripts/android_maven_install.gradle @@ -0,0 +1,38 @@ +apply plugin: 'com.github.dcendents.android-maven' + +version = VERSION_NAME +group = GROUP +project.archivesBaseName = POM_ARTIFACT_ID + +install { + repositories.mavenInstaller { + pom.project { + name POM_NAME + artifactId POM_ARTIFACT_ID + packaging POM_PACKAGING + description POM_DESCRIPTION + url projectUrl + + scm { + url scmUrl + connection scmConnection + developerConnection scmDeveloperConnection + } + + licenses { + license { + name = POM_LICENSE_NAME + url = POM_LICENSE_URL + distribution = POM_LICENSE_DIST + } + } + + developers { + developer { + id developerId + name developerName + } + } + } + } +} diff --git a/android/gradle_scripts/android_tasks.gradle b/android/gradle_scripts/android_tasks.gradle new file mode 100644 index 00000000000..ca188ac72d0 --- /dev/null +++ b/android/gradle_scripts/android_tasks.gradle @@ -0,0 +1,95 @@ + +import java.nio.file.Files +import java.nio.file.Paths +import java.io.FileOutputStream +import java.util.zip.ZipFile + +// Android tasks for Javadoc and sources.jar generation + +afterEvaluate { project -> + if (POM_PACKAGING == 'aar') { + task androidJavadoc(type: Javadoc, dependsOn: assembleDebug) { + source += files(android.sourceSets.main.java.srcDirs) + failOnError false + // This task will try to compile *everything* it finds in the above directory and + // will choke on text files it doesn't understand. + exclude '**/BUCK' + exclude '**/*.md' + } + + task androidJavadocJar(type: Jar, dependsOn: androidJavadoc) { + classifier = 'javadoc' + from androidJavadoc.destinationDir + } + + task androidSourcesJar(type: Jar) { + classifier = 'sources' + from android.sourceSets.main.java.srcDirs + } + + android.libraryVariants.all { variant -> + def name = variant.name.capitalize() + task "jar${name}"(type: Jar, dependsOn: variant.javaCompileProvider) { + from variant.javaCompileProvider.get().destinationDir + } + + androidJavadoc.doFirst { + classpath += files(android.bootClasspath) + classpath += files(variant.javaCompileProvider.get().classpath.files) + // This is generated by `assembleDebug` and holds the JARs generated by the APT. + classpath += fileTree(dir: "$buildDir/intermediates/bundles/debug/", include: '**/*.jar') + + // Process AAR dependencies + def aarDependencies = classpath.filter { it.name.endsWith('.aar') } + classpath -= aarDependencies + aarDependencies.each { aar -> + // Extract classes.jar from the AAR dependency, and add it to the javadoc classpath + def outputPath = "$buildDir/tmp/aarJar/${aar.name.replace('.aar', '.jar')}" + classpath += files(outputPath) + + // Use a task so the actual extraction only happens before the javadoc task is run + dependsOn task(name: "extract ${aar.name}").doLast { + extractEntry(aar, 'classes.jar', outputPath) + } + } + } + } + + artifacts.add('archives', androidJavadocJar) + artifacts.add('archives', androidSourcesJar) + } + + if (POM_PACKAGING == 'jar') { + task javadocJar(type: Jar, dependsOn: javadoc) { + classifier = 'javadoc' + from javadoc.destinationDir + } + + task sourcesJar(type: Jar, dependsOn: classes) { + classifier = 'sources' + from sourceSets.main.allSource + } + + artifacts.add('archives', javadocJar) + artifacts.add('archives', sourcesJar) + } +} + +// Utility method to extract only one entry in a zip file +private def extractEntry(archive, entryPath, outputPath) { + if (!archive.exists()) { + throw new GradleException("archive $archive not found") + } + + def zip = new ZipFile(archive) + zip.entries().each { + if (it.name == entryPath) { + def path = Paths.get(outputPath) + if (!Files.exists(path)) { + Files.createDirectories(path.getParent()) + Files.copy(zip.getInputStream(it), path) + } + } + } + zip.close() +} diff --git a/android/gradle_scripts/bintray.gradle b/android/gradle_scripts/bintray.gradle new file mode 100644 index 00000000000..c20073964f7 --- /dev/null +++ b/android/gradle_scripts/bintray.gradle @@ -0,0 +1,64 @@ +apply plugin: 'com.jfrog.bintray' + +def getBintrayUsername() { + return project.hasProperty('bintrayUsername') ? property('bintrayUsername') : System.getenv('BINTRAY_USERNAME') +} + +def getBintrayApiKey() { + return project.hasProperty('bintrayApiKey') ? property('bintrayApiKey') : System.getenv('BINTRAY_API_KEY') +} + +def getBintrayGpgPassword() { + return project.hasProperty('bintrayGpgPassword') ? property('bintrayGpgPassword') : System.getenv('BINTRAY_GPG_PASSWORD') +} + +def getMavenCentralUsername() { + return project.hasProperty('mavenCentralUsername') ? property('mavenCentralUsername') : System.getenv('MAVEN_CENTRAL_USERNAME') +} + +def getMavenCentralPassword() { + return project.hasProperty('mavenCentralPassword') ? property('mavenCentralPassword') : System.getenv('MAVEN_CENTRAL_PASSWORD') +} + +def shouldSyncWithMavenCentral() { + return project.hasProperty('syncWithMavenCentral') ? property('syncWithMavenCentral').toBoolean() : false +} + +def dryRunOnly() { + return project.hasProperty('dryRun') ? property('dryRun').toBoolean() : false +} + +bintray { + user = getBintrayUsername() + key = getBintrayApiKey() + override = false + configurations = ['archives'] + pkg { + repo = bintrayRepo + userOrg = bintrayUserOrg + name = bintrayName + desc = bintrayDescription + websiteUrl = projectUrl + issueTrackerUrl = issuesUrl + vcsUrl = scmUrl + licenses = [ POM_LICENSE_NAME ] + dryRun = dryRunOnly() + override = false + publish = true + publicDownloadNumbers = true + version { + name = versionName + desc = bintrayDescription + gpg { + sign = true + passphrase = getBintrayGpgPassword() + } + mavenCentralSync { + sync = shouldSyncWithMavenCentral() + user = getMavenCentralUsername() + password = getMavenCentralPassword() + close = '1' // If set to 0, you have to manually click release + } + } + } +} diff --git a/android/gradle_scripts/gradle_maven_push.gradle b/android/gradle_scripts/gradle_maven_push.gradle new file mode 100644 index 00000000000..5fdd8fbc6a0 --- /dev/null +++ b/android/gradle_scripts/gradle_maven_push.gradle @@ -0,0 +1,99 @@ +apply plugin: 'signing' + +version = VERSION_NAME +group = MAVEN_GROUP + +def isReleaseBuild() { + return !VERSION_NAME.contains('SNAPSHOT') +} + +def getReleaseRepositoryUrl() { + return hasProperty('RELEASE_REPOSITORY_URL') ? RELEASE_REPOSITORY_URL + : "https://oss.sonatype.org/service/local/staging/deploy/maven2/" +} + +def getSnapshotRepositoryUrl() { + return hasProperty('SNAPSHOT_REPOSITORY_URL') ? SNAPSHOT_REPOSITORY_URL + : "https://oss.sonatype.org/content/repositories/snapshots/" +} + +def getRepositoryUsername() { + return hasProperty('SONATYPE_NEXUS_USERNAME') ? SONATYPE_NEXUS_USERNAME : "" +} + +def getRepositoryPassword() { + return hasProperty('SONATYPE_NEXUS_PASSWORD') ? SONATYPE_NEXUS_PASSWORD : "" +} + +def getHttpProxyHost() { + return project.properties['systemProp.http.proxyHost'] +} + +def getHttpProxyPort() { + return project.properties['systemProp.http.proxyPort'] +} + +def needProxy() { + return (getHttpProxyHost() != null) && (getHttpProxyPort() != null) +} + +afterEvaluate { project -> + uploadArchives { + repositories { + mavenDeployer { + beforeDeployment { MavenDeployment deployment -> signing.signPom(deployment) } + + pom.groupId = MAVEN_GROUP + pom.artifactId = POM_ARTIFACT_ID + pom.version = VERSION_NAME + + repository(url: getReleaseRepositoryUrl()) { + authentication(userName: getRepositoryUsername(), password: getRepositoryPassword()) + if (needProxy()) { + proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http') + } + } + snapshotRepository(url: getSnapshotRepositoryUrl()) { + authentication(userName: getRepositoryUsername(), password: getRepositoryPassword()) + if (needProxy()) { + proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http') + } + } + + pom.project { + name POM_NAME + packaging POM_PACKAGING + description POM_DESCRIPTION + url POM_URL + + scm { + url POM_SCM_URL + connection POM_SCM_CONNECTION + developerConnection POM_SCM_DEV_CONNECTION + } + + licenses { + license { + name POM_LICENSE_NAME + url POM_LICENSE_URL + distribution POM_LICENSE_DIST + } + } + + developers { + developer { + id POM_DEVELOPER_ID + name POM_DEVELOPER_NAME + } + } + } + } + } + } + + signing { + required { isReleaseBuild() && gradle.taskGraph.hasTask('uploadArchives') } + sign configurations.archives + } + +} diff --git a/android/gradle_scripts/release.gradle b/android/gradle_scripts/release.gradle new file mode 100644 index 00000000000..d4e3aef4a22 --- /dev/null +++ b/android/gradle_scripts/release.gradle @@ -0,0 +1,5 @@ +apply from: rootProject.file('gradle_scripts/android_tasks.gradle') + +apply from: rootProject.file('gradle_scripts/release_bintray.gradle') + +apply from: rootProject.file('gradle_scripts/gradle_maven_push.gradle') diff --git a/android/gradle_scripts/release_bintray.gradle b/android/gradle_scripts/release_bintray.gradle new file mode 100644 index 00000000000..9b2af121a94 --- /dev/null +++ b/android/gradle_scripts/release_bintray.gradle @@ -0,0 +1,32 @@ +ext { + bintrayRepo = 'maven' + bintrayUserOrg = 'pytorch' + bintrayName = "${GROUP}:${POM_ARTIFACT_ID}" + bintrayDescription = POM_DESCRIPTION + projectUrl = POM_URL + issuesUrl = POM_ISSUES_URL + scmUrl = POM_SCM_URL + scmConnection = POM_SCM_CONNECTION + scmDeveloperConnection = POM_SCM_DEV_CONNECTION + + publishedGroupId = GROUP + libraryName = 'torchvision' + artifact = 'torchvision' + + developerId = POM_DEVELOPER_ID + developerName = POM_DEVELOPER_NAME + + versionName = VERSION_NAME + + projectLicenses = { + license = { + name = POM_LICENSE_NAME + url = POM_LICENSE_URL + distribution = POM_LICENSE_DIST + } + } +} + +apply from: rootProject.file('gradle_scripts/android_maven_install.gradle') + +apply from: rootProject.file('gradle_scripts/bintray.gradle') diff --git a/android/ops/CMakeLists.txt b/android/ops/CMakeLists.txt new file mode 100644 index 00000000000..ad42adbfa71 --- /dev/null +++ b/android/ops/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.4.1) +set(TARGET torchvision_ops) +project(${TARGET} CXX) +set(CMAKE_CXX_STANDARD 14) + +string(APPEND CMAKE_CXX_FLAGS " -DMOBILE") + +set(build_DIR ${CMAKE_SOURCE_DIR}/build) +set(root_DIR ${CMAKE_CURRENT_LIST_DIR}/..) + +file(GLOB VISION_SRCS + ../../torchvision/csrc/ops/cpu/*.h + ../../torchvision/csrc/ops/cpu/*.cpp + ../../torchvision/csrc/ops/*.h + ../../torchvision/csrc/ops/*.cpp) + +add_library(${TARGET} SHARED + ${VISION_SRCS} +) + +file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers") +file(GLOB PYTORCH_INCLUDE_DIRS_CSRC "${build_DIR}/pytorch_android*.aar/headers/torch/csrc/api/include") +file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}") + +target_compile_options(${TARGET} PRIVATE + -fexceptions +) + +set(BUILD_SUBDIR ${ANDROID_ABI}) + +find_library(PYTORCH_LIBRARY pytorch_jni + PATHS ${PYTORCH_LINK_DIRS} + NO_CMAKE_FIND_ROOT_PATH) + +find_library(FBJNI_LIBRARY fbjni + PATHS ${PYTORCH_LINK_DIRS} + NO_CMAKE_FIND_ROOT_PATH) + +target_include_directories(${TARGET} PRIVATE + ${PYTORCH_INCLUDE_DIRS} + ${PYTORCH_INCLUDE_DIRS_CSRC} +) + +target_link_libraries(${TARGET} PRIVATE + ${PYTORCH_LIBRARY} + ${FBJNI_LIBRARY} +) diff --git a/android/ops/build.gradle b/android/ops/build.gradle new file mode 100644 index 00000000000..773e09fb280 --- /dev/null +++ b/android/ops/build.gradle @@ -0,0 +1,94 @@ +apply plugin: 'com.android.library' +apply plugin: 'maven' + +repositories { + jcenter() + maven { + url "https://oss.sonatype.org/content/repositories/snapshots" + } + flatDir { + dirs 'aars' + } +} + +android { + configurations { + extractForNativeBuild + } + compileSdkVersion rootProject.compileSdkVersion + buildToolsVersion rootProject.buildToolsVersion + + + defaultConfig { + minSdkVersion rootProject.minSdkVersion + targetSdkVersion rootProject.targetSdkVersion + versionCode 0 + versionName "0.1" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + ndk { + abiFilters ABI_FILTERS.split(",") + } + } + + buildTypes { + debug { + minifyEnabled false + debuggable true + } + release { + minifyEnabled false + } + } + + externalNativeBuild { + cmake { + path "CMakeLists.txt" + } + } + + useLibrary 'android.test.runner' + useLibrary 'android.test.base' + useLibrary 'android.test.mock' +} + +dependencies { + implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version + + implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT' + extractForNativeBuild 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT' + + // For testing: deps on local aar files + //implementation(name: 'pytorch_android-release', ext: 'aar') + //extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar') + //implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3' +} + +task extractAARForNativeBuild { + doLast { + configurations.extractForNativeBuild.files.each { + def file = it.absoluteFile + copy { + from zipTree(file) + into "$buildDir/$file.name" + include "headers/**" + include "jni/**" + } + } + } +} + +tasks.whenTaskAdded { task -> + if (task.name.contains('externalNativeBuild')) { + task.dependsOn(extractAARForNativeBuild) + } +} + +apply from: rootProject.file('gradle_scripts/release.gradle') + +task sourcesJar(type: Jar) { + from android.sourceSets.main.java.srcDirs + classifier = 'sources' +} + +artifacts.add('archives', sourcesJar) diff --git a/android/ops/gradle.properties b/android/ops/gradle.properties new file mode 100644 index 00000000000..5a4ea2f3aba --- /dev/null +++ b/android/ops/gradle.properties @@ -0,0 +1,4 @@ +POM_NAME=torchvision ops +POM_DESCRIPTION=torchvision ops +POM_ARTIFACT_ID=torchvision_ops +POM_PACKAGING=aar diff --git a/android/ops/src/main/AndroidManifest.xml b/android/ops/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..8ca386493c4 --- /dev/null +++ b/android/ops/src/main/AndroidManifest.xml @@ -0,0 +1 @@ + diff --git a/android/settings.gradle b/android/settings.gradle new file mode 100644 index 00000000000..6d34eb8d51a --- /dev/null +++ b/android/settings.gradle @@ -0,0 +1,4 @@ +include ':ops', ':test_app' + +project(':ops').projectDir = file('ops') +project(':test_app').projectDir = file('test_app/app') diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle new file mode 100644 index 00000000000..e95078c401d --- /dev/null +++ b/android/test_app/app/build.gradle @@ -0,0 +1,138 @@ +apply plugin: 'com.android.application' + +repositories { + jcenter() + maven { + url "https://oss.sonatype.org/content/repositories/snapshots" + } + flatDir { + dirs 'aars' + } +} + +android { + configurations { + extractForNativeBuild + } + compileOptions { + sourceCompatibility 1.8 + targetCompatibility 1.8 + } + compileSdkVersion rootProject.compileSdkVersion + buildToolsVersion rootProject.buildToolsVersion + defaultConfig { + applicationId "org.pytorch.testapp" + minSdkVersion rootProject.minSdkVersion + targetSdkVersion rootProject.targetSdkVersion + versionCode 1 + versionName "1.0" + ndk { + abiFilters ABI_FILTERS.split(",") + } + externalNativeBuild { + cmake { + abiFilters ABI_FILTERS.split(",") + arguments "-DANDROID_STL=c++_shared" + } + } + buildConfigField("String", "MODULE_ASSET_NAME", "\"frcnn_mnetv3.pt\"") + buildConfigField("String", "LOGCAT_TAG", "@string/app_name") + buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{3, 96, 96}") + addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"]) + } + buildTypes { + debug { + minifyEnabled false + debuggable true + } + release { + minifyEnabled false + } + } + flavorDimensions "model", "activity", "build" + productFlavors { + frcnnMnetv3 { + dimension "model" + applicationIdSuffix ".frcnnMnetv3" + buildConfigField("String", "MODULE_ASSET_NAME", "\"frcnn_mnetv3.pt\"") + addManifestPlaceholders([APP_NAME: "TV_FRCNN_MNETV3"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-frcnn-mnetv3\"") + } + camera { + dimension "activity" + addManifestPlaceholders([APP_NAME: "TV_CAMERA_FRCNN"]) + addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"]) + } + base { + dimension "activity" + } + aar { + dimension "build" + } + nightly { + dimension "build" + } + local { + dimension "build" + } + } + packagingOptions { + doNotStrip '**.so' + } + + // Filtering for CI + if (!testAppAllVariantsEnabled.toBoolean()) { + variantFilter { variant -> + def names = variant.flavors*.name + if (names.contains("aar")) { + setIgnore(true) + } + } + } +} + +tasks.all { task -> + // Disable externalNativeBuild for all but nativeBuild variant + if (task.name.startsWith('externalNativeBuild') + && !task.name.contains('NativeBuild')) { + task.enabled = false + } +} + +dependencies { + implementation 'com.android.support:appcompat-v7:28.0.0' + implementation 'com.facebook.soloader:nativeloader:0.8.0' + localImplementation project(':ops') + + implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT' + implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT' + implementation 'org.pytorch:torchvision_ops:0.0.1-SNAPSHOT' + + aarImplementation(name: 'pytorch_android-release', ext: 'aar') + aarImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar') + + def camerax_version = "1.0.0-alpha05" + implementation "androidx.camera:camera-core:$camerax_version" + implementation "androidx.camera:camera-camera2:$camerax_version" + implementation 'com.google.android.material:material:1.0.0-beta01' +} + +task extractAARForNativeBuild { + doLast { + configurations.extractForNativeBuild.files.each { + def file = it.absoluteFile + copy { + from zipTree(file) + into "$buildDir/$file.name" + include "headers/**" + include "jni/**" + } + } + } +} + +tasks.whenTaskAdded { task -> + if (task.name.contains('externalNativeBuild')) { + task.dependsOn(extractAARForNativeBuild) + } +} diff --git a/android/test_app/app/src/main/AndroidManifest.xml b/android/test_app/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..a83bf223bda --- /dev/null +++ b/android/test_app/app/src/main/AndroidManifest.xml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/BBox.java b/android/test_app/app/src/main/java/org/pytorch/testapp/BBox.java new file mode 100644 index 00000000000..6fd60791864 --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/BBox.java @@ -0,0 +1,22 @@ +package org.pytorch.testapp; + +class BBox { + public final float score; + public final float x0; + public final float y0; + public final float x1; + public final float y1; + + public BBox(float score, float x0, float y0, float x1, float y1) { + this.score = score; + this.x0 = x0; + this.y0 = y0; + this.x1 = x1; + this.y1 = y1; + } + + @Override + public String toString() { + return String.format("Box{score=%f x0=%f y0=%f x1=%f y1=%f", score, x0, y0, x1, y1); + } +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java new file mode 100644 index 00000000000..cfbbad4a8d2 --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java @@ -0,0 +1,440 @@ +package org.pytorch.testapp; + +import android.Manifest; +import android.content.Context; +import android.content.pm.PackageManager; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.Rect; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.SystemClock; +import android.util.DisplayMetrics; +import android.util.Log; +import android.util.Size; +import android.view.TextureView; +import android.view.ViewStub; +import android.widget.ImageView; +import android.widget.TextView; +import android.widget.Toast; + +import androidx.annotation.Nullable; +import androidx.annotation.UiThread; +import androidx.annotation.WorkerThread; +import androidx.appcompat.app.AppCompatActivity; +import androidx.camera.core.CameraX; +import androidx.camera.core.ImageAnalysis; +import androidx.camera.core.ImageAnalysisConfig; +import androidx.camera.core.ImageProxy; +import androidx.camera.core.Preview; +import androidx.camera.core.PreviewConfig; +import androidx.core.app.ActivityCompat; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; + +import org.pytorch.IValue; +import org.pytorch.MemoryFormat; +import org.pytorch.Module; +import org.pytorch.PyTorchAndroid; +import org.pytorch.Tensor; + +public class CameraActivity extends AppCompatActivity { + + private static final float BBOX_SCORE_DRAW_THRESHOLD = 0.5f; + private static final String TAG = BuildConfig.LOGCAT_TAG; + private static final int TEXT_TRIM_SIZE = 4096; + private static final int RGB_MAX_CHANNEL_VALUE = 262143; + + private static final int REQUEST_CODE_CAMERA_PERMISSION = 200; + private static final String[] PERMISSIONS = {Manifest.permission.CAMERA}; + + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("pytorch_jni"); + NativeLoader.loadLibrary("torchvision_ops"); + } + + private Bitmap mInputTensorBitmap; + private Bitmap mBitmap; + private Canvas mCanvas; + + private long mLastAnalysisResultTime; + + protected HandlerThread mBackgroundThread; + protected Handler mBackgroundHandler; + protected Handler mUIHandler; + + private TextView mTextView; + private ImageView mCameraOverlay; + private StringBuilder mTextViewStringBuilder = new StringBuilder(); + + private Paint mBboxPaint; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_camera); + mTextView = findViewById(R.id.text); + mCameraOverlay = findViewById(R.id.camera_overlay); + mUIHandler = new Handler(getMainLooper()); + startBackgroundThread(); + + if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA) + != PackageManager.PERMISSION_GRANTED) { + ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION); + } else { + setupCameraX(); + } + mBboxPaint = new Paint(); + mBboxPaint.setAntiAlias(true); + mBboxPaint.setDither(true); + mBboxPaint.setColor(Color.GREEN); + } + + @Override + protected void onPostCreate(@Nullable Bundle savedInstanceState) { + super.onPostCreate(savedInstanceState); + startBackgroundThread(); + } + + protected void startBackgroundThread() { + mBackgroundThread = new HandlerThread("ModuleActivity"); + mBackgroundThread.start(); + mBackgroundHandler = new Handler(mBackgroundThread.getLooper()); + } + + @Override + protected void onDestroy() { + stopBackgroundThread(); + super.onDestroy(); + } + + protected void stopBackgroundThread() { + mBackgroundThread.quitSafely(); + try { + mBackgroundThread.join(); + mBackgroundThread = null; + mBackgroundHandler = null; + } catch (InterruptedException e) { + Log.e(TAG, "Error on stopping background thread", e); + } + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) { + if (grantResults[0] == PackageManager.PERMISSION_DENIED) { + Toast.makeText( + this, + "You can't use image classification example without granting CAMERA permission", + Toast.LENGTH_LONG) + .show(); + finish(); + } else { + setupCameraX(); + } + } + } + + private void setupCameraX() { + final TextureView textureView = + ((ViewStub) findViewById(R.id.camera_texture_view_stub)) + .inflate() + .findViewById(R.id.texture_view); + final PreviewConfig previewConfig = new PreviewConfig.Builder().build(); + final Preview preview = new Preview(previewConfig); + preview.setOnPreviewOutputUpdateListener( + new Preview.OnPreviewOutputUpdateListener() { + @Override + public void onUpdated(Preview.PreviewOutput output) { + textureView.setSurfaceTexture(output.getSurfaceTexture()); + } + }); + + final DisplayMetrics displayMetrics = new DisplayMetrics(); + getWindowManager().getDefaultDisplay().getMetrics(displayMetrics); + + final ImageAnalysisConfig imageAnalysisConfig = + new ImageAnalysisConfig.Builder() + .setTargetResolution(new Size(displayMetrics.widthPixels, displayMetrics.heightPixels)) + .setCallbackHandler(mBackgroundHandler) + .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE) + .build(); + final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig); + imageAnalysis.setAnalyzer( + new ImageAnalysis.Analyzer() { + @Override + public void analyze(ImageProxy image, int rotationDegrees) { + if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) { + return; + } + + final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees); + + if (result != null) { + mLastAnalysisResultTime = SystemClock.elapsedRealtime(); + CameraActivity.this.runOnUiThread( + new Runnable() { + @Override + public void run() { + CameraActivity.this.handleResult(result); + } + }); + } + } + }); + + CameraX.bindToLifecycle(this, preview, imageAnalysis); + } + + private Module mModule; + private FloatBuffer mInputTensorBuffer; + private Tensor mInputTensor; + + private static int clamp0255(int x) { + if (x > 255) { + return 255; + } + return x < 0 ? 0 : x; + } + + protected void fillInputTensorBuffer( + ImageProxy image, + int rotationDegrees, + FloatBuffer inputTensorBuffer) { + + if (mInputTensorBitmap == null) { + final int tensorSize = Math.min(image.getWidth(), image.getHeight()); + mInputTensorBitmap = Bitmap.createBitmap(tensorSize, tensorSize, Bitmap.Config.ARGB_8888); + } + + ImageProxy.PlaneProxy[] planes = image.getPlanes(); + ImageProxy.PlaneProxy Y = planes[0]; + ImageProxy.PlaneProxy U = planes[1]; + ImageProxy.PlaneProxy V = planes[2]; + ByteBuffer yBuffer = Y.getBuffer(); + ByteBuffer uBuffer = U.getBuffer(); + ByteBuffer vBuffer = V.getBuffer(); + final int imageWidth = image.getWidth(); + final int imageHeight = image.getHeight(); + final int tensorSize = Math.min(imageWidth, imageHeight); + + int widthAfterRtn = imageWidth; + int heightAfterRtn = imageHeight; + boolean oddRotation = rotationDegrees == 90 || rotationDegrees == 270; + if (oddRotation) { + widthAfterRtn = imageHeight; + heightAfterRtn = imageWidth; + } + + int minSizeAfterRtn = Math.min(heightAfterRtn, widthAfterRtn); + int cropWidthAfterRtn = minSizeAfterRtn; + int cropHeightAfterRtn = minSizeAfterRtn; + + int cropWidthBeforeRtn = cropWidthAfterRtn; + int cropHeightBeforeRtn = cropHeightAfterRtn; + if (oddRotation) { + cropWidthBeforeRtn = cropHeightAfterRtn; + cropHeightBeforeRtn = cropWidthAfterRtn; + } + + int offsetX = (int) ((imageWidth - cropWidthBeforeRtn) / 2.f); + int offsetY = (int) ((imageHeight - cropHeightBeforeRtn) / 2.f); + + int yRowStride = Y.getRowStride(); + int yPixelStride = Y.getPixelStride(); + int uvRowStride = U.getRowStride(); + int uvPixelStride = U.getPixelStride(); + + float scale = cropWidthAfterRtn / tensorSize; + int yIdx, uvIdx, yi, ui, vi; + final int channelSize = tensorSize * tensorSize; + for (int y = 0; y < tensorSize; y++) { + for (int x = 0; x < tensorSize; x++) { + final int centerCropX = (int) Math.floor(x * scale); + final int centerCropY = (int) Math.floor(y * scale); + int srcX = centerCropX + offsetX; + int srcY = centerCropY + offsetY; + + if (rotationDegrees == 90) { + srcX = offsetX + centerCropY; + srcY = offsetY + (minSizeAfterRtn - 1) - centerCropX; + } else if (rotationDegrees == 180) { + srcX = offsetX + (minSizeAfterRtn - 1) - centerCropX; + srcY = offsetY + (minSizeAfterRtn - 1) - centerCropY; + } else if (rotationDegrees == 270) { + srcX = offsetX + (minSizeAfterRtn - 1) - centerCropY; + srcY = offsetY + centerCropX; + } + + yIdx = srcY * yRowStride + srcX * yPixelStride; + uvIdx = (srcY >> 1) * uvRowStride + (srcX >> 1) * uvPixelStride; + + yi = yBuffer.get(yIdx) & 0xff; + ui = uBuffer.get(uvIdx) & 0xff; + vi = vBuffer.get(uvIdx) & 0xff; + + yi = (yi - 16) < 0 ? 0 : (yi - 16); + ui -= 128; + vi -= 128; + + int a0 = 1192 * yi; + int ri = (a0 + 1634 * vi); + int gi = (a0 - 833 * vi - 400 * ui); + int bi = (a0 + 2066 * ui); + + ri = ri > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (ri < 0 ? 0 : ri); + gi = gi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (gi < 0 ? 0 : gi); + bi = bi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (bi < 0 ? 0 : bi); + + final int color = 0xff000000 | ((ri << 6) & 0xff0000) | ((gi >> 2) & 0xff00) | ((bi >> 10) & 0xff); + mInputTensorBitmap.setPixel(x, y, color); + inputTensorBuffer.put(0 * channelSize + y * tensorSize + x, clamp0255(ri >> 10) / 255.f); + inputTensorBuffer.put(1 * channelSize + y * tensorSize + x, clamp0255(gi >> 10) / 255.f); + inputTensorBuffer.put(2 * channelSize + y * tensorSize + x, clamp0255(bi >> 10) / 255.f); + } + } + } + + public static String assetFilePath(Context context, String assetName) { + File file = new File(context.getFilesDir(), assetName); + if (file.exists() && file.length() > 0) { + return file.getAbsolutePath(); + } + + try (InputStream is = context.getAssets().open(assetName)) { + try (OutputStream os = new FileOutputStream(file)) { + byte[] buffer = new byte[4 * 1024]; + int read; + while ((read = is.read(buffer)) != -1) { + os.write(buffer, 0, read); + } + os.flush(); + } + return file.getAbsolutePath(); + } catch (IOException e) { + Log.e(TAG, "Error process asset " + assetName + " to file path"); + } + return null; + } + + @WorkerThread + @Nullable + protected Result analyzeImage(ImageProxy image, int rotationDegrees) { + Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees)); + final int tensorSize = Math.min(image.getWidth(), image.getHeight()); + if (mModule == null) { + Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'"); + mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * tensorSize * tensorSize); + mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{3, tensorSize, tensorSize}); + final String modelFileAbsoluteFilePath = + new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath(); + mModule = Module.load(modelFileAbsoluteFilePath); + } + + final long startTime = SystemClock.elapsedRealtime(); + fillInputTensorBuffer(image, rotationDegrees, mInputTensorBuffer); + + final long moduleForwardStartTime = SystemClock.elapsedRealtime(); + final IValue outputTuple = mModule.forward(IValue.listFrom(mInputTensor)); + final IValue out1 = outputTuple.toTuple()[1]; + final Map map = out1.toList()[0].toDictStringKey(); + + float[] boxesData = new float[]{}; + float[] scoresData = new float[]{}; + final List bboxes = new ArrayList<>(); + if (map.containsKey("boxes")) { + final Tensor boxesTensor = map.get("boxes").toTensor(); + final Tensor scoresTensor = map.get("scores").toTensor(); + boxesData = boxesTensor.getDataAsFloatArray(); + scoresData = scoresTensor.getDataAsFloatArray(); + final int n = scoresData.length; + for (int i = 0; i < n; i++) { + final BBox bbox = new BBox( + scoresData[i], + boxesData[4 * i + 0], + boxesData[4 * i + 1], + boxesData[4 * i + 2], + boxesData[4 * i + 3] + ); + android.util.Log.i(TAG, String.format("Forward result %d: %s", i, bbox)); + bboxes.add(bbox); + } + } else { + android.util.Log.i(TAG, "Forward result empty"); + } + + final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime; + final long analysisDuration = SystemClock.elapsedRealtime() - startTime; + return new Result(tensorSize, bboxes, moduleForwardDuration, analysisDuration); + } + + @UiThread + protected void handleResult(Result result) { + final int W = mCameraOverlay.getMeasuredWidth(); + final int H = mCameraOverlay.getMeasuredHeight(); + + final int size = Math.min(W, H); + final int offsetX = (W - size) / 2; + final int offsetY = (H - size) / 2; + + float scaleX = (float) size / result.tensorSize; + float scaleY = (float) size / result.tensorSize; + if (mBitmap == null) { + mBitmap = Bitmap.createBitmap(W, H, Bitmap.Config.ARGB_8888); + mCanvas = new Canvas(mBitmap); + } + + mCanvas.drawBitmap( + mInputTensorBitmap, + new Rect(0, 0, result.tensorSize, result.tensorSize), + new Rect(offsetX, offsetY, offsetX + size, offsetY + size), + null + ); + + for (final BBox bbox : result.bboxes) { + if (bbox.score < BBOX_SCORE_DRAW_THRESHOLD) { + continue; + } + + float c_x0 = offsetX + scaleX * bbox.x0; + float c_y0 = offsetY + scaleY * bbox.y0; + + float c_x1 = offsetX + scaleX * bbox.x1; + float c_y1 = offsetY + scaleY * bbox.y1; + + mCanvas.drawLine(c_x0, c_y0, c_x1, c_y0, mBboxPaint); + mCanvas.drawLine(c_x1, c_y0, c_x1, c_y1, mBboxPaint); + mCanvas.drawLine(c_x1, c_y1, c_x0, c_y1, mBboxPaint); + mCanvas.drawLine(c_x0, c_y1, c_x0, c_y0, mBboxPaint); + mCanvas.drawText(String.format("%.2f", bbox.score), c_x0, c_y0, mBboxPaint); + } + mCameraOverlay.setImageBitmap(mBitmap); + + String message = String.format("forwardDuration:%d", result.moduleForwardDuration); + Log.i(TAG, message); + mTextViewStringBuilder.insert(0, '\n').insert(0, message); + if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) { + mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length()); + } + mTextView.setText(mTextViewStringBuilder.toString()); + } +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java new file mode 100644 index 00000000000..1ee16a87a0a --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java @@ -0,0 +1,158 @@ +package org.pytorch.testapp; + +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.SystemClock; +import android.util.Log; +import android.widget.TextView; +import androidx.annotation.Nullable; +import androidx.annotation.UiThread; +import androidx.annotation.WorkerThread; +import androidx.appcompat.app.AppCompatActivity; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import org.pytorch.IValue; +import org.pytorch.Module; +import org.pytorch.PyTorchAndroid; +import org.pytorch.Tensor; + +import java.nio.FloatBuffer; +import java.util.Map; + +public class MainActivity extends AppCompatActivity { + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("pytorch_jni"); + NativeLoader.loadLibrary("torchvision_ops"); + } + + private static final String TAG = BuildConfig.LOGCAT_TAG; + private static final int TEXT_TRIM_SIZE = 4096; + + private TextView mTextView; + + protected HandlerThread mBackgroundThread; + protected Handler mBackgroundHandler; + private Module mModule; + private FloatBuffer mInputTensorBuffer; + private Tensor mInputTensor; + private StringBuilder mTextViewStringBuilder = new StringBuilder(); + + private final Runnable mModuleForwardRunnable = + new Runnable() { + @Override + public void run() { + final Result result = doModuleForward(); + runOnUiThread( + () -> { + handleResult(result); + if (mBackgroundHandler != null) { + mBackgroundHandler.post(mModuleForwardRunnable); + } + }); + } + }; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + mTextView = findViewById(R.id.text); + startBackgroundThread(); + mBackgroundHandler.post(mModuleForwardRunnable); + } + + protected void startBackgroundThread() { + mBackgroundThread = new HandlerThread(TAG + "_bg"); + mBackgroundThread.start(); + mBackgroundHandler = new Handler(mBackgroundThread.getLooper()); + } + + @Override + protected void onDestroy() { + stopBackgroundThread(); + super.onDestroy(); + } + + protected void stopBackgroundThread() { + mBackgroundThread.quitSafely(); + try { + mBackgroundThread.join(); + mBackgroundThread = null; + mBackgroundHandler = null; + } catch (InterruptedException e) { + Log.e(TAG, "Error stopping background thread", e); + } + } + + @WorkerThread + @Nullable + protected Result doModuleForward() { + if (mModule == null) { + final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE; + long numElements = 1; + for (int i = 0; i < shape.length; i++) { + numElements *= shape[i]; + } + mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements); + mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE); + PyTorchAndroid.setNumThreads(1); + mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); + } + + final long startTime = SystemClock.elapsedRealtime(); + final long moduleForwardStartTime = SystemClock.elapsedRealtime(); + final IValue outputTuple = mModule.forward(IValue.listFrom(mInputTensor)); + final IValue[] outputArray = outputTuple.toTuple(); + final IValue out0 = outputArray[0]; + final Map map = out0.toDictStringKey(); + if (map.containsKey("boxes")) { + final Tensor boxes = map.get("boxes").toTensor(); + final Tensor scores = map.get("scores").toTensor(); + final float[] boxesData = boxes.getDataAsFloatArray(); + final float[] scoresData = scores.getDataAsFloatArray(); + final int n = scoresData.length; + for (int i = 0; i < n; i++) { + android.util.Log.i(TAG, + String.format("Forward result %d: score %f box:(%f, %f, %f, %f)", + scoresData[i], + boxesData[4 * i + 0], + boxesData[4 * i + 1], + boxesData[4 * i + 2], + boxesData[4 * i + 3])); + } + } else { + android.util.Log.i(TAG, "Forward result empty"); + } + + final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime; + final long analysisDuration = SystemClock.elapsedRealtime() - startTime; + return new Result(new float[]{}, moduleForwardDuration, analysisDuration); + } + + static class Result { + + private final float[] scores; + private final long totalDuration; + private final long moduleForwardDuration; + + public Result(float[] scores, long moduleForwardDuration, long totalDuration) { + this.scores = scores; + this.moduleForwardDuration = moduleForwardDuration; + this.totalDuration = totalDuration; + } + } + + @UiThread + protected void handleResult(Result result) { + String message = String.format("forwardDuration:%d", result.moduleForwardDuration); + mTextViewStringBuilder.insert(0, '\n').insert(0, message); + if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) { + mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length()); + } + mTextView.setText(mTextViewStringBuilder.toString()); + } +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java b/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java new file mode 100644 index 00000000000..ed7ebd006cd --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java @@ -0,0 +1,17 @@ +package org.pytorch.testapp; + +import java.util.List; + +class Result { + public final int tensorSize; + public final List bboxes; + public final long totalDuration; + public final long moduleForwardDuration; + + public Result(int tensorSize, List bboxes, long moduleForwardDuration, long totalDuration) { + this.tensorSize = tensorSize; + this.bboxes = bboxes; + this.moduleForwardDuration = moduleForwardDuration; + this.totalDuration = totalDuration; + } +} diff --git a/android/test_app/app/src/main/res/layout/activity_camera.xml b/android/test_app/app/src/main/res/layout/activity_camera.xml new file mode 100644 index 00000000000..7ba2e42b7c0 --- /dev/null +++ b/android/test_app/app/src/main/res/layout/activity_camera.xml @@ -0,0 +1,28 @@ + + + + + + + + + diff --git a/android/test_app/app/src/main/res/layout/activity_main.xml b/android/test_app/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000000..c0939ebc0eb --- /dev/null +++ b/android/test_app/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,17 @@ + + + + + + \ No newline at end of file diff --git a/android/test_app/app/src/main/res/layout/texture_view.xml b/android/test_app/app/src/main/res/layout/texture_view.xml new file mode 100644 index 00000000000..6518c6c84c6 --- /dev/null +++ b/android/test_app/app/src/main/res/layout/texture_view.xml @@ -0,0 +1,5 @@ + + diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png new file mode 100644 index 00000000000..64ba76f75e9 Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png differ diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png new file mode 100644 index 00000000000..dae5e082342 Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png differ diff --git a/android/test_app/app/src/main/res/values/colors.xml b/android/test_app/app/src/main/res/values/colors.xml new file mode 100644 index 00000000000..69b22338c65 --- /dev/null +++ b/android/test_app/app/src/main/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/android/test_app/app/src/main/res/values/strings.xml b/android/test_app/app/src/main/res/values/strings.xml new file mode 100644 index 00000000000..cafbaad1511 --- /dev/null +++ b/android/test_app/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + TV_FRCNN + diff --git a/android/test_app/app/src/main/res/values/styles.xml b/android/test_app/app/src/main/res/values/styles.xml new file mode 100644 index 00000000000..5885930df6d --- /dev/null +++ b/android/test_app/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py new file mode 100644 index 00000000000..7860c759a57 --- /dev/null +++ b/android/test_app/make_assets.py @@ -0,0 +1,17 @@ +import torch +import torchvision +from torch.utils.mobile_optimizer import optimize_for_mobile + +print(torch.__version__) + +model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( + pretrained=True, + box_score_thresh=0.7, + rpn_post_nms_top_n_test=100, + rpn_score_thresh=0.4, + rpn_pre_nms_top_n_test=150) + +model.eval() +script_model = torch.jit.script(model) +opt_script_model = optimize_for_mobile(script_model) +opt_script_model.save("app/src/main/assets/frcnn_mnetv3.pt") diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index baad319e7c0..1b75de4a754 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -1,6 +1,8 @@ #include "vision.h" +#ifndef MOBILE #include +#endif #include #ifdef WITH_CUDA