Installation

As Storch is still in an early stage of development, there are no releases available yet. We provide snapshots built from main though, to make it easier to already try Storch.

Storch requires at least Scala 3.3

To use the snapshots, add a resolver for the sonatype snapshots repository and then add Storch as a dependency to your project:

resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
  "dev.storch" %% "core" % "0.0-2dfa388-SNAPSHOT",
)
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:0.0-2dfa388-SNAPSHOT"

Adding the native PyTorch libraries

To use Storch, we also need to depend on the native PyTorch libraries (LibTorch), which are provided by the JavaCPP project, as part of their autogenerated Java bindings. There are multiple ways to add the native libraries.

Why doesn't Storch just depend on the native PyTorch libraries itself?

Because these are native C++ libraries, which are different for each operating system and architecture. Furthermore, there are variants for CPU and GPUs which are incompatible with each other. We don't want to force users on Storch to use one variant over another, so you'll have to add the native dependency yourself.

Via PyTorch platform

The easiest and most portable way to depend on the native library is via the PyTorch platform dependency:

resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
  "dev.storch" %% "core" % "0.0-2dfa388-SNAPSHOT",
  "org.bytedeco" % "pytorch-platform" % "2.1.2-1.5.10"
)
fork := true
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:0.0-2dfa388-SNAPSHOT"
//> using lib "org.bytedeco:pytorch-platform:2.1.2-1.5.10"

There is one downside to this approach. Because pytorch-platform depends on the native libraries for all supported platforms, it will download and cache all these libraries, no matter on which platform you actually are.

One way to avoid the overhead is to explicitly depend on the native libraries for your platform instead of using pytorch-platform. As of JavaCPP 1.5.10 the platform approach also doesn't work for macosx-arm64 because the native dependencies are missing in the current version of pytorch-platform (should be fixed in the next release).

Via classifier

This can be done by providing dependency classifiers specifically for your platform. Currently supported are linux-x86_64, macosx-x86_64, macosx-arm64 and windows-x86_64.

resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
  "dev.storch" %% "core" % "0.0-2dfa388-SNAPSHOT",
  "org.bytedeco" % "pytorch" % "2.1.2-1.5.10",
  "org.bytedeco" % "pytorch" % "2.1.2-1.5.10" classifier "linux-x86_64",
  "org.bytedeco" % "openblas" % "0.3.26-1.5.10" classifier "linux-x86_64"
)
fork := true
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:0.0-2dfa388-SNAPSHOT"
//> using lib "org.bytedeco:openblas:0.3.26-1.5.10,classifier=linux-x86_64"
//> using lib "org.bytedeco:pytorch:2.1.2-1.5.10,classifier=linux-x86_64"

Now we're only downloading the native libraries for a single platform. The downside though is that the build is not portable anymore. Fortunately for sbt and Gradle, there's a solution available as a build plugin.

Automatically detect your platform

The SBT-JavaCPP will automatically detect the current platform and set the right classifier.

project/plugins.sbt:

addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17")

build.sbt:

resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
  "dev.storch" %% "core" % "0.0-2dfa388-SNAPSHOT",
)
javaCppPresetLibs ++= Seq("pytorch" -> "2.1.2", "openblas" -> "0.3.26")
fork := true

If you're using Gradle, you can use the Gradle JavaCPP plugin to do the same.

Enable GPU support

Storch supports GPU accelerated tensor operations for Nvidia GPUs via CUDA. JavaCPP also provides matching CUDA toolkit distribution including cuDNN, helping you to avoid having to mess with local CUDA installations.

Via PyTorch platform

resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
  "dev.storch" %% "core" % "0.0-2dfa388-SNAPSHOT",
  "org.bytedeco" % "pytorch-platform-gpu" % "2.1.2-1.5.10",
  "org.bytedeco" % "cuda-platform-redist" % "12.3-8.9-1.5.10"
)
fork := true
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:0.0-2dfa388-SNAPSHOT"
//> using lib "org.bytedeco:pytorch-platform-gpu:2.1.2-1.5.10"
//> using lib "org.bytedeco:cuda-platform-redist:12.3-8.9-1.5.10"

This approach should work on any platform with CUDA support (Linux and Windows) but it causes even more overhead than the CPU variants as CUDA is quite large. So, to save space and bandwidth you can use classifiers, again at the expense of portability.

Via classifier

This can be done by providing dependency classifiers specifically for your platform. Currently supported are linux-x86_64-gpu and windows-x86_64-gpu.

resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
  "dev.storch" %% "core" % "0.0-2dfa388-SNAPSHOT",
  "org.bytedeco" % "pytorch" % "2.1.2-1.5.10",
  "org.bytedeco" % "pytorch" % "2.1.2-1.5.10" classifier "linux-x86_64-gpu",
  "org.bytedeco" % "openblas" % "0.3.26-1.5.10" classifier "linux-x86_64",
  "org.bytedeco" % "cuda" % "12.3-8.9-1.5.10",
  "org.bytedeco" % "cuda" % "12.3-8.9-1.5.10" classifier "linux-x86_64",
  "org.bytedeco" % "cuda" % "12.3-8.9-1.5.10" classifier "linux-x86_64-redist"
)
fork := true
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:0.0-2dfa388-SNAPSHOT"
//> using lib "org.bytedeco:pytorch:2.1.2-1.5.10,classifier=linux-x86_64-gpu"
//> using lib "org.bytedeco:openblas:0.3.26-1.5.10,classifier=linux-x86_64"
//> using lib "org.bytedeco:cuda:12.3-8.9-1.5.10,classifier=linux-x86_64-redist"

Automatically detect your platform

The SBT-JavaCPP also works with the GPU variant by adding pytorch-gpu instead of pytorch to javaCppPresetLibs.

project/plugins.sbt:

addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17")

build.sbt:

resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
  "dev.storch" %% "core" % "0.0-2dfa388-SNAPSHOT",
)
javaCppPresetLibs ++= Seq("pytorch-gpu" -> "2.1.2", "openblas" -> "0.3.26", "cuda-redist" -> "12.3-8.9")
fork := true

Running tensor operations on the GPU

You can create tensors directly on the GPU:

import torch.Device.{CPU, CUDA}
val device = if torch.cuda.isAvailable then CUDA else CPU
// device: Device = Device(device = CUDA, index = -1)
torch.rand(Seq(3,3), device=device)
// res1: Tensor[Float32] = dtype=float32, shape=[3, 3], device=CUDA 
// [[0.3990, 0.5167, 0.0249],
//  [0.9401, 0.9459, 0.7967],
//  [0.4150, 0.8203, 0.2290]]

// Use device index if you have multiple GPUs
torch.rand(Seq(3,3), device=torch.Device(torch.DeviceType.CUDA, 0: Byte))
// res2: Tensor[Float32] = dtype=float32, shape=[3, 3], device=CUDA 
// [[0.9722, 0.7910, 0.4690],
//  [0.3300, 0.3345, 0.3783],
//  [0.7640, 0.6405, 0.1103]]

Or move them from the CPU:

val cpuTensor = torch.Tensor(Seq(1,2,3))
// cpuTensor: Tensor[Int32] = dtype=int32, shape=[3], device=CPU 
// [1, 2, 3]
val gpuTensor = cpuTensor.to(device=device)
// gpuTensor: Tensor[Int32] = dtype=int32, shape=[3], device=CUDA 
// [1, 2, 3]