# Anatomy of Partial Evaluation in a Deep Learning DSL

In recent weeks, I’ve contributed a prototype module that implemented *partial evaluation* (PE) in a deep learning DSL for encoding neural networks. The implementation is very intricate. Thus I consider a blog post would be a proper place to discuss the relevant code, design, and examples. Hope that can help you learn about implementing eDSL in Haskell, as well as the principles about partial evaluation.

## DeepDarkFantasy

A friend of mine, Marisa Kirisame, is the creator of this project. Initially, it is `DeepLearning.scala`

, but later on the author went back to school and developed a Haskell version.

The idea is quite simple. When designing neural network (NN) layers, we need to know how to adjust the edge weights according to the change of loss. That is, we need to know the derivative of loss w.r.t multiple weights. I recommend another tutorial here if you are not already familiar with this.

More materials:

- Deep Learning, from a Programming Language Perspective
- Neural Networks, Types, and Functional Programming
- Talk on LambdaConf

## PE in a De Brujin indexed, final tagless eDSL

Honestly speaking, I didn’t do much theoretical innovation up to now, and my work in DDF is heavily based on the tutorial by Oleg. But I found his tutorial too contrived, more or less like:

Here is how to draw a circle, got it? Let’s draw the Mona Lisa now!

In this post, I will try to show more examples, and be more verbose about details.

Second, while Oleg’s interpretation is already quite beautiful, I believe that another interpretation in a slightly different scenario will be useful for more people to understand the material. This will give you an idea about the common considerations when implementing PE for certain language.

## Basics

In this section, I am going to talk about the de-brujin-index based, tagless approach towards a more extensible eDSL. Let’s take `DDF.DBI`

as an example:

```
class DBI (r :: * -> * -> *) where
z :: r (a, h) a
s :: r h b -> r (a, h) b
abs :: r (a, h) b -> r h (a -> b)
app :: r h (a -> b) -> r h a -> r h b
```

Here, `DBI`

encodes *what is needed* to be a typed lambda calculus embedded inside the host language. This might look confusing to you now, so let’s instead go for the hello-world version of a deeply embedded DSL first:

```
data Expr = EVar String
| EAbs String Expr
| EApp Expr Expr
```

An example is `EApp (EAbs "f" "f") (EAbs "x" "x")`

, or $(\lambda f. f) (\lambda x. x)$, which will evaluate to $(\lambda x. x)$ under the classical lambda calculus semantics.

Next, we are going to write two interpreters: a pretty-printer and an evaluator for this language.

```
prettyAST :: Expr -> String
evalAST :: Expr -> Maybe Expr
```

In each interpreter, we need to do case-analyze on `e`

. So, if we add a new case, every interpreter needs to be updated.

Note: Another way to implement this is through *type class*, for example:

```
class Pretty a where
pretty :: a -> String
```

In this way, we need less names when we have a complex hierarchy: `Expr`

over `Term`

, `Stmt`

over `Expr`

etc. But the extensibility problem is still not solved.

We hope there is a way to make syntax *and* semantics highly composable. Let’s think about a language that is *composed* of different smaller languages, possibly forming a dependency graph. You can choose and pick a subset of it, so long as the dependency is correct (which usually means that your final interpreter implements whatever each smaller language requires). Second, each syntax is associated with its own semantics, either being implemented by the interpreter, or built on top of depended language primitives.

```
F B E
\ | /
A
/ \
C D
```

For the above example, when you want to instantiate the language `B`

, you just needs to define your “ground rules” for `C`

and `D`

. `A`

’s derivation is automatic, since its instance is defined w.r.t `C`

and `D`

.

So here are two things:

- Orthogonal composition of AST (and its defined semantics)
- No more boilerplate code for the upper-level derivation (like
`A`

’s).

To do so, we must have a flexible enough representation, on which we can require that it contains certain some sub-representations, rather then fixing everything from the very beginning.

Another advantage of this encoding, is the ability to piggyback on the host language’s type system. For example:

```
abs :: r (a, h) b -> r h (a -> b)
```

