Created
May 13, 2026 10:28
-
-
Save saulshanabrook/5f6d6e58de3dbb7c8c8bf6017c9c3e2f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| The goal here is to see if we can replicate the PEGGY paper but by embedding it in an egglog | |
| dialect that has fixpoint as a builtin as well as higher order functions. | |
| If we can implement fix point fusion as a rule, | |
| """ | |
| # mypy: disable-error-code="empty-body" | |
| from collections.abc import Callable | |
| from typing import Generic, Protocol, TypeAlias, TypeVar, cast | |
| from egglog import * | |
| T = TypeVar("T") | |
| @function | |
| def fix(fn: Callable[[T], T]) -> T: | |
| """ | |
| fn(fn) == fn(fix(fn)) | |
| """ | |
| class Boolean(Expr): | |
| def __init__(self, value: BoolLike) -> None: ... | |
| def if_(self, true: T, false: T) -> T: ... | |
| def __eq__(self, other: BooleanLike) -> Boolean: ... | |
| BooleanLike: TypeAlias = Boolean | BoolLike | |
| converter(Bool, Boolean, Boolean) | |
| class Num(Expr): | |
| def __init__(self, value: i64Like) -> None: ... | |
| def __add__(self, other: NumLike) -> Num: ... | |
| def __mul__(self, other: NumLike) -> Num: ... | |
| def __rmul__(self, other: NumLike) -> Num: ... | |
| def __radd__(self, other: NumLike) -> Num: ... | |
| def __eq__(self, other: NumLike) -> Boolean: ... | |
| def __sub__(self, other: NumLike) -> Num: ... | |
| NumLike: TypeAlias = Num | i64Like | |
| converter(i64, Num, Num) | |
| T_co = TypeVar("T_co", covariant=True) | |
| class SupportsSameAdd(Protocol[T]): | |
| def __add__(self, other: T, /) -> T: ... | |
| class SupportsSameMul(Protocol[T]): | |
| def __mul__(self, other: T, /) -> T: ... | |
| class Stream(Expr, Generic[T_co]): | |
| """ | |
| Stream is a mapping from indices to values, like is defined in Lean | |
| https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream' | |
| """ | |
| def __init__(self, idx: Callable[[Num], T_co]) -> None: ... | |
| def __mul__(self: Stream[SupportsSameMul[T]], other: StreamLike[T]) -> Stream[T]: | |
| other = cast("Stream[T]", other) | |
| return Stream(lambda i: self[i] * other[i]) | |
| def __add__(self: Stream[SupportsSameAdd[T]], other: StreamLike[T]) -> Stream[T]: | |
| other = cast("Stream[T]", other) | |
| return Stream(lambda i: self[i] + other[i]) | |
| def __getitem__(self, idx: NumLike) -> T_co: | |
| """ | |
| https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.get | |
| """ | |
| def tail(self) -> Stream[T_co]: | |
| """ | |
| https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.tail | |
| """ | |
| return Stream(lambda idx: self[idx + 1]) | |
| # def map(self, fn: Callable[[Num], Num]) -> StreamNum: | |
| # return StreamNum(lambda idx: fn(self[idx])) | |
| # @classmethod | |
| # def iterate(cls, next: Callable[[Num], Num], init: NumLike) -> StreamNum: | |
| # """ | |
| # https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.iterate | |
| # """ | |
| # # A stream is represented extensionally as its indexing function. | |
| # # iterate(next, init)[0] = init | |
| # # iterate(next, init)[n + 1] = next(iterate(next, init)[n]) | |
| # return cls(fix(lambda self: lambda idx: (idx == 0).if_(init, next(self(idx - 1))))) | |
| @classmethod | |
| def const(cls, value: T) -> Stream[T]: | |
| """ | |
| https://leanprover-community.github.io/mathlib4_docs/Mathlib/Data/Stream/Defs.html#Stream'.const | |
| """ | |
| return Stream(lambda _: value) | |
| V = TypeVar("V") | |
| StreamLike: TypeAlias = Stream[T] | T | |
| converter(object, Stream, Stream.const) | |
| @function | |
| def φ(cond: StreamLike[Boolean], true: StreamLike[T], false: StreamLike[T]) -> Stream[T]: | |
| cond = cast("Stream[Boolean]", cond) | |
| true = cast("Stream[T]", true) | |
| false = cast("Stream[T]", false) | |
| return Stream(lambda i: cond[i].if_(true[i], false[i])) | |
| @function | |
| def θ(first: Stream[T], rest: Stream[T]) -> Stream[T]: | |
| return Stream(lambda i: (i == 0).if_(first[i], rest[i - 1])) | |
| @function | |
| def eval_(s: StreamLike[T], i: NumLike) -> T: | |
| s = cast("Stream[T]", s) | |
| return s[i] | |
| @function | |
| def pass_(s: StreamLike[Boolean]) -> Num: | |
| s = cast("Stream[Boolean]", s) | |
| return s[0].if_(Num(0), pass_(s.tail()) + 1) | |
| @function | |
| def peel(s: StreamLike[T]) -> Stream[T]: | |
| s = cast("Stream[T]", s) | |
| return s.tail() | |
| δ = constant("δ", Stream[Boolean]) | |
| # We want to verify this equality: | |
| eq( | |
| fix( | |
| lambda x: θ( | |
| Stream.const(Num(0)), φ(δ, x + Stream.const(Num(1)) + Stream.const(Num(3)), x + Stream.const(Num(1))) | |
| ) | |
| ) | |
| * Stream.const(Num(5)) | |
| ).to( | |
| fix( | |
| lambda x: θ( | |
| Stream.const(Num(0)), φ(δ, x + Stream.const(Num(5)) + Stream.const(Num(15)), x + Stream.const(Num(5))) | |
| ) | |
| ), | |
| ) | |
| # If we expand out the θ: | |
| eq( | |
| fix( | |
| lambda x: Stream( | |
| lambda i: (i == 0).if_( | |
| Stream.const(Num(0))[0], | |
| φ(δ, x + Stream.const(Num(1)) + Stream.const(Num(3)), x + Stream.const(Num(1)))[i - 1], | |
| ) | |
| ) | |
| ) | |
| * Stream.const(Num(5)) | |
| ).to( | |
| fix( | |
| lambda x: Stream( | |
| lambda i: (i == 0).if_( | |
| Stream.const(Num(0))[0], | |
| φ(δ, x + Stream.const(Num(5)) + Stream.const(Num(15)), x + Stream.const(Num(5)))[i - 1], | |
| ) | |
| ) | |
| ), | |
| ) | |
| # Then we just need a rule that does fixed point fusion, inferring g: | |
| rule(k(fix(f))).then(compose(k, f)) | |
| rewrite(k(fix(f))).to(fix(g), compose(k, f) == compose(g, k)) | |
| # And that should work, along with some constant folding, distribution through if_ and distributivity | |
| # For this example: | |
| # Let: | |
| # F(x) = θ(0, φ(δ, x + 1 + 3, x + 1)) | |
| # K(x) = x * 5 | |
| # Now compute: | |
| # K(F(x)) | |
| # = θ(0, φ(δ, x + 1 + 3, x + 1)) * 5 | |
| # Use the Peggy-style rule: | |
| # θ(a, b) * 5 = θ(a * 5, b * 5) | |
| # because 5 is loop-invariant. | |
| # So: | |
| # K(F(x)) | |
| # = θ(0 * 5, φ(δ, x + 1 + 3, x + 1) * 5) | |
| # Constant-fold: | |
| # = θ(0, φ(δ, x + 4, x + 1) * 5) | |
| # Distribute through φ: | |
| # = θ(0, φ(δ, (x + 4) * 5, (x + 1) * 5)) | |
| # Distribute through + and fold constants: | |
| # = θ(0, φ(δ, x * 5 + 20, x * 5 + 5)) | |
| # Now notice: | |
| # K(x) = x * 5 | |
| # So replace x * 5 with fresh y: | |
| # G(y) = θ(0, φ(δ, y + 20, y + 5)) | |
| # Therefore: | |
| # G(x) = θ(0, φ(δ, x + 20, x + 5)) | |
| # or with your constants split as in the paper: | |
| # G(x) = θ(0, φ(δ, x + 5 + 15, x + 5)) | |
| # That is exactly the optimized recurrence. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment