Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/prims/pkg.generated.mbti
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Generated using `moon info`, DON'T EDIT IT
package "FlyCloudC/autodiff/prims"

import(
"FlyCloudC/autodiff/tape"
)

// Values

// Errors

// Types and methods
pub(all) struct AllPrims[A] {
neg : (@tape.Loc[A]) -> @tape.Loc[A]
exp : (@tape.Loc[A]) -> @tape.Loc[A]
sin : (@tape.Loc[A]) -> @tape.Loc[A]
cos : (@tape.Loc[A]) -> @tape.Loc[A]
ln : (@tape.Loc[A]) -> @tape.Loc[A]
add : (@tape.Loc[A], @tape.Loc[A]) -> @tape.Loc[A]
sub : (@tape.Loc[A], @tape.Loc[A]) -> @tape.Loc[A]
mul : (@tape.Loc[A], @tape.Loc[A]) -> @tape.Loc[A]
div : (@tape.Loc[A], @tape.Loc[A]) -> @tape.Loc[A]
}
pub fn[A : Number] AllPrims::on(@tape.Tape[A]) -> Self[A]

// Type aliases

// Traits
pub(open) trait Number : @tape.Diffable + Add + Mul + Neg + Sub + Div {
exp(Self) -> Self
ln(Self) -> Self
sin(Self) -> Self
cos(Self) -> Self
}
pub impl Number for Float
pub impl Number for Double

4 changes: 2 additions & 2 deletions src/prims/test.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ test "tape" {
// Now the instructions are recorded on the tape
inspect(
tape.dump(pad_1=2, pad_2=3),
content=
content=(
#|x0= 2
#|x1= 5
#|x2= - x0
#|x3= * x0 x1
#|x4= + x2 x3
#|
,
),
)

// Eval
Expand Down
6 changes: 3 additions & 3 deletions src/prims/types.mbt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
///|
typealias (A) -> A as Op1[A]
type Op1[A] = (A) -> A

///|
typealias (A, A) -> A as Op2[A]
type Op2[A] = (A, A) -> A

///|
typealias @tape.(Tape, Loc)
using @tape {type Tape, type Loc}

///|
pub(open) trait Number: @tape.Diffable + Add + Mul + Neg + Sub + Div {
Expand Down
10 changes: 5 additions & 5 deletions src/tape/build.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ pub fn[A] constant(x : A) -> Loc[A] {
}

///|
pub fn[A] variable(self : Tape[A], x : A) -> Loc[A] {
pub fn[A] Tape::variable(self : Tape[A], x : A) -> Loc[A] {
let { insts, names } = self
insts.push(Var(x))
names.push("var")
Memory(insts.length() - 1)
}

///|
pub fn[A] op1(
pub fn[A] Tape::op1(
self : Tape[A],
name : String,
op : Op1[A],
diff : Op1[A]
diff : Op1[A],
) -> Op1[Loc[A]] {
x => {
let { insts, names } = self
Expand All @@ -32,12 +32,12 @@ pub fn[A] op1(
}

///|
pub fn[A] op2(
pub fn[A] Tape::op2(
self : Tape[A],
name : String,
op : Op2[A],
diff_l : Op2[A],
diff_r : Op2[A]
diff_r : Op2[A],
) -> Op2[Loc[A]] {
(l, r) => {
let { insts, names } = self
Expand Down
4 changes: 2 additions & 2 deletions src/tape/debug_utils.mbt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
///|
pub fn[A : Show] Tape::dump(
self : Tape[A],
pad_1~ : Int = 4,
pad_2~ : Int = 4
pad_1? : Int = 4,
pad_2? : Int = 4,
) -> String {
let loc_name = (loc : Loc[A]) => match loc {
Const(x) => x.to_string()
Expand Down
12 changes: 6 additions & 6 deletions src/tape/eval_and_diff.mbt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
///|
pub fn[A] at(self : Loc[A], mem : Array[A]) -> A {
pub fn[A] Loc::at(self : Loc[A], mem : Array[A]) -> A {
match self {
Const(x) => x
Memory(i) => mem[i]
}
}

///|
pub fn[A] eval(self : Tape[A]) -> Array[A] {
pub fn[A] Tape::eval(self : Tape[A]) -> Array[A] {
let insts = self.insts
// m is memory
// m[i] = x_i
Expand All @@ -26,10 +26,10 @@ pub fn[A] eval(self : Tape[A]) -> Array[A] {
}

///|
pub fn[A : Diffable] diff_forward(
pub fn[A : Diffable] Tape::diff_forward(
self : Tape[A],
m : Array[A],
wrt~ : Int = 0
wrt? : Int = 0,
) -> Array[A] {
let insts = self.insts
// d is diff memory.
Expand Down Expand Up @@ -60,10 +60,10 @@ pub fn[A : Diffable] diff_forward(
}

///|
pub fn[A : Diffable] diff_backward(
pub fn[A : Diffable] Tape::diff_backward(
self : Tape[A],
m : Array[A],
wrt~ : Int = -1
wrt? : Int = -1,
) -> Array[A] {
let insts = self.insts
let wrt = if wrt < 0 { insts.length() + wrt } else { wrt }
Expand Down
45 changes: 45 additions & 0 deletions src/tape/pkg.generated.mbti
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Generated using `moon info`, DON'T EDIT IT
package "FlyCloudC/autodiff/tape"

// Values
pub fn[A] constant(A) -> Loc[A]

// Errors

// Types and methods
pub(all) enum Inst[A] {
Var(A)
App1((A) -> A, Loc[A], diff~ : (A) -> A)
App2((A, A) -> A, Loc[A], Loc[A], diff_l~ : (A, A) -> A, diff_r~ : (A, A) -> A)
}

pub enum Loc[A] {
Const(A)
Memory(Int)
}
pub fn[A] Loc::at(Self[A], Array[A]) -> A
pub impl[A : Show] Show for Loc[A]

pub(all) struct Tape[A] {
insts : Array[Inst[A]]
names : Array[String]
}
pub fn[A : Diffable] Tape::diff_backward(Self[A], Array[A], wrt? : Int) -> Array[A]
pub fn[A : Diffable] Tape::diff_forward(Self[A], Array[A], wrt? : Int) -> Array[A]
pub fn[A : Show] Tape::dump(Self[A], pad_1? : Int, pad_2? : Int) -> String
pub fn[A] Tape::eval(Self[A]) -> Array[A]
pub fn[A] Tape::new() -> Self[A]
pub fn[A] Tape::op1(Self[A], String, (A) -> A, (A) -> A) -> (Loc[A]) -> Loc[A]
pub fn[A] Tape::op2(Self[A], String, (A, A) -> A, (A, A) -> A, (A, A) -> A) -> (Loc[A], Loc[A]) -> Loc[A]
pub fn[A] Tape::variable(Self[A], A) -> Loc[A]

// Type aliases

// Traits
pub(open) trait Diffable : Add + Mul {
zero() -> Self
one() -> Self
}
pub impl Diffable for Float
pub impl Diffable for Double

4 changes: 2 additions & 2 deletions src/tape/types.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ pub enum Loc[A] {
} derive(Show)

///|
typealias (A) -> A as Op1[A]
type Op1[A] = (A) -> A

///|
typealias (A, A) -> A as Op2[A]
type Op2[A] = (A, A) -> A

///|
pub(all) enum Inst[A] {
Expand Down