This means you can’t use arbitrary sub-expression as the body of a lambda expression, unless it really has the right type for being the body of a lambda expression (in this case, it should have a free variable of the same type as the lambda’s binder).

But other parts of the DSL look weird: that’s because we encode de-brujin indexing on the type-level. It’s like, we implemented the rules of de-brujin encoding using the type system of the host language, which enforces the type environment checking for us. It is all static – if the host program compiles, then the guest language compiles. In contrast, the STLC example in Pierce’s TAPL has a compiler that checks types at the *runtime*.

Let’s explain each component in details:

```
z :: r (a, h) a
s :: r h b -> r (a, h) b
```

For arbitrary outer environment `h`

, we use `z`

to represent the variable bound to the closest binder (index 0). For example, in $\lambda x . \lambda y. y + x$, $y$ in $y + x$ binds to the closest binder, so we encode it using `z`

, and `h`

would be `(ty, (tx, h'))`

for arbitrary `h'`

. Another variable `x`

is bound to the second closest binder, so we simply encode it with type level pair: `s z :: r (ty, (tx, h')) tx`

.

Below is a workable snippet:

```
t1 :: forall r h. Double r => r (M.Double, (M.Double, h)) M.Double
t1 = app2 doublePlus (z :: r (M.Double, (M.Double, h)) M.Double)
(s (z :: r (M.Double, h) M.Double)
:: r (M.Double, (M.Double, h)) M.Double)
```

For abstraction:

```
abs :: r (a, h) b -> r h (a -> b)
```

Under the de brujin encoding, we don’t have to give names to the binder, so the `abs`

only takes the lambda body expression as the sole parameter. To make sense of it type parameter, we provide another example:

```
id = abs (z :: r (a, h) a) :: r h (a -> a)
fxy = abs (abs (z
:: r (ty, (tx, h)) ty)
:: r (tx, h) (ty -> ty))
:: r h (tx -> (ty -> ty))
fxy' = abs (abs (s (z
:: r (tx, h) tx)
:: r (ty, (tx, h)) tx)
:: r (tx, h) (ty -> tx))
:: r h (tx -> (ty -> tx))
```

Note the difference of types between two `z`

in the above `fxy`

and `fxy'`

.

`app`

is much more straight-forward.

```
app :: r h (a -> b) -> r h a -> r h b
```

## Going through the hard part

Now that we have a rough idea about how such advanced AST construct works, we can consider its *partial evaluation*. You might hear about it in your *Compiler* course, in which PE is one of the common techniques for code optimization. For example, in your C code, you can write an expression like `1 + 2`

. The PE phase can automatically detect such statically reducible expressions, and evaluate it at compile-time, emitting the `3`

rather than the more expensive equivalence `1 + 2`

. This process is called **partial**, since 99% programs have unknown input parameters, like `x + 1`

, and we can’t evaluate this since we don’t know `x`

statically.

PE is like a bridge between static time and dynamic time. JIT is, as well.

There is also an interesting research model developed for it, called Futamura projections. Just think about it:

```
compile p = pe (interpret p)
```

### Intuition

Let’s try to formulate the principles of partial evaluation step by step:

The simplest example: consider the PE’ed result of `1 + 2`

: Let’s analyze it bottom-up: `1`

can be PE’ed trivially, so does `2`

, but for `1 + 2`

, we need information telling us that both operands are statically known, so we can replace the entire expression with another statically known value. Let’s use $(e)_s$ to denote static expression $e$ and use $(e)_u$ to denote unknown expression $e$. This is how to process $1 + 2$ in this framework:

- 1 =pe=> $(1)_s$
- 2 =pe=> $(2)_s$
- 1 + 2 =pe=> $(1)_s + (2)_s = (1 +_s 2)_s = (3)_s$

Apparently, for $(e)_s$, there mustn’t be any free variable in $e$. That seems plausible, but consider this: $(\lambda x . x + y) \, 2$, since $y$ is free, $(\lambda x . x + y)$ can’t be static, thus the whole expression will just return as it is. But, there is apparently a chance for PE: the optimal result should be $2 + y$.

