torch.nn.modules.flatten

Members list

Grouped members

nn_flatten

final class Flatten[D <: DType](startDim: Int, endDim: Int)(implicit evidence$1: Default[D]) extends TensorModule[D]

Flattens a contiguous range of dims into a tensor. For use with nn.Sequential.

Flattens a contiguous range of dims into a tensor. For use with nn.Sequential.

Shape: - Input: $(, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, )$,' where $S_{i}$ is the size at dimension $i$ and $$ means any number of dimensions including none. - Output: $(, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)$.

Example:

import torch.nn

val input = torch.randn(Seq(32, 1, 5, 5))
// With default parameters
val m1 = nn.Flatten()
// With non-default parameters
val m2 = nn.Flatten(0, 2)

Value parameters

endDim

last dim to flatten

startDim

first dim to flatten

Attributes

Source
Flatten.scala
Supertypes
trait TensorModule[D]
trait Tensor[D] => Tensor[D]
class Module
class Object
trait Matchable
class Any
Show all