This document provides detailed API documentation for key components in the PyTorch Android Mobile Application.
The main activity that handles image selection and classification.
public class MainActivity extends AppCompatActivitypublic static final int PICK_IMAGE = 1Request code for image picker intent.
Uri selectedImageURI of the currently selected image.
Module module = nullPyTorch module instance for model inference.
@Override
protected void onCreate(Bundle savedInstanceState)Description: Initializes the activity and sets up UI components.
Parameters:
savedInstanceState- Bundle containing saved state
Behavior:
- Sets up the content view
- Initializes button listeners
- Configures permission handlers
public static String assetFilePath(Context context, String assetName) throws IOExceptionDescription: Copies an asset file to internal storage and returns its path.
Parameters:
context- Application contextassetName- Name of the asset file (e.g., "model.pt")
Returns: Absolute file path to the copied asset
Throws: IOException if file operations fail
Example:
String modelPath = assetFilePath(this, "model.pt");
Module module = Module.load(modelPath);private void pickFromGallery()Description: Launches the system image picker.
Behavior:
- Creates an ACTION_PICK intent
- Filters for JPEG and PNG images
- Starts activity for result
Example Flow:
User taps button → pickFromGallery() → System Gallery → onActivityResult()
@Override
public void onActivityResult(int requestCode, int resultCode, Intent data)Description: Handles the result from image picker and performs classification.
Parameters:
requestCode- Request code (should bePICK_IMAGE)resultCode- Result code from the pickerdata- Intent containing selected image URI
Workflow:
- Load PyTorch model
- Get selected image URI
- Display image in ImageView
- Preprocess image (resize, normalize)
- Convert to tensor
- Run inference
- Find top prediction
- Display class name
Bitmap bitmap = ((BitmapDrawable)imageView.getDrawable()).getBitmap();Bitmap resized = Bitmap.createScaledBitmap(bitmap, 320, 320, false);Parameters:
bitmap- Source bitmapwidth- Target width (320)height- Target height (320)filter- Whether to filter (false for faster processing)
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
);Normalization Values:
- Mean RGB: [0.485, 0.456, 0.406]
- Std RGB: [0.229, 0.224, 0.225]
Module module = Module.load(assetFilePath(context, "model.pt"));Description: Loads a TorchScript model from the assets folder.
Parameters:
modelPath- Absolute path to the model file
Returns: Module instance ready for inference
Best Practices:
- Load model once and reuse
- Load asynchronously to avoid blocking UI
- Handle loading errors gracefully
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();Description: Performs forward pass through the model.
Parameters:
inputTensor- Preprocessed input tensor (1, 3, 320, 320)
Returns: Output tensor with class scores (1, 1000)
Performance:
- Average time: 50-100ms
- Device dependent
- No GPU acceleration on most devices
float[] scores = outputTensor.getDataAsFloatArray();
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];Description: Extracts the top prediction from model output.
Output Format:
- Array of 1000 float values
- Values are logits (pre-softmax)
- Higher value = higher confidence
Improvements:
- Add softmax for probability scores
- Return top-K predictions
- Add confidence threshold
public class ImageNetClasses {
public static final String[] IMAGENET_CLASSES = {
"tench",
"goldfish",
"great white shark",
// ... 997 more classes
};
}Description: Array containing all 1000 ImageNet class labels.
Usage:
String label = ImageNetClasses.IMAGENET_CLASSES[predictedIndex];Shape: [1, 3, 320, 320]
Type: Float32
Format: NCHW (Batch, Channels, Height, Width)
Range: [-2.64, 2.64] (after normalization)
Shape: [1, 1000]
Type: Float32
Format: Logits (pre-softmax scores)
Range: Unbounded (typically -20 to 20)
// Future enhancement
List<Bitmap> images = loadImages();
Tensor batchTensor = createBatchTensor(images);
Tensor output = module.forward(IValue.from(batchTensor)).toTensor();public List<Prediction> getTopK(float[] scores, int k) {
// Sort scores and return top K predictions
List<Prediction> results = new ArrayList<>();
// Implementation here
return results;
}public float[] softmax(float[] logits) {
float[] probabilities = new float[logits.length];
float sum = 0.0f;
for (float logit : logits) {
sum += Math.exp(logit);
}
for (int i = 0; i < logits.length; i++) {
probabilities[i] = (float) (Math.exp(logits[i]) / sum);
}
return probabilities;
}try {
module = Module.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("ModelError", "Failed to load model", e);
Toast.makeText(this, "Model loading failed", Toast.LENGTH_SHORT).show();
}try {
// Inference code
} catch (OutOfMemoryError e) {
System.gc();
Log.e("MemoryError", "Out of memory during inference", e);
}if (bitmap == null) {
Toast.makeText(this, "Invalid image format", Toast.LENGTH_SHORT).show();
return;
}@Test
public void testImagePreprocessing() {
Bitmap testImage = createTestBitmap(320, 320);
Tensor tensor = preprocessImage(testImage);
assertEquals(1, tensor.shape()[0]); // Batch size
assertEquals(3, tensor.shape()[1]); // Channels
assertEquals(320, tensor.shape()[2]); // Height
assertEquals(320, tensor.shape()[3]); // Width
}@Test
public void testModelInference() {
Module module = loadModel();
Bitmap testImage = loadTestImage();
Tensor input = preprocessImage(testImage);
Tensor output = module.forward(IValue.from(input)).toTensor();
assertNotNull(output);
assertEquals(1000, output.shape()[1]);
}public String classifyImage(Bitmap bitmap) {
try {
// 1. Load model (do this once)
if (module == null) {
module = Module.load(assetFilePath(context, "model.pt"));
}
// 2. Preprocess
Bitmap resized = Bitmap.createScaledBitmap(bitmap, 320, 320, false);
Tensor input = TensorImageUtils.bitmapToFloat32Tensor(
resized,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
);
// 3. Inference
Tensor output = module.forward(IValue.from(input)).toTensor();
// 4. Get result
float[] scores = output.getDataAsFloatArray();
int maxIdx = findMaxIndex(scores);
return ImageNetClasses.IMAGENET_CLASSES[maxIdx];
} catch (Exception e) {
Log.e("Classification", "Error during inference", e);
return "Error";
}
}Last Updated: 2024-2025
API Version: 1.0.0