Aha, our naive formulation turns out to be too restrictive. We should be able to identify “static function” like $\lambda x. x + y$: no matter what is in it, as long as we can track where $x$ binds to, we can do a step of $\beta$-reduction PE when this *static function* is **applied** to some other PE’ed result. Note that for function with multiple parameters, we apply one by one. So we only need two rules: one for a static value, and one for a static lambda.

Now we consider a third possibility:

The big $\Lambda$ is the abstraction operator in the host language. During the PE of a static lambda, We lift the static function from DSL space to the host language space, so we can apply it later.

Now, we have $\lambda x . x + y$ PE’ed as:

$(\Lambda x . x + y)_{\lambda \tau_x}$

The operands being applied on can be any PE’ed form, which we will denote as $e_\text{?}$, then we can perform the application, resulting in $(e)_\text{?} + y$.

Clearly, the ? here can only be either static or just unknown. For the first case, we can wrap the thing into $(e +_s y)_s$, and for the second case, we can wrap the thing as an unknown value – wait a minute, unknown value? What if the `y`

binder can be also statically applied in future? Maybe there is some information loss again?

The question is: how to wrap the above expression into a form that preserves the $y$ there. In fact, the acute reader might already ask that a few minutes ago, since we didn’t give any instruction on how we actually get $\Lambda$.

You can imagine that, each time we have such free variable like $y$, we can store its *location info* (de brujin index) along the wrapped expression. So when we are processing $\lambda$ in the DSL, we will check if the function body has at least one free variable, if so, we will take out the head variable’s location information, which is like a substitution operator, and abstract over it – so later when we enter the application processing step, we can simply apply it at the host level.

Consider $y$, the simplest expression containing a free variable. The $y$ in it *might* be bound to some binder in succeeding PE process. We will replace it with a special *host-level* closure $\Lambda h_y. h_y[z]$.

Each free variable or expression containing free variable can react to the fed-in environment. The free variable will see if the given environment is on the same index level. If yes, then we perform the substitution for it.

Now think about $z + z$, how to compose them together?

- $(\Lambda h . h[z])_f + (\Lambda h . h[z])_f$
- $\Lambda h . h[z] + h[z]$

After step 2, if $h$ represents the index 0 environment (let’s say, $k$), then $h[z]$ will succeed, and the final result will be $k + k$. Or, $h$ has index > 0, thus not on the same level, so the result is still $z + z$.

This is the intuition. Beyond this, we need to handle more complex cases caused by composition. In Oleg’s tutorial, he named three challenges:

- The environment depth of function body will change when the function is changed by PE. Just consider $(\lambda x. (\lambda y. x + y + 1) x)$. Initially,
`x + y + 1`

is of type`r (Int, (Int, h)) Int`

, but after the`y`

binder is applied with $x$, result`x + x + 1`

is of type`r (Int, h) Int`

. - When we substitute free variables under more binders in the function body, the binding depth of the variables changes, so its host-level de-brujin indexing should also be changed. The intuition is: we should increase the indexing by one, to bind the right ones outside the closest binder.
- The variable to substitute may be
`s (s ... (s z))`

rather just`z`

. That is, the closure formed by open expression might need to be transformed when wrapped inside another binder.

### Partially-evaluated Term

Actually, we have already given enough motivation to talk about my implementation of Oleg’s idea in DeepDarkFantasy. First we give the full partially-evaluated representation `P r h a`

that carries static information.

```
data P r h a where
Unk :: r h a -> P r h a
Static :: a -> (forall h. r h a) -> P r h a
StaFun :: (forall hout. EnvT r (a, h) hout -> P r hout b) ->
P r h (a -> b)
Open :: (forall hout. EnvT r h hout -> P r hout a) ->
P r h a
```

`Unk`

means “Unknown”, term that has zero static information.`Static`

means fully evaluated, statically known value (so you see host-level type`a`

here). Sine it is environment-irrelevant, you can see it needs a`h`

-universal injection`forall h. repr h a`

.`StaFun`

means yet not fully applied, statically known function. It is a host-level closure, which when given a*proper*environment`EnvT`

