Skip to content

Instantly share code, notes, and snippets.

@saulshanabrook
Created May 13, 2026 10:28
Show Gist options
  • Select an option

  • Save saulshanabrook/5f6d6e58de3dbb7c8c8bf6017c9c3e2f to your computer and use it in GitHub Desktop.

Select an option

Save saulshanabrook/5f6d6e58de3dbb7c8c8bf6017c9c3e2f to your computer and use it in GitHub Desktop.
"""
The goal here is to see if we can replicate the PEGGY paper but by embedding it in an egglog
dialect that has fixpoint as a builtin as well as higher order functions.
If we can implement fix point fusion as a rule,
"""
# mypy: disable-error-code="empty-body"
from collections.abc import Callable
from typing import Generic, Protocol, TypeAlias, TypeVar, cast
from egglog import *
T = TypeVar("T")
@function
def fix(fn: Callable[[T], T]) -> T:
"""
fn(fn) == fn(fix(fn))
"""
class Boolean(Expr):
def __init__(self, value: BoolLike) -> None: ...
def if_(self, true: T, false: T) -> T: ...
def __eq__(self, other: BooleanLike) -> Boolean: ...
BooleanLike: TypeAlias = Boolean | BoolLike
converter(Bool, Boolean, Boolean)
class Num(Expr):
def __init__(self, value: i64Like) -> None: ...
def __add__(self, other: NumLike) -> Num: ...
def __mul__(self, other: NumLike) -> Num: ...
def __rmul__(self, other: NumLike) -> Num: ...
def __radd__(self, other: NumLike) -> Num: ...
def __eq__(self, other: NumLike) -> Boolean: ...
def __sub__(self, other: NumLike) -> Num: ...
NumLike: TypeAlias = Num | i64Like
converter(i64, Num, Num)
T_co = TypeVar("T_co", covariant=True)
class SupportsSameAdd(Protocol[T]):
def __add__(self, other: T, /) -> T: ...
class SupportsSameMul(Protocol[T]):
def __mul__(self, other: T, /) -> T: ...
class Stream(Expr, Generic[T_co]):
"""
Stream is a mapping from indices to values, like is defined in Lean
https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'
"""
def __init__(self, idx: Callable[[Num], T_co]) -> None: ...
def __mul__(self: Stream[SupportsSameMul[T]], other: StreamLike[T]) -> Stream[T]:
other = cast("Stream[T]", other)
return Stream(lambda i: self[i] * other[i])
def __add__(self: Stream[SupportsSameAdd[T]], other: StreamLike[T]) -> Stream[T]:
other = cast("Stream[T]", other)
return Stream(lambda i: self[i] + other[i])
def __getitem__(self, idx: NumLike) -> T_co:
"""
https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.get
"""
def tail(self) -> Stream[T_co]:
"""
https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.tail
"""
return Stream(lambda idx: self[idx + 1])
# def map(self, fn: Callable[[Num], Num]) -> StreamNum:
# return StreamNum(lambda idx: fn(self[idx]))
# @classmethod
# def iterate(cls, next: Callable[[Num], Num], init: NumLike) -> StreamNum:
# """
# https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.iterate
# """
# # A stream is represented extensionally as its indexing function.
# # iterate(next, init)[0] = init
# # iterate(next, init)[n + 1] = next(iterate(next, init)[n])
# return cls(fix(lambda self: lambda idx: (idx == 0).if_(init, next(self(idx - 1)))))
@classmethod
def const(cls, value: T) -> Stream[T]:
"""
https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.const
"""
return Stream(lambda _: value)
V = TypeVar("V")
StreamLike: TypeAlias = Stream[T] | T
converter(object, Stream, Stream.const)
@function
def φ(cond: StreamLike[Boolean], true: StreamLike[T], false: StreamLike[T]) -> Stream[T]:
cond = cast("Stream[Boolean]", cond)
true = cast("Stream[T]", true)
false = cast("Stream[T]", false)
return Stream(lambda i: cond[i].if_(true[i], false[i]))
@function
def θ(first: Stream[T], rest: Stream[T]) -> Stream[T]:
return Stream(lambda i: (i == 0).if_(first[i], rest[i - 1]))
@function
def eval_(s: StreamLike[T], i: NumLike) -> T:
s = cast("Stream[T]", s)
return s[i]
@function
def pass_(s: StreamLike[Boolean]) -> Num:
s = cast("Stream[Boolean]", s)
return s[0].if_(Num(0), pass_(s.tail()) + 1)
@function
def peel(s: StreamLike[T]) -> Stream[T]:
s = cast("Stream[T]", s)
return s.tail()
δ = constant("δ", Stream[Boolean])
# We want to verify this equality:
eq(
fix(
lambda x: θ(
Stream.const(Num(0)), φ(δ, x + Stream.const(Num(1)) + Stream.const(Num(3)), x + Stream.const(Num(1)))
)
)
* Stream.const(Num(5))
).to(
fix(
lambda x: θ(
Stream.const(Num(0)), φ(δ, x + Stream.const(Num(5)) + Stream.const(Num(15)), x + Stream.const(Num(5)))
)
),
)
# If we expand out the θ:
eq(
fix(
lambda x: Stream(
lambda i: (i == 0).if_(
Stream.const(Num(0))[0],
φ(δ, x + Stream.const(Num(1)) + Stream.const(Num(3)), x + Stream.const(Num(1)))[i - 1],
)
)
)
* Stream.const(Num(5))
).to(
fix(
lambda x: Stream(
lambda i: (i == 0).if_(
Stream.const(Num(0))[0],
φ(δ, x + Stream.const(Num(5)) + Stream.const(Num(15)), x + Stream.const(Num(5)))[i - 1],
)
)
),
)
# Then we just need a rule that does fixed point fusion, inferring g:
rule(k(fix(f))).then(compose(k, f))
rewrite(k(fix(f))).to(fix(g), compose(k, f) == compose(g, k))
# And that should work, along with some constant folding, distribution through if_ and distributivity
# For this example:
# Let:
# F(x) = θ(0, φ(δ, x + 1 + 3, x + 1))
# K(x) = x * 5
# Now compute:
# K(F(x))
# = θ(0, φ(δ, x + 1 + 3, x + 1)) * 5
# Use the Peggy-style rule:
# θ(a, b) * 5 = θ(a * 5, b * 5)
# because 5 is loop-invariant.
# So:
# K(F(x))
# = θ(0 * 5, φ(δ, x + 1 + 3, x + 1) * 5)
# Constant-fold:
# = θ(0, φ(δ, x + 4, x + 1) * 5)
# Distribute through φ:
# = θ(0, φ(δ, (x + 4) * 5, (x + 1) * 5))
# Distribute through + and fold constants:
# = θ(0, φ(δ, x * 5 + 20, x * 5 + 5))
# Now notice:
# K(x) = x * 5
# So replace x * 5 with fresh y:
# G(y) = θ(0, φ(δ, y + 20, y + 5))
# Therefore:
# G(x) = θ(0, φ(δ, x + 20, x + 5))
# or with your constants split as in the paper:
# G(x) = θ(0, φ(δ, x + 5 + 15, x + 5))
# That is exactly the optimized recurrence.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment