Modules

Storch provides a neural network module API with building blocks for creating stateful neural network architectures.

Simple custom module example

import torch.*
import torch.nn
import torch.nn.functional as F

class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module:
  val conv1 = register(nn.Conv2d(1, 6, 5))
  val conv2 = register(nn.Conv2d(6, 16, 5))
  val fc1 = register(nn.Linear(16 * 4 * 4, 120))
  val fc2 = register(nn.Linear(120, 84))
  val fc3 = register(nn.Linear(84, 10))

  def apply(i: Tensor[D]): Tensor[D] =
    var x = F.maxPool2d(F.relu(conv1(i)), (2, 2))
    x = F.maxPool2d(F.relu(conv2(x)), 2)
    x = x.view(-1, 16 * 4 * 4)
    x = F.relu(fc1(x))
    x = F.relu(fc2(x))
    x = fc3(x)
    x
val model = LeNet()
// model: LeNet[Float32] = LeNet
val input = torch.rand(Seq(1, 1, 28, 28))
// input: Tensor[Float32] = tensor dtype=float32, shape=[1, 1, 28, 28], device=CPU 
// [[[[0.0239, 0.1741, 0.5066, ..., 0.9512, 0.8070, 0.3739],
//    [0.5625, 0.4165, 0.1153, ..., 0.3510, 0.0086, 0.2082],
//    [0.8317, 0.0963, 0.1241, ..., 0.4985, 0.4273, 0.7526],
//    ...,
//    [0.1227, 0.3874, 0.4374, ..., 0.4191, 0.5043, 0.0511],
//    [0.9404, 0.1533, 0.4236, ..., 0.5938, 0.1366, 0.9904],
//    [0.7634, 0.4054, 0.2530, ..., 0.6332, 0.9523, 0.4602]]]]
model(input)
// res1: Tensor[Float32] = tensor dtype=float32, shape=[1, 10], device=CPU 
// [[0.0094, 0.0280, 0.0906, ..., 0.1166, -0.0993, 0.0102]]