, will return the result PE’ed representation`P repr hout b`

.`Open`

means PE’ed term that contains free variables. It is also a host-level closure, which when given a proper environment, instantiates to the substituted form.

### Environment

```
data EnvT repr hin hout where
Dyn :: EnvT repr hin hin
Arg :: P repr hout a -> EnvT repr (a, hout) hout
Weak :: EnvT repr h (a, h)
Next :: EnvT repr hin hout -> EnvT repr (a, hin) (a, hout)
```

Next, let’s see how the compile-time de-brujin type environment is encoded.

The environment is provided on a *per-variable* basis plus the its effect on the *entire* environment. That is to say, we substitute free variables one by one, and either weaken or shrink the type environment gradually according to the input/output change.

Let’s check out some examples to explain what Oleg means in his tutorial:

- “
`Dyn`

is the identity transformer; it also requests the forgetting of all statically known data, converting`P repr h a`

to the form`Unk (x :: repr h a)`

”- Consider we are ending the PE with a function-valued expression in the end. Then we need to fill in some “unknown dynamic parameters” in order to get the final PE’ed expression out. These parameters are
`Dyn`

. (Note that`dynamic (StaFun f) = abs $ dynamic (f Dyn)`

) - Because it doesn’t really eliminates/instantiates the free variables, it is an
*identity*function

- Consider we are ending the PE with a function-valued expression in the end. Then we need to fill in some “unknown dynamic parameters” in order to get the final PE’ed expression out. These parameters are
- “
`Arg p`

asks to substitute`p`

for the`z`

free variable and hence removes its type`a`

from the environment”- In contrast to
`Dyn`

, this time we know the parameter statically, so this can reduce free variable type environment from`(a, hout)`

to`hout`

- In contrast to
- “
`Weak`

requests weakening of the environment, adding some type`a`

at the top”- Weakening the environment doesn’t require any condition. Just consider that $(\lambda x.e)() \equiv e$ ($x$ is free in $e$).
- This is useful when we substitute an open expression $e$ into another lambda $\lambda x : a$, and the outer lambda weakens $e$ by adding type $a$ to the type environment.

- “Since we may need to weaken and substitute for free variables other than
`z`

,`Next`

increments the environment level at which`Weak`

and`Arg`

should apply”- Consider what
`Next (Arg p)`

means. It assumes the closest binder which has arbitrary type`a`

, and lifts the transformation effect by one.

- Consider what

### Substitution

```
app_open :: DBI repr =>
P repr hin r -> EnvT repr hin hout -> P repr hout r
```

`app_open`

means, substituting (or “apply”) open terms with the environment.

#### Two simple rules

```
app_open e Dyn = Unk (dynamic e)
-- If we know nothing about parameter
app_open (Static es ed) _ = Static es ed
-- If we already know its value statically, then we can just ignore the environment.
```

`Open`

term

```
app_open (StaFun fs) h = abs (fs (Next h))
```

If we have an explicitly `Open`

term (simplest one: `z :: P r h a`

), then we apply the meta-level closure with this environmental parameter.

#### Static function

```
app_open (StaFun fs) h = abs (fs (Next h))
```

If we have static *function* `fs`

, e.g. $\lambda x . e$ ($x$ is not free in $e$), then any environment $h$ has to be shifted one level up with `Next`

when instantiating under $\lambda x$. After that, we wrap it back with an `abs`

.

The process:

```
StaFun fs :: P r h (a -> b) =>
fs (Next h) :: P r (a, hout) b =>
abs (fs (Next h)) :: P r hout (a -> b)
```

#### Unknown term

```
app_open (Unk e) h = Unk (app_unk e h) where
app_unk :: DBI repr =>
repr hin a -> EnvT repr hin hout -> repr hout a
app_unk e Dyn = e
app_unk e (Arg p) = app (abs e) (dynamic p)
app_unk e (Next h) = app (s (app_unk (abs e) h)) z
app_unk e Weak = s e
```

