Like any good DL library, we organize our networks into Modules. Here is the module trait:

/// A module with a forward pass
pub trait Module<I> {
    type Output;
    fn forward(&self, input: I) -> Self::Output;
}

Super simple, we just define a forward function that takes an input and returns an output. A consequence of this is it allows us to define seperate forward passes for single and batched inputs!

Now let’s take a look at how Linear is defined:

/// A simple linear layer
pub struct Linear<const A: usize, const B: usize> {
    pub(crate) weight: GraphTensor<R2<A, B>>,
}

impl<const A: usize, const B: usize> Module<GraphTensor<R1<A>>> for Linear<A, B> {
    type Output = GraphTensor<R1<B>>;

    fn forward(&self, input: GraphTensor<R1<A>>) -> Self::Output {
        input.matmul(self.weight)
    }
}

Here we see a single weight matrix as the internal state, of size AxB. We’ve written a single forward function for single input vectors of shape (A,) and matmul it by our weight matrix to get an output of shape (B,).

Now all of these ops are recorded on the graph, to be compiled and ran later on.