About

Storch is a Scala library for fast tensor computations and deep learning, based on PyTorch.

Like PyTorch, Storch provides

Storch aims to stay close to the original PyTorch API to make porting existing models and the life of people already familiar with PyTorch easier.

val data = Seq(0,1,2,3)
// data: Seq[Int] = List(0, 1, 2, 3)
val t1 = torch.Tensor(data)
// t1: Tensor[Int32] = tensor dtype=int32, shape=[4], device=CPU 
// [0, 1, 2, 3]
t1.equal(torch.arange(0,4))
// res1: Boolean = true
val t2 = t1.to(dtype=torch.float32)
// t2: Tensor[Float32] = tensor dtype=float32, shape=[4], device=CPU 
// [0.0000, 1.0000, 2.0000, 3.0000]
val t3 = t1 + t2
// t3: Tensor[Float32] = tensor dtype=float32, shape=[4], device=CPU 
// [0.0000, 2.0000, 4.0000, 6.0000]

val shape = Seq(2,3)
// shape: Seq[Int] = List(2, 3)
val randTensor = torch.rand(shape)
// randTensor: Tensor[Float32] = tensor dtype=float32, shape=[2, 3], device=CPU 
// [[0.4963, 0.7682, 0.0885],
//  [0.1320, 0.3074, 0.6341]]
val zerosTensor = torch.zeros(shape, dtype=torch.int64)
// zerosTensor: Tensor[Int64] = tensor dtype=int64, shape=[2, 3], device=CPU 
// [[0, 0, 0],
//  [0, 0, 0]]

val x = torch.ones(Seq(5))
// x: Tensor[Float32] = tensor dtype=float32, shape=[5], device=CPU 
// [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]
val w = torch.randn(Seq(5, 3), requiresGrad=true)
// w: Tensor[Float32] = tensor dtype=float32, shape=[5, 3], device=CPU 
// [[1.2645, -0.6874, 0.1604],
//  [-0.6065, -0.7831, 1.0622],
//  [-0.2613, 1.0667, 0.4159],
//  [0.8396, -0.8265, -0.7949],
//  [-0.9528, 0.3717, 0.4087]]
val b = torch.randn(Seq(3), requiresGrad=true)
// b: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CPU 
// [1.4214, 0.1494, -0.6709]
val z = (x matmul w) + b
// z: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CPU 
// [1.7048, -0.7091, 0.5814]

One notable difference is that tensors in Storch are statically typed regarding the underlying dtype. So you'll see Tensor[Float32] or Tensor[Int8] instead of just Tensor.

Tracking the data type at compile time enables us to catch certain errors earlier. For instance, torch.rand is only implemented for float types and the following will trigger a runtime error in PyTorch:

torch.rand([3,3], dtype=torch.int32) # RuntimeError: "check_uniform_bounds" not implemented for 'Int'

In Storch, the same code does not compile:

torch.rand(Seq(3,3), dtype=torch.int32)
// error:
// Found:    (torch.int32 : torch.Int32)
// Required: torch.FloatNN | torch.ComplexNN
// torch.rand(Seq(3,3), dtype=torch.int32)
//                            ^^^^^^^^^^^

Storch is powered by LibTorch, the C++ library underlying PyTorch and JavaCPP, which provides generated Java bindings for LibTorch as well as important utilities to integrate with native code on the JVM.