In the last case, we consider how to refine an *unknown* (fully dynamic) term with environment. First, when we have an unknown term, can we return a PE’ed term that is *not* unknown? There are only three other possibilities now: open term, static function, or fully static value. The third case is impossible, since we can’t know more than the plain term when given the `r h a`

body in `Unk`

. Second, we can’t make up an open term or static function from a *more general* type either. So it must be `Unk`

again.

If the parameter is dynamic, then `e`

won’t change since no useful information is supplied.
If the parameter is an AST term, then we will lift them up to the representation-level `app`

. For `Next`

and `Weak`

, the condition is similar. Let’s consider `app_unk (z + s z + s s z) (Next (Arg (Static _ x)))`

as an example. By application, we substitute `s z`

, but after substitution, `s s z`

will change as well, becoming `s z`

, since one layer of $\lambda$ is used. The ideal result is `z + x + s z`

.

We encode this logic with `app_unk`

:

```
app_unk :: DBI repr =>
repr hin a -> EnvT repr hin hout -> repr hout a
app_unk e Dyn = e
app_unk e (Arg p) = app (abs e) (dynamic p)
app_unk e (Next h) = app (s (app_unk (abs e) h)) z
app_unk e Weak = s e
```

This is how the above example works in this framework:

```
app_unk (z + s z + s s z) (Next (Arg (Static _ x))) =>
app (s (app_unk (abs (z + s z + s s z)) (Arg (Static _ x)))) z =>
app (s (app (abs (abs (z + s z + s s z))) (dynamic (Static _ x)))) z
app (s (app (abs (abs (z + s z + s s z))) x)) z
app (s (abs (z + x + s z))) z
app (abs (z + x + s s z)) z
z + x + s z
```

During this process, we also used another helper function `dynamic`

:

```
dynamic:: DBI repr => P repr h a -> repr h a
dynamic (Unk x) = x
dynamic (Static _ x) = x
dynamic (StaFun f) = abs $ dynamic (f Dyn)
dynamic (Open f) = dynamic (f Dyn)
```

`pe`

Function

Next is the final `PE`

function: from the intermediate, hybrid state to the source form. Note how `pe`

executes: first, we write the syntactic form of the AST, and force its type to be `P r h a`

, thus, it will be interpreted as the PE’ed form automatically. Then, using `dynamic`

, it is closed back to the `r h a`

again.

```
pe :: Double repr => P repr () a -> repr () a
pe = dynamic
```

#### PE main process

```
instance DBI r => DBI (P r) where
z = Open f where
f :: EnvT r (a,h) hout -> P r hout a
f Dyn = Unk z -- turn to dynamic as requested
f (Arg x) = x -- substitution
f (Next _) = z -- not my level
f Weak = s z
```

For `z`

free variable, we turn it into an open term, which responds to the environment’s request of instantiation.

- If the parameter (request) is fully dynamic
`Dyn`

, then we just return the old thing as it is, wrapped in`Unk`

. - Or, it is static
`Arg x`

, forcing`hout`

to be`h`

, we just get`x`

out (this parameter value should already be PE’ed). - Or it is
`Next _`

, forcing`hout ~ (a, h)`

, we just skip it. - Or it is
`Weak`

, which forces`hout ~ (a', (a, h))`

, then we binds a bit further to leave space for the weakened environment.

```
s :: forall h a any. P r h a -> P r (any, h) a
s (Unk x) = Unk (s x)
s (Static a ar) = Static a ar
s (StaFun fs) = abs (fs (Next Weak))
s p = Open f where
f :: EnvT r (any, h) hout -> P r hout a
-- Nothing is statically known, dynamize
f Dyn = Unk (s (dynamic p))
f (Arg _) = p
f (Next h) = s (app_open p h)
f Weak = s (s p)
```

For `s`

free variable, it needs to consider the body to be shifted. If the body is just unknown, then we shift the inner expression and keep unknown; static value is not effected by shifting; For static function, we just wrap another layer of weak environment, so all inner free variables should increase its de-brujin index by 1.

`abs`

and `app`

are intuitive:

```
abs (Unk f) = Unk (abs f)
abs (Static k ks) = StaFun $ \_ -> Static k ks
abs body = StaFun (app_open body)
app (Unk f) (Unk x) = Unk (app f x)
app (StaFun fs) p = fs (Arg p)
app (Static _ fs) p = Unk (app fs (dynamic p))
app e1 e2 = Open (\h -> app (app_open e1 h) (app_open e2 h))
```

#### PE of binary operators

```
binaryPE :: forall h a r.
DBI r => (a -> a -> a) -> r h (a -> a -> a) ->
(forall h. P r h (a -> a -> a)) -> (forall h. a -> P r h a) ->
(a -> M.Bool) -> (a -> M.Bool) -> P r h (a -> a -> a)
```

This is a general mechanism for PE of binary operations, like `Double`

’s `doublePlus`

operator:

```
class DBI r => Double r where
double :: M.Double -> r h M.Double
doublePlus :: r h (M.Double -> M.Double -> M.Double)
```

We use it like this:

```
instance Double r => Double (P r) where
double x = Static x (double x)
doublePlus = binaryPE (+) doublePlus doublePlus double (== 0.0) (== 0.0)
```

This tricky thing is that `doublePlus`

has a DSL-level function type, rather than the meta-level type `r h M.Double -> r h M.Double -> r h M.Double`

. So its currying happens in DSL semantics.

This means that we have to expand this into meta-level using `StaFun`

(rationale: built-in operator like `(+)`

is a static function).

```
binaryPE op opM opPM liftM isLeftUnit isRightUnit = StaFun binaryPE'
where
binaryPE' :: forall hout. EnvT r (a, h) hout -> P r hout (a -> a)
```

After we get access to the compile-time environment `EnvT r (a, h) hout`

, we need to case-analyze it.

```
binaryPE' (Arg a) = StaFun (binaryPE'' a)
where
binaryPE'' :: forall hout. P r h a -> EnvT r (a, h) hout -> P r hout a
```

The first possibility is that it is PE’ed concrete form, in this case, we pass it to future processing. For left possibilities, we just need to consider how to form the curried function. For `Dyn`

, it forces `hout`

to be `(a, h)`

. With environment-universal `opPM :: (forall h. P r h (a -> a -> a))`

, we can `app`

it with `z`

, here `b`

is `a -> a`

i.e. returned function type. For `Next _`

, we similarly skip. For `Weak`

, it forces `hout`

to be `(a', (a, h))`

, so we need to use `s z`

to refer to the weakened environment.

```
binaryPE' Dyn = app opPM z
binaryPE' (Next _) = app opPM z
binaryPE' Weak = app opPM (s z)
```

For the right operand, it is basically the same. We process the case of two PE’ed concrete terms in a separate function:

```
binaryPE'' a (Arg b) = f a b
binaryPE'' a Dyn = app2 opPM (s a) z
binaryPE'' a (Next h') = app2 opPM (s (app_open a h')) z
binaryPE'' a Weak = app2 opPM (s (s a)) (s z)
```

The `s a`

in `Dyn`

case is to fetch over the `z`

binder. If `Next`

, we instantiate free vars in `a`

with `h'`

, then we `app_open a h'`

. In fact, similar to `Next`

, we can consider `Dyn`

as `app_open = const`

. For `Weak`

, we just add `s`

to both operands in `Dyn`

case.

Finally, we came to the `f`

. What might the PE’ed concrete form look like?

If both operands are static, then we apply the *static* operator *statically*.

```
f (Static d1 _) (Static d2 _) = liftM (d1 `op` d2)
```

Using left/right identity rule to prune out computation.

```
f (Static d1 _) x | isLeftUnit d1 = x
f x (Static d2 _) | isRightUnit d2 = x
```

If both arguments are unknown (i.e. cannot be substituted into and hence improved) there is nothing else we can do:

```
f (Unk x) (Unk y) = Unk (app2 opM x y)
```

Otherwise, at least one argument may be improved by a substitution. We will look again through the result to see if the static addition becomes possible:

```
f e1 e2 = Open (\h -> app2 opPM (app_open e1 h) (app_open e2 